/*
 * Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
 * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 * PERFORMANCE OF THIS SOFTWARE.
 */

/*
 * Initial timeout between connection attempts.  The smaller this is the
 * more embryonic connection attempts that will be made.  On each subsequent
 * connection attempt the timeout will be halved leading to all connection
 * attempts being initiated within 2 * TIMEOUT ms.
 *
 * 100 ms will let most intra continent connections succeed without a
 *	  embryonic connection.
 * 500 ms well let most intercontinental connections succeed without a
 * 	  embryonic connection.
 */
#define TIMEOUT 500	/* 500 ms */

#define TESTING 1

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/poll.h>
#include <sys/time.h>

#include <netinet/in.h>

#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

int
connect_to_host(struct addrinfo *res0) {
	struct addrinfo *res;
	int fd = -1, n, i, j, flags, count;
	struct pollfd *fds;
	int timeout = TIMEOUT;

	/*
	 * Work out how many possible descriptors we could use.
	 */
	for (res = res0, count = 0; res; res = res->ai_next)
		count++;
	fds = calloc(count, sizeof(*fds));
	if (fds == NULL) {
		perror("calloc");
		goto cleanup;
	}

	for (res = res0, i = 0, count = 0; res; res = res->ai_next) {
		fd = socket(res->ai_family, res->ai_socktype, res->ai_protocol);
		if (fd == -1) {
			/*
			 * If AI_ADDRCONFIG is not supported we will get
			 * EAFNOSUPPORT returned.  Behave as if the address
			 * was not there.
			 */
			if (errno != EAFNOSUPPORT)
				perror("socket");
			else if (res->ai_next != NULL)
				continue;
		} else if ((flags = fcntl(fd, F_GETFL)) == -1) {
			perror("fcntl");
			close(fd);
		} else if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) {
			perror("fcntl");
			close(fd);
		} else if (connect(fd, res->ai_addr, res->ai_addrlen) == -1) {
			if (errno != EINPROGRESS) {
				perror("connect");
				close(fd);
			} else {
				/*
				 * Record the information for this descriptor.
				 */
				fds[i].fd = fd;
				fds[i].events = POLLERR | POLLHUP |
						POLLIN | POLLOUT;
				count++;
				i++;
			}
		} else  {
			/*
			 * We connected without blocking.
			 */
			goto done;
		}

		if (count == 0)
			continue;

		do {
			if (res->ai_next == NULL)
				timeout = -1;

			n = poll(fds, i, timeout);
			if (n == 0) {
				timeout >>= 1;
				break;
			}
			if (n < 0) {
				if (errno == EAGAIN || errno == EINTR)
					continue;
				perror("poll");
				fd = -1;
				goto done;
			}
			for (j = 0; j < i; j++) {
				if (fds[j].fd == -1 || fds[j].events == 0 ||
				    fds[j].revents == 0)
					continue;
				fd = fds[j].fd;
				if (fds[j].revents & POLLHUP) {
					close(fd);
					fds[j].fd = -1;
					fds[j].events = 0;
					count--;
					continue;
				}
				/* Connect succeeded. */
				goto done;
			}
		} while (timeout == -1 && count != 0);
	}

	/* We failed to connect. */
	fd = -1;

 done:
	/* Close all other descriptors we have created. */
	for (j = 0; j < i; j++)
		if (fds[j].fd != fd && fds[j].fd != -1) {
			close(fds[j].fd);
		}

	if (fd != -1) {
		/* Restore default blocking behaviour.  */
		if ((flags = fcntl(fd, F_GETFL)) != -1) {
			flags &= ~O_NONBLOCK;
			if (fcntl(fd, F_SETFL, flags) == -1)
				perror("fcntl");
		} else
			perror("fcntl");
	}

 cleanup:
	/* Free everything. */
	if (fds != NULL) free(fds);

	return (fd);
}

#if TESTING
int
main(int argc, char **argv) {
	int fd, n;
	struct timeval then, now;
	struct addrinfo hints, *res0;
	const char *hostname, *servname;

	hostname  = "localhost";
	if (argv[1])
		hostname = argv[1];
	servname = "http";

	/*
	 * Not all getaddrinfo() implementations support AI_ADDRCONFIG
	 * even if it is defined.  Retry without it on EAI_BADFLAGS.
	 */
	memset(&hints, 0, sizeof(hints));
	hints.ai_family = PF_UNSPEC;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;
#ifdef AI_ADDRCONFIG
	hints.ai_flags = AI_ADDRCONFIG;
#endif

#ifdef AI_ADDRCONFIG
 again:
#endif
	n = getaddrinfo(hostname, servname, &hints, &res0);
	if (n != 0) {
#ifdef AI_ADDRCONFIG
		if (n == EAI_BADFLAGS && hints.ai_flags & AI_ADDRCONFIG) {
			hints.ai_flags &= ~AI_ADDRCONFIG;
			goto again;
		}
#endif
		fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(n));
		exit(1);
	}

	gettimeofday(&then, NULL);
	fd = connect_to_host(res0);
	gettimeofday(&now, NULL);
	freeaddrinfo(res0);
	now.tv_sec -= then.tv_sec;
	now.tv_usec -= then.tv_usec;
	while (now.tv_sec > 0) {
		now.tv_usec += 1000000;
		now.tv_sec -= 1;
	}
	fprintf(stderr, "connect_to_host(%s) -> %d in %d ms\n", hostname, fd,
		(int)now.tv_usec/1000);
	close(fd);
	exit(0);
}
#endif
