diff options
Diffstat (limited to 'security/nss/cmd/tstclnt')
-rw-r--r-- | security/nss/cmd/tstclnt/tstclnt.c | 301 |
1 files changed, 236 insertions, 65 deletions
diff --git a/security/nss/cmd/tstclnt/tstclnt.c b/security/nss/cmd/tstclnt/tstclnt.c index 959afec59..1ad99502b 100644 --- a/security/nss/cmd/tstclnt/tstclnt.c +++ b/security/nss/cmd/tstclnt/tstclnt.c @@ -31,6 +31,7 @@ #include "ocsp.h" #include "ssl.h" #include "sslproto.h" +#include "sslexp.h" #include "pk11func.h" #include "secmod.h" #include "plgetopt.h" @@ -95,6 +96,7 @@ PRBool verbose; int dumpServerChain = 0; int renegotiationsToDo = 0; int renegotiationsDone = 0; +PRBool initializedServerSessionCache = PR_FALSE; static char *progName; @@ -178,7 +180,7 @@ PrintUsageHeader(const char *progName) "[-n nickname] [-Bafosvx] [-c ciphers] [-Y] [-Z]\n" "[-V [min-version]:[max-version]] [-K] [-T] [-U]\n" "[-r N] [-w passwd] [-W pwfile] [-q [-t seconds]] [-I groups]\n" - "[-A requestfile] [-L totalconnections]\n" + "[-A requestfile] [-L totalconnections] [-P {client,server}] [-Q]\n" "\n", progName); } @@ -202,7 +204,7 @@ PrintParameterUsage(void) fprintf(stderr, "%-20s Print certificate chain information\n", "-C"); fprintf(stderr, "%-20s (use -C twice to print more certificate details)\n", ""); fprintf(stderr, "%-20s (use -C three times to include PEM format certificate dumps)\n", ""); - fprintf(stderr, "%-20s Nickname of key and cert for client auth\n", + fprintf(stderr, "%-20s Nickname of key and cert\n", "-n nickname"); fprintf(stderr, "%-20s Restricts the set of enabled SSL/TLS protocols versions.\n" @@ -251,6 +253,9 @@ PrintParameterUsage(void) "%-20s The following values are valid:\n" "%-20s P256, P384, P521, x25519, FF2048, FF3072, FF4096, FF6144, FF8192\n", "-I", "", ""); + fprintf(stderr, "%-20s Enable alternative TLS 1.3 handshake\n", "-X alt-server-hello"); + fprintf(stderr, "%-20s Use DTLS\n", "-P {client, server}"); + fprintf(stderr, "%-20s Exit after handshake\n", "-Q"); } static void @@ -914,6 +919,12 @@ char *requestString = NULL; PRInt32 requestStringLen = 0; PRBool requestSent = PR_FALSE; PRBool enableZeroRtt = PR_FALSE; +PRBool enableAltServerHello = PR_FALSE; +PRBool useDTLS = PR_FALSE; +PRBool actAsServer = PR_FALSE; +PRBool stopAfterHandshake = PR_FALSE; +PRBool requestToExit = PR_FALSE; +char *versionString = NULL; static int writeBytesToServer(PRFileDesc *s, const char *buf, int nb) @@ -996,12 +1007,129 @@ handshakeCallback(PRFileDesc *fd, void *client_data) writeBytesToServer(fd, requestString, requestStringLen); } } + if (stopAfterHandshake) { + requestToExit = PR_TRUE; + } } #define REQUEST_WAITING (requestString && !requestSent) +static SECStatus +installServerCertificate(PRFileDesc *s, char *nickname) +{ + CERTCertificate *cert; + SECKEYPrivateKey *privKey = NULL; + + if (!nickname) { + PORT_SetError(SEC_ERROR_INVALID_ARGS); + return SECFailure; + } + + cert = PK11_FindCertFromNickname(nickname, &pwdata); + if (cert == NULL) { + return SECFailure; + } + + privKey = PK11_FindKeyByAnyCert(cert, &pwdata); + if (privKey == NULL) { + return SECFailure; + } + if (SSL_ConfigServerCert(s, cert, privKey, NULL, 0) != SECSuccess) { + return SECFailure; + } + SECKEY_DestroyPrivateKey(privKey); + CERT_DestroyCertificate(cert); + + return SECSuccess; +} + +static SECStatus +bindToClient(PRFileDesc *s) +{ + PRStatus status; + status = PR_Bind(s, &addr); + if (status != PR_SUCCESS) { + return SECFailure; + } + + for (;;) { + /* Bind the remote address on first packet. This must happen + * before we SSL-ize the socket because we need to get the + * peer's address before SSLizing. Recvfrom gives us that + * while not consuming any data. */ + unsigned char tmp; + PRNetAddr remote; + int nb; + + nb = PR_RecvFrom(s, &tmp, 1, PR_MSG_PEEK, + &remote, PR_INTERVAL_NO_TIMEOUT); + if (nb != 1) + continue; + + status = PR_Connect(s, &remote, PR_INTERVAL_NO_TIMEOUT); + if (status != PR_SUCCESS) { + SECU_PrintError(progName, "server bind to remote end failed"); + return SECFailure; + } + return SECSuccess; + } + + /* Unreachable. */ +} + +static SECStatus +connectToServer(PRFileDesc *s, PRPollDesc *pollset) +{ + PRStatus status; + PRInt32 filesReady; + + status = PR_Connect(s, &addr, PR_INTERVAL_NO_TIMEOUT); + if (status != PR_SUCCESS) { + if (PR_GetError() == PR_IN_PROGRESS_ERROR) { + if (verbose) + SECU_PrintError(progName, "connect"); + milliPause(50 * multiplier); + pollset[SSOCK_FD].in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT; + pollset[SSOCK_FD].out_flags = 0; + pollset[SSOCK_FD].fd = s; + while (1) { + FPRINTF(stderr, + "%s: about to call PR_Poll for connect completion!\n", + progName); + filesReady = PR_Poll(pollset, 1, PR_INTERVAL_NO_TIMEOUT); + if (filesReady < 0) { + SECU_PrintError(progName, "unable to connect (poll)"); + return SECFailure; + } + FPRINTF(stderr, + "%s: PR_Poll returned 0x%02x for socket out_flags.\n", + progName, pollset[SSOCK_FD].out_flags); + if (filesReady == 0) { /* shouldn't happen! */ + SECU_PrintError(progName, "%s: PR_Poll returned zero!\n"); + return SECFailure; + } + status = PR_GetConnectStatus(pollset); + if (status == PR_SUCCESS) { + break; + } + if (PR_GetError() != PR_IN_PROGRESS_ERROR) { + SECU_PrintError(progName, "unable to connect (poll)"); + return SECFailure; + } + SECU_PrintError(progName, "poll"); + milliPause(50 * multiplier); + } + } else { + SECU_PrintError(progName, "unable to connect"); + return SECFailure; + } + } + + return SECSuccess; +} + static int -run_client(void) +run(void) { int headerSeparatorPtrnId = 0; int error = 0; @@ -1017,13 +1145,23 @@ run_client(void) requestSent = PR_FALSE; /* Create socket */ - s = PR_OpenTCPSocket(addr.raw.family); + if (useDTLS) { + s = PR_OpenUDPSocket(addr.raw.family); + } else { + s = PR_OpenTCPSocket(addr.raw.family); + } + if (s == NULL) { SECU_PrintError(progName, "error creating socket"); error = 1; goto done; } + if (actAsServer) { + if (bindToClient(s) != SECSuccess) { + return 1; + } + } opt.option = PR_SockOpt_Nonblocking; opt.value.non_blocking = PR_TRUE; /* default */ if (serverCertAuth.testFreshStatusFromSideChannel) { @@ -1036,13 +1174,16 @@ run_client(void) goto done; } - s = SSL_ImportFD(NULL, s); + if (useDTLS) { + s = DTLS_ImportFD(NULL, s); + } else { + s = SSL_ImportFD(NULL, s); + } if (s == NULL) { SECU_PrintError(progName, "error importing socket"); error = 1; goto done; } - SSL_SetPKCS11PinArg(s, &pwdata); rv = SSL_OptionSet(s, SSL_SECURITY, 1); @@ -1052,7 +1193,7 @@ run_client(void) goto done; } - rv = SSL_OptionSet(s, SSL_HANDSHAKE_AS_CLIENT, 1); + rv = SSL_OptionSet(s, actAsServer ? SSL_HANDSHAKE_AS_SERVER : SSL_HANDSHAKE_AS_CLIENT, 1); if (rv != SECSuccess) { SECU_PrintError(progName, "error enabling client handshake"); error = 1; @@ -1178,6 +1319,16 @@ run_client(void) } } + /* Alternate ServerHello content type (TLS 1.3 only) */ + if (enableAltServerHello) { + rv = SSL_UseAltServerHelloType(s, PR_TRUE); + if (rv != SECSuccess) { + SECU_PrintError(progName, "error enabling alternate ServerHello type"); + error = 1; + goto done; + } + } + /* require the use of fixed finite-field DH groups */ if (requireDHNamedGroups) { rv = SSL_OptionSet(s, SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); @@ -1212,7 +1363,21 @@ run_client(void) if (override) { SSL_BadCertHook(s, ownBadCertHandler, NULL); } - SSL_GetClientAuthDataHook(s, own_GetClientAuthData, (void *)nickname); + if (actAsServer) { + rv = installServerCertificate(s, nickname); + if (rv != SECSuccess) { + SECU_PrintError(progName, "error installing server cert"); + return 1; + } + rv = SSL_ConfigServerSessionIDCache(1024, 0, 0, "."); + if (rv != SECSuccess) { + SECU_PrintError(progName, "error configuring session cache"); + return 1; + } + initializedServerSessionCache = PR_TRUE; + } else { + SSL_GetClientAuthDataHook(s, own_GetClientAuthData, (void *)nickname); + } SSL_HandshakeCallback(s, handshakeCallback, hs2SniHostName); if (hs1SniHostName) { SSL_SetURL(s, hs1SniHostName); @@ -1220,56 +1385,27 @@ run_client(void) SSL_SetURL(s, host); } - /* Try to connect to the server */ - status = PR_Connect(s, &addr, PR_INTERVAL_NO_TIMEOUT); - if (status != PR_SUCCESS) { - if (PR_GetError() == PR_IN_PROGRESS_ERROR) { - if (verbose) - SECU_PrintError(progName, "connect"); - milliPause(50 * multiplier); - pollset[SSOCK_FD].in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT; - pollset[SSOCK_FD].out_flags = 0; - pollset[SSOCK_FD].fd = s; - while (1) { - FPRINTF(stderr, - "%s: about to call PR_Poll for connect completion!\n", - progName); - filesReady = PR_Poll(pollset, 1, PR_INTERVAL_NO_TIMEOUT); - if (filesReady < 0) { - SECU_PrintError(progName, "unable to connect (poll)"); - error = 1; - goto done; - } - FPRINTF(stderr, - "%s: PR_Poll returned 0x%02x for socket out_flags.\n", - progName, pollset[SSOCK_FD].out_flags); - if (filesReady == 0) { /* shouldn't happen! */ - FPRINTF(stderr, "%s: PR_Poll returned zero!\n", progName); - error = 1; - goto done; - } - status = PR_GetConnectStatus(pollset); - if (status == PR_SUCCESS) { - break; - } - if (PR_GetError() != PR_IN_PROGRESS_ERROR) { - SECU_PrintError(progName, "unable to connect (poll)"); - error = 1; - goto done; - } - SECU_PrintError(progName, "poll"); - milliPause(50 * multiplier); - } - } else { - SECU_PrintError(progName, "unable to connect"); + if (actAsServer) { + rv = SSL_ResetHandshake(s, PR_TRUE /* server */); + if (rv != SECSuccess) { + return 1; + } + } else { + /* Try to connect to the server */ + rv = connectToServer(s, pollset); + if (rv != SECSuccess) { + ; error = 1; goto done; } } pollset[SSOCK_FD].fd = s; - pollset[SSOCK_FD].in_flags = PR_POLL_EXCEPT | - (clientSpeaksFirst ? 0 : PR_POLL_READ); + pollset[SSOCK_FD].in_flags = PR_POLL_EXCEPT; + if (!actAsServer) + pollset[SSOCK_FD].in_flags |= (clientSpeaksFirst ? 0 : PR_POLL_READ); + else + pollset[SSOCK_FD].in_flags |= PR_POLL_READ; pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput); if (!REQUEST_WAITING) { pollset[STDIN_FD].in_flags = PR_POLL_READ; @@ -1319,9 +1455,11 @@ run_client(void) ** Select on stdin and on the socket. Write data from stdin to ** socket, read data from socket and write to stdout. */ + requestToExit = PR_FALSE; FPRINTF(stderr, "%s: ready...\n", progName); - while ((pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) || - REQUEST_WAITING) { + while (!requestToExit && + ((pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) || + REQUEST_WAITING)) { char buf[4000]; /* buffer for stdin */ int nb; /* num bytes read from stdin. */ @@ -1507,12 +1645,10 @@ main(int argc, char **argv) } } - SSL_VersionRangeGetSupported(ssl_variant_stream, &enabledVersions); - /* XXX: 'B' was used in the past but removed in 3.28, * please leave some time before resuing it. */ optstate = PL_CreateOptState(argc, argv, - "46A:CDFGHI:KL:M:OR:STUV:W:YZa:bc:d:fgh:m:n:op:qr:st:uvw:z"); + "46A:CDFGHI:KL:M:OP:QR:STUV:W:X:YZa:bc:d:fgh:m:n:op:qr:st:uvw:z"); while ((optstatus = PL_GetNextOpt(optstate)) == PL_OPT_OK) { switch (optstate->option) { case '?': @@ -1593,6 +1729,21 @@ main(int argc, char **argv) }; break; + case 'P': + useDTLS = PR_TRUE; + if (!strcmp(optstate->value, "server")) { + actAsServer = 1; + } else { + if (strcmp(optstate->value, "client")) { + Usage(progName); + } + } + break; + + case 'Q': + stopAfterHandshake = PR_TRUE; + break; + case 'R': rootModule = PORT_Strdup(optstate->value); break; @@ -1610,14 +1761,16 @@ main(int argc, char **argv) break; case 'V': - if (SECU_ParseSSLVersionRangeString(optstate->value, - enabledVersions, &enabledVersions) != - SECSuccess) { - fprintf(stderr, "Bad version specified.\n"); + versionString = PORT_Strdup(optstate->value); + break; + + case 'X': + if (!strcmp(optstate->value, "alt-server-hello")) { + enableAltServerHello = PR_TRUE; + } else { Usage(progName); } break; - case 'Y': PrintCipherUsage(progName); exit(0); @@ -1727,9 +1880,20 @@ main(int argc, char **argv) break; } } - PL_DestroyOptState(optstate); + SSL_VersionRangeGetSupported(useDTLS ? ssl_variant_datagram : ssl_variant_stream, &enabledVersions); + + if (versionString) { + if (SECU_ParseSSLVersionRangeString(versionString, + enabledVersions, &enabledVersions) != + SECSuccess) { + fprintf(stderr, "Bad version specified.\n"); + Usage(progName); + } + PORT_Free(versionString); + } + if (optstatus == PL_OPT_BAD) { Usage(progName); } @@ -1758,7 +1922,7 @@ main(int argc, char **argv) PR_Init(PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1); PK11_SetPasswordFunc(SECU_GetModulePassword); - + memset(&addr, 0, sizeof(addr)); status = PR_StringToNetAddr(host, &addr); if (status == PR_SUCCESS) { addr.inet.port = PR_htons(portno); @@ -1770,6 +1934,7 @@ main(int argc, char **argv) addrInfo = PR_GetAddrInfoByName(host, PR_AF_UNSPEC, PR_AI_ADDRCONFIG | PR_AI_NOCANONNAME); if (!addrInfo) { + fprintf(stderr, "HOSTNAME=%s\n", host); SECU_PrintError(progName, "error looking up host"); error = 1; goto done; @@ -1884,7 +2049,7 @@ main(int argc, char **argv) } while (numConnections--) { - error = run_client(); + error = run(); if (error) { goto done; } @@ -1915,6 +2080,12 @@ done: } if (NSS_IsInitialized()) { SSL_ClearSessionCache(); + if (initializedServerSessionCache) { + if (SSL_ShutdownServerSessionIDCache() != SECSuccess) { + error = 1; + } + } + if (NSS_Shutdown() != SECSuccess) { error = 1; } |