diff common-channel.c @ 363:6ba2894ec8d5 channel-fix

Rearranged (and hopefully simplified) channel close/eof handling
author Matt Johnston <matt@ucc.asn.au>
date Sat, 07 Oct 2006 17:48:55 +0000
parents 1c7bf9cec6c8
children 90cb290836de
line wrap: on
line diff
--- a/common-channel.c	Mon Oct 02 16:34:06 2006 +0000
+++ b/common-channel.c	Sat Oct 07 17:48:55 2006 +0000
@@ -43,17 +43,14 @@
 static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf);
 static void send_msg_channel_window_adjust(struct Channel *channel, 
 		unsigned int incr);
-static void send_msg_channel_data(struct Channel *channel, int isextended,
-		unsigned int exttype);
+static void send_msg_channel_data(struct Channel *channel, int isextended);
 static void send_msg_channel_eof(struct Channel *channel);
 static void send_msg_channel_close(struct Channel *channel);
 static void remove_channel(struct Channel *channel);
 static void delete_channel(struct Channel *channel);
 static void check_in_progress(struct Channel *channel);
+static unsigned int write_pending(struct Channel * channel);
 static void check_close(struct Channel *channel);
-
-static void close_write_fd(struct Channel * channel);
-static void close_read_fd(struct Channel * channel, int fd);
 static void close_chan_fd(struct Channel *channel, int fd, int how);
 
 #define FD_UNINIT (-2)
@@ -192,7 +189,7 @@
 	struct Channel *channel;
 	unsigned int i;
 
-	/* iterate through all the possible channels */
+	/* foreach channel */
 	for (i = 0; i < ses.chansize; i++) {
 
 		channel = ses.channels[i];
@@ -203,31 +200,19 @@
 
 		/* read data and send it over the wire */
 		if (channel->readfd >= 0 && FD_ISSET(channel->readfd, readfds)) {
-			send_msg_channel_data(channel, 0, 0);
+			send_msg_channel_data(channel, 0);
 		}
 
 		/* read stderr data and send it over the wire */
 		if (channel->extrabuf == NULL &&
 				channel->errfd >= 0 && FD_ISSET(channel->errfd, readfds)) {
-				send_msg_channel_data(channel, 1, SSH_EXTENDED_DATA_STDERR);
+				send_msg_channel_data(channel, 1);
 		}
 
-#if 0
-		/* XXX where is this required? */
-			if (channel->initconn) {
-				/* Handling for "in progress" connection - this is needed
-				 * to avoid spinning 100% CPU when we connect to a server
-				 * which doesn't send anything (tcpfwding) */
-				check_in_progress(channel);
-				continue; /* Important not to use the channel after 
-							 check_in_progress(), as it may be NULL */
-			}
-#endif
-
 		/* write to program/pipe stdin */
 		if (channel->writefd >= 0 && FD_ISSET(channel->writefd, writefds)) {
 			if (channel->initconn) {
-				/* XXX could this go somewhere cleaner? */
+				/* XXX should this go somewhere cleaner? */
 				check_in_progress(channel);
 				continue; /* Important not to use the channel after
 							 check_in_progress(), as it may be NULL */
@@ -241,10 +226,10 @@
 			writechannel(channel, channel->errfd, channel->extrabuf);
 		}
 	
-		/* now handle any of the channel-closing type stuff */
+		/* handle any channel closing etc */
 		check_close(channel);
 
-	} /* foreach channel */
+	}
 
 	/* Listeners such as TCP, X11, agent-auth */
 #ifdef USING_LISTENERS
@@ -253,6 +238,20 @@
 }
 
 
+/* Returns true if there is data remaining to be written to stdin or
+ * stderr of a channel's endpoint. */
+static unsigned int write_pending(struct Channel * channel) {
+
+	if (channel->writefd >= 0 && cbuf_getused(channel->writebuf) > 0) {
+		return 1;
+	} else if (channel->errfd >= 0 && channel->extrabuf && 
+			cbuf_getused(channel->writebuf) > 0) {
+		return 1;
+	}
+	return 0;
+}
+
+
 /* EOF/close handling */
 static void check_close(struct Channel *channel) {
 
@@ -264,8 +263,13 @@
 				channel->writebuf,
 				channel->writebuf ? 0 : cbuf_getused(channel->extrabuf)))
 
-	/* XXX not good, doesn't flush out */
-	if (channel->recv_close) {
+	if (!channel->sent_close
+			&& channel->writefd == FD_CLOSED
+			&& (channel->errfd == FD_CLOSED || channel->extrabuf == NULL)) {
+		send_msg_channel_close(channel);
+	}
+
+	if (channel->recv_close && !write_pending(channel)) {
 		if (! channel->sent_close) {
 			TRACE(("Sending MSG_CHANNEL_CLOSE in response to same."))
 			send_msg_channel_close(channel);
@@ -274,8 +278,10 @@
 		return;
 	}
 
-	/* server chansession channels are special, since readfd mightn't
-	 * close in the case of "sleep 4 & echo blah" until the sleep is up */
+#if 0
+	// The only use of check_close is "return channel->writefd == -1;" for a server
+	// chansession. Should be able to handle that with just the general
+	// socket close handling...?
 	if (channel->type->check_close) {
 		if (channel->type->check_close(channel)) {
 			close_write_fd(channel);
@@ -283,6 +289,7 @@
 			close_read_fd(channel, channel->errfd);
 		}
 	}
+#endif
 
 	if (!channel->sent_eof
 		&& channel->readfd == FD_CLOSED 
@@ -296,25 +303,6 @@
 		&& (channel->extrabuf != NULL || channel->errfd == FD_CLOSED)) {
 		send_msg_channel_close(channel);
 	}
-
-	/* XXX blah */
-	if (channel->recv_eof &&
-		(cbuf_getused(channel->writebuf) == 0
-			&& (channel->extrabuf == NULL 
-					|| cbuf_getused(channel->extrabuf) == 0))) {
-		close_write_fd(channel);
-	}
-
-	/* When either party wishes to terminate the channel, it sends
-	 * SSH_MSG_CHANNEL_CLOSE.  Upon receiving this message, a party MUST
-	 * send back a SSH_MSG_CHANNEL_CLOSE unless it has already sent this
-	 * message for the channel.  The channel is considered closed for a
-	 * party when it has both sent and received SSH_MSG_CHANNEL_CLOSE, and
-	 * the party may then reuse the channel number.  A party MAY send
-	 * SSH_MSG_CHANNEL_CLOSE without having sent or received
-	 * SSH_MSG_CHANNEL_EOF. 
-	 * (from draft-ietf-secsh-connect)
-	 */
 }
 
 
@@ -398,9 +386,7 @@
 	len = write(fd, cbuf_readptr(cbuf, maxlen), maxlen);
 	if (len <= 0) {
 		if (len < 0 && errno != EINTR) {
-			/* no more to write - we close it even if the fd was stderr, since
-			 * that's a nasty failure too */
-			close_write_fd(channel);
+			close_chan_fd(channel, fd, SHUT_WR);
 		}
 		TRACE(("leave writechannel: len <= 0"))
 		return;
@@ -409,6 +395,13 @@
 	cbuf_incrread(cbuf, len);
 	channel->recvdonelen += len;
 
+	/* We're closing out */
+	if (channel->recv_eof && cbuf_getused(cbuf) == 0) {
+	TRACE(("leave writechannel"))
+		close_chan_fd(channel, fd, SHUT_WR);
+		return;
+	}
+
 	/* Window adjust handling */
 	if (channel->recvdonelen >= RECV_WINDOWEXTEND) {
 		/* Set it back to max window */
@@ -572,16 +565,12 @@
  * chan is the remote channel, isextended is 0 if it is normal data, 1
  * if it is extended data. if it is extended, then the type is in
  * exttype */
-static void send_msg_channel_data(struct Channel *channel, int isextended,
-		unsigned int exttype) {
+static void send_msg_channel_data(struct Channel *channel, int isextended) {
 
 	int len;
 	size_t maxlen, size_pos;
 	int fd;
 
-/*	TRACE(("enter send_msg_channel_data"))
-	TRACE(("extended = %d type = %d", isextended, exttype))*/
-
 	CHECKCLEARTOWRITE();
 
 	dropbear_assert(!channel->sent_close);
@@ -608,7 +597,7 @@
 			isextended ? SSH_MSG_CHANNEL_EXTENDED_DATA : SSH_MSG_CHANNEL_DATA);
 	buf_putint(ses.writepayload, channel->remotechan);
 	if (isextended) {
-		buf_putint(ses.writepayload, exttype);
+		buf_putint(ses.writepayload, SSH_EXTENDED_DATA_STDERR);
 	}
 	/* a dummy size first ...*/
 	size_pos = ses.writepayload->pos;
@@ -618,7 +607,7 @@
 	len = read(fd, buf_getwriteptr(ses.writepayload, maxlen), maxlen);
 	if (len <= 0) {
 		if (len == 0 || errno != EINTR) {
-			close_read_fd(channel, fd);
+			close_chan_fd(channel, fd, SHUT_RD);
 		}
 		ses.writepayload->len = ses.writepayload->pos = 0;
 		TRACE(("leave send_msg_channel_data: read err or EOF for fd %d", 
@@ -891,6 +880,47 @@
 	TRACE(("leave send_msg_channel_open_confirmation"))
 }
 
+/* close a fd, how is SHUT_RD or SHUT_WR */
+static void close_chan_fd(struct Channel *channel, int fd, int how) {
+
+	int closein = 0, closeout = 0;
+
+	if (channel->type->sepfds) {
+		TRACE(("shutdown((%d), %d)", fd, how))
+		shutdown(fd, how);
+		if (how == 0) {
+			closeout = 1;
+		} else {
+			closein = 1;
+		}
+	} else {
+		close(fd);
+		closein = closeout = 1;
+	}
+
+	if (closeout && fd == channel->readfd) {
+		channel->readfd = FD_CLOSED;
+	}
+	if (closeout && (channel->extrabuf == NULL) && (fd == channel->errfd)) {
+		channel->errfd = FD_CLOSED;
+	}
+
+	if (closein && fd == channel->writefd) {
+		channel->writefd = FD_CLOSED;
+	}
+	if (closein && (channel->extrabuf != NULL) && (fd == channel->errfd)) {
+		channel->errfd = FD_CLOSED;
+	}
+
+	/* if we called shutdown on it and all references are gone, then we 
+	 * need to close() it to stop it lingering */
+	if (channel->type->sepfds && channel->readfd == FD_CLOSED 
+		&& channel->writefd == FD_CLOSED && channel->errfd == FD_CLOSED) {
+		close(fd);
+	}
+}
+
+
 #if defined(USING_LISTENERS) || defined(DROPBEAR_CLIENT)
 /* Create a new channel, and start the open request. This is intended
  * for X11, agent, tcp forwarding, and should be filled with channel-specific
@@ -980,61 +1010,3 @@
 	remove_channel(channel);
 }
 #endif /* USING_LISTENERS */
-
-/* close a stdout/stderr fd */
-static void close_read_fd(struct Channel * channel, int fd) {
-
-	/* don't close it if it is the same as writefd,
-	 * unless writefd is already set -1 */
-	TRACE(("enter close_read_fd"))
-	close_chan_fd(channel, fd, 0);
-	TRACE(("leave close_read_fd"))
-}
-
-/* close a stdin fd */
-static void close_write_fd(struct Channel * channel) {
-
-	TRACE(("enter close_write_fd"))
-	close_chan_fd(channel, channel->writefd, 1);
-	TRACE(("leave close_write_fd"))
-}
-
-/* close a fd, how is 0 for stdout/stderr, 1 for stdin */
-static void close_chan_fd(struct Channel *channel, int fd, int how) {
-
-	int closein = 0, closeout = 0;
-
-	if (channel->type->sepfds) {
-		TRACE(("shutdown((%d), %d)", fd, how))
-		shutdown(fd, how);
-		if (how == 0) {
-			closeout = 1;
-		} else {
-			closein = 1;
-		}
-	} else {
-		close(fd);
-		closein = closeout = 1;
-	}
-
-	if (closeout && fd == channel->readfd) {
-		channel->readfd = FD_CLOSED;
-	}
-	if (closeout && (channel->extrabuf == NULL) && (fd == channel->errfd)) {
-		channel->errfd = FD_CLOSED;
-	}
-
-	if (closein && fd == channel->writefd) {
-		channel->writefd = FD_CLOSED;
-	}
-	if (closein && (channel->extrabuf != NULL) && (fd == channel->errfd)) {
-		channel->errfd = FD_CLOSED;
-	}
-
-	/* if we called shutdown on it and all references are gone, then we 
-	 * need to close() it to stop it lingering */
-	if (channel->type->sepfds && channel->readfd == FD_CLOSED 
-		&& channel->writefd == FD_CLOSED && channel->errfd == FD_CLOSED) {
-		close(fd);
-	}
-}