summaryrefslogtreecommitdiffstats
path: root/testing/mochitest/ssltunnel/ssltunnel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'testing/mochitest/ssltunnel/ssltunnel.cpp')
-rw-r--r--testing/mochitest/ssltunnel/ssltunnel.cpp1635
1 files changed, 1635 insertions, 0 deletions
diff --git a/testing/mochitest/ssltunnel/ssltunnel.cpp b/testing/mochitest/ssltunnel/ssltunnel.cpp
new file mode 100644
index 000000000..a80fe624d
--- /dev/null
+++ b/testing/mochitest/ssltunnel/ssltunnel.cpp
@@ -0,0 +1,1635 @@
+/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this
+ * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/*
+ * WARNING: DO NOT USE THIS CODE IN PRODUCTION SYSTEMS. It is highly likely to
+ * be plagued with the usual problems endemic to C (buffer overflows
+ * and the like). We don't especially care here (but would accept
+ * patches!) because this is only intended for use in our test
+ * harnesses in controlled situations where input is guaranteed not to
+ * be malicious.
+ */
+
+#include "ScopedNSSTypes.h"
+#include <assert.h>
+#include <stdio.h>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <stdarg.h>
+#include "prinit.h"
+#include "prerror.h"
+#include "prenv.h"
+#include "prnetdb.h"
+#include "prtpool.h"
+#include "nsAlgorithm.h"
+#include "nss.h"
+#include "key.h"
+#include "ssl.h"
+#include "sslproto.h"
+#include "plhash.h"
+#include "mozilla/Sprintf.h"
+
+using namespace mozilla;
+using namespace mozilla::psm;
+using std::string;
+using std::vector;
+
+#define IS_DELIM(m, c) ((m)[(c) >> 3] & (1 << ((c) & 7)))
+#define SET_DELIM(m, c) ((m)[(c) >> 3] |= (1 << ((c) & 7)))
+#define DELIM_TABLE_SIZE 32
+
+// You can set the level of logging by env var SSLTUNNEL_LOG_LEVEL=n, where n
+// is 0 through 3. The default is 1, INFO level logging.
+enum LogLevel {
+ LEVEL_DEBUG = 0,
+ LEVEL_INFO = 1,
+ LEVEL_ERROR = 2,
+ LEVEL_SILENT = 3
+} gLogLevel, gLastLogLevel;
+
+#define _LOG_OUTPUT(level, func, params) \
+PR_BEGIN_MACRO \
+ if (level >= gLogLevel) { \
+ gLastLogLevel = level; \
+ func params;\
+ } \
+PR_END_MACRO
+
+// The most verbose output
+#define LOG_DEBUG(params) \
+ _LOG_OUTPUT(LEVEL_DEBUG, printf, params)
+
+// Top level informative messages
+#define LOG_INFO(params) \
+ _LOG_OUTPUT(LEVEL_INFO, printf, params)
+
+// Serious errors that must be logged always until completely gag
+#define LOG_ERROR(params) \
+ _LOG_OUTPUT(LEVEL_ERROR, eprintf, params)
+
+// Same as LOG_ERROR, but when logging is set to LEVEL_DEBUG, the message
+// will be put to the stdout instead of stderr to keep continuity with other
+// LOG_DEBUG message output
+#define LOG_ERRORD(params) \
+PR_BEGIN_MACRO \
+ if (gLogLevel == LEVEL_DEBUG) \
+ _LOG_OUTPUT(LEVEL_ERROR, printf, params); \
+ else \
+ _LOG_OUTPUT(LEVEL_ERROR, eprintf, params); \
+PR_END_MACRO
+
+// If there is any output written between LOG_BEGIN_BLOCK() and
+// LOG_END_BLOCK() then a new line will be put to the proper output (out/err)
+#define LOG_BEGIN_BLOCK() \
+ gLastLogLevel = LEVEL_SILENT;
+
+#define LOG_END_BLOCK() \
+PR_BEGIN_MACRO \
+ if (gLastLogLevel == LEVEL_ERROR) \
+ LOG_ERROR(("\n")); \
+ if (gLastLogLevel < LEVEL_ERROR) \
+ _LOG_OUTPUT(gLastLogLevel, printf, ("\n")); \
+PR_END_MACRO
+
+int eprintf(const char* str, ...)
+{
+ va_list ap;
+ va_start(ap, str);
+ int result = vfprintf(stderr, str, ap);
+ va_end(ap);
+ return result;
+}
+
+// Copied from nsCRT
+char* strtok2(char* string, const char* delims, char* *newStr)
+{
+ PR_ASSERT(string);
+
+ char delimTable[DELIM_TABLE_SIZE];
+ uint32_t i;
+ char* result;
+ char* str = string;
+
+ for (i = 0; i < DELIM_TABLE_SIZE; i++)
+ delimTable[i] = '\0';
+
+ for (i = 0; delims[i]; i++) {
+ SET_DELIM(delimTable, static_cast<uint8_t>(delims[i]));
+ }
+
+ // skip to beginning
+ while (*str && IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
+ str++;
+ }
+ result = str;
+
+ // fix up the end of the token
+ while (*str) {
+ if (IS_DELIM(delimTable, static_cast<uint8_t>(*str))) {
+ *str++ = '\0';
+ break;
+ }
+ str++;
+ }
+ *newStr = str;
+
+ return str == result ? nullptr : result;
+}
+
+
+
+enum client_auth_option {
+ caNone = 0,
+ caRequire = 1,
+ caRequest = 2
+};
+
+// Structs for passing data into jobs on the thread pool
+typedef struct {
+ int32_t listen_port;
+ string cert_nickname;
+ PLHashTable* host_cert_table;
+ PLHashTable* host_clientauth_table;
+ PLHashTable* host_redir_table;
+ PLHashTable* host_ssl3_table;
+ PLHashTable* host_tls1_table;
+ PLHashTable* host_rc4_table;
+ PLHashTable* host_failhandshake_table;
+} server_info_t;
+
+typedef struct {
+ PRFileDesc* client_sock;
+ PRNetAddr client_addr;
+ server_info_t* server_info;
+ // the original host in the Host: header for this connection is
+ // stored here, for proxied connections
+ string original_host;
+ // true if no SSL should be used for this connection
+ bool http_proxy_only;
+ // true if this connection is for a WebSocket
+ bool iswebsocket;
+} connection_info_t;
+
+typedef struct {
+ string fullHost;
+ bool matched;
+} server_match_t;
+
+const int32_t BUF_SIZE = 16384;
+const int32_t BUF_MARGIN = 1024;
+const int32_t BUF_TOTAL = BUF_SIZE + BUF_MARGIN;
+
+struct relayBuffer
+{
+ char *buffer, *bufferhead, *buffertail, *bufferend;
+
+ relayBuffer()
+ {
+ // Leave 1024 bytes more for request line manipulations
+ bufferhead = buffertail = buffer = new char[BUF_TOTAL];
+ bufferend = buffer + BUF_SIZE;
+ }
+
+ ~relayBuffer()
+ {
+ delete [] buffer;
+ }
+
+ void compact() {
+ if (buffertail == bufferhead)
+ buffertail = bufferhead = buffer;
+ }
+
+ bool empty() { return bufferhead == buffertail; }
+ size_t areafree() { return bufferend - buffertail; }
+ size_t margin() { return areafree() + BUF_MARGIN; }
+ size_t present() { return buffertail - bufferhead; }
+};
+
+// These numbers are multiplied by the number of listening ports (actual
+// servers running). According the thread pool implementation there is no
+// need to limit the number of threads initially, threads are allocated
+// dynamically and stored in a linked list. Initial number of 2 is chosen
+// to allocate a thread for socket accept and preallocate one for the first
+// connection that is with high probability expected to come.
+const uint32_t INITIAL_THREADS = 2;
+const uint32_t MAX_THREADS = 100;
+const uint32_t DEFAULT_STACKSIZE = (512 * 1024);
+
+// global data
+string nssconfigdir;
+vector<server_info_t> servers;
+PRNetAddr remote_addr;
+PRNetAddr websocket_server;
+PRThreadPool* threads = nullptr;
+PRLock* shutdown_lock = nullptr;
+PRCondVar* shutdown_condvar = nullptr;
+// Not really used, unless something fails to start
+bool shutdown_server = false;
+bool do_http_proxy = false;
+bool any_host_spec_config = false;
+
+int ClientAuthValueComparator(const void *v1, const void *v2)
+{
+ int a = *static_cast<const client_auth_option*>(v1) -
+ *static_cast<const client_auth_option*>(v2);
+ if (a == 0)
+ return 0;
+ if (a > 0)
+ return 1;
+ else // (a < 0)
+ return -1;
+}
+
+static int match_hostname(PLHashEntry *he, int index, void* arg)
+{
+ server_match_t *match = (server_match_t*)arg;
+ if (match->fullHost.find((char*)he->key) != string::npos)
+ match->matched = true;
+ return HT_ENUMERATE_NEXT;
+}
+
+/*
+ * Signal the main thread that the application should shut down.
+ */
+void SignalShutdown()
+{
+ PR_Lock(shutdown_lock);
+ PR_NotifyCondVar(shutdown_condvar);
+ PR_Unlock(shutdown_lock);
+}
+
+// available flags
+enum {
+ USE_SSL3 = 1 << 0,
+ USE_RC4 = 1 << 1,
+ FAIL_HANDSHAKE = 1 << 2,
+ USE_TLS1 = 1 << 4
+};
+
+bool ReadConnectRequest(server_info_t* server_info,
+ relayBuffer& buffer, int32_t* result, string& certificate,
+ client_auth_option* clientauth, string& host, string& location,
+ int32_t* flags)
+{
+ if (buffer.present() < 4) {
+ LOG_DEBUG((" !! only %d bytes present in the buffer", (int)buffer.present()));
+ return false;
+ }
+ if (strncmp(buffer.buffertail-4, "\r\n\r\n", 4)) {
+ LOG_ERRORD((" !! request is not tailed with CRLFCRLF but with %x %x %x %x",
+ *(buffer.buffertail-4),
+ *(buffer.buffertail-3),
+ *(buffer.buffertail-2),
+ *(buffer.buffertail-1)));
+ return false;
+ }
+
+ LOG_DEBUG((" parsing initial connect request, dump:\n%.*s\n", (int)buffer.present(), buffer.bufferhead));
+
+ *result = 400;
+
+ char* token;
+ char* _caret;
+ token = strtok2(buffer.bufferhead, " ", &_caret);
+ if (!token) {
+ LOG_ERRORD((" no space found"));
+ return true;
+ }
+ if (strcmp(token, "CONNECT")) {
+ LOG_ERRORD((" not CONNECT request but %s", token));
+ return true;
+ }
+
+ token = strtok2(_caret, " ", &_caret);
+ void* c = PL_HashTableLookup(server_info->host_cert_table, token);
+ if (c)
+ certificate = static_cast<char*>(c);
+
+ host = "https://";
+ host += token;
+
+ c = PL_HashTableLookup(server_info->host_clientauth_table, token);
+ if (c)
+ *clientauth = *static_cast<client_auth_option*>(c);
+ else
+ *clientauth = caNone;
+
+ void *redir = PL_HashTableLookup(server_info->host_redir_table, token);
+ if (redir)
+ location = static_cast<char*>(redir);
+
+ if (PL_HashTableLookup(server_info->host_ssl3_table, token)) {
+ *flags |= USE_SSL3;
+ }
+
+ if (PL_HashTableLookup(server_info->host_rc4_table, token)) {
+ *flags |= USE_RC4;
+ }
+
+ if (PL_HashTableLookup(server_info->host_tls1_table, token)) {
+ *flags |= USE_TLS1;
+ }
+
+ if (PL_HashTableLookup(server_info->host_failhandshake_table, token)) {
+ *flags |= FAIL_HANDSHAKE;
+ }
+
+ token = strtok2(_caret, "/", &_caret);
+ if (strcmp(token, "HTTP")) {
+ LOG_ERRORD((" not tailed with HTTP but with %s", token));
+ return true;
+ }
+
+ *result = (redir) ? 302 : 200;
+ return true;
+}
+
+bool ConfigureSSLServerSocket(PRFileDesc* socket, server_info_t* si, const string &certificate,
+ const client_auth_option clientAuth, int32_t flags)
+{
+ const char* certnick = certificate.empty() ?
+ si->cert_nickname.c_str() : certificate.c_str();
+
+ UniqueCERTCertificate cert(PK11_FindCertFromNickname(certnick, nullptr));
+ if (!cert) {
+ LOG_ERROR(("Failed to find cert %s\n", certnick));
+ return false;
+ }
+
+ UniqueSECKEYPrivateKey privKey(PK11_FindKeyByAnyCert(cert.get(), nullptr));
+ if (!privKey) {
+ LOG_ERROR(("Failed to find private key\n"));
+ return false;
+ }
+
+ PRFileDesc* ssl_socket = SSL_ImportFD(nullptr, socket);
+ if (!ssl_socket) {
+ LOG_ERROR(("Error importing SSL socket\n"));
+ return false;
+ }
+
+ if (flags & FAIL_HANDSHAKE) {
+ // deliberately cause handshake to fail by sending the client a client hello
+ SSL_ResetHandshake(ssl_socket, false);
+ return true;
+ }
+
+ SSLKEAType certKEA = NSS_FindCertKEAType(cert.get());
+ if (SSL_ConfigSecureServer(ssl_socket, cert.get(), privKey.get(), certKEA)
+ != SECSuccess) {
+ LOG_ERROR(("Error configuring SSL server socket\n"));
+ return false;
+ }
+
+ SSL_OptionSet(ssl_socket, SSL_SECURITY, true);
+ SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_CLIENT, false);
+ SSL_OptionSet(ssl_socket, SSL_HANDSHAKE_AS_SERVER, true);
+
+ if (clientAuth != caNone)
+ {
+ SSL_OptionSet(ssl_socket, SSL_REQUEST_CERTIFICATE, true);
+ SSL_OptionSet(ssl_socket, SSL_REQUIRE_CERTIFICATE, clientAuth == caRequire);
+ }
+
+ if (flags & USE_SSL3) {
+ SSLVersionRange range = { SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_3_0 };
+ SSL_VersionRangeSet(ssl_socket, &range);
+ }
+
+ if (flags & USE_TLS1) {
+ SSLVersionRange range = { SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0 };
+ SSL_VersionRangeSet(ssl_socket, &range);
+ }
+
+ if (flags & USE_RC4) {
+ for (uint16_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ uint16_t cipher_id = SSL_ImplementedCiphers[i];
+ switch (cipher_id) {
+ case TLS_ECDHE_ECDSA_WITH_RC4_128_SHA:
+ case TLS_ECDHE_RSA_WITH_RC4_128_SHA:
+ case TLS_RSA_WITH_RC4_128_SHA:
+ case TLS_RSA_WITH_RC4_128_MD5:
+ SSL_CipherPrefSet(ssl_socket, cipher_id, true);
+ break;
+
+ default:
+ SSL_CipherPrefSet(ssl_socket, cipher_id, false);
+ break;
+ }
+ }
+ }
+
+ SSL_ResetHandshake(ssl_socket, true);
+
+ return true;
+}
+
+/**
+ * This function examines the buffer for a Sec-WebSocket-Location: field,
+ * and if it's present, it replaces the hostname in that field with the
+ * value in the server's original_host field. This function works
+ * in the reverse direction as AdjustWebSocketHost(), replacing the real
+ * hostname of a response with the potentially fake hostname that is expected
+ * by the browser (e.g., mochi.test).
+ *
+ * @return true if the header was adjusted successfully, or not found, false
+ * if the header is present but the url is not, which should indicate
+ * that more data needs to be read from the socket
+ */
+bool AdjustWebSocketLocation(relayBuffer& buffer, connection_info_t *ci)
+{
+ assert(buffer.margin());
+ buffer.buffertail[1] = '\0';
+
+ char* wsloc = strstr(buffer.bufferhead, "Sec-WebSocket-Location:");
+ if (!wsloc)
+ return true;
+ // advance pointer to the start of the hostname
+ wsloc = strstr(wsloc, "ws://");
+ if (!wsloc)
+ return false;
+ wsloc += 5;
+ // find the end of the hostname
+ char* wslocend = strchr(wsloc + 1, '/');
+ if (!wslocend)
+ return false;
+ char *crlf = strstr(wsloc, "\r\n");
+ if (!crlf)
+ return false;
+ if (ci->original_host.empty())
+ return true;
+
+ int diff = ci->original_host.length() - (wslocend-wsloc);
+ if (diff > 0)
+ assert(size_t(diff) <= buffer.margin());
+ memmove(wslocend + diff, wslocend, buffer.buffertail - wsloc - diff);
+ buffer.buffertail += diff;
+
+ memcpy(wsloc, ci->original_host.c_str(), ci->original_host.length());
+ return true;
+}
+
+/**
+ * This function examines the buffer for a Host: field, and if it's present,
+ * it replaces the hostname in that field with the hostname in the server's
+ * remote_addr field. This is needed because proxy requests may be coming
+ * from mochitest with fake hosts, like mochi.test, and these need to be
+ * replaced with the host that the destination server is actually running
+ * on.
+ */
+bool AdjustWebSocketHost(relayBuffer& buffer, connection_info_t *ci)
+{
+ const char HEADER_UPGRADE[] = "Upgrade:";
+ const char HEADER_HOST[] = "Host:";
+
+ PRNetAddr inet_addr = (websocket_server.inet.port ? websocket_server :
+ remote_addr);
+
+ assert(buffer.margin());
+
+ // Cannot use strnchr so add a null char at the end. There is always some
+ // space left because we preserve a margin.
+ buffer.buffertail[1] = '\0';
+
+ // Verify this is a WebSocket header.
+ char* h1 = strstr(buffer.bufferhead, HEADER_UPGRADE);
+ if (!h1)
+ return false;
+ h1 += strlen(HEADER_UPGRADE);
+ h1 += strspn(h1, " \t");
+ char* h2 = strstr(h1, "WebSocket\r\n");
+ if (!h2) h2 = strstr(h1, "websocket\r\n");
+ if (!h2) h2 = strstr(h1, "Websocket\r\n");
+ if (!h2)
+ return false;
+
+ char* host = strstr(buffer.bufferhead, HEADER_HOST);
+ if (!host)
+ return false;
+ // advance pointer to beginning of hostname
+ host += strlen(HEADER_HOST);
+ host += strspn(host, " \t");
+
+ char* endhost = strstr(host, "\r\n");
+ if (!endhost)
+ return false;
+
+ // Save the original host, so we can use it later on responses from the
+ // server.
+ ci->original_host.assign(host, endhost-host);
+
+ char newhost[40];
+ PR_NetAddrToString(&inet_addr, newhost, sizeof(newhost));
+ assert(strlen(newhost) < sizeof(newhost) - 7);
+ SprintfLiteral(newhost, "%s:%d", newhost, PR_ntohs(inet_addr.inet.port));
+
+ int diff = strlen(newhost) - (endhost-host);
+ if (diff > 0)
+ assert(size_t(diff) <= buffer.margin());
+ memmove(endhost + diff, endhost, buffer.buffertail - host - diff);
+ buffer.buffertail += diff;
+
+ memcpy(host, newhost, strlen(newhost));
+ return true;
+}
+
+/**
+ * This function prefixes Request-URI path with a full scheme-host-port
+ * string.
+ */
+bool AdjustRequestURI(relayBuffer& buffer, string *host)
+{
+ assert(buffer.margin());
+
+ // Cannot use strnchr so add a null char at the end. There is always some space left
+ // because we preserve a margin.
+ buffer.buffertail[1] = '\0';
+ LOG_DEBUG((" incoming request to adjust:\n%s\n", buffer.bufferhead));
+
+ char *token, *path;
+ path = strchr(buffer.bufferhead, ' ') + 1;
+ if (!path)
+ return false;
+
+ // If the path doesn't start with a slash don't change it, it is probably '*' or a full
+ // path already. Return true, we are done with this request adjustment.
+ if (*path != '/')
+ return true;
+
+ token = strchr(path, ' ') + 1;
+ if (!token)
+ return false;
+
+ if (strncmp(token, "HTTP/", 5))
+ return false;
+
+ size_t hostlength = host->length();
+ assert(hostlength <= buffer.margin());
+
+ memmove(path + hostlength, path, buffer.buffertail - path);
+ memcpy(path, host->c_str(), hostlength);
+ buffer.buffertail += hostlength;
+
+ return true;
+}
+
+bool ConnectSocket(UniquePRFileDesc& fd, const PRNetAddr* addr,
+ PRIntervalTime timeout)
+{
+ PRStatus stat = PR_Connect(fd.get(), addr, timeout);
+ if (stat != PR_SUCCESS)
+ return false;
+
+ PRSocketOptionData option;
+ option.option = PR_SockOpt_Nonblocking;
+ option.value.non_blocking = true;
+ PR_SetSocketOption(fd.get(), &option);
+
+ return true;
+}
+
+/*
+ * Handle an incoming client connection. The server thread has already
+ * accepted the connection, so we just need to connect to the remote
+ * port and then proxy data back and forth.
+ * The data parameter is a connection_info_t*, and must be deleted
+ * by this function.
+ */
+void HandleConnection(void* data)
+{
+ connection_info_t* ci = static_cast<connection_info_t*>(data);
+ PRIntervalTime connect_timeout = PR_SecondsToInterval(30);
+
+ UniquePRFileDesc other_sock(PR_NewTCPSocket());
+ bool client_done = false;
+ bool client_error = false;
+ bool connect_accepted = !do_http_proxy;
+ bool ssl_updated = !do_http_proxy;
+ bool expect_request_start = do_http_proxy;
+ string certificateToUse;
+ string locationHeader;
+ client_auth_option clientAuth;
+ string fullHost;
+ int32_t flags = 0;
+
+ LOG_DEBUG(("SSLTUNNEL(%p)): incoming connection csock(0)=%p, ssock(1)=%p\n",
+ static_cast<void*>(data),
+ static_cast<void*>(ci->client_sock),
+ static_cast<void*>(other_sock.get())));
+ if (other_sock)
+ {
+ int32_t numberOfSockets = 1;
+
+ relayBuffer buffers[2];
+
+ if (!do_http_proxy)
+ {
+ if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info,
+ certificateToUse, caNone, flags))
+ client_error = true;
+ else if (!ConnectSocket(other_sock, &remote_addr, connect_timeout))
+ client_error = true;
+ else
+ numberOfSockets = 2;
+ }
+
+ PRPollDesc sockets[2] =
+ {
+ {ci->client_sock, PR_POLL_READ, 0},
+ {other_sock.get(), PR_POLL_READ, 0}
+ };
+ bool socketErrorState[2] = {false, false};
+
+ while (!((client_error||client_done) && buffers[0].empty() && buffers[1].empty()))
+ {
+ sockets[0].in_flags |= PR_POLL_EXCEPT;
+ sockets[1].in_flags |= PR_POLL_EXCEPT;
+ LOG_DEBUG(("SSLTUNNEL(%p)): polling flags csock(0)=%c%c, ssock(1)=%c%c\n",
+ static_cast<void*>(data),
+ sockets[0].in_flags & PR_POLL_READ ? 'R' : '-',
+ sockets[0].in_flags & PR_POLL_WRITE ? 'W' : '-',
+ sockets[1].in_flags & PR_POLL_READ ? 'R' : '-',
+ sockets[1].in_flags & PR_POLL_WRITE ? 'W' : '-'));
+ int32_t pollStatus = PR_Poll(sockets, numberOfSockets, PR_MillisecondsToInterval(1000));
+ if (pollStatus < 0)
+ {
+ LOG_DEBUG(("SSLTUNNEL(%p)): pollStatus=%d, exiting\n",
+ static_cast<void*>(data), pollStatus));
+ client_error = true;
+ break;
+ }
+
+ if (pollStatus == 0)
+ {
+ // timeout
+ LOG_DEBUG(("SSLTUNNEL(%p)): poll timeout, looping\n",
+ static_cast<void*>(data)));
+ continue;
+ }
+
+ for (int32_t s = 0; s < numberOfSockets; ++s)
+ {
+ int32_t s2 = s == 1 ? 0 : 1;
+ int16_t out_flags = sockets[s].out_flags;
+ int16_t &in_flags = sockets[s].in_flags;
+ int16_t &in_flags2 = sockets[s2].in_flags;
+ sockets[s].out_flags = 0;
+
+ LOG_BEGIN_BLOCK();
+ LOG_DEBUG(("SSLTUNNEL(%p)): %csock(%d)=%p out_flags=%d",
+ static_cast<void*>(data),
+ s == 0 ? 'c' : 's',
+ s,
+ static_cast<void*>(sockets[s].fd),
+ out_flags));
+ if (out_flags & (PR_POLL_EXCEPT | PR_POLL_ERR | PR_POLL_HUP))
+ {
+ LOG_DEBUG((" :exception\n"));
+ client_error = true;
+ socketErrorState[s] = true;
+ // We got a fatal error state on the socket. Clear the output buffer
+ // for this socket to break the main loop, we will never more be able
+ // to send those data anyway.
+ buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
+ continue;
+ } // PR_POLL_EXCEPT, PR_POLL_ERR, PR_POLL_HUP handling
+
+ if (out_flags & PR_POLL_READ && !buffers[s].areafree())
+ {
+ LOG_DEBUG((" no place in read buffer but got read flag, dropping it now!"));
+ in_flags &= ~PR_POLL_READ;
+ }
+
+ if (out_flags & PR_POLL_READ && buffers[s].areafree())
+ {
+ LOG_DEBUG((" :reading"));
+ int32_t bytesRead = PR_Recv(sockets[s].fd, buffers[s].buffertail,
+ buffers[s].areafree(), 0, PR_INTERVAL_NO_TIMEOUT);
+
+ if (bytesRead == 0)
+ {
+ LOG_DEBUG((" socket gracefully closed"));
+ client_done = true;
+ in_flags &= ~PR_POLL_READ;
+ }
+ else if (bytesRead < 0)
+ {
+ if (PR_GetError() != PR_WOULD_BLOCK_ERROR)
+ {
+ LOG_DEBUG((" error=%d", PR_GetError()));
+ // We are in error state, indicate that the connection was
+ // not closed gracefully
+ client_error = true;
+ socketErrorState[s] = true;
+ // Wipe out our send buffer, we cannot send it anyway.
+ buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
+ }
+ else
+ LOG_DEBUG((" would block"));
+ }
+ else
+ {
+ // If the other socket is in error state (unable to send/receive)
+ // throw this data away and continue loop
+ if (socketErrorState[s2])
+ {
+ LOG_DEBUG((" have read but other socket is in error state\n"));
+ continue;
+ }
+
+ buffers[s].buffertail += bytesRead;
+ LOG_DEBUG((", read %d bytes", bytesRead));
+
+ // We have to accept and handle the initial CONNECT request here
+ int32_t response;
+ if (!connect_accepted && ReadConnectRequest(ci->server_info, buffers[s],
+ &response, certificateToUse, &clientAuth, fullHost, locationHeader,
+ &flags))
+ {
+ // Mark this as a proxy-only connection (no SSL) if the CONNECT
+ // request didn't come for port 443 or from any of the server's
+ // cert or clientauth hostnames.
+ if (fullHost.find(":443") == string::npos)
+ {
+ server_match_t match;
+ match.fullHost = fullHost;
+ match.matched = false;
+ PL_HashTableEnumerateEntries(ci->server_info->host_cert_table,
+ match_hostname,
+ &match);
+ PL_HashTableEnumerateEntries(ci->server_info->host_clientauth_table,
+ match_hostname,
+ &match);
+ PL_HashTableEnumerateEntries(ci->server_info->host_ssl3_table,
+ match_hostname,
+ &match);
+ PL_HashTableEnumerateEntries(ci->server_info->host_tls1_table,
+ match_hostname,
+ &match);
+ PL_HashTableEnumerateEntries(ci->server_info->host_rc4_table,
+ match_hostname,
+ &match);
+ PL_HashTableEnumerateEntries(ci->server_info->host_failhandshake_table,
+ match_hostname,
+ &match);
+ ci->http_proxy_only = !match.matched;
+ }
+ else
+ {
+ ci->http_proxy_only = false;
+ }
+
+ // Clean the request as it would be read
+ buffers[s].bufferhead = buffers[s].buffertail = buffers[s].buffer;
+ in_flags |= PR_POLL_WRITE;
+ connect_accepted = true;
+
+ // Store response to the oposite buffer
+ if (response == 200)
+ {
+ LOG_DEBUG((" accepted CONNECT request, connected to the server, sending OK to the client\n"));
+ strcpy(buffers[s2].buffer, "HTTP/1.1 200 Connected\r\nConnection: keep-alive\r\n\r\n");
+ }
+ else if (response == 302)
+ {
+ LOG_DEBUG((" accepted CONNECT request with redirection, "
+ "sending location and 302 to the client\n"));
+ client_done = true;
+ snprintf(buffers[s2].buffer,
+ buffers[s2].bufferend - buffers[s2].buffer,
+ "HTTP/1.1 302 Moved\r\n"
+ "Location: https://%s/\r\n"
+ "Connection: close\r\n\r\n",
+ locationHeader.c_str());
+ }
+ else
+ {
+ LOG_ERRORD((" could not read the connect request, closing connection with %d", response));
+ client_done = true;
+ snprintf(buffers[s2].buffer,
+ buffers[s2].bufferend - buffers[s2].buffer,
+ "HTTP/1.1 %d ERROR\r\nConnection: close\r\n\r\n", response);
+
+ break;
+ }
+
+ buffers[s2].buffertail = buffers[s2].buffer + strlen(buffers[s2].buffer);
+
+ // Send the response to the client socket
+ break;
+ } // end of CONNECT handling
+
+ if (!buffers[s].areafree())
+ {
+ // Do not poll for read when the buffer is full
+ LOG_DEBUG((" no place in our read buffer, stop reading"));
+ in_flags &= ~PR_POLL_READ;
+ }
+
+ if (ssl_updated)
+ {
+ if (s == 0 && expect_request_start)
+ {
+ if (!strstr(buffers[s].bufferhead, "\r\n\r\n"))
+ {
+ // We haven't received the complete header yet, so wait.
+ continue;
+ }
+ else
+ {
+ ci->iswebsocket = AdjustWebSocketHost(buffers[s], ci);
+ expect_request_start = !(ci->iswebsocket ||
+ AdjustRequestURI(buffers[s], &fullHost));
+ PRNetAddr* addr = &remote_addr;
+ if (ci->iswebsocket && websocket_server.inet.port)
+ addr = &websocket_server;
+ if (!ConnectSocket(other_sock, addr, connect_timeout))
+ {
+ LOG_ERRORD((" could not open connection to the real server\n"));
+ client_error = true;
+ break;
+ }
+ LOG_DEBUG(("\n connected to remote server\n"));
+ numberOfSockets = 2;
+ }
+ }
+ else if (s == 1 && ci->iswebsocket)
+ {
+ if (!AdjustWebSocketLocation(buffers[s], ci))
+ continue;
+ }
+
+ in_flags2 |= PR_POLL_WRITE;
+ LOG_DEBUG((" telling the other socket to write"));
+ }
+ else
+ LOG_DEBUG((" we have something for the other socket to write, but ssl has not been administered on it"));
+ }
+ } // PR_POLL_READ handling
+
+ if (out_flags & PR_POLL_WRITE)
+ {
+ LOG_DEBUG((" :writing"));
+ int32_t bytesWrite = PR_Send(sockets[s].fd, buffers[s2].bufferhead,
+ buffers[s2].present(), 0, PR_INTERVAL_NO_TIMEOUT);
+
+ if (bytesWrite < 0)
+ {
+ if (PR_GetError() != PR_WOULD_BLOCK_ERROR) {
+ LOG_DEBUG((" error=%d", PR_GetError()));
+ client_error = true;
+ socketErrorState[s] = true;
+ // We got a fatal error while writting the buffer. Clear it to break
+ // the main loop, we will never more be able to send it.
+ buffers[s2].bufferhead = buffers[s2].buffertail = buffers[s2].buffer;
+ }
+ else
+ LOG_DEBUG((" would block"));
+ }
+ else
+ {
+ LOG_DEBUG((", written %d bytes", bytesWrite));
+ buffers[s2].buffertail[1] = '\0';
+ LOG_DEBUG((" dump:\n%.*s\n", bytesWrite, buffers[s2].bufferhead));
+
+ buffers[s2].bufferhead += bytesWrite;
+ if (buffers[s2].present())
+ {
+ LOG_DEBUG((" still have to write %d bytes", (int)buffers[s2].present()));
+ in_flags |= PR_POLL_WRITE;
+ }
+ else
+ {
+ if (!ssl_updated)
+ {
+ LOG_DEBUG((" proxy response sent to the client"));
+ // Proxy response has just been writen, update to ssl
+ ssl_updated = true;
+ if (ci->http_proxy_only)
+ {
+ LOG_DEBUG((" not updating to SSL based on http_proxy_only for this socket"));
+ }
+ else if (!ConfigureSSLServerSocket(ci->client_sock, ci->server_info,
+ certificateToUse, clientAuth, flags))
+ {
+ LOG_ERRORD((" failed to config server socket\n"));
+ client_error = true;
+ break;
+ }
+ else
+ {
+ LOG_DEBUG((" client socket updated to SSL"));
+ }
+ } // sslUpdate
+
+ LOG_DEBUG((" dropping our write flag and setting other socket read flag"));
+ in_flags &= ~PR_POLL_WRITE;
+ in_flags2 |= PR_POLL_READ;
+ buffers[s2].compact();
+ }
+ }
+ } // PR_POLL_WRITE handling
+ LOG_END_BLOCK(); // end the log
+ } // for...
+ } // while, poll
+ }
+ else
+ client_error = true;
+
+ LOG_DEBUG(("SSLTUNNEL(%p)): exiting root function for csock=%p, ssock=%p\n",
+ static_cast<void*>(data),
+ static_cast<void*>(ci->client_sock),
+ static_cast<void*>(other_sock.get())));
+ if (!client_error)
+ PR_Shutdown(ci->client_sock, PR_SHUTDOWN_SEND);
+ PR_Close(ci->client_sock);
+
+ delete ci;
+}
+
+/*
+ * Start listening for SSL connections on a specified port, handing
+ * them off to client threads after accepting the connection.
+ * The data parameter is a server_info_t*, owned by the calling
+ * function.
+ */
+void StartServer(void* data)
+{
+ server_info_t* si = static_cast<server_info_t*>(data);
+
+ //TODO: select ciphers?
+ UniquePRFileDesc listen_socket(PR_NewTCPSocket());
+ if (!listen_socket) {
+ LOG_ERROR(("failed to create socket\n"));
+ SignalShutdown();
+ return;
+ }
+
+ // In case the socket is still open in the TIME_WAIT state from a previous
+ // instance of ssltunnel we ask to reuse the port.
+ PRSocketOptionData socket_option;
+ socket_option.option = PR_SockOpt_Reuseaddr;
+ socket_option.value.reuse_addr = true;
+ PR_SetSocketOption(listen_socket.get(), &socket_option);
+
+ PRNetAddr server_addr;
+ PR_InitializeNetAddr(PR_IpAddrAny, si->listen_port, &server_addr);
+ if (PR_Bind(listen_socket.get(), &server_addr) != PR_SUCCESS) {
+ LOG_ERROR(("failed to bind socket on port %d: error %d\n", si->listen_port, PR_GetError()));
+ SignalShutdown();
+ return;
+ }
+
+ if (PR_Listen(listen_socket.get(), 1) != PR_SUCCESS) {
+ LOG_ERROR(("failed to listen on socket\n"));
+ SignalShutdown();
+ return;
+ }
+
+ LOG_INFO(("Server listening on port %d with cert %s\n", si->listen_port,
+ si->cert_nickname.c_str()));
+
+ while (!shutdown_server) {
+ connection_info_t* ci = new connection_info_t();
+ ci->server_info = si;
+ ci->http_proxy_only = do_http_proxy;
+ // block waiting for connections
+ ci->client_sock = PR_Accept(listen_socket.get(), &ci->client_addr,
+ PR_INTERVAL_NO_TIMEOUT);
+
+ PRSocketOptionData option;
+ option.option = PR_SockOpt_Nonblocking;
+ option.value.non_blocking = true;
+ PR_SetSocketOption(ci->client_sock, &option);
+
+ if (ci->client_sock)
+ // Not actually using this PRJob*...
+ //PRJob* job =
+ PR_QueueJob(threads, HandleConnection, ci, true);
+ else
+ delete ci;
+ }
+}
+
+// bogus password func, just don't use passwords. :-P
+char* password_func(PK11SlotInfo* slot, PRBool retry, void* arg)
+{
+ if (retry)
+ return nullptr;
+
+ return PL_strdup("");
+}
+
+server_info_t* findServerInfo(int portnumber)
+{
+ for (vector<server_info_t>::iterator it = servers.begin();
+ it != servers.end(); it++)
+ {
+ if (it->listen_port == portnumber)
+ return &(*it);
+ }
+
+ return nullptr;
+}
+
+PLHashTable* get_ssl3_table(server_info_t* server)
+{
+ return server->host_ssl3_table;
+}
+
+PLHashTable* get_tls1_table(server_info_t* server)
+{
+ return server->host_tls1_table;
+}
+
+PLHashTable* get_rc4_table(server_info_t* server)
+{
+ return server->host_rc4_table;
+}
+
+PLHashTable* get_failhandshake_table(server_info_t* server)
+{
+ return server->host_failhandshake_table;
+}
+
+int parseWeakCryptoConfig(char* const& keyword, char*& _caret,
+ PLHashTable* (*get_table)(server_info_t*))
+{
+ char* hostname = strtok2(_caret, ":", &_caret);
+ char* hostportstring = strtok2(_caret, ":", &_caret);
+ char* serverportstring = strtok2(_caret, "\n", &_caret);
+
+ int port = atoi(serverportstring);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
+ return 1;
+ }
+
+ if (server_info_t* existingServer = findServerInfo(port))
+ {
+ any_host_spec_config = true;
+
+ char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
+ if (!hostname_copy) {
+ LOG_ERROR(("Out of memory"));
+ return 1;
+ }
+
+ strcpy(hostname_copy, hostname);
+ strcat(hostname_copy, ":");
+ strcat(hostname_copy, hostportstring);
+
+ PLHashEntry* entry = PL_HashTableAdd(get_table(existingServer), hostname_copy, keyword);
+ if (!entry) {
+ LOG_ERROR(("Out of memory"));
+ return 1;
+ }
+ }
+ else
+ {
+ LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port));
+ return 1;
+ }
+
+ return 0;
+}
+
+int processConfigLine(char* configLine)
+{
+ if (*configLine == 0 || *configLine == '#')
+ return 0;
+
+ char* _caret;
+ char* keyword = strtok2(configLine, ":", &_caret);
+
+ // Configure usage of http/ssl tunneling proxy behavior
+ if (!strcmp(keyword, "httpproxy"))
+ {
+ char* value = strtok2(_caret, ":", &_caret);
+ if (!strcmp(value, "1"))
+ do_http_proxy = true;
+
+ return 0;
+ }
+
+ if (!strcmp(keyword, "websocketserver"))
+ {
+ char* ipstring = strtok2(_caret, ":", &_caret);
+ if (PR_StringToNetAddr(ipstring, &websocket_server) != PR_SUCCESS) {
+ LOG_ERROR(("Invalid IP address in proxy config: %s\n", ipstring));
+ return 1;
+ }
+ char* remoteport = strtok2(_caret, ":", &_caret);
+ int port = atoi(remoteport);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid remote port in proxy config: %s\n", remoteport));
+ return 1;
+ }
+ websocket_server.inet.port = PR_htons(port);
+ return 0;
+ }
+
+ // Configure the forward address of the target server
+ if (!strcmp(keyword, "forward"))
+ {
+ char* ipstring = strtok2(_caret, ":", &_caret);
+ if (PR_StringToNetAddr(ipstring, &remote_addr) != PR_SUCCESS) {
+ LOG_ERROR(("Invalid remote IP address: %s\n", ipstring));
+ return 1;
+ }
+ char* serverportstring = strtok2(_caret, ":", &_caret);
+ int port = atoi(serverportstring);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid remote port: %s\n", serverportstring));
+ return 1;
+ }
+ remote_addr.inet.port = PR_htons(port);
+
+ return 0;
+ }
+
+ // Configure all listen sockets and port+certificate bindings
+ if (!strcmp(keyword, "listen"))
+ {
+ char* hostname = strtok2(_caret, ":", &_caret);
+ char* hostportstring = nullptr;
+ if (strcmp(hostname, "*"))
+ {
+ any_host_spec_config = true;
+ hostportstring = strtok2(_caret, ":", &_caret);
+ }
+
+ char* serverportstring = strtok2(_caret, ":", &_caret);
+ char* certnick = strtok2(_caret, ":", &_caret);
+
+ int port = atoi(serverportstring);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
+ return 1;
+ }
+
+ if (server_info_t* existingServer = findServerInfo(port))
+ {
+ char *certnick_copy = new char[strlen(certnick)+1];
+ char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
+
+ strcpy(hostname_copy, hostname);
+ strcat(hostname_copy, ":");
+ strcat(hostname_copy, hostportstring);
+ strcpy(certnick_copy, certnick);
+
+ PLHashEntry* entry = PL_HashTableAdd(existingServer->host_cert_table, hostname_copy, certnick_copy);
+ if (!entry) {
+ LOG_ERROR(("Out of memory"));
+ return 1;
+ }
+ }
+ else
+ {
+ server_info_t server;
+ server.cert_nickname = certnick;
+ server.listen_port = port;
+ server.host_cert_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);
+ if (!server.host_cert_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+ server.host_clientauth_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ ClientAuthValueComparator, nullptr, nullptr);
+ if (!server.host_clientauth_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+ server.host_redir_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);
+ if (!server.host_redir_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+
+ server.host_ssl3_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);;
+ if (!server.host_ssl3_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+
+ server.host_tls1_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);;
+ if (!server.host_tls1_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+
+ server.host_rc4_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);;
+ if (!server.host_rc4_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+
+ server.host_failhandshake_table = PL_NewHashTable(0, PL_HashString, PL_CompareStrings,
+ PL_CompareStrings, nullptr, nullptr);;
+ if (!server.host_failhandshake_table)
+ {
+ LOG_ERROR(("Internal, could not create hash table\n"));
+ return 1;
+ }
+
+ servers.push_back(server);
+ }
+
+ return 0;
+ }
+
+ if (!strcmp(keyword, "clientauth"))
+ {
+ char* hostname = strtok2(_caret, ":", &_caret);
+ char* hostportstring = strtok2(_caret, ":", &_caret);
+ char* serverportstring = strtok2(_caret, ":", &_caret);
+
+ int port = atoi(serverportstring);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
+ return 1;
+ }
+
+ if (server_info_t* existingServer = findServerInfo(port))
+ {
+ char* authoptionstring = strtok2(_caret, ":", &_caret);
+ client_auth_option* authoption = new client_auth_option;
+ if (!authoption) {
+ LOG_ERROR(("Out of memory"));
+ return 1;
+ }
+
+ if (!strcmp(authoptionstring, "require"))
+ *authoption = caRequire;
+ else if (!strcmp(authoptionstring, "request"))
+ *authoption = caRequest;
+ else if (!strcmp(authoptionstring, "none"))
+ *authoption = caNone;
+ else
+ {
+ LOG_ERROR(("Incorrect client auth option modifier for host '%s'", hostname));
+ delete authoption;
+ return 1;
+ }
+
+ any_host_spec_config = true;
+
+ char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
+ if (!hostname_copy) {
+ LOG_ERROR(("Out of memory"));
+ delete authoption;
+ return 1;
+ }
+
+ strcpy(hostname_copy, hostname);
+ strcat(hostname_copy, ":");
+ strcat(hostname_copy, hostportstring);
+
+ PLHashEntry* entry = PL_HashTableAdd(existingServer->host_clientauth_table, hostname_copy, authoption);
+ if (!entry) {
+ LOG_ERROR(("Out of memory"));
+ delete authoption;
+ return 1;
+ }
+ }
+ else
+ {
+ LOG_ERROR(("Server on port %d for client authentication option is not defined, use 'listen' option first", port));
+ return 1;
+ }
+
+ return 0;
+ }
+
+ if (!strcmp(keyword, "redirhost"))
+ {
+ char* hostname = strtok2(_caret, ":", &_caret);
+ char* hostportstring = strtok2(_caret, ":", &_caret);
+ char* serverportstring = strtok2(_caret, ":", &_caret);
+
+ int port = atoi(serverportstring);
+ if (port <= 0) {
+ LOG_ERROR(("Invalid port specified: %s\n", serverportstring));
+ return 1;
+ }
+
+ if (server_info_t* existingServer = findServerInfo(port))
+ {
+ char* redirhoststring = strtok2(_caret, ":", &_caret);
+
+ any_host_spec_config = true;
+
+ char *hostname_copy = new char[strlen(hostname)+strlen(hostportstring)+2];
+ if (!hostname_copy) {
+ LOG_ERROR(("Out of memory"));
+ return 1;
+ }
+
+ strcpy(hostname_copy, hostname);
+ strcat(hostname_copy, ":");
+ strcat(hostname_copy, hostportstring);
+
+ char *redir_copy = new char[strlen(redirhoststring)+1];
+ strcpy(redir_copy, redirhoststring);
+ PLHashEntry* entry = PL_HashTableAdd(existingServer->host_redir_table, hostname_copy, redir_copy);
+ if (!entry) {
+ LOG_ERROR(("Out of memory"));
+ delete[] hostname_copy;
+ delete[] redir_copy;
+ return 1;
+ }
+ }
+ else
+ {
+ LOG_ERROR(("Server on port %d for redirhost option is not defined, use 'listen' option first", port));
+ return 1;
+ }
+
+ return 0;
+ }
+
+ if (!strcmp(keyword, "ssl3")) {
+ return parseWeakCryptoConfig(keyword, _caret, get_ssl3_table);
+ }
+ if (!strcmp(keyword, "tls1")) {
+ return parseWeakCryptoConfig(keyword, _caret, get_tls1_table);
+ }
+
+ if (!strcmp(keyword, "rc4")) {
+ return parseWeakCryptoConfig(keyword, _caret, get_rc4_table);
+ }
+
+ if (!strcmp(keyword, "failHandshake")) {
+ return parseWeakCryptoConfig(keyword, _caret, get_failhandshake_table);
+ }
+
+ // Configure the NSS certificate database directory
+ if (!strcmp(keyword, "certdbdir"))
+ {
+ nssconfigdir = strtok2(_caret, "\n", &_caret);
+ return 0;
+ }
+
+ LOG_ERROR(("Error: keyword \"%s\" unexpected\n", keyword));
+ return 1;
+}
+
+int parseConfigFile(const char* filePath)
+{
+ FILE* f = fopen(filePath, "r");
+ if (!f)
+ return 1;
+
+ char buffer[1024], *b = buffer;
+ while (!feof(f))
+ {
+ char c;
+
+ if (fscanf(f, "%c", &c) != 1) {
+ break;
+ }
+
+ switch (c)
+ {
+ case '\n':
+ *b++ = 0;
+ if (processConfigLine(buffer))
+ {
+ fclose(f);
+ return 1;
+ }
+ b = buffer;
+ continue;
+
+ case '\r':
+ continue;
+
+ default:
+ *b++ = c;
+ }
+ }
+
+ fclose(f);
+
+ // Check mandatory items
+ if (nssconfigdir.empty())
+ {
+ LOG_ERROR(("Error: missing path to NSS certification database\n,use certdbdir:<path> in the config file\n"));
+ return 1;
+ }
+
+ if (any_host_spec_config && !do_http_proxy)
+ {
+ LOG_ERROR(("Warning: any host-specific configurations are ignored, add httpproxy:1 to allow them\n"));
+ }
+
+ return 0;
+}
+
+int freeHostCertHashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ delete [] (char*)he->value;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int freeHostRedirHashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ delete [] (char*)he->value;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int freeClientAuthHashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ delete (client_auth_option*)he->value;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int freeSSL3HashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int freeTLS1HashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int freeRC4HashItems(PLHashEntry *he, int i, void *arg)
+{
+ delete [] (char*)he->key;
+ return HT_ENUMERATE_REMOVE;
+}
+
+int main(int argc, char** argv)
+{
+ const char* configFilePath;
+
+ const char* logLevelEnv = PR_GetEnv("SSLTUNNEL_LOG_LEVEL");
+ gLogLevel = logLevelEnv ? (LogLevel)atoi(logLevelEnv) : LEVEL_INFO;
+
+ if (argc == 1)
+ configFilePath = "ssltunnel.cfg";
+ else
+ configFilePath = argv[1];
+
+ memset(&websocket_server, 0, sizeof(PRNetAddr));
+
+ if (parseConfigFile(configFilePath)) {
+ LOG_ERROR(("Error: config file \"%s\" missing or formating incorrect\n"
+ "Specify path to the config file as parameter to ssltunnel or \n"
+ "create ssltunnel.cfg in the working directory.\n\n"
+ "Example format of the config file:\n\n"
+ " # Enable http/ssl tunneling proxy-like behavior.\n"
+ " # If not specified ssltunnel simply does direct forward.\n"
+ " httpproxy:1\n\n"
+ " # Specify path to the certification database used.\n"
+ " certdbdir:/path/to/certdb\n\n"
+ " # Forward/proxy all requests in raw to 127.0.0.1:8888.\n"
+ " forward:127.0.0.1:8888\n\n"
+ " # Accept connections on port 4443 or 5678 resp. and authenticate\n"
+ " # to any host ('*') using the 'server cert' or 'server cert 2' resp.\n"
+ " listen:*:4443:server cert\n"
+ " listen:*:5678:server cert 2\n\n"
+ " # Accept connections on port 4443 and authenticate using\n"
+ " # 'a different cert' when target host is 'my.host.name:443'.\n"
+ " # This only works in httpproxy mode and has higher priority\n"
+ " # than the previous option.\n"
+ " listen:my.host.name:443:4443:a different cert\n\n"
+ " # To make a specific host require or just request a client certificate\n"
+ " # to authenticate use the following options. This can only be used\n"
+ " # in httpproxy mode and only after the 'listen' option has been\n"
+ " # specified. You also have to specify the tunnel listen port.\n"
+ " clientauth:requesting-client-cert.host.com:443:4443:request\n"
+ " clientauth:requiring-client-cert.host.com:443:4443:require\n"
+ " # Proxy WebSocket traffic to the server at 127.0.0.1:9999,\n"
+ " # instead of the server specified in the 'forward' option.\n"
+ " websocketserver:127.0.0.1:9999\n",
+ configFilePath));
+ return 1;
+ }
+
+ // create a thread pool to handle connections
+ threads = PR_CreateThreadPool(INITIAL_THREADS * servers.size(),
+ MAX_THREADS * servers.size(),
+ DEFAULT_STACKSIZE);
+ if (!threads) {
+ LOG_ERROR(("Failed to create thread pool\n"));
+ return 1;
+ }
+
+ shutdown_lock = PR_NewLock();
+ if (!shutdown_lock) {
+ LOG_ERROR(("Failed to create lock\n"));
+ PR_ShutdownThreadPool(threads);
+ return 1;
+ }
+ shutdown_condvar = PR_NewCondVar(shutdown_lock);
+ if (!shutdown_condvar) {
+ LOG_ERROR(("Failed to create condvar\n"));
+ PR_ShutdownThreadPool(threads);
+ PR_DestroyLock(shutdown_lock);
+ return 1;
+ }
+
+ PK11_SetPasswordFunc(password_func);
+
+ // Initialize NSS
+ if (NSS_Init(nssconfigdir.c_str()) != SECSuccess) {
+ int32_t errorlen = PR_GetErrorTextLength();
+ if (errorlen) {
+ auto err = mozilla::MakeUnique<char[]>(errorlen + 1);
+ PR_GetErrorText(err.get());
+ LOG_ERROR(("Failed to init NSS: %s", err.get()));
+ } else {
+ LOG_ERROR(("Failed to init NSS: Cannot get error from NSPR."));
+ }
+ PR_ShutdownThreadPool(threads);
+ PR_DestroyCondVar(shutdown_condvar);
+ PR_DestroyLock(shutdown_lock);
+ return 1;
+ }
+
+ if (NSS_SetDomesticPolicy() != SECSuccess) {
+ LOG_ERROR(("NSS_SetDomesticPolicy failed\n"));
+ PR_ShutdownThreadPool(threads);
+ PR_DestroyCondVar(shutdown_condvar);
+ PR_DestroyLock(shutdown_lock);
+ NSS_Shutdown();
+ return 1;
+ }
+
+ // these values should make NSS use the defaults
+ if (SSL_ConfigServerSessionIDCache(0, 0, 0, nullptr) != SECSuccess) {
+ LOG_ERROR(("SSL_ConfigServerSessionIDCache failed\n"));
+ PR_ShutdownThreadPool(threads);
+ PR_DestroyCondVar(shutdown_condvar);
+ PR_DestroyLock(shutdown_lock);
+ NSS_Shutdown();
+ return 1;
+ }
+
+ for (vector<server_info_t>::iterator it = servers.begin();
+ it != servers.end(); it++) {
+ // Not actually using this PRJob*...
+ // PRJob* server_job =
+ PR_QueueJob(threads, StartServer, &(*it), true);
+ }
+ // now wait for someone to tell us to quit
+ PR_Lock(shutdown_lock);
+ PR_WaitCondVar(shutdown_condvar, PR_INTERVAL_NO_TIMEOUT);
+ PR_Unlock(shutdown_lock);
+ shutdown_server = true;
+ LOG_INFO(("Shutting down...\n"));
+ // cleanup
+ PR_ShutdownThreadPool(threads);
+ PR_JoinThreadPool(threads);
+ PR_DestroyCondVar(shutdown_condvar);
+ PR_DestroyLock(shutdown_lock);
+ if (NSS_Shutdown() == SECFailure) {
+ LOG_DEBUG(("Leaked NSS objects!\n"));
+ }
+
+ for (vector<server_info_t>::iterator it = servers.begin();
+ it != servers.end(); it++)
+ {
+ PL_HashTableEnumerateEntries(it->host_cert_table, freeHostCertHashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_clientauth_table, freeClientAuthHashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_redir_table, freeHostRedirHashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_ssl3_table, freeSSL3HashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_tls1_table, freeTLS1HashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_rc4_table, freeRC4HashItems, nullptr);
+ PL_HashTableEnumerateEntries(it->host_failhandshake_table, freeRC4HashItems, nullptr);
+ PL_HashTableDestroy(it->host_cert_table);
+ PL_HashTableDestroy(it->host_clientauth_table);
+ PL_HashTableDestroy(it->host_redir_table);
+ PL_HashTableDestroy(it->host_ssl3_table);
+ PL_HashTableDestroy(it->host_tls1_table);
+ PL_HashTableDestroy(it->host_rc4_table);
+ PL_HashTableDestroy(it->host_failhandshake_table);
+ }
+
+ PR_Cleanup();
+ return 0;
+}