/*
 * pipe it, pipe out !
 * 
 * tranforms a new connection into a socket
 *
 * ./pipe -u -s /tmp/my.socket -h <host> <port>
 *
 * -h : hostname.
 * -s : unix socket (default: /tmp/socket)
 * -u : udp socket
 * -r : bind a network socket port to a unix socket
 *
 */
/*
 * ./pipe -s /tmp/paf -h localhost 22
 * ** connects to localhost:22 when something opens /tmp/paf
 *
 * ./pipe -r -s /tmp/paf 5050
 * ** connects to /tmp/paf when some one connects to localhost:5050/tcp
 */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>

#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/un.h>

#include <netdb.h>

//#define DEBUG
#define DEFAULT_SOCKET 	"/tmp/socket"
#define DEFAULT_HOST	"localhost"

int
Usage(char *name)
{
	printf("pipe it, pipe out !\n");
	printf("%s -u -s /tmp/my.socket -h <host> <port>\n", name);
	printf("-h : remote host (default: localhost)\n");
	printf("-s : unix socket (default: /tmp/socket)\n");
	printf("-u : udp socket (default: off, using tcp socket\n");
	printf("-r : reverse mode: unix socket is read, "
	       "and bind a local port\n");
	return 0;
}

struct in_addr
resolv(char *hostname)
{
	struct in_addr in;
	struct hostent *hp;

	if ((in.s_addr = inet_addr(hostname)) == -1) {
		if ((hp = gethostbyname(hostname)))
			bcopy(hp->h_addr, &in.s_addr, hp->h_length);
		else {
			printf("Can't resolv hostname ?!\n");
			exit(-1);
		}
	}

	return in;
}


int
main(int argc, char **argv)
{
	char *nsocket = NULL;
	char *host = NULL;
#define SIZEBUF	2048
	char buffer[SIZEBUF];
	unsigned short port;
#define TCP_MODE 1
#define UDP_MODE 2
	int mode = TCP_MODE;
	int smode = 0;
	int ret, slen;

	int fds = -1, fdn = -1;
	int rfds = -1, rfdn = -1;
	struct sockaddr_in sai;
	struct sockaddr_un sua; 

	int ch;
	char *p;

	fd_set ens, save_ens;

	if(argc <= 2) {
		Usage(*argv);
		exit(-1);
	}

	while ((ch = getopt (argc, argv, "h:rs:u")) != -1) {
		p = argv[optind - 1];
		switch(ch) {
		case 'h':
			host = strdup(p);
			break;
		case 'r':
			smode = 1;
			break;
		case 's':
			nsocket = strdup(p);
			break;
		case 'u':
			mode = UDP_MODE;
			break;
		}
	}

	if (nsocket == NULL)
		nsocket = strdup(DEFAULT_SOCKET);

	if (host == NULL)
		host = strdup(DEFAULT_HOST);

	if ((argc - optind) != 1) {
		Usage(*argv);
		exit(-1);
	}

	port = strtol(argv[argc - 1], NULL, 10);

	if (smode == 1)
		printf("%s -> %s:%d\n", nsocket, host, port);
	else
		printf("localhost:%d -> %s\n", port, nsocket);

	if ((fds = socket(PF_LOCAL, SOCK_STREAM, 0)) == -1) {
		perror("unix socket");
		goto fin;
	}

	sua.sun_family = PF_LOCAL;
	strcpy(sua.sun_path, nsocket);
	if (smode == 0) {
		unlink(nsocket);
		ret = bind(fds, (struct sockaddr *)&sua, sizeof(sua));
		if (ret == -1) {
			perror("bind/unix socket");
			goto fin;
		}	
		ret = listen(fds, 128);
		if (ret == -1) {
			perror("listen/unix socket");
			goto fin;
		}
		slen = sizeof(sua);
		rfds = accept(fds, (struct sockaddr *)&sua, &slen);
		if (rfds == -1) {
			perror("accept/unix socket");
			goto fin;
		}
#ifdef DEBUG
		else {
			printf("unix socket connection !!\n");
		}
#endif
		/* ok, i go connecting to network ! */
	}

	sai.sin_family = PF_INET;
	sai.sin_addr = resolv(host);
	sai.sin_port = htons(port);

	if (mode == TCP_MODE)
		fdn = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
	else
		fdn = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);

	if (fdn == -1) {
		perror("net socket");
		goto fin;
	}

	if (smode == 1) {
		sai.sin_addr.s_addr = INADDR_ANY;
		/* bind tcp socket */
		ret = bind(fdn, (struct sockaddr*)&sai, 
		           sizeof(struct sockaddr));
		if (ret == -1) {
			perror("bind/net socket");
			goto fin;
		}
		ret = listen(fdn, 128);
		if (ret == -1){
			perror("listen/net socket");
			goto fin;
		}

		slen = sizeof(struct sockaddr);
		rfdn = accept(fdn, (struct sockaddr*)&sai, &slen);
		if (rfdn == -1) {
			perror("accept/net socket");
			goto fin;
		}
#ifdef DEBUG
		else {
			printf("net connection !!!\n");
		}
#endif
		/* ok, only now i'll open socket connection. */

		ret = connect(fds, (struct sockaddr *)&sua, sizeof(sua));
		if (ret == -1) {
			perror("connect/unix socket");
			goto fin;
		}
	} else {
		ret = connect(fdn, (struct sockaddr*)&sai, 
		              sizeof (struct sockaddr));
		if (ret == -1) {
			perror("connect/net socket");
			goto fin;
		}
	}

#ifdef DEBUG
	printf("fdn: %d, rfdn: %d, fds: %d, rfds: %d\n", fdn, rfdn, fds, rfds);
#endif
	FD_ZERO(&ens);
	if (rfds != -1) {
		close(fds);
		fds = rfds;
	}
	FD_SET(fds, &ens);

	if (rfdn != -1) {
		close(fdn);
		fdn = rfdn;
	}
	FD_SET(fdn, &ens);

#ifdef DEBUG
	printf("fds: %d, fdn: %d (max: %d)\n", fds, fdn, 1+((fdn>fds)?fdn:fds));
#endif

	save_ens = ens;

	while((ret = select(1+((fdn>fds)?fdn:fds), &ens, NULL, NULL,0)) != -1){
#ifdef DEBUG
		printf("select returned %d\n", ret);
#endif

		memset(buffer, 0, sizeof(buffer));
		if(FD_ISSET(fds, &ens)){
			if ((ret = read(fds, buffer, sizeof(buffer))) <= 0) {
				/* hugh ! */
				perror("read unix socket");
				goto fin;
			}
#ifdef DEBUG
			else {
				printf("got from unix socket: %d bytes\n", ret);
			}
#endif
			if ((ret = write(fdn, buffer, ret)) == -1) {
				perror("write network socket");
				goto fin;
			}
#ifdef DEBUG
			else {
				printf("write to network: %d bytes\n", ret);
			}
#endif
		}

		if(FD_ISSET(fdn, &ens)){
			if ((ret = read(fdn, buffer, sizeof(buffer))) <= 0){ 
				perror("read network socket");
				goto fin;
			}
#ifdef DEBUG
			else {
				printf("got from network socket: %d bytes.\n", ret);
			}
#endif
			if ((ret = write(fds, buffer, ret)) == -1) {
				perror("write unix socket");
				goto fin;
			}
#ifdef DEBUG
			else {
				printf("write to unix: (%d) %d bytes\n", fds, ret);
			}
#endif
		}
		ens = save_ens;
	}
	
fin:
	if (fds != -1)
		close(fds);
	if (fdn != -1)
		close(fdn);

	free(nsocket);
	free(host);

	return 0;
}

