diff common-channel.c @ 1069:2fa71c3b2827 pam

merge pam branch up to date
author Matt Johnston <matt@ucc.asn.au>
date Mon, 16 Mar 2015 21:34:05 +0800
parents c71df09bc610
children 696205e3dc99
line wrap: on
line diff
--- a/common-channel.c	Fri Jan 23 22:32:49 2015 +0800
+++ b/common-channel.c	Mon Mar 16 21:34:05 2015 +0800
@@ -35,20 +35,21 @@
 #include "ssh.h"
 #include "listener.h"
 #include "runopts.h"
+#include "netio.h"
 
 static void send_msg_channel_open_failure(unsigned int remotechan, int reason,
 		const unsigned char *text, const unsigned char *lang);
 static void send_msg_channel_open_confirmation(struct Channel* channel,
 		unsigned int recvwindow, 
 		unsigned int recvmaxpacket);
-static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf);
+static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf,
+	const unsigned char *moredata, unsigned int *morelen);
 static void send_msg_channel_window_adjust(struct Channel *channel, 
 		unsigned int incr);
 static void send_msg_channel_data(struct Channel *channel, int isextended);
 static void send_msg_channel_eof(struct Channel *channel);
 static void send_msg_channel_close(struct Channel *channel);
 static void remove_channel(struct Channel *channel);
-static void check_in_progress(struct Channel *channel);
 static unsigned int write_pending(struct Channel * channel);
 static void check_close(struct Channel *channel);
 static void close_chan_fd(struct Channel *channel, int fd, int how);
@@ -163,7 +164,6 @@
 	newchan->writefd = FD_UNINIT;
 	newchan->readfd = FD_UNINIT;
 	newchan->errfd = FD_CLOSED; /* this isn't always set to start with */
-	newchan->initconn = 0;
 	newchan->await_open = 0;
 	newchan->flushing = 0;
 
@@ -242,20 +242,14 @@
 
 		/* write to program/pipe stdin */
 		if (channel->writefd >= 0 && FD_ISSET(channel->writefd, writefds)) {
-			if (channel->initconn) {
-				/* XXX should this go somewhere cleaner? */
-				check_in_progress(channel);
-				continue; /* Important not to use the channel after
-							 check_in_progress(), as it may be NULL */
-			}
-			writechannel(channel, channel->writefd, channel->writebuf);
+			writechannel(channel, channel->writefd, channel->writebuf, NULL, NULL);
 			do_check_close = 1;
 		}
 		
 		/* stderr for client mode */
 		if (ERRFD_IS_WRITE(channel)
 				&& channel->errfd >= 0 && FD_ISSET(channel->errfd, writefds)) {
-			writechannel(channel, channel->errfd, channel->extrabuf);
+			writechannel(channel, channel->errfd, channel->extrabuf, NULL, NULL);
 			do_check_close = 1;
 		}
 
@@ -374,27 +368,27 @@
  * if so, set up the channel properly. Otherwise, the channel is cleaned up, so
  * it is important that the channel reference isn't used after a call to this
  * function */
-static void check_in_progress(struct Channel *channel) {
+void channel_connect_done(int result, int sock, void* user_data, const char* UNUSED(errstring)) {
 
-	int val;
-	socklen_t vallen = sizeof(val);
-
-	TRACE(("enter check_in_progress"))
+	struct Channel *channel = user_data;
 
-	if (getsockopt(channel->writefd, SOL_SOCKET, SO_ERROR, &val, &vallen)
-			|| val != 0) {
-		send_msg_channel_open_failure(channel->remotechan,
-				SSH_OPEN_CONNECT_FAILED, "", "");
-		close(channel->writefd);
-		remove_channel(channel);
-		TRACE(("leave check_in_progress: fail"))
-	} else {
+	TRACE(("enter channel_connect_done"))
+
+	if (result == DROPBEAR_SUCCESS)
+	{
+		channel->readfd = channel->writefd = sock;
+		channel->conn_pending = NULL;
 		chan_initwritebuf(channel);
 		send_msg_channel_open_confirmation(channel, channel->recvwindow,
 				channel->recvmaxpacket);
-		channel->readfd = channel->writefd;
-		channel->initconn = 0;
-		TRACE(("leave check_in_progress: success"))
+		TRACE(("leave channel_connect_done: success"))
+	}
+	else
+	{
+		send_msg_channel_open_failure(channel->remotechan,
+				SSH_OPEN_CONNECT_FAILED, "", "");
+		remove_channel(channel);
+		TRACE(("leave check_in_progress: fail"))
 	}
 }
 
@@ -402,7 +396,7 @@
 /* Send the close message and set the channel as closed */
 static void send_msg_channel_close(struct Channel *channel) {
 
-	TRACE(("enter send_msg_channel_close %p", channel))
+	TRACE(("enter send_msg_channel_close %p", (void*)channel))
 	if (channel->type->closehandler 
 			&& !channel->close_handler_done) {
 		channel->type->closehandler(channel);
@@ -441,14 +435,67 @@
 }
 
 /* Called to write data out to the local side of the channel. 
- * Only called when we know we can write to a channel, writes as much as
- * possible */
-static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf) {
+   Writes the circular buffer contents and also the "moredata" buffer
+   if not null. Will ignore EAGAIN */
+static void writechannel(struct Channel* channel, int fd, circbuffer *cbuf,
+	const unsigned char *moredata, unsigned int *morelen) {
 
-	int len, maxlen;
+	struct iovec iov[3];
+	unsigned char *circ_p1, *circ_p2;
+	unsigned int circ_len1, circ_len2;
+	int io_count = 0;
+
+	int written;
 
 	TRACE(("enter writechannel fd %d", fd))
 
+	cbuf_readptrs(cbuf, &circ_p1, &circ_len1, &circ_p2, &circ_len2);
+
+	if (circ_len1 > 0) {
+		TRACE(("circ1 %d", circ_len1))
+		iov[io_count].iov_base = circ_p1;
+		iov[io_count].iov_len = circ_len1;
+		io_count++;
+	}
+
+	if (circ_len2 > 0) {
+		TRACE(("circ2 %d", circ_len2))
+		iov[io_count].iov_base = circ_p2;
+		iov[io_count].iov_len = circ_len2;
+		io_count++;
+	}
+
+	if (morelen) {
+		assert(moredata);
+		TRACE(("more %d", *morelen))
+		iov[io_count].iov_base = (void*)moredata;
+		iov[io_count].iov_len  = *morelen;
+		io_count++;
+	}
+
+	if (morelen) {
+		/* Default return value, none consumed */
+		*morelen = 0;
+	}
+
+	written = writev(fd, iov, io_count);
+
+	if (written < 0) {
+		if (errno != EINTR && errno != EAGAIN) {
+			TRACE(("errno %d len %d", errno, len))
+			close_chan_fd(channel, fd, SHUT_WR);
+		}
+	} else {
+		int cbuf_written = MIN(circ_len1+circ_len2, (unsigned int)written);
+		cbuf_incrread(cbuf, cbuf_written);
+		if (morelen) {
+			*morelen = written - cbuf_written;
+		}
+		channel->recvdonelen += written;
+	}
+
+#if 0
+
 	maxlen = cbuf_readlen(cbuf);
 
 	/* Write the data out */
@@ -465,10 +512,10 @@
 
 	cbuf_incrread(cbuf, len);
 	channel->recvdonelen += len;
+#endif
 
 	/* Window adjust handling */
 	if (channel->recvdonelen >= RECV_WINDOWEXTEND) {
-		/* Set it back to max window */
 		send_msg_channel_window_adjust(channel, channel->recvdonelen);
 		channel->recvwindow += channel->recvdonelen;
 		channel->recvdonelen = 0;
@@ -514,8 +561,7 @@
 		}
 
 		/* Stuff from the wire */
-		if (channel->initconn
-			||(channel->writefd >= 0 && cbuf_getused(channel->writebuf) > 0)) {
+		if (channel->writefd >= 0 && cbuf_getused(channel->writebuf) > 0) {
 				FD_SET(channel->writefd, writefds);
 		}
 
@@ -586,11 +632,11 @@
 		/* close the FDs in case they haven't been done
 		 * yet (they might have been shutdown etc) */
 		TRACE(("CLOSE writefd %d", channel->writefd))
-		close(channel->writefd);
+		m_close(channel->writefd);
 		TRACE(("CLOSE readfd %d", channel->readfd))
-		close(channel->readfd);
+		m_close(channel->readfd);
 		TRACE(("CLOSE errfd %d", channel->errfd))
-		close(channel->errfd);
+		m_close(channel->errfd);
 	}
 
 	if (!channel->close_handler_done
@@ -599,6 +645,10 @@
 		channel->close_handler_done = 1;
 	}
 
+	if (channel->conn_pending) {
+		cancel_connect(channel->conn_pending);
+	}
+
 	ses.channels[channel->index] = NULL;
 	m_free(channel);
 	ses.chancount--;
@@ -616,7 +666,7 @@
 
 	channel = getchannel();
 
-	TRACE(("enter recv_msg_channel_request %p", channel))
+	TRACE(("enter recv_msg_channel_request %p", (void*)channel))
 
 	if (channel->sent_close) {
 		TRACE(("leave recv_msg_channel_request: already closed channel"))
@@ -749,6 +799,7 @@
 	unsigned int maxdata;
 	unsigned int buflen;
 	unsigned int len;
+	unsigned int consumed;
 
 	TRACE(("enter recv_msg_channel_data"))
 
@@ -775,6 +826,17 @@
 		dropbear_exit("Oversized packet");
 	}
 
+	dropbear_assert(channel->recvwindow >= datalen);
+	channel->recvwindow -= datalen;
+	dropbear_assert(channel->recvwindow <= opts.recv_window);
+
+	consumed = datalen;
+	writechannel(channel, fd, cbuf, buf_getptr(ses.payload, datalen), &consumed);
+
+	datalen -= consumed;
+	buf_incrpos(ses.payload, consumed);
+
+
 	/* We may have to run throught twice, if the buffer wraps around. Can't
 	 * just "leave it for next time" like with writechannel, since this
 	 * is payload data */
@@ -790,10 +852,6 @@
 		len -= buflen;
 	}
 
-	dropbear_assert(channel->recvwindow >= datalen);
-	channel->recvwindow -= datalen;
-	dropbear_assert(channel->recvwindow <= opts.recv_window);
-
 	TRACE(("leave recv_msg_channel_data"))
 }
 
@@ -1001,7 +1059,7 @@
 		}
 	} else {
 		TRACE(("CLOSE some fd %d", fd))
-		close(fd);
+		m_close(fd);
 		closein = closeout = 1;
 	}
 
@@ -1024,7 +1082,7 @@
 	if (channel->type->sepfds && channel->readfd == FD_CLOSED 
 		&& channel->writefd == FD_CLOSED && channel->errfd == FD_CLOSED) {
 		TRACE(("CLOSE (finally) of %d", fd))
-		close(fd);
+		m_close(fd);
 	}
 }
 
@@ -1141,15 +1199,15 @@
 }
 
 struct Channel* get_any_ready_channel() {
+	size_t i;
 	if (ses.chancount == 0) {
 		return NULL;
 	}
-	size_t i;
 	for (i = 0; i < ses.chansize; i++) {
 		struct Channel *chan = ses.channels[i];
 		if (chan
 				&& !(chan->sent_eof || chan->recv_eof)
-				&& !(chan->await_open || chan->initconn)) {
+				&& !(chan->await_open)) {
 			return chan;
 		}
 	}