diff common-channel.c @ 362:1c7bf9cec6c8 channel-fix

Rearranged some more bits, marked some areas that need work. * send_msg_channel_data() no longer allocates a separate buffer * getchannel() handles unknown channels so callers don't have to
author Matt Johnston <matt@ucc.asn.au>
date Mon, 02 Oct 2006 16:34:06 +0000
parents 78518751cb82
children 6ba2894ec8d5
line wrap: on
line diff
--- a/common-channel.c	Sun Oct 01 16:35:13 2006 +0000
+++ b/common-channel.c	Mon Oct 02 16:34:06 2006 +0000
@@ -164,24 +164,33 @@
 }
 
 /* Returns the channel structure corresponding to the channel in the current
- * data packet (ses.payload must be positioned appropriately) */
-struct Channel* getchannel() {
+ * data packet (ses.payload must be positioned appropriately).
+ * A valid channel is always returns, it will fail fatally with an unknown
+ * channel */
+static struct Channel* getchannel_msg(const char* kind) {
 
 	unsigned int chan;
 
 	chan = buf_getint(ses.payload);
 	if (chan >= ses.chansize || ses.channels[chan] == NULL) {
-		return NULL;
+		if (kind) {
+			dropbear_exit("%s for unknown channel %d", kind, chan);
+		} else {
+			dropbear_exit("Unknown channel %d", chan);
+		}
 	}
 	return ses.channels[chan];
 }
 
+struct Channel* getchannel() {
+	return getchannel_msg(NULL);
+}
+
 /* Iterate through the channels, performing IO if available */
 void channelio(fd_set *readfds, fd_set *writefds) {
 
 	struct Channel *channel;
 	unsigned int i;
-	int ret;
 
 	/* iterate through all the possible channels */
 	for (i = 0; i < ses.chansize; i++) {
@@ -218,6 +227,7 @@
 		/* write to program/pipe stdin */
 		if (channel->writefd >= 0 && FD_ISSET(channel->writefd, writefds)) {
 			if (channel->initconn) {
+				/* XXX could this go somewhere cleaner? */
 				check_in_progress(channel);
 				continue; /* Important not to use the channel after
 							 check_in_progress(), as it may be NULL */
@@ -254,6 +264,16 @@
 				channel->writebuf,
 				channel->writebuf ? 0 : cbuf_getused(channel->extrabuf)))
 
+	/* XXX not good, doesn't flush out */
+	if (channel->recv_close) {
+		if (! channel->sent_close) {
+			TRACE(("Sending MSG_CHANNEL_CLOSE in response to same."))
+			send_msg_channel_close(channel);
+		}
+		remove_channel(channel);
+		return;
+	}
+
 	/* server chansession channels are special, since readfd mightn't
 	 * close in the case of "sleep 4 & echo blah" until the sleep is up */
 	if (channel->type->check_close) {
@@ -277,6 +297,14 @@
 		send_msg_channel_close(channel);
 	}
 
+	/* XXX blah */
+	if (channel->recv_eof &&
+		(cbuf_getused(channel->writebuf) == 0
+			&& (channel->extrabuf == NULL 
+					|| cbuf_getused(channel->extrabuf) == 0))) {
+		close_write_fd(channel);
+	}
+
 	/* When either party wishes to terminate the channel, it sends
 	 * SSH_MSG_CHANNEL_CLOSE.  Upon receiving this message, a party MUST
 	 * send back a SSH_MSG_CHANNEL_CLOSE unless it has already sent this
@@ -287,13 +315,6 @@
 	 * SSH_MSG_CHANNEL_EOF. 
 	 * (from draft-ietf-secsh-connect)
 	 */
-	if (channel->recv_close) {
-		if (! channel->sent_close) {
-			TRACE(("Sending MSG_CHANNEL_CLOSE in response to same."))
-			send_msg_channel_close(channel);
-		}
-		remove_channel(channel);
-	}
 }
 
 
@@ -325,7 +346,6 @@
 }
 
 
-
 /* Send the close message and set the channel as closed */
 static void send_msg_channel_close(struct Channel *channel) {
 
@@ -341,6 +361,7 @@
 
 	encrypt_packet();
 
+	/* XXX is setting sent_eof required? */
 	channel->sent_eof = 1;
 	channel->sent_close = 1;
 	TRACE(("leave send_msg_channel_close"))
@@ -388,13 +409,6 @@
 	cbuf_incrread(cbuf, len);
 	channel->recvdonelen += len;
 
-	if (fd == channel->writefd && cbuf_getused(cbuf) == 0 && channel->recv_eof) { 
-		/* Check if we're closing up */
-		close_write_fd(channel);
-		TRACE(("leave writechannel: recv_eof set"))
-		return;
-	}
-
 	/* Window adjust handling */
 	if (channel->recvdonelen >= RECV_WINDOWEXTEND) {
 		/* Set it back to max window */
@@ -408,7 +422,6 @@
 	dropbear_assert(channel->extrabuf == NULL ||
 			channel->recvwindow <= cbuf_getavail(channel->extrabuf));
 	
-	
 	TRACE(("leave writechannel"))
 }
 
@@ -466,18 +479,11 @@
 
 	TRACE(("enter recv_msg_channel_eof"))
 
-	channel = getchannel();
-	if (channel == NULL) {
-		dropbear_exit("EOF for unknown channel");
-	}
+	channel = getchannel_msg("EOF");
 
 	channel->recv_eof = 1;
-	if (cbuf_getused(channel->writebuf) == 0
-			&& (channel->extrabuf == NULL 
-					|| cbuf_getused(channel->extrabuf) == 0)) {
-		close_write_fd(channel);
-	}
 
+	check_close(channel);
 	TRACE(("leave recv_msg_channel_eof"))
 }
 
@@ -489,19 +495,13 @@
 
 	TRACE(("enter recv_msg_channel_close"))
 
-	channel = getchannel();
-	if (channel == NULL) {
-		/* disconnect ? */
-		dropbear_exit("Close for unknown channel");
-	}
+	channel = getchannel_msg("Close");
 
+	/* XXX eof required? */
 	channel->recv_eof = 1;
 	channel->recv_close = 1;
 
-	if (channel->sent_close) {
-		remove_channel(channel);
-	}
-
+	check_close(channel);
 	TRACE(("leave recv_msg_channel_close"))
 }
 
@@ -512,6 +512,9 @@
 	TRACE(("enter remove_channel"))
 	TRACE(("channel index is %d", channel->index))
 
+	/* XXX shuold we assert for sent_closed and recv_closed?
+	 * but we also cleanup manually, maybe we need a flag. */
+
 	cbuf_free(channel->writebuf);
 	channel->writebuf = NULL;
 
@@ -522,7 +525,7 @@
 
 
 	/* close the FDs in case they haven't been done
-	 * yet (ie they were shutdown etc */
+	 * yet (they might have been shutdown etc) */
 	close(channel->writefd);
 	close(channel->readfd);
 	close(channel->errfd);
@@ -553,10 +556,6 @@
 	TRACE(("enter recv_msg_channel_request"))
 	
 	channel = getchannel();
-	if (channel == NULL) {
-		/* disconnect ? */
-		dropbear_exit("Unknown channel");
-	}
 
 	if (channel->type->reqhandler) {
 		channel->type->reqhandler(channel);
@@ -576,9 +575,8 @@
 static void send_msg_channel_data(struct Channel *channel, int isextended,
 		unsigned int exttype) {
 
-	buffer *buf;
 	int len;
-	unsigned int maxlen;
+	size_t maxlen, size_pos;
 	int fd;
 
 /*	TRACE(("enter send_msg_channel_data"))
@@ -600,40 +598,37 @@
 	 * exttype if is extended */
 	maxlen = MIN(maxlen, 
 			ses.writepayload->size - 1 - 4 - 4 - (isextended ? 4 : 0));
+	TRACE(("maxlen %d", maxlen))
 	if (maxlen == 0) {
 		TRACE(("leave send_msg_channel_data: no window"))
-		return; /* the data will get written later */
+		return;
 	}
 
+	buf_putbyte(ses.writepayload, 
+			isextended ? SSH_MSG_CHANNEL_EXTENDED_DATA : SSH_MSG_CHANNEL_DATA);
+	buf_putint(ses.writepayload, channel->remotechan);
+	if (isextended) {
+		buf_putint(ses.writepayload, exttype);
+	}
+	/* a dummy size first ...*/
+	size_pos = ses.writepayload->pos;
+	buf_putint(ses.writepayload, 0);
+
 	/* read the data */
-	TRACE(("maxlen %d", maxlen))
-	buf = buf_new(maxlen);
-	TRACE(("buf pos %d data %x", buf->pos, buf->data))
-	len = read(fd, buf_getwriteptr(buf, maxlen), maxlen);
+	len = read(fd, buf_getwriteptr(ses.writepayload, maxlen), maxlen);
 	if (len <= 0) {
-		/* on error/eof, send eof */
 		if (len == 0 || errno != EINTR) {
 			close_read_fd(channel, fd);
 		}
-		buf_free(buf);
-		buf = NULL;
+		ses.writepayload->len = ses.writepayload->pos = 0;
 		TRACE(("leave send_msg_channel_data: read err or EOF for fd %d", 
 					channel->index));
 		return;
 	}
-	buf_incrlen(buf, len);
-
-	buf_putbyte(ses.writepayload, 
-			isextended ? SSH_MSG_CHANNEL_EXTENDED_DATA : SSH_MSG_CHANNEL_DATA);
-	buf_putint(ses.writepayload, channel->remotechan);
-
-	if (isextended) {
-		buf_putint(ses.writepayload, exttype);
-	}
-
-	buf_putstring(ses.writepayload, buf_getptr(buf, len), len);
-	buf_free(buf);
-	buf = NULL;
+	buf_incrwritepos(ses.writepayload, len);
+	/* ... real size here */
+	buf_setpos(ses.writepayload, size_pos);
+	buf_putint(ses.writepayload, len);
 
 	channel->transwindow -= len;
 
@@ -647,9 +642,6 @@
 	struct Channel *channel;
 
 	channel = getchannel();
-	if (channel == NULL) {
-		dropbear_exit("Unknown channel");
-	}
 
 	common_recv_msg_channel_data(channel, channel->writefd, channel->writebuf);
 }
@@ -726,9 +718,6 @@
 	unsigned int incr;
 	
 	channel = getchannel();
-	if (channel == NULL) {
-		dropbear_exit("Unknown channel");
-	}
 	
 	incr = buf_getint(ses.payload);
 	TRACE(("received window increment %d", incr))
@@ -786,6 +775,7 @@
 	/* Get the channel type. Client and server style invokation will set up a
 	 * different list for ses.chantypes at startup. We just iterate through
 	 * this list and find the matching name */
+	/* XXX fugly */
 	for (cp = &ses.chantypes[0], chantype = (*cp); 
 			chantype != NULL;
 			cp++, chantype = (*cp)) {
@@ -811,11 +801,11 @@
 
 	if (channel->type->inithandler) {
 		ret = channel->type->inithandler(channel);
+		if (ret == SSH_OPEN_IN_PROGRESS) {
+			/* We'll send the confirmation later */
+			goto cleanup;
+		}
 		if (ret > 0) {
-			if (ret == SSH_OPEN_IN_PROGRESS) {
-				/* We'll send the confirmation later */
-				goto cleanup;
-			}
 			errtype = ret;
 			delete_channel(channel);
 			TRACE(("inithandler returned failure %d", ret))
@@ -949,9 +939,6 @@
 	TRACE(("enter recv_msg_channel_open_confirmation"))
 
 	channel = getchannel();
-	if (channel == NULL) {
-		dropbear_exit("Unknown channel");
-	}
 
 	if (!channel->await_open) {
 		dropbear_exit("unexpected channel reply");
@@ -984,9 +971,6 @@
 	struct Channel * channel;
 
 	channel = getchannel();
-	if (channel == NULL) {
-		dropbear_exit("Unknown channel");
-	}
 
 	if (!channel->await_open) {
 		dropbear_exit("unexpected channel reply");