diff common-session.c @ 726:78eda530c000

send out our kexinit packet before blocking to read the SSH version string
author Matt Johnston <matt@ucc.asn.au>
date Sun, 31 Mar 2013 00:40:00 +0800
parents 2e573f39b88e
children 70811267715c
line wrap: on
line diff
--- a/common-session.c	Sun Mar 24 00:02:20 2013 +0800
+++ b/common-session.c	Sun Mar 31 00:40:00 2013 +0800
@@ -39,6 +39,7 @@
 static void checktimeouts();
 static long select_timeout();
 static int ident_readln(int fd, char* buf, int count);
+static void read_session_identification();
 
 struct sshsession ses; /* GLOBAL */
 
@@ -141,7 +142,10 @@
 		FD_ZERO(&writefd);
 		FD_ZERO(&readfd);
 		dropbear_assert(ses.payload == NULL);
-		if (ses.sock_in != -1) {
+
+		/* during initial setup we flush out the KEXINIT packet before
+		 * attempting to read the remote version string, which might block */
+		if (ses.sock_in != -1 && (ses.remoteident || isempty(&ses.writequeue))) {
 			FD_SET(ses.sock_in, &readfd);
 		}
 		if (ses.sock_out != -1 && !isempty(&ses.writequeue)) {
@@ -195,7 +199,12 @@
 
 		if (ses.sock_in != -1) {
 			if (FD_ISSET(ses.sock_in, &readfd)) {
-				read_packet();
+				if (!ses.remoteident) {
+					/* blocking read of the version string */
+					read_session_identification();
+				} else {
+					read_packet();
+				}
 			}
 			
 			/* Process the decrypted packet. After this, the read buffer
@@ -245,20 +254,20 @@
 }
 
 
-void session_identification() {
+void send_session_identification() {
+	/* write our version string, this blocks */
+	if (atomicio(write, ses.sock_out, LOCAL_IDENT "\r\n",
+				strlen(LOCAL_IDENT "\r\n")) == DROPBEAR_FAILURE) {
+		ses.remoteclosed();
+	}
+}
 
+static void read_session_identification() {
 	/* max length of 255 chars */
 	char linebuf[256];
 	int len = 0;
 	char done = 0;
 	int i;
-
-	/* write our version string, this blocks */
-	if (atomicio(write, ses.sock_out, LOCAL_IDENT "\r\n",
-				strlen(LOCAL_IDENT "\r\n")) == DROPBEAR_FAILURE) {
-		ses.remoteclosed();
-	}
-
 	/* If they send more than 50 lines, something is wrong */
 	for (i = 0; i < 50; i++) {
 		len = ident_readln(ses.sock_in, linebuf, sizeof(linebuf));