/*
 * Copyright (C) 2011  Internet Systems Consortium, Inc. ("ISC")
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
 * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 * PERFORMANCE OF THIS SOFTWARE.
 */

/*
 * Initial timeout between connection attempts.  The smaller this is the
 * more embryonic connection attempts that will be made.  On each subsequent
 * connection attempt the timeout will be halved leading to all connection
 * attempts being initiated within 2 * TIMEOUT ms.
 *
 * 100 ms will let most intra continent connections succeed without a
 *	  embryonic connection.
 * 500 ms well let most intercontinental connections succeed without a
 * 	  embryonic connection.
 */
#define TIMEOUT 500	/* 500 ms */

#define TESTING 1

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/time.h>

#include <netinet/in.h>

#include <assert.h>
#include <errno.h>
#include <netdb.h>
#include <pthread.h>
#include <stdarg.h>
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

struct common {
	pthread_mutex_t *mutex;
	pthread_cond_t *cond;
	int *count;
	int *fd;
};

struct state {
	struct addrinfo *addrinfo;
	struct common *common;
};

static void
fatal(char *format, ...) {
	va_list ap;

	va_start(ap, format);
	vfprintf(stderr, format, ap);
	va_end(ap);
	abort();
}

static void
connect_to_address_cleanup(void *arg) {
	int *fd = arg;

	if (*fd != -1)
		close(*fd);
}

static void *
connect_to_address(void *arg) {
	struct state *state = arg;
	struct addrinfo *addrinfo = state->addrinfo;
	struct common *common = state->common;
	int fd = -1, n;

	/* Ensure that fd is closed if we are canceled. */
	pthread_cleanup_push(connect_to_address_cleanup, &fd);
	fd = socket(addrinfo->ai_family, addrinfo->ai_socktype,
		    addrinfo->ai_protocol);
	if (fd < 0) {
		/*
		 * If AI_ADDRCONFIG is not supported we will get EAFNOSUPPORT
		 * returned.  Silently ignore it.
		 */
		if (errno != EAFNOSUPPORT)
			perror("socket");
	} else if (connect(fd, addrinfo->ai_addr, addrinfo->ai_addrlen) == -1) {
		perror("connect");
		close(fd);
		fd = -1;
	} 
	/* If we get here we want the rest of the thread to complete. */
	n = pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, NULL);
	if (n != 0)
		fatal("pthread_setcancelstate: %s", strerror(n));
	n = pthread_mutex_lock(common->mutex);
	if (n != 0)
		fatal("pthread_mutex_lock: %s", strerror(n));
	if (fd != -1 && *common->fd == -1) {
		/* Record success. */
		*common->fd = fd;
		fd = -1;
	}
	*common->count -= 1;
	n = pthread_mutex_unlock(common->mutex);
	if (n != 0)
		fatal("pthread_mutex_unlock: %s", strerror(n));
	pthread_cleanup_pop(1);
	/* Signal that we are done. */
	n = pthread_cond_signal(common->cond);
	if (n != 0)
		fatal("pthread_cond_signal: %s", strerror(n));
	pthread_exit(NULL);
}

int
connect_to_host(struct addrinfo *res0) {
	struct addrinfo *res;
	int fd = -1, n, i, j, count;
	pthread_t *threads;
	struct state *state;
	int timeout = TIMEOUT * 1000;
	struct timespec timespec;
	pthread_cond_t cond;
	pthread_mutex_t mutex;
	struct common common;

	/*
	 * Work out how many possible descriptors we could use.
	 */
	for (res = res0, count = 0; res; res = res->ai_next)
		count++;
	threads = calloc(count, sizeof(*threads));
	state = calloc(count, sizeof(*state));
	if (threads == NULL || state == NULL) {
		perror("calloc");
		goto cleanup;
	}

	n = pthread_mutex_init(&mutex, NULL);
	if (n != 0) {
		fprintf(stderr, "pthread_mutex_init: %s", strerror(n));
		goto cleanup;
	}

	n = pthread_cond_init(&cond, NULL);
	if (n != 0) {
		fprintf(stderr, "pthread_cond_init: %s", strerror(n));
		goto cleanup_mutex;
	}

	common.fd = &fd;
	common.cond = &cond;
	common.count = &count;
	common.mutex = &mutex;

	/*
	 * fd and count are protected by mutex.
	 */
	for (res = res0, i = 0, count = 0; res; res = res->ai_next) {
		bool done;

		state[i].common = &common;
		state[i].addrinfo = res;
		
		n = pthread_create(&threads[i], NULL, connect_to_address,
				   &state[i]);
		if (n != 0)
			fprintf(stderr, "pthread_create: %s", strerror(n));
		else {
			i++;
			n = pthread_mutex_lock(&mutex);
			if (n != 0)
				fatal("pthread_mutex_lock: %s", strerror(n));
			count++;
			n = pthread_mutex_unlock(&mutex);
			if (n != 0)
				fatal("pthread_mutex_unlock: %s", strerror(n));
		}
		done = false;
		do {
			n = pthread_mutex_lock(&mutex);
			if (n != 0)
				fatal("pthread_mutex_lock: %s", strerror(n));
			/* No outstanding threads? */
			if (count == 0) {
				/* Are we done? */
				if (fd != -1 || res->ai_next == NULL)
					done = true;
				n = pthread_mutex_unlock(&mutex);
				if (n != 0)
					fatal("pthread_mutex_unlock: %s",
					      strerror(n));
				break;
			}
			if (res->ai_next != NULL) {
				struct timeval tv;

				n = gettimeofday(&tv, NULL);
				if (n != 0)
					fatal("gettimeofday: %s\n",
					      strerror(errno));
				timespec.tv_sec = tv.tv_sec;
				timespec.tv_nsec = (tv.tv_usec + timeout) *
						   1000;
				while (timespec.tv_nsec >= 1000000000) {
					timespec.tv_nsec -= 1000000000;
					timespec.tv_sec += 1;
				}
				n = pthread_cond_timedwait(&cond, &mutex,
						           &timespec);
			} else
				n = pthread_cond_wait(&cond, &mutex);

			if (n == ETIMEDOUT)
				timeout >>= 1;
			else if (n != 0)
				fatal("pthread_cond_%swait: %s\n",
				      res->ai_next != NULL ? "timed" : "",
				      strerror(n));
			if (fd != -1 || (count == 0 && res->ai_next == NULL))
				done = true;
			n = pthread_mutex_unlock(&mutex);
			if (n != 0)
				fatal("pthread_mutex_unlock: %s", strerror(n));
		} while (!done && res->ai_next == NULL);
		if (done)
			break;
	}

	/* Shutdown and tidy up all the threads we started. */
	for (j = 0; j < i; j++) {
		n = pthread_cancel(threads[j]);
		if (n != 0 && n != ESRCH)
			fatal("pthread_cancel: %s\n", strerror(n));
		n = pthread_join(threads[j], NULL);
		if (n != 0)
			fatal("pthread_join: %s\n", strerror(n));
	}

	/* Cleanup the resources we used. */
	n = pthread_cond_destroy(&cond);
	if (n != 0)
		fatal("pthread_cond_destroy: %s", strerror(n));

 cleanup_mutex:
	n = pthread_mutex_destroy(&mutex);
	if (n != 0)
		fatal("pthread_mutex_destroy: %s", strerror(n));

 cleanup:
	/* Free everything. */
	if (threads) free(threads);
	if (state) free(state);

	return (fd);
}

#if TESTING
int
main(int argc, char **argv) {
	int fd, n;
	struct timeval then, now;
	struct addrinfo hints, *res0;
	const char *hostname, *servname;

	hostname  = "localhost";
	if (argv[1])
		hostname = argv[1];
	servname = "http";

	/*
	 * Not all getaddrinfo() implementations support AI_ADDRCONFIG
	 * even if it is defined.  Retry without it on EAI_BADFLAGS.
	 */
	memset(&hints, 0, sizeof(hints));
	hints.ai_family = PF_UNSPEC;
	hints.ai_socktype = SOCK_STREAM;
	hints.ai_protocol = IPPROTO_TCP;
#ifdef AI_ADDRCONFIG
	hints.ai_flags = AI_ADDRCONFIG;
#endif

#ifdef AI_ADDRCONFIG
 again:
#endif
	n = getaddrinfo(hostname, servname, &hints, &res0);
	if (n != 0) {
#ifdef AI_ADDRCONFIG
		if (n == EAI_BADFLAGS && hints.ai_flags & AI_ADDRCONFIG) {
			hints.ai_flags &= ~AI_ADDRCONFIG;
			goto again;
		}
#endif
		fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(n));
		exit(1);
	}

	gettimeofday(&then, NULL);
	fd = connect_to_host(res0);
	gettimeofday(&now, NULL);
	freeaddrinfo(res0);
	now.tv_sec -= then.tv_sec;
	now.tv_usec -= then.tv_usec;
	while (now.tv_sec > 0) {
		now.tv_usec += 1000000;
		now.tv_sec -= 1;
	}
	fprintf(stderr, "connect_to_host(%s) -> %d in %d ms\n", hostname, fd,
		(int)now.tv_usec/1000);
	close(fd);
	exit(0);
}
#endif
