diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/libssl_internals.c')
-rw-r--r-- | security/nss/gtests/ssl_gtest/libssl_internals.c | 139 |
1 files changed, 84 insertions, 55 deletions
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index 887d85278..97b8354ae 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -34,17 +34,18 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, return SECFailure; } + ssl3_InitState(ss); ssl3_RestartHandshakeHashes(ss); // Ensure we don't overrun hs.client_random. rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); - // Zero the client_random. - PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + // Zero the client_random struct. + PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); // Copy over the challenge bytes. size_t offset = SSL3_RANDOM_LENGTH - rnd_len; - PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len); + PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len); // Rehash the SSLv2 client hello message. return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); @@ -72,11 +73,10 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { return SECFailure; } ss->ssl3.mtu = mtu; - ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */ return SECSuccess; } -PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { +PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { PRCList *cur_p; PRInt32 ct = 0; @@ -92,7 +92,7 @@ PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { return ct; } -void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { +void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { PRCList *cur_p; sslSocket *ss = ssl_FindSocket(fd); @@ -100,31 +100,27 @@ void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { return; } - fprintf(stderr, "Cipher specs for %s\n", label); + fprintf(stderr, "Cipher specs\n"); for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; - fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec), - spec->epoch, spec->phase, spec->refCt); + fprintf(stderr, " %s\n", spec->phase); } } -/* Force a timer expiry by backdating when all active timers were started. We - * could set the remaining time to 0 but then backoff would not work properly if - * we decide to test it. */ -SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { - size_t i; +/* Force a timer expiry by backdating when the timer was started. + * We could set the remaining time to 0 but then backoff would not + * work properly if we decide to test it. */ +void SSLInt_ForceTimerExpiry(PRFileDesc *fd) { sslSocket *ss = ssl_FindSocket(fd); if (!ss) { - return SECFailure; + return; } - for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { - if (ss->ssl3.hs.timers[i].cb) { - ss->ssl3.hs.timers[i].started -= shift; - } - } - return SECSuccess; + if (!ss->ssl3.hs.rtTimerCb) return; + + ss->ssl3.hs.rtTimerStarted = + PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1); } #define CHECK_SECRET(secret) \ @@ -140,6 +136,7 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { } CHECK_SECRET(currentSecret); + CHECK_SECRET(resumptionMasterSecret); CHECK_SECRET(dheSecret); CHECK_SECRET(clientEarlyTrafficSecret); CHECK_SECRET(clientHsTrafficSecret); @@ -229,7 +226,28 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { return PR_TRUE; } +PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + ssl_GetSSL3HandshakeLock(ss); + ssl_GetXmitBufLock(ss); + + SECStatus rv = tls13_SendNewSessionTicket(ss); + if (rv == SECSuccess) { + rv = ssl3_FlushHandshake(ss, 0); + } + + ssl_ReleaseXmitBufLock(ss); + ssl_ReleaseSSL3HandshakeLock(ss); + + return rv == SECSuccess; +} + SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { + PRUint64 epoch; sslSocket *ss; ssl3CipherSpec *spec; @@ -237,40 +255,43 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { if (!ss) { return SECFailure; } - if (to >= RECORD_SEQ_MAX) { + if (to >= (1ULL << 48)) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); spec = ss->ssl3.crSpec; - spec->seqNum = to; + epoch = spec->read_seq_num >> 48; + spec->read_seq_num = (epoch << 48) | to; /* For DTLS, we need to fix the record sequence number. For this, we can just * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ - if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } - dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum); + dtls_RecordSetRecvd(&spec->recvdRecords, to); ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { + PRUint64 epoch; sslSocket *ss; ss = ssl_FindSocket(fd); if (!ss) { return SECFailure; } - if (to >= RECORD_SEQ_MAX) { + if (to >= (1ULL << 48)) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); - ss->ssl3.cwSpec->seqNum = to; + epoch = ss->ssl3.cwSpec->write_seq_num >> 48; + ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to; ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } @@ -284,9 +305,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { return SECFailure; } ssl_GetSpecReadLock(ss); - to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra; + to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra; ssl_ReleaseSpecReadLock(ss); - return SSLInt_AdvanceWriteSeqNum(fd, to); + return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX); } SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { @@ -312,20 +333,46 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, return SECSuccess; } -PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) { - return spec->keyMaterial.key; +static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer, + ssl3CipherSpec *spec) { + return isServer ? &spec->server : &spec->client; } -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) { - return spec->cipherDef->calg; +PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) { + return GetKeyingMaterial(isServer, spec)->write_key; } -const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) { - return spec->keyMaterial.iv; +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, + ssl3CipherSpec *spec) { + return spec->cipher_def->calg; } -PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) { - return spec->epoch; +unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) { + return GetKeyingMaterial(isServer, spec)->write_iv; +} + +SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ss->opt.enableShortHeaders = PR_TRUE; + return SECSuccess; +} + +SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + *result = ss->ssl3.hs.shortHeaders; + return SECSuccess; } void SSLInt_SetTicketLifetime(uint32_t lifetime) { @@ -358,21 +405,3 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { return SECSuccess; } - -void SSLInt_RolloverAntiReplay(void) { - tls13_AntiReplayRollover(ssl_TimeUsec()); -} - -SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, - PRUint16 *writeEpoch) { - sslSocket *ss = ssl_FindSocket(fd); - if (!ss || !readEpoch || !writeEpoch) { - return SECFailure; - } - - ssl_GetSpecReadLock(ss); - *readEpoch = ss->ssl3.crSpec->epoch; - *writeEpoch = ss->ssl3.cwSpec->epoch; - ssl_ReleaseSpecReadLock(ss); - return SECSuccess; -} |