diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/libssl_internals.c')
-rw-r--r-- | security/nss/gtests/ssl_gtest/libssl_internals.c | 139 |
1 files changed, 55 insertions, 84 deletions
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index 97b8354ae..887d85278 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -34,18 +34,17 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, return SECFailure; } - ssl3_InitState(ss); ssl3_RestartHandshakeHashes(ss); // Ensure we don't overrun hs.client_random. rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); - // Zero the client_random struct. - PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + // Zero the client_random. + PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); // Copy over the challenge bytes. size_t offset = SSL3_RANDOM_LENGTH - rnd_len; - PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len); + PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len); // Rehash the SSLv2 client hello message. return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); @@ -73,10 +72,11 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { return SECFailure; } ss->ssl3.mtu = mtu; + ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */ return SECSuccess; } -PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { PRCList *cur_p; PRInt32 ct = 0; @@ -92,7 +92,7 @@ PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { return ct; } -void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { PRCList *cur_p; sslSocket *ss = ssl_FindSocket(fd); @@ -100,27 +100,31 @@ void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { return; } - fprintf(stderr, "Cipher specs\n"); + fprintf(stderr, "Cipher specs for %s\n", label); for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; - fprintf(stderr, " %s\n", spec->phase); + fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec), + spec->epoch, spec->phase, spec->refCt); } } -/* Force a timer expiry by backdating when the timer was started. - * We could set the remaining time to 0 but then backoff would not - * work properly if we decide to test it. */ -void SSLInt_ForceTimerExpiry(PRFileDesc *fd) { +/* Force a timer expiry by backdating when all active timers were started. We + * could set the remaining time to 0 but then backoff would not work properly if + * we decide to test it. */ +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { + size_t i; sslSocket *ss = ssl_FindSocket(fd); if (!ss) { - return; + return SECFailure; } - if (!ss->ssl3.hs.rtTimerCb) return; - - ss->ssl3.hs.rtTimerStarted = - PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1); + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + if (ss->ssl3.hs.timers[i].cb) { + ss->ssl3.hs.timers[i].started -= shift; + } + } + return SECSuccess; } #define CHECK_SECRET(secret) \ @@ -136,7 +140,6 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { } CHECK_SECRET(currentSecret); - CHECK_SECRET(resumptionMasterSecret); CHECK_SECRET(dheSecret); CHECK_SECRET(clientEarlyTrafficSecret); CHECK_SECRET(clientHsTrafficSecret); @@ -226,28 +229,7 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { return PR_TRUE; } -PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) { - sslSocket *ss = ssl_FindSocket(fd); - if (!ss) { - return PR_FALSE; - } - - ssl_GetSSL3HandshakeLock(ss); - ssl_GetXmitBufLock(ss); - - SECStatus rv = tls13_SendNewSessionTicket(ss); - if (rv == SECSuccess) { - rv = ssl3_FlushHandshake(ss, 0); - } - - ssl_ReleaseXmitBufLock(ss); - ssl_ReleaseSSL3HandshakeLock(ss); - - return rv == SECSuccess; -} - SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ssl3CipherSpec *spec; @@ -255,43 +237,40 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); spec = ss->ssl3.crSpec; - epoch = spec->read_seq_num >> 48; - spec->read_seq_num = (epoch << 48) | to; + spec->seqNum = to; /* For DTLS, we need to fix the record sequence number. For this, we can just * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ - if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } - dtls_RecordSetRecvd(&spec->recvdRecords, to); + dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum); ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ss = ssl_FindSocket(fd); if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); - epoch = ss->ssl3.cwSpec->write_seq_num >> 48; - ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to; + ss->ssl3.cwSpec->seqNum = to; ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } @@ -305,9 +284,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { return SECFailure; } ssl_GetSpecReadLock(ss); - to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra; + to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra; ssl_ReleaseSpecReadLock(ss); - return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX); + return SSLInt_AdvanceWriteSeqNum(fd, to); } SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { @@ -333,46 +312,20 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, return SECSuccess; } -static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer, - ssl3CipherSpec *spec) { - return isServer ? &spec->server : &spec->client; +PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) { + return spec->keyMaterial.key; } -PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_key; +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) { + return spec->cipherDef->calg; } -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, - ssl3CipherSpec *spec) { - return spec->cipher_def->calg; +const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) { + return spec->keyMaterial.iv; } -unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_iv; -} - -SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - ss->opt.enableShortHeaders = PR_TRUE; - return SECSuccess; -} - -SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - *result = ss->ssl3.hs.shortHeaders; - return SECSuccess; +PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) { + return spec->epoch; } void SSLInt_SetTicketLifetime(uint32_t lifetime) { @@ -405,3 +358,21 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { return SECSuccess; } + +void SSLInt_RolloverAntiReplay(void) { + tls13_AntiReplayRollover(ssl_TimeUsec()); +} + +SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, + PRUint16 *writeEpoch) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss || !readEpoch || !writeEpoch) { + return SECFailure; + } + + ssl_GetSpecReadLock(ss); + *readEpoch = ss->ssl3.crSpec->epoch; + *writeEpoch = ss->ssl3.cwSpec->epoch; + ssl_ReleaseSpecReadLock(ss); + return SECSuccess; +} |