diff netio.c @ 1511:5916af64acd4 fuzz

merge from main
author Matt Johnston <matt@ucc.asn.au>
date Sat, 17 Feb 2018 19:29:51 +0800
parents 69862e8cc405 78d8c3ffdfe1
children 2f64cb3d3007
line wrap: on
line diff
--- a/netio.c	Tue Jan 23 23:27:40 2018 +0800
+++ b/netio.c	Sat Feb 17 19:29:51 2018 +0800
@@ -19,6 +19,7 @@
 	int sock;
 
 	char* errstring;
+	char *bind_address, *bind_port;
 };
 
 /* Deallocate a progress connection. Removes from the pending list if iter!=NULL.
@@ -30,6 +31,8 @@
 	m_free(c->remotehost);
 	m_free(c->remoteport);
 	m_free(c->errstring);
+	m_free(c->bind_address);
+	m_free(c->bind_port);
 	m_free(c);
 
 	if (iter) {
@@ -51,6 +54,7 @@
 
 static void connect_try_next(struct dropbear_progress_connection *c) {
 	struct addrinfo *r;
+	int err;
 	int res = 0;
 	int fastopen = 0;
 #if DROPBEAR_CLIENT_TCP_FAST_OPEN
@@ -66,6 +70,44 @@
 			continue;
 		}
 
+		if (c->bind_address || c->bind_port) {
+			/* bind to a source port/address */
+			struct addrinfo hints;
+			struct addrinfo *bindaddr = NULL;
+			memset(&hints, 0, sizeof(hints));
+			hints.ai_socktype = SOCK_STREAM;
+			hints.ai_family = r->ai_family;
+			hints.ai_flags = AI_PASSIVE;
+
+			err = getaddrinfo(c->bind_address, c->bind_port, &hints, &bindaddr);
+			if (err) {
+				int len = 100 + strlen(gai_strerror(err));
+				m_free(c->errstring);
+				c->errstring = (char*)m_malloc(len);
+				snprintf(c->errstring, len, "Error resolving bind address '%s' (port %s). %s", 
+						c->bind_address, c->bind_port, gai_strerror(err));
+				TRACE(("Error resolving bind: %s", gai_strerror(err)))
+				close(c->sock);
+				c->sock = -1;
+				continue;
+			}
+			res = bind(c->sock, bindaddr->ai_addr, bindaddr->ai_addrlen);
+			freeaddrinfo(bindaddr);
+			bindaddr = NULL;
+			if (res < 0) {
+				/* failure */
+				int keep_errno = errno;
+				int len = 300;
+				m_free(c->errstring);
+				c->errstring = m_malloc(len);
+				snprintf(c->errstring, len, "Error binding local address '%s' (port %s). %s", 
+						c->bind_address, c->bind_port, strerror(keep_errno));
+				close(c->sock);
+				c->sock = -1;
+				continue;
+			}
+		}
+
 		ses.maxfd = MAX(ses.maxfd, c->sock);
 		set_sock_nodelay(c->sock);
 		setnonblocking(c->sock);
@@ -130,7 +172,8 @@
 
 /* Connect via TCP to a host. */
 struct dropbear_progress_connection *connect_remote(const char* remotehost, const char* remoteport,
-	connect_callback cb, void* cb_data)
+	connect_callback cb, void* cb_data, 
+	const char* bind_address, const char* bind_port)
 {
 	struct dropbear_progress_connection *c = NULL;
 	int err;
@@ -160,6 +203,13 @@
 	} else {
 		c->res_iter = c->res;
 	}
+	
+	if (bind_address) {
+		c->bind_address = m_strdup(bind_address);
+	}
+	if (bind_port) {
+		c->bind_port = m_strdup(bind_port);
+	}
 
 	return c;
 }
@@ -198,7 +248,7 @@
 	TRACE(("leave set_connect_fds"))
 }
 
-void handle_connect_fds(fd_set *writefd) {
+void handle_connect_fds(const fd_set *writefd) {
 	m_list_elem *iter;
 	TRACE(("enter handle_connect_fds"))
 	for (iter = ses.conn_pending.first; iter; iter = iter->next) {
@@ -241,7 +291,7 @@
 	c->writequeue = writequeue;
 }
 
-void packet_queue_to_iovec(struct Queue *queue, struct iovec *iov, unsigned int *iov_count) {
+void packet_queue_to_iovec(const struct Queue *queue, struct iovec *iov, unsigned int *iov_count) {
 	struct Link *l;
 	unsigned int i;
 	int len;
@@ -355,6 +405,37 @@
 
 }
 
+/* from openssh/canohost.c avoid premature-optimization */
+int get_sock_port(int sock) {
+	struct sockaddr_storage from;
+	socklen_t fromlen;
+	char strport[NI_MAXSERV];
+	int r;
+
+	/* Get IP address of client. */
+	fromlen = sizeof(from);
+	memset(&from, 0, sizeof(from));
+	if (getsockname(sock, (struct sockaddr *)&from, &fromlen) < 0) {
+		TRACE(("getsockname failed: %d", errno))
+		return 0;
+	}
+
+	/* Work around Linux IPv6 weirdness */
+	if (from.ss_family == AF_INET6)
+		fromlen = sizeof(struct sockaddr_in6);
+
+	/* Non-inet sockets don't have a port number. */
+	if (from.ss_family != AF_INET && from.ss_family != AF_INET6)
+		return 0;
+
+	/* Return port number. */
+	if ((r = getnameinfo((struct sockaddr *)&from, fromlen, NULL, 0,
+	    strport, sizeof(strport), NI_NUMERICSERV)) != 0) {
+		TRACE(("netio.c/get_sock_port/getnameinfo NI_NUMERICSERV failed: %d", r))
+	}
+	return atoi(strport);
+}
+
 /* Listen on address:port. 
  * Special cases are address of "" listening on everything,
  * and address of NULL listening on localhost only.
@@ -407,11 +488,29 @@
 		return -1;
 	}
 
+	/*
+	 * when listening on server-assigned-port 0
+	 * the assigned ports may differ for address families (v4/v6)
+	 * causing problems for tcpip-forward
+	 * caller can do a get_socket_address to discover assigned-port
+	 * hence, use same port for all address families
+	 */
+	u_int16_t *allocated_lport_p = NULL;
+	int allocated_lport = 0;
 
 	nsock = 0;
 	for (res = res0; res != NULL && nsock < sockcount;
 			res = res->ai_next) {
 
+		if (allocated_lport > 0) {
+			if (AF_INET == res->ai_family) {
+				allocated_lport_p = &((struct sockaddr_in *)res->ai_addr)->sin_port;
+			} else if (AF_INET6 == res->ai_family) {
+				allocated_lport_p = &((struct sockaddr_in6 *)res->ai_addr)->sin6_port;
+			}
+			*allocated_lport_p = htons(allocated_lport);
+		}
+
 		/* Get a socket */
 		socks[nsock] = socket(res->ai_family, res->ai_socktype,
 				res->ai_protocol);
@@ -458,6 +557,10 @@
 			continue;
 		}
 
+		if (0 == allocated_lport) {
+			allocated_lport = get_sock_port(sock);
+		}
+
 		*maxfd = MAX(*maxfd, sock);
 
 		nsock++;
@@ -524,7 +627,7 @@
 	
 	int flags = NI_NUMERICSERV | NI_NUMERICHOST;
 
-#ifndef DO_HOST_LOOKUP
+#if !DO_HOST_LOOKUP
 	host_lookup = 0;
 #endif