Mercurial > dropbear
diff netio.c @ 1032:0da8ba489c23 fastopen
Move generic network routines to netio.c
author | Matt Johnston <matt@ucc.asn.au> |
---|---|
date | Fri, 20 Feb 2015 23:16:38 +0800 |
parents | |
children | ca71904cf3ee |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/netio.c Fri Feb 20 23:16:38 2015 +0800 @@ -0,0 +1,598 @@ +#include "netio.h" +#include "list.h" +#include "dbutil.h" +#include "session.h" +#include "debug.h" + +struct dropbear_progress_connection { + struct addrinfo *res; + struct addrinfo *res_iter; + + char *remotehost, *remoteport; /* For error reporting */ + + connect_callback cb; + void *cb_data; + + struct Queue *writequeue; /* A queue of encrypted packets to send with TCP fastopen, + or NULL. */ + + int sock; + + char* errstring; +}; + +#if defined(__linux__) && defined(TCP_DEFER_ACCEPT) +static void set_piggyback_ack(int sock) { + /* Undocumented Linux feature - set TCP_DEFER_ACCEPT and data will be piggybacked + on the 3rd packet (ack) of the TCP handshake. Saves a IP packet. + http://thread.gmane.org/gmane.linux.network/224627/focus=224727 + "Piggyback the final ACK of the three way TCP connection establishment with the data" */ + int val = 1; + /* No error checking, this is opportunistic */ + int err = setsockopt(sock, IPPROTO_TCP, TCP_DEFER_ACCEPT, (void*)&val, sizeof(val)); + if (err) + { + TRACE(("Failed setsockopt TCP_DEFER_ACCEPT: %s", strerror(errno))) + } +} +#endif + + +/* Deallocate a progress connection. Removes from the pending list if iter!=NULL. +Does not close sockets */ +static void remove_connect(struct dropbear_progress_connection *c, m_list_elem *iter) { + if (c->res) { + freeaddrinfo(c->res); + } + m_free(c->remotehost); + m_free(c->remoteport); + m_free(c->errstring); + m_free(c); + + if (iter) { + list_remove(iter); + } +} + +static void cancel_callback(int result, int sock, void* UNUSED(data), const char* UNUSED(errstring)) { + if (result == DROPBEAR_SUCCESS) + { + m_close(sock); + } +} + +void cancel_connect(struct dropbear_progress_connection *c) { + c->cb = cancel_callback; + c->cb_data = NULL; +} + +static void connect_try_next(struct dropbear_progress_connection *c) { + struct addrinfo *r; + + if (!c->res_iter) { + return; + } + + for (r = c->res_iter; r; r = r->ai_next) + { + assert(c->sock == -1); + + c->sock = socket(c->res_iter->ai_family, c->res_iter->ai_socktype, c->res_iter->ai_protocol); + if (c->sock < 0) { + continue; + } + + ses.maxfd = MAX(ses.maxfd, c->sock); + setnonblocking(c->sock); + +#if defined(__linux__) && defined(TCP_DEFER_ACCEPT) + set_piggyback_ack(c->sock); +#endif + +#ifdef PROGRESS_CONNECT_FALLBACK +#if 0 + if (connect(c->sock, r->ai_addr, r->ai_addrlen) < 0) { + if (errno == EINPROGRESS) { + TRACE(("Connect in progress")) + break; + } else { + close(c->sock); + c->sock = -1; + continue; + } + } + + break; /* Success. Treated the same as EINPROGRESS */ +#endif +#else + { + struct msghdr message; + int res = 0; + memset(&message, 0x0, sizeof(message)); + message.msg_name = r->ai_addr; + message.msg_namelen = r->ai_addrlen; + + if (c->writequeue) { + int iovlen; /* Linux msg_iovlen is a size_t */ + message.msg_iov = packet_queue_to_iovec(c->writequeue, &iovlen); + message.msg_iovlen = iovlen; + res = sendmsg(c->sock, &message, MSG_FASTOPEN); + if (res < 0 && errno == EOPNOTSUPP) { + TRACE(("Fastopen not supported")); + /* No kernel MSG_FASTOPEN support. Fall back below */ + c->writequeue = NULL; + } + m_free(message.msg_iov); + if (res > 0) { + packet_queue_consume(c->writequeue, res); + } + } + + if (!c->writequeue) { + res = connect(c->sock, r->ai_addr, r->ai_addrlen); + } + if (res < 0 && errno != EINPROGRESS) { + close(c->sock); + c->sock = -1; + continue; + } else { + break; + } + } +#endif + } + + + if (r) { + c->res_iter = r->ai_next; + } else { + c->res_iter = NULL; + } + + if (c->sock >= 0 || (errno == EINPROGRESS)) { + /* Success */ + set_sock_nodelay(c->sock); + return; + } else { + if (!c->res_iter) + { + + } + /* XXX - returning error message through */ +#if 0 + /* Failed */ + if (errstring != NULL && *errstring == NULL) { + int len; + len = 20 + strlen(strerror(err)); + *errstring = (char*)m_malloc(len); + snprintf(*errstring, len, "Error connecting: %s", strerror(err)); + } + TRACE(("Error connecting: %s", strerror(err))) +#endif + } +} + +/* Connect via TCP to a host. */ +struct dropbear_progress_connection *connect_remote(const char* remotehost, const char* remoteport, + connect_callback cb, void* cb_data) +{ + struct dropbear_progress_connection *c = NULL; + int err; + struct addrinfo hints; + + c = m_malloc(sizeof(*c)); + c->remotehost = m_strdup(remotehost); + c->remoteport = m_strdup(remoteport); + c->sock = -1; + c->cb = cb; + c->cb_data = cb_data; + + list_append(&ses.conn_pending, c); + + memset(&hints, 0, sizeof(hints)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_family = PF_UNSPEC; + + err = getaddrinfo(remotehost, remoteport, &hints, &c->res); + if (err) { + int len; + len = 100 + strlen(gai_strerror(err)); + c->errstring = (char*)m_malloc(len); + snprintf(c->errstring, len, "Error resolving '%s' port '%s'. %s", + remotehost, remoteport, gai_strerror(err)); + TRACE(("Error resolving: %s", gai_strerror(err))) + return NULL; + } + + c->res_iter = c->res; + + return c; +} + + +void set_connect_fds(fd_set *writefd) { + m_list_elem *iter; + TRACE(("enter handle_connect_fds")) + for (iter = ses.conn_pending.first; iter; iter = iter->next) { + struct dropbear_progress_connection *c = iter->item; + /* Set one going */ + while (c->res_iter && c->sock < 0) + { + connect_try_next(c); + } + if (c->sock >= 0) { + FD_SET(c->sock, writefd); + } else { + m_list_elem *remove_iter; + /* Final failure */ + if (!c->errstring) { + c->errstring = m_strdup("unexpected failure"); + } + c->cb(DROPBEAR_FAILURE, -1, c->cb_data, c->errstring); + /* Safely remove without invalidating iter */ + remove_iter = iter; + iter = iter->prev; + remove_connect(c, remove_iter); + } + } +} + +void handle_connect_fds(fd_set *writefd) { + m_list_elem *iter; + TRACE(("enter handle_connect_fds")) + for (iter = ses.conn_pending.first; iter; iter = iter->next) { + int val; + socklen_t vallen = sizeof(val); + struct dropbear_progress_connection *c = iter->item; + + if (!FD_ISSET(c->sock, writefd)) { + continue; + } + + TRACE(("handling %s port %s socket %d", c->remotehost, c->remoteport, c->sock)); + + if (getsockopt(c->sock, SOL_SOCKET, SO_ERROR, &val, &vallen) != 0) { + TRACE(("handle_connect_fds getsockopt(%d) SO_ERROR failed: %s", c->sock, strerror(errno))) + /* This isn't expected to happen - Unix has surprises though, continue gracefully. */ + m_close(c->sock); + c->sock = -1; + } else if (val != 0) { + /* Connect failed */ + TRACE(("connect to %s port %s failed.", c->remotehost, c->remoteport)) + m_close(c->sock); + c->sock = -1; + + m_free(c->errstring); + c->errstring = strerror(val); + } else { + /* New connection has been established */ + c->cb(DROPBEAR_SUCCESS, c->sock, c->cb_data, NULL); + remove_connect(c, iter); + TRACE(("leave handle_connect_fds - success")) + /* Must return here - remove_connect() invalidates iter */ + return; + } + } + TRACE(("leave handle_connect_fds - end iter")) +} + +void connect_set_writequeue(struct dropbear_progress_connection *c, struct Queue *writequeue) { + c->writequeue = writequeue; +} + +struct iovec * packet_queue_to_iovec(struct Queue *queue, int *ret_iov_count) { + struct iovec *iov = NULL; + struct Link *l; + unsigned int i, packet_type; + int len; + buffer *writebuf; + + #ifndef IOV_MAX + #define IOV_MAX UIO_MAXIOV + #endif + + *ret_iov_count = MIN(queue->count, IOV_MAX); + + iov = m_malloc(sizeof(*iov) * *ret_iov_count); + for (l = queue->head, i = 0; l; l = l->link, i++) + { + writebuf = (buffer*)l->item; + packet_type = writebuf->data[writebuf->len-1]; + len = writebuf->len - 1 - writebuf->pos; + dropbear_assert(len > 0); + TRACE2(("write_packet writev #%d type %d len %d/%d", i, packet_type, + len, writebuf->len-1)) + iov[i].iov_base = buf_getptr(writebuf, len); + iov[i].iov_len = len; + } + + return iov; +} + +void packet_queue_consume(struct Queue *queue, ssize_t written) { + buffer *writebuf; + int len; + while (written > 0) { + writebuf = (buffer*)examine(queue); + len = writebuf->len - 1 - writebuf->pos; + if (len > written) { + /* partial buffer write */ + buf_incrpos(writebuf, written); + written = 0; + } else { + written -= len; + dequeue(queue); + buf_free(writebuf); + } + } +} + +void set_sock_nodelay(int sock) { + int val; + + /* disable nagle */ + val = 1; + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, (void*)&val, sizeof(val)); +} + +#ifdef DROPBEAR_TCP_FAST_OPEN +void set_listen_fast_open(int sock) { + int qlen = MAX(MAX_UNAUTH_PER_IP, 5); + if (setsockopt(sock, SOL_TCP, TCP_FASTOPEN, &qlen, sizeof(qlen)) != 0) { + TRACE(("set_listen_fast_open failed for socket %d: %s", sock, strerror(errno))) + } +} + +#endif + +void set_sock_priority(int sock, enum dropbear_prio prio) { + + int iptos_val = 0, so_prio_val = 0, rc; + + /* Don't log ENOTSOCK errors so that this can harmlessly be called + * on a client '-J' proxy pipe */ + + /* set the TOS bit for either ipv4 or ipv6 */ +#ifdef IPTOS_LOWDELAY + if (prio == DROPBEAR_PRIO_LOWDELAY) { + iptos_val = IPTOS_LOWDELAY; + } else if (prio == DROPBEAR_PRIO_BULK) { + iptos_val = IPTOS_THROUGHPUT; + } +#if defined(IPPROTO_IPV6) && defined(IPV6_TCLASS) + rc = setsockopt(sock, IPPROTO_IPV6, IPV6_TCLASS, (void*)&iptos_val, sizeof(iptos_val)); + if (rc < 0 && errno != ENOTSOCK) { + TRACE(("Couldn't set IPV6_TCLASS (%s)", strerror(errno))); + } +#endif + rc = setsockopt(sock, IPPROTO_IP, IP_TOS, (void*)&iptos_val, sizeof(iptos_val)); + if (rc < 0 && errno != ENOTSOCK) { + TRACE(("Couldn't set IP_TOS (%s)", strerror(errno))); + } +#endif + +#ifdef SO_PRIORITY + if (prio == DROPBEAR_PRIO_LOWDELAY) { + so_prio_val = TC_PRIO_INTERACTIVE; + } else if (prio == DROPBEAR_PRIO_BULK) { + so_prio_val = TC_PRIO_BULK; + } + /* linux specific, sets QoS class. see tc-prio(8) */ + rc = setsockopt(sock, SOL_SOCKET, SO_PRIORITY, (void*) &so_prio_val, sizeof(so_prio_val)); + if (rc < 0 && errno != ENOTSOCK) + dropbear_log(LOG_WARNING, "Couldn't set SO_PRIORITY (%s)", + strerror(errno)); +#endif + +} + +/* Listen on address:port. + * Special cases are address of "" listening on everything, + * and address of NULL listening on localhost only. + * Returns the number of sockets bound on success, or -1 on failure. On + * failure, if errstring wasn't NULL, it'll be a newly malloced error + * string.*/ +int dropbear_listen(const char* address, const char* port, + int *socks, unsigned int sockcount, char **errstring, int *maxfd) { + + struct addrinfo hints, *res = NULL, *res0 = NULL; + int err; + unsigned int nsock; + struct linger linger; + int val; + int sock; + + TRACE(("enter dropbear_listen")) + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; /* TODO: let them flag v4 only etc */ + hints.ai_socktype = SOCK_STREAM; + + /* for calling getaddrinfo: + address == NULL and !AI_PASSIVE: local loopback + address == NULL and AI_PASSIVE: all interfaces + address != NULL: whatever the address says */ + if (!address) { + TRACE(("dropbear_listen: local loopback")) + } else { + if (address[0] == '\0') { + TRACE(("dropbear_listen: all interfaces")) + address = NULL; + } + hints.ai_flags = AI_PASSIVE; + } + err = getaddrinfo(address, port, &hints, &res0); + + if (err) { + if (errstring != NULL && *errstring == NULL) { + int len; + len = 20 + strlen(gai_strerror(err)); + *errstring = (char*)m_malloc(len); + snprintf(*errstring, len, "Error resolving: %s", gai_strerror(err)); + } + if (res0) { + freeaddrinfo(res0); + res0 = NULL; + } + TRACE(("leave dropbear_listen: failed resolving")) + return -1; + } + + + nsock = 0; + for (res = res0; res != NULL && nsock < sockcount; + res = res->ai_next) { + + /* Get a socket */ + socks[nsock] = socket(res->ai_family, res->ai_socktype, + res->ai_protocol); + + sock = socks[nsock]; /* For clarity */ + + if (sock < 0) { + err = errno; + TRACE(("socket() failed")) + continue; + } + + /* Various useful socket options */ + val = 1; + /* set to reuse, quick timeout */ + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (void*) &val, sizeof(val)); + linger.l_onoff = 1; + linger.l_linger = 5; + setsockopt(sock, SOL_SOCKET, SO_LINGER, (void*)&linger, sizeof(linger)); + +#if defined(IPPROTO_IPV6) && defined(IPV6_V6ONLY) + if (res->ai_family == AF_INET6) { + int on = 1; + if (setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + &on, sizeof(on)) == -1) { + dropbear_log(LOG_WARNING, "Couldn't set IPV6_V6ONLY"); + } + } +#endif + + set_sock_nodelay(sock); + + if (bind(sock, res->ai_addr, res->ai_addrlen) < 0) { + err = errno; + close(sock); + TRACE(("bind(%s) failed", port)) + continue; + } + + if (listen(sock, DROPBEAR_LISTEN_BACKLOG) < 0) { + err = errno; + close(sock); + TRACE(("listen() failed")) + continue; + } + + *maxfd = MAX(*maxfd, sock); + + nsock++; + } + + if (res0) { + freeaddrinfo(res0); + res0 = NULL; + } + + if (nsock == 0) { + if (errstring != NULL && *errstring == NULL) { + int len; + len = 20 + strlen(strerror(err)); + *errstring = (char*)m_malloc(len); + snprintf(*errstring, len, "Error listening: %s", strerror(err)); + } + TRACE(("leave dropbear_listen: failure, %s", strerror(err))) + return -1; + } + + TRACE(("leave dropbear_listen: success, %d socks bound", nsock)) + return nsock; +} + +void get_socket_address(int fd, char **local_host, char **local_port, + char **remote_host, char **remote_port, int host_lookup) +{ + struct sockaddr_storage addr; + socklen_t addrlen; + + if (local_host || local_port) { + addrlen = sizeof(addr); + if (getsockname(fd, (struct sockaddr*)&addr, &addrlen) < 0) { + dropbear_exit("Failed socket address: %s", strerror(errno)); + } + getaddrstring(&addr, local_host, local_port, host_lookup); + } + if (remote_host || remote_port) { + addrlen = sizeof(addr); + if (getpeername(fd, (struct sockaddr*)&addr, &addrlen) < 0) { + dropbear_exit("Failed socket address: %s", strerror(errno)); + } + getaddrstring(&addr, remote_host, remote_port, host_lookup); + } +} + +/* Return a string representation of the socket address passed. The return + * value is allocated with malloc() */ +void getaddrstring(struct sockaddr_storage* addr, + char **ret_host, char **ret_port, + int host_lookup) { + + char host[NI_MAXHOST+1], serv[NI_MAXSERV+1]; + unsigned int len; + int ret; + + int flags = NI_NUMERICSERV | NI_NUMERICHOST; + +#ifndef DO_HOST_LOOKUP + host_lookup = 0; +#endif + + if (host_lookup) { + flags = NI_NUMERICSERV; + } + + len = sizeof(struct sockaddr_storage); + /* Some platforms such as Solaris 8 require that len is the length + * of the specific structure. Some older linux systems (glibc 2.1.3 + * such as debian potato) have sockaddr_storage.__ss_family instead + * but we'll ignore them */ +#ifdef HAVE_STRUCT_SOCKADDR_STORAGE_SS_FAMILY + if (addr->ss_family == AF_INET) { + len = sizeof(struct sockaddr_in); + } +#ifdef AF_INET6 + if (addr->ss_family == AF_INET6) { + len = sizeof(struct sockaddr_in6); + } +#endif +#endif + + ret = getnameinfo((struct sockaddr*)addr, len, host, sizeof(host)-1, + serv, sizeof(serv)-1, flags); + + if (ret != 0) { + if (host_lookup) { + /* On some systems (Darwin does it) we get EINTR from getnameinfo + * somehow. Eew. So we'll just return the IP, since that doesn't seem + * to exhibit that behaviour. */ + getaddrstring(addr, ret_host, ret_port, 0); + return; + } else { + /* if we can't do a numeric lookup, something's gone terribly wrong */ + dropbear_exit("Failed lookup: %s", gai_strerror(ret)); + } + } + + if (ret_host) { + *ret_host = m_strdup(host); + } + if (ret_port) { + *ret_port = m_strdup(serv); + } +} +