#include <sys/types.h>
#include <sys/time.h>
#include <sys/socket.h>
#include <sys/signal.h>
#include <sys/syslog.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/ioctl.h>
#include <netdb.h>
#include <stdio.h>
#include <ctype.h>
#include "socks.h"

#ifndef ntohb
#define ntohb(x) x
#endif
#ifndef htonb
#define htonb(x) x
#endif

#define STREQ(a, b)	(strcasecmp(a, b) == 0)

#ifdef sun
struct in_addr ipua;
#endif /* #ifdef sun */

/*
**  Current version for response messages
*/
int		Version = 0;

die()
{
	exit(1);
}

main(argc, argv)
int	argc;
char	**argv;
{
	char			buf[40], c;
	int			inp, in, out, index=0;
	int			n, len = sizeof(struct sockaddr_in);
	struct sockaddr_in	sin, from, dstsin;
	int			fromlen = sizeof(struct sockaddr_in);
	Socks_t			dst;
	int			one = 1;
	char			buffer[20];

	if (openlog("sockd", LOG_PID, LOG_DAEMON) < 0)
		;

#ifdef INETD
	inp  = socket(AF_INET, SOCK_STREAM, 0);
	sin.sin_family = AF_INET;
	sin.sin_port = htons(GetSocksPort());
	sin.sin_addr.s_addr = htonl(INADDR_ANY);

	if (bind(inp, &sin, sizeof(sin)) < 0) {
		syslog(LOG_ERR, "bind %m");
		exit(1);
	}

	if (listen(inp, 5) < 0) {
		syslog(LOG_ERR, "listen %m");
		exit(1);
	}

	if ((in = accept(inp, &sin, &len)) < 0) {
		syslog(LOG_ERR, "accept %m");
		exit(1);
	}
#else
	in = dup(0);
#endif

	if (getpeername(in, &from, &fromlen) < 0) {
		syslog(LOG_ERR, "Unable to get remote address");
		exit(1);
	}

	GetDst(in, &dst);

	dstsin.sin_family      = AF_INET;
	dstsin.sin_addr.s_addr = dst.host;
	dstsin.sin_port        = dst.port;
	if (!Validate(&from, &dstsin)) {
#ifdef sun
		syslog(LOG_CRIT, "Invalid connection attempted from %s",
				inet_ntoa(from.sin_addr));
#else
		syslog(LOG_CRIT, "Invalid connection attempted from %s",
				inet_ntoa(from.sin_addr.s_addr));
#endif /*#ifdef sun */
		exit(1);
	}
	
	/*
	**  Kill a connecting off if bind or connect takes to
	**    long to complete
	*/
	signal(SIGALRM, die);
	alarm(60*5);		/* 5 minutes */

	switch (Version = dst.version) {
#ifdef SUPPORT_VERSION_1
	case 1:
		if (dst.cmd == SOCKS_CONNECT) 
			DoConnect(in, &dst);
		if (dst.cmd == SOCKS_BIND) 
			DoBind(in, &dst);
		break;
#endif
#ifdef SUPPORT_VERSION_2
	case 2:
		if (dst.cmd == SOCKS_CONNECT) 
			DoConnect(in, &dst);
		if (dst.cmd == SOCKS_BIND) 
			DoNewBind(in, &dst);
		break;
#endif
	case 3:
		while (read(in, &c, 1) == 1)
			if (c == '\0')
				break;
			else if (index < sizeof(buf) - 1)
				buf[index++] = c;
		buf[index] = '\0';
#ifdef sun
		strcpy(buffer,inet_ntoa(from.sin_addr));
		ipua.s_addr = dst.host;
		syslog(LOG_INFO, "%s from %s by %s to %s", 
				dst.cmd == SOCKS_CONNECT ? "Connect" : "Bind",
				buffer, buf, inet_ntoa(ipua));
#else
		strcpy(buffer,inet_ntoa(from.sin_addr.s_addr));
		syslog(LOG_INFO, "%s from %s by %s to %s", 
				dst.cmd == SOCKS_CONNECT ? "Connect" : "Bind",
				buffer, buf, inet_ntoa(dst.host));
#endif /* #ifdef sun */

		if (dst.cmd == SOCKS_CONNECT) 
			DoConnect(in, &dst);
		if (dst.cmd == SOCKS_BIND) 
			DoNewBind(in, &dst);
		break;
	default:
		syslog(LOG_ERR, "Version missmatch got %d\n", dst.version);
		exit(1);
	}
}

int SendDst(s, dst)
int     s;
Socks_t *dst;
{
	Write8(s,  dst->version);
	Write8(s,  dst->cmd);
	Write32(s, dst->port);
	Write32(s, dst->host);
}

int GetDst(s, dst)
int     s;
Socks_t *dst;
{
	Read8(s,  dst->version);
	Read8(s,  dst->cmd);
	Read32(s, dst->port);
	Read32(s, dst->host);
}

/*
** Actually connect a socket to the outside world,
*/
DoConnect(in, dst)
int	in;
Socks_t	*dst;
{
	int			out;
	struct sockaddr_in	sin;
	Socks_t			ndst;

	if ((out = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		syslog(LOG_ERR, "out-socket %m\n");
		exit(1);
	}

	sin.sin_family = AF_INET;
	sin.sin_port = dst->port;
	sin.sin_addr.s_addr = dst->host;

	ndst.version = Version;
	ndst.cmd = SOCKS_RESULT;

	if (connect(out, &sin, sizeof(struct sockaddr_in)) < 0) {
		syslog(LOG_ERR, "DoConnect connect %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);		exit(1);
	}

	SendDst(in, &ndst);

	Pump(in, out);
}

#ifdef SUPPORT_VERSION_1
/*
**  Set up a socket to be connected to from the outside world.
*/
DoBind(in, dst)
int	in;
Socks_t	*dst;
{
	int			new, out, len = sizeof(struct sockaddr_in);
	struct sockaddr_in	sin;
	Socks_t			ndst;

	if ((out = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		syslog(LOG_ERR, "out-socket %m\n");
		exit(1);
	}

	sin.sin_family = AF_INET;
	ndst.version = Version;
	ndst.cmd  = SOCKS_RESULT;
	sin.sin_port = 0;
	sin.sin_addr.s_addr = htonl(INADDR_ANY);

	if (bind(out, &sin, sizeof(sin)) < 0) {
		syslog(LOG_ERR, "DoBind bind %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}
	if (getsockname(out, &sin, &len) < 0) {
		syslog(LOG_ERR, "DoBind getsockname %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	ndst.port = sin.sin_port;
	ndst.host = sin.sin_addr.s_addr;

	if (listen(out, 1) < 0) {
		syslog(LOG_ERR, "DoBind listen %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	SendDst(in, &ndst);

	len = sizeof(struct sockaddr_in);
	if ((new = accept(out, &sin, &len)) < 0) {
		syslog(LOG_ERR, "DoBind accept %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	ndst.port = sin.sin_port;
	ndst.host = sin.sin_addr.s_addr;
	SendDst(in, &ndst);

	Pump(in, new);
}
#endif /* SUPPORT_VERSION_1 */

/*
**  Set up a socket to be connected to from the outside world.
**   diffrence between this an the Version1 protocal is that
**   the socket has to be bound from a specific host that
**   is passed.
*/
DoNewBind(in, dst)
int	in;
Socks_t	*dst;
{
	int			new, out, len = sizeof(struct sockaddr_in);
	struct sockaddr_in	sin;
	Socks_t			ndst;

	if ((out = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		syslog(LOG_ERR, "out-socket %m\n");
		exit(1);
	}

	sin.sin_family = AF_INET;
	ndst.version = Version;
	ndst.cmd  = SOCKS_RESULT;
	sin.sin_port = 0;
	sin.sin_addr.s_addr = htonl(INADDR_ANY);

	if (bind(out, &sin, sizeof(sin)) < 0) {
		syslog(LOG_ERR, "DoBind bind %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}
	if (getsockname(out, &sin, &len) < 0) {
		syslog(LOG_ERR, "DoBind getsockname %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	ndst.port = sin.sin_port;
	ndst.host = sin.sin_addr.s_addr;

	if (listen(out, 1) < 0) {
		syslog(LOG_ERR, "DoBind listen %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	SendDst(in, &ndst);

	len = sizeof(struct sockaddr_in);
	if ((new = accept(out, &sin, &len)) < 0) {
		syslog(LOG_ERR, "DoBind accept %m\n");
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	if (sin.sin_addr.s_addr != dst->host) {
		syslog(LOG_ERR, "DoBind Incorrect host 0x%08x 0x%08x  -- %m\n",
				sin.sin_addr.s_addr, dst->host);
		ndst.cmd = SOCKS_FAIL;
		SendDst(in, &ndst);	exit(1);
	}

	ndst.port = sin.sin_port;
	ndst.host = sin.sin_addr.s_addr;
	SendDst(in, &ndst);

	Pump(in, new);
}

/*
**  Now just pump the packets/character through..
*/
Pump(in, out)
int	in, out;
{
	static char		buf[4096];
	fd_set			fds;
	int			n;
	int			nfds = getdtablesize();
	static struct timeval	tout = { 60*60*2, 0 };

	alarm(0);

	FD_ZERO(&fds);

	while (1) {
		tout.tv_sec = SOCKS_TIMEOUT;
		tout.tv_usec = 0;
		FD_SET(in, &fds);
		FD_SET(out, &fds);
		if ((n = select(nfds, &fds, NULL, NULL, &tout)) > 0) {
			if (FD_ISSET(in, &fds)) {
				if ((n = read(in, buf, sizeof buf)) > 0) {
					if (write(out, buf, n) < 0) {
						goto bad;
					}
				} else {
					goto bad;
				}
			}
			if (FD_ISSET(out, &fds)) {
				if ((n = read(out, buf, sizeof buf)) > 0) {
					if (write(in, buf, n) < 0) {
						goto bad;
					}
				} else {
					goto bad;
				}
			}
		} else {
			if (n != 0)
				syslog(LOG_ERR, "select %m\n");
			goto bad;
		}
	}

bad:
	;	/* Make the goto happy */
}

/*
**  Simple 'mkargs' doesn't handle \, ", or '.
*/
void mkargs(cp, argc, argv, max)
char	*cp;
int	*argc;
char	*argv[];
int	max;
{
	*argc = 0;
	while (isspace(*cp))
		cp++;

	while (*cp != '\0') {
		argv[(*argc)++] = cp;
		if (*argc >= max)
			return;

		while (!isspace(*cp) && (*cp != '\0'))
			cp++;
		while (isspace(*cp))
			*cp++ = '\0';
	}
}

/* 
**  Get address, either numeric or dotted quad, or hex.
*/
int GetAddr(name, addr)
char		*name;
unsigned long	*addr;
{
	struct hostent	*hp;
	struct netent	*np;

	if ((hp = gethostbyname(name)) != NULL) {
		bcopy(hp->h_addr_list[0], addr, sizeof(*addr));
		return *addr;
	}
	if ((np = getnetbyname(name)) != NULL) {
		bcopy(&np->n_net, addr, sizeof(*addr));
		return *addr;
	}
	return *addr = inet_addr(name);
}

int GetPort(name)
char		*name;
{
	struct servent	*sp;

	if ((sp = getservbyname(name, "tcp")) != NULL) {
		return sp->s_port;
	}
	if (!isdigit(*name))
		return -1;
	return atoi(name);
}

Validate(src, dst)
struct sockaddr_in	*src, *dst;
{
	FILE		*fd;
	static char	buf[1024];
	char		*bp;
	int		linenum = 0, permit, pos;
	char		*argv[10];
	int		argc, p;
	unsigned long	saddr, smask, daddr, dmask;
	unsigned short	dport;
	enum 		{ e_lt, e_gt, e_eq, e_neq, e_nil } tst;

#ifdef DEBUG
	syslog(LOG_ERR,"SRC: 0x%08x DST: 0x%08x",
			src->sin_addr.s_addr, dst->sin_addr.s_addr);
#endif

	if ((fd = fopen(SOCKS_CONF, "r")) == NULL) {
		syslog(LOG_ERR, "Unable to open config file (%s)", SOCKS_CONF);

		return 0;
	}

	while (fgets(buf, sizeof(buf) - 1, fd) != NULL) {
		linenum++;
		/*
		**  Comments start with a '#' anywhere on the line
		*/
		for (bp = buf; *bp != '\0'; bp++) {
			if (*bp == '#') {
				*bp = '\0';
				break;
			}
			if ((*bp == '\t') || (*bp == '\t'))
				*bp = ' ';
		}

		mkargs(buf, &argc, argv, 8);
		if (argc == 0)
			continue;
		if ((argc < 3) || (argc > 7)) {
			syslog(LOG_ERR, "Invalid entry at line %d", linenum);
			continue;
		}

		if (STREQ(argv[0], "permit")) {
			permit = 1;
		} else if (STREQ(argv[0], "deny")) {
			permit = 0;
		} else {
			syslog(LOG_ERR, "Invalid permit/deny field at line %d", linenum);
			continue;
		}

		pos = 3;

		GetAddr(argv[1], &saddr);
		GetAddr(argv[2], &smask);

		if ((saddr & smask) != 0) {
			syslog(LOG_ERR, "Inavlid source address and mask pair at line %d", linenum);
			continue;
		}

		if ((argc > 4) &&
			!(STREQ(argv[3], "eq") || STREQ(argv[3], "neq") ||
			  STREQ(argv[3], "lt") || STREQ(argv[3], "gt"))) {
			GetAddr(argv[3], &daddr);
			GetAddr(argv[4], &dmask);

			if ((daddr & dmask) != 0) {
				syslog(LOG_ERR, "Inavlid source address and mask pair at line %d", linenum);
				continue;
			}
			pos = 5;
		} else {
			daddr = 0;
			dmask = ~0;
			pos = 3;
		}
		if (argc > pos + 1) {
			if (STREQ(argv[pos], "eq"))
				tst = e_eq;
			else if (STREQ(argv[pos], "neq"))
				tst = e_neq;
			else if (STREQ(argv[pos], "lt"))
				tst = e_lt;
			else if (STREQ(argv[pos], "gt"))
				tst = e_gt;
			else {
				syslog(LOG_ERR, "Invalid comparison at line %d", linenum);
				continue;
			}
				
			if ((p = GetPort(argv[pos+1])) == -1) {
				syslog(LOG_ERR, "Invalid port number at line %d", linenum);
				continue;
			} else {
				dport = (unsigned short)p;
			}
		} else {
			tst = e_nil;
			dport = 0;
		}

#ifdef DEBUG
		{
			char msg[1024];
			sprintf(msg,"%s 0x%08x 0x%08x 0x%08x 0x%08x %s %d",
				permit ? "permit" : "deny",
				saddr, smask, daddr, dmask,
					tst == e_eq ? "==" :
					tst == e_neq ? "!=" :
					tst == e_lt ? "<" :
					tst == e_lt ? ">" : "NIL",
					dport);
			syslog(LOG_ERR, "%s", msg);
		}
#endif

		if ((saddr & ~smask) == (src->sin_addr.s_addr & ~smask) &&
		    (daddr & ~dmask) == (dst->sin_addr.s_addr & ~dmask)) {
			switch (tst) {
			case e_eq:
				if (dport == dst->sin_port) {
					fclose(fd);
					return permit;
				}
				break;
			case e_neq:
				if (dport != dst->sin_port) {
					fclose(fd);
					return permit;
				}
				break;
			case e_lt:
				if (dport < dst->sin_port) {
					fclose(fd);
					return permit;
				}
				break;
			case e_gt:
				if (dport > dst->sin_port) {
					fclose(fd);
					return permit;
				}
				break;
			case e_nil:
				fclose(fd);
				return permit;
			}
		}
	}

	fclose(fd);
	return 0;
}
