summaryrefslogtreecommitdiffstats
path: root/security/nss/cmd/tstclnt/tstclnt.c
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/cmd/tstclnt/tstclnt.c')
-rw-r--r--security/nss/cmd/tstclnt/tstclnt.c225
1 files changed, 101 insertions, 124 deletions
diff --git a/security/nss/cmd/tstclnt/tstclnt.c b/security/nss/cmd/tstclnt/tstclnt.c
index 1ad99502b..6f5a43146 100644
--- a/security/nss/cmd/tstclnt/tstclnt.c
+++ b/security/nss/cmd/tstclnt/tstclnt.c
@@ -51,6 +51,7 @@
#define MAX_WAIT_FOR_SERVER 600
#define WAIT_INTERVAL 100
+#define ZERO_RTT_MAX (2 << 16)
#define EXIT_CODE_HANDSHAKE_FAILED 254
@@ -99,6 +100,7 @@ int renegotiationsDone = 0;
PRBool initializedServerSessionCache = PR_FALSE;
static char *progName;
+static const char *requestFile;
secuPWData pwdata = { PW_NONE, 0 };
@@ -172,7 +174,7 @@ printSecurityInfo(PRFileDesc *fd)
}
static void
-PrintUsageHeader(const char *progName)
+PrintUsageHeader()
{
fprintf(stderr,
"Usage: %s -h host [-a 1st_hs_name ] [-a 2nd_hs_name ] [-p port]\n"
@@ -186,7 +188,7 @@ PrintUsageHeader(const char *progName)
}
static void
-PrintParameterUsage(void)
+PrintParameterUsage()
{
fprintf(stderr, "%-20s Send different SNI name. 1st_hs_name - at first\n"
"%-20s handshake, 2nd_hs_name - at second handshake.\n"
@@ -259,17 +261,17 @@ PrintParameterUsage(void)
}
static void
-Usage(const char *progName)
+Usage()
{
- PrintUsageHeader(progName);
+ PrintUsageHeader();
PrintParameterUsage();
exit(1);
}
static void
-PrintCipherUsage(const char *progName)
+PrintCipherUsage()
{
- PrintUsageHeader(progName);
+ PrintUsageHeader();
fprintf(stderr, "%-20s Letter(s) chosen from the following list\n",
"-c ciphers");
fprintf(stderr,
@@ -303,7 +305,7 @@ milliPause(PRUint32 milli)
}
void
-disableAllSSLCiphers(void)
+disableAllSSLCiphers()
{
const PRUint16 *cipherSuites = SSL_GetImplementedCiphers();
int i = SSL_GetNumImplementedCiphers();
@@ -711,12 +713,18 @@ void
thread_main(void *arg)
{
PRFileDesc *ps = (PRFileDesc *)arg;
- PRFileDesc *std_in = PR_GetSpecialFD(PR_StandardInput);
+ PRFileDesc *std_in;
int wc, rc;
char buf[256];
+ if (requestFile) {
+ std_in = PR_Open(requestFile, PR_RDONLY, 0);
+ } else {
+ std_in = PR_GetSpecialFD(PR_StandardInput);
+ }
+
#ifdef WIN32
- {
+ if (!requestFile) {
/* Put stdin into O_BINARY mode
** or else incoming \r\n's will become \n's.
*/
@@ -737,6 +745,9 @@ thread_main(void *arg)
wc = PR_Send(ps, buf, rc, 0, maxInterval);
} while (wc == rc);
PR_Close(ps);
+ if (requestFile) {
+ PR_Close(std_in);
+ }
}
#endif
@@ -844,7 +855,7 @@ separateReqHeader(const PRFileDesc *outFd, const char *buf, const int nb,
} else if (((c) >= 'A') && ((c) <= 'F')) { \
i = (c) - 'A' + 10; \
} else { \
- Usage(progName); \
+ Usage(); \
}
static SECStatus
@@ -915,22 +926,22 @@ char *hs1SniHostName = NULL;
char *hs2SniHostName = NULL;
PRUint16 portno = 443;
int override = 0;
-char *requestString = NULL;
-PRInt32 requestStringLen = 0;
-PRBool requestSent = PR_FALSE;
PRBool enableZeroRtt = PR_FALSE;
+PRUint8 *zeroRttData;
+unsigned int zeroRttLen = 0;
PRBool enableAltServerHello = PR_FALSE;
PRBool useDTLS = PR_FALSE;
PRBool actAsServer = PR_FALSE;
PRBool stopAfterHandshake = PR_FALSE;
PRBool requestToExit = PR_FALSE;
char *versionString = NULL;
+PRBool handshakeComplete = PR_FALSE;
static int
-writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
+writeBytesToServer(PRFileDesc *s, const PRUint8 *buf, int nb)
{
SECStatus rv;
- const char *bufp = buf;
+ const PRUint8 *bufp = buf;
PRPollDesc pollDesc;
pollDesc.in_flags = PR_POLL_WRITE | PR_POLL_EXCEPT;
@@ -944,12 +955,20 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
if (cc < 0) {
PRErrorCode err = PR_GetError();
if (err != PR_WOULD_BLOCK_ERROR) {
- SECU_PrintError(progName,
- "write to SSL socket failed");
+ SECU_PrintError(progName, "write to SSL socket failed");
return 254;
}
cc = 0;
}
+ FPRINTF(stderr, "%s: %d bytes written\n", progName, cc);
+ if (enableZeroRtt && !handshakeComplete) {
+ if (zeroRttLen + cc > ZERO_RTT_MAX) {
+ SECU_PrintError(progName, "too much early data to save");
+ return -1;
+ }
+ PORT_Memcpy(zeroRttData + zeroRttLen, bufp, cc);
+ zeroRttLen += cc;
+ }
bufp += cc;
nb -= cc;
if (nb <= 0)
@@ -969,8 +988,7 @@ writeBytesToServer(PRFileDesc *s, const char *buf, int nb)
progName);
cc = PR_Poll(&pollDesc, 1, PR_INTERVAL_NO_TIMEOUT);
if (cc < 0) {
- SECU_PrintError(progName,
- "PR_Poll failed");
+ SECU_PrintError(progName, "PR_Poll failed");
return -1;
}
FPRINTF(stderr,
@@ -993,7 +1011,7 @@ handshakeCallback(PRFileDesc *fd, void *client_data)
SSL_ReHandshake(fd, (renegotiationsToDo < 2));
++renegotiationsDone;
}
- if (requestString && requestSent) {
+ if (zeroRttLen) {
/* This data was sent in 0-RTT. */
SSLChannelInfo info;
SECStatus rv;
@@ -1003,29 +1021,30 @@ handshakeCallback(PRFileDesc *fd, void *client_data)
return;
if (!info.earlyDataAccepted) {
- FPRINTF(stderr, "Early data rejected. Re-sending\n");
- writeBytesToServer(fd, requestString, requestStringLen);
+ FPRINTF(stderr, "Early data rejected. Re-sending %d bytes\n",
+ zeroRttLen);
+ writeBytesToServer(fd, zeroRttData, zeroRttLen);
+ zeroRttLen = 0;
}
}
if (stopAfterHandshake) {
requestToExit = PR_TRUE;
}
+ handshakeComplete = PR_TRUE;
}
-#define REQUEST_WAITING (requestString && !requestSent)
-
static SECStatus
-installServerCertificate(PRFileDesc *s, char *nickname)
+installServerCertificate(PRFileDesc *s, char *nick)
{
CERTCertificate *cert;
SECKEYPrivateKey *privKey = NULL;
- if (!nickname) {
+ if (!nick) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
- cert = PK11_FindCertFromNickname(nickname, &pwdata);
+ cert = PK11_FindCertFromNickname(nick, &pwdata);
if (cert == NULL) {
return SECFailure;
}
@@ -1129,20 +1148,19 @@ connectToServer(PRFileDesc *s, PRPollDesc *pollset)
}
static int
-run(void)
+run()
{
int headerSeparatorPtrnId = 0;
int error = 0;
SECStatus rv;
PRStatus status;
PRInt32 filesReady;
- int npds;
PRFileDesc *s = NULL;
PRFileDesc *std_out;
- PRPollDesc pollset[2];
+ PRPollDesc pollset[2] = { { 0 }, { 0 } };
PRBool wrStarted = PR_FALSE;
- requestSent = PR_FALSE;
+ handshakeComplete = PR_FALSE;
/* Create socket */
if (useDTLS) {
@@ -1225,19 +1243,18 @@ run(void)
cipherString++;
} else {
if (!isalpha(ndx))
- Usage(progName);
+ Usage();
ndx = tolower(ndx) - 'a';
if (ndx < PR_ARRAY_SIZE(ssl3CipherSuites)) {
cipher = ssl3CipherSuites[ndx];
}
}
if (cipher > 0) {
- SECStatus status;
- status = SSL_CipherPrefSet(s, cipher, SSL_ALLOWED);
- if (status != SECSuccess)
+ rv = SSL_CipherPrefSet(s, cipher, SSL_ALLOWED);
+ if (rv != SECSuccess)
SECU_PrintError(progName, "SSL_CipherPrefSet()");
} else {
- Usage(progName);
+ Usage();
}
}
PORT_Free(cstringSaved);
@@ -1394,7 +1411,6 @@ run(void)
/* Try to connect to the server */
rv = connectToServer(s, pollset);
if (rv != SECSuccess) {
- ;
error = 1;
goto done;
}
@@ -1406,13 +1422,18 @@ run(void)
pollset[SSOCK_FD].in_flags |= (clientSpeaksFirst ? 0 : PR_POLL_READ);
else
pollset[SSOCK_FD].in_flags |= PR_POLL_READ;
- pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput);
- if (!REQUEST_WAITING) {
- pollset[STDIN_FD].in_flags = PR_POLL_READ;
- npds = 2;
+ if (requestFile) {
+ pollset[STDIN_FD].fd = PR_Open(requestFile, PR_RDONLY, 0);
+ if (!pollset[STDIN_FD].fd) {
+ fprintf(stderr, "%s: unable to open input file: %s\n",
+ progName, requestFile);
+ error = 1;
+ goto done;
+ }
} else {
- npds = 1;
+ pollset[STDIN_FD].fd = PR_GetSpecialFD(PR_StandardInput);
}
+ pollset[STDIN_FD].in_flags = PR_POLL_READ;
std_out = PR_GetSpecialFD(PR_StandardOutput);
#if defined(WIN32) || defined(OS2)
@@ -1458,10 +1479,9 @@ run(void)
requestToExit = PR_FALSE;
FPRINTF(stderr, "%s: ready...\n", progName);
while (!requestToExit &&
- ((pollset[SSOCK_FD].in_flags | pollset[STDIN_FD].in_flags) ||
- REQUEST_WAITING)) {
- char buf[4000]; /* buffer for stdin */
- int nb; /* num bytes read from stdin. */
+ (pollset[SSOCK_FD].in_flags || pollset[STDIN_FD].in_flags)) {
+ PRUint8 buf[4000]; /* buffer for stdin */
+ int nb; /* num bytes read from stdin. */
rv = restartHandshakeAfterServerCertIfNeeded(s, &serverCertAuth,
override);
@@ -1475,7 +1495,8 @@ run(void)
pollset[STDIN_FD].out_flags = 0;
FPRINTF(stderr, "%s: about to call PR_Poll !\n", progName);
- filesReady = PR_Poll(pollset, npds, PR_INTERVAL_NO_TIMEOUT);
+ filesReady = PR_Poll(pollset, PR_ARRAY_SIZE(pollset),
+ PR_INTERVAL_NO_TIMEOUT);
if (filesReady < 0) {
SECU_PrintError(progName, "select failed");
error = 1;
@@ -1497,14 +1518,6 @@ run(void)
"%s: PR_Poll returned 0x%02x for socket out_flags.\n",
progName, pollset[SSOCK_FD].out_flags);
}
- if (REQUEST_WAITING) {
- error = writeBytesToServer(s, requestString, requestStringLen);
- if (error) {
- goto done;
- }
- requestSent = PR_TRUE;
- pollset[SSOCK_FD].in_flags = PR_POLL_READ;
- }
if (pollset[STDIN_FD].out_flags & PR_POLL_READ) {
/* Read from stdin and write to socket */
nb = PR_Read(pollset[STDIN_FD].fd, buf, sizeof(buf));
@@ -1518,6 +1531,8 @@ run(void)
} else if (nb == 0) {
/* EOF on stdin, stop polling stdin for read. */
pollset[STDIN_FD].in_flags = 0;
+ if (actAsServer)
+ requestToExit = PR_TRUE;
} else {
error = writeBytesToServer(s, buf, nb);
if (error) {
@@ -1532,12 +1547,12 @@ run(void)
"%s: PR_Poll returned 0x%02x for socket out_flags.\n",
progName, pollset[SSOCK_FD].out_flags);
}
- if ((pollset[SSOCK_FD].out_flags & PR_POLL_READ) ||
- (pollset[SSOCK_FD].out_flags & PR_POLL_ERR)
#ifdef PR_POLL_HUP
- || (pollset[SSOCK_FD].out_flags & PR_POLL_HUP)
+#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR | PR_POLL_HUP)
+#else
+#define POLL_RECV_FLAGS (PR_POLL_READ | PR_POLL_ERR)
#endif
- ) {
+ if (pollset[SSOCK_FD].out_flags & POLL_RECV_FLAGS) {
/* Read from socket and write to stdout */
nb = PR_Recv(pollset[SSOCK_FD].fd, buf, sizeof buf, 0, maxInterval);
FPRINTF(stderr, "%s: Read from server %d bytes\n", progName, nb);
@@ -1554,7 +1569,7 @@ run(void)
if (skipProtoHeader != PR_TRUE || wrStarted == PR_TRUE) {
PR_Write(std_out, buf, nb);
} else {
- separateReqHeader(std_out, buf, nb, &wrStarted,
+ separateReqHeader(std_out, (char *)buf, nb, &wrStarted,
&headerSeparatorPtrnId);
}
if (verbose)
@@ -1568,42 +1583,10 @@ done:
if (s) {
PR_Close(s);
}
-
- return error;
-}
-
-PRInt32
-ReadFile(const char *filename, char **data)
-{
- char *ret = NULL;
- char buf[8192];
- unsigned int len = 0;
- PRStatus rv;
-
- PRFileDesc *fd = PR_Open(filename, PR_RDONLY, 0);
- if (!fd)
- return -1;
-
- for (;;) {
- rv = PR_Read(fd, buf, sizeof(buf));
- if (rv < 0) {
- PR_Free(ret);
- return rv;
- }
-
- if (!rv)
- break;
-
- ret = PR_Realloc(ret, len + rv);
- if (!ret) {
- return -1;
- }
- PORT_Memcpy(ret + len, buf, rv);
- len += rv;
+ if (requestFile && pollset[STDIN_FD].fd) {
+ PR_Close(pollset[STDIN_FD].fd);
}
-
- *data = ret;
- return len;
+ return error;
}
int
@@ -1653,26 +1636,22 @@ main(int argc, char **argv)
switch (optstate->option) {
case '?':
default:
- Usage(progName);
+ Usage();
break;
case '4':
allowIPv6 = PR_FALSE;
if (!allowIPv4)
- Usage(progName);
+ Usage();
break;
case '6':
allowIPv4 = PR_FALSE;
if (!allowIPv6)
- Usage(progName);
+ Usage();
break;
case 'A':
- requestStringLen = ReadFile(optstate->value, &requestString);
- if (requestStringLen < 0) {
- fprintf(stderr, "Couldn't read file %s\n", optstate->value);
- exit(1);
- }
+ requestFile = PORT_Strdup(optstate->value);
break;
case 'C':
@@ -1735,7 +1714,7 @@ main(int argc, char **argv)
actAsServer = 1;
} else {
if (strcmp(optstate->value, "client")) {
- Usage(progName);
+ Usage();
}
}
break;
@@ -1768,16 +1747,21 @@ main(int argc, char **argv)
if (!strcmp(optstate->value, "alt-server-hello")) {
enableAltServerHello = PR_TRUE;
} else {
- Usage(progName);
+ Usage();
}
break;
case 'Y':
- PrintCipherUsage(progName);
+ PrintCipherUsage();
exit(0);
break;
case 'Z':
enableZeroRtt = PR_TRUE;
+ zeroRttData = PORT_ZAlloc(ZERO_RTT_MAX);
+ if (!zeroRttData) {
+ fprintf(stderr, "Unable to allocate buffer for 0-RTT\n");
+ exit(1);
+ }
break;
case 'a':
@@ -1786,7 +1770,7 @@ main(int argc, char **argv)
} else if (!hs2SniHostName) {
hs2SniHostName = PORT_Strdup(optstate->value);
} else {
- Usage(progName);
+ Usage();
}
break;
@@ -1875,7 +1859,7 @@ main(int argc, char **argv)
if (rv != SECSuccess) {
PL_DestroyOptState(optstate);
fprintf(stderr, "Bad group specified.\n");
- Usage(progName);
+ Usage();
}
break;
}
@@ -1889,18 +1873,18 @@ main(int argc, char **argv)
enabledVersions, &enabledVersions) !=
SECSuccess) {
fprintf(stderr, "Bad version specified.\n");
- Usage(progName);
+ Usage();
}
PORT_Free(versionString);
}
if (optstatus == PL_OPT_BAD) {
- Usage(progName);
+ Usage();
}
if (!host || !portno) {
fprintf(stderr, "%s: parameters -h and -p are mandatory\n", progName);
- Usage(progName);
+ Usage();
}
if (serverCertAuth.testFreshStatusFromSideChannel &&
@@ -2060,20 +2044,13 @@ done:
PR_Close(s);
}
- if (hs1SniHostName) {
- PORT_Free(hs1SniHostName);
- }
- if (hs2SniHostName) {
- PORT_Free(hs2SniHostName);
- }
- if (nickname) {
- PORT_Free(nickname);
- }
- if (pwdata.data) {
- PORT_Free(pwdata.data);
- }
+ PORT_Free((void *)requestFile);
+ PORT_Free(hs1SniHostName);
+ PORT_Free(hs2SniHostName);
+ PORT_Free(nickname);
+ PORT_Free(pwdata.data);
PORT_Free(host);
- PORT_Free(requestString);
+ PORT_Free(zeroRttData);
if (enabledGroups) {
PORT_Free(enabledGroups);