From f4a12fc67689a830e9da1c87fd11afe5bc09deb3 Mon Sep 17 00:00:00 2001 From: wolfbeast Date: Thu, 2 Jan 2020 21:06:40 +0100 Subject: Issue #1338 - Part 2: Update NSS to 3.48-RTM --- security/nss/gtests/ssl_gtest/Makefile | 6 + security/nss/gtests/ssl_gtest/libssl_internals.c | 101 ++-- security/nss/gtests/ssl_gtest/libssl_internals.h | 18 +- security/nss/gtests/ssl_gtest/manifest.mn | 7 +- security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc | 410 +++++++++++++- security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc | 622 ++++++++++++++++++++- .../nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc | 14 +- .../gtests/ssl_gtest/ssl_cipherorder_unittest.cc | 241 ++++++++ .../gtests/ssl_gtest/ssl_ciphersuite_unittest.cc | 29 +- .../nss/gtests/ssl_gtest/ssl_damage_unittest.cc | 5 +- .../nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc | 53 ++ security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc | 96 ++++ security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc | 172 +++--- security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc | 74 +++ .../nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 125 ++++- security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc | 63 ++- security/nss/gtests/ssl_gtest/ssl_gtest.gyp | 18 +- .../nss/gtests/ssl_gtest/ssl_keylog_unittest.cc | 112 ++-- .../nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc | 31 + .../nss/gtests/ssl_gtest/ssl_primitive_unittest.cc | 218 ++++++++ .../nss/gtests/ssl_gtest/ssl_record_unittest.cc | 36 ++ .../nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc | 577 +++++++++++++++++++ .../gtests/ssl_gtest/ssl_recordsize_unittest.cc | 57 +- .../gtests/ssl_gtest/ssl_renegotiation_unittest.cc | 23 + .../gtests/ssl_gtest/ssl_resumption_unittest.cc | 190 ++++++- .../nss/gtests/ssl_gtest/ssl_version_unittest.cc | 13 +- .../gtests/ssl_gtest/ssl_versionpolicy_unittest.cc | 10 - security/nss/gtests/ssl_gtest/test_io.cc | 14 +- security/nss/gtests/ssl_gtest/test_io.h | 6 +- security/nss/gtests/ssl_gtest/tls_agent.cc | 107 +++- security/nss/gtests/ssl_gtest/tls_agent.h | 41 +- security/nss/gtests/ssl_gtest/tls_connect.cc | 91 ++- security/nss/gtests/ssl_gtest/tls_connect.h | 20 + security/nss/gtests/ssl_gtest/tls_esni_unittest.cc | 174 +++--- security/nss/gtests/ssl_gtest/tls_filter.cc | 227 +++++--- security/nss/gtests/ssl_gtest/tls_filter.h | 112 +++- security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc | 206 ++++++- security/nss/gtests/ssl_gtest/tls_protect.cc | 179 +++--- security/nss/gtests/ssl_gtest/tls_protect.h | 77 +-- .../nss/gtests/ssl_gtest/tls_subcerts_unittest.cc | 568 +++++++++++++++++++ 40 files changed, 4448 insertions(+), 695 deletions(-) create mode 100644 security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc create mode 100644 security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc create mode 100644 security/nss/gtests/ssl_gtest/ssl_primitive_unittest.cc create mode 100644 security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc create mode 100644 security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc (limited to 'security/nss/gtests/ssl_gtest') diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile index 95c111aeb..46f030357 100644 --- a/security/nss/gtests/ssl_gtest/Makefile +++ b/security/nss/gtests/ssl_gtest/Makefile @@ -36,6 +36,12 @@ CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS)) CFLAGS += -DNSS_DISABLE_TLS_1_3 endif +ifdef NSS_ALLOW_SSLKEYLOGFILE +SSLKEYLOGFILE_FILES = ssl_keylog_unittest.cc +else +SSLKEYLOGFILE_FILES = $(NULL) +endif + ####################################################################### # (5) Execute "global" rules. (OPTIONAL) # ####################################################################### diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index e43113de4..44eee9aa8 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -12,6 +12,48 @@ #include "seccomon.h" #include "selfencrypt.h" +SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits, + PRBool changeScheme) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + // Just toggle so we'll always have a valid value. + if (changeScheme) { + ss->sec.signatureScheme = (ss->sec.signatureScheme == ssl_sig_ed25519) + ? ssl_sig_ecdsa_secp256r1_sha256 + : ssl_sig_ed25519; + } + if (changeAuthKeyBits) { + ss->sec.authKeyBits = ss->sec.authKeyBits ? ss->sec.authKeyBits * 2 : 384; + } + + return SECSuccess; +} + +SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random, + SSL3Random server_random) { + if (!fd) { + return SECFailure; + } + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + if (client_random) { + memcpy(client_random, ss->ssl3.hs.client_random, sizeof(SSL3Random)); + } + if (server_random) { + memcpy(server_random, ss->ssl3.hs.server_random, sizeof(SSL3Random)); + } + return SECSuccess; +} + SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) { sslSocket *ss = ssl_FindSocket(fd); if (!ss) { @@ -109,9 +151,10 @@ void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { } } -/* Force a timer expiry by backdating when all active timers were started. We - * could set the remaining time to 0 but then backoff would not work properly if - * we decide to test it. */ +/* DTLS timers are separate from the time that the rest of the stack uses. + * Force a timer expiry by backdating when all active timers were started. + * We could set the remaining time to 0 but then backoff would not work properly + * if we decide to test it. */ SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { size_t i; sslSocket *ss = ssl_FindSocket(fd); @@ -297,42 +340,6 @@ SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { return groupDef->keaType; } -SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, - sslCipherSpecChangedFunc func, - void *arg) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - ss->ssl3.changedCipherSpecFunc = func; - ss->ssl3.changedCipherSpecArg = arg; - - return SECSuccess; -} - -PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) { - return spec->keyMaterial.key; -} - -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) { - return spec->cipherDef->calg; -} - -const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) { - return spec->keyMaterial.iv; -} - -PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) { - return spec->epoch; -} - -void SSLInt_SetTicketLifetime(uint32_t lifetime) { - ssl_ticket_lifetime = lifetime; -} - SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { sslSocket *ss; @@ -356,20 +363,14 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { return SECSuccess; } -void SSLInt_RolloverAntiReplay(void) { - tls13_AntiReplayRollover(ssl_TimeUsec()); -} - -SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, - PRUint16 *writeEpoch) { +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending) { sslSocket *ss = ssl_FindSocket(fd); - if (!ss || !readEpoch || !writeEpoch) { + if (!ss) { return SECFailure; } - ssl_GetSpecReadLock(ss); - *readEpoch = ss->ssl3.crSpec->epoch; - *writeEpoch = ss->ssl3.cwSpec->epoch; - ssl_ReleaseSpecReadLock(ss); + ssl_GetSSL3HandshakeLock(ss); + *pending = ss->ssl3.hs.msg_body.len > 0; + ssl_ReleaseSSL3HandshakeLock(ss); return SECSuccess; } diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h index 3efb362c2..a908c9ab1 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.h +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -20,7 +20,8 @@ SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd); SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, size_t rnd_len, uint8_t *msg, size_t msg_len); - +SECStatus SSLInt_GetHandshakeRandoms(PRFileDesc *fd, SSL3Random client_random, + SSL3Random server_random); PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext); void SSLInt_ClearSelfEncryptKey(); void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key); @@ -39,18 +40,9 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); -SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, - PRUint16 *writeEpoch); - -SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, - sslCipherSpecChangedFunc func, - void *arg); -PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec); -PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec); -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec); -const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec); -void SSLInt_SetTicketLifetime(uint32_t lifetime); +SECStatus SSLInt_HasPendingHandshakeData(PRFileDesc *fd, PRBool *pending); SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); -void SSLInt_RolloverAntiReplay(void); +SECStatus SSLInt_TweakChannelInfoForDC(PRFileDesc *fd, PRBool changeAuthKeyBits, + PRBool changeScheme); #endif // ndef libssl_internals_h_ diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn index 7f4ee7953..ed1128f7c 100644 --- a/security/nss/gtests/ssl_gtest/manifest.mn +++ b/security/nss/gtests/ssl_gtest/manifest.mn @@ -17,9 +17,11 @@ CPPSRCS = \ ssl_agent_unittest.cc \ ssl_auth_unittest.cc \ ssl_cert_ext_unittest.cc \ + ssl_cipherorder_unittest.cc \ ssl_ciphersuite_unittest.cc \ ssl_custext_unittest.cc \ ssl_damage_unittest.cc \ + ssl_debug_env_unittest.cc \ ssl_dhe_unittest.cc \ ssl_drop_unittest.cc \ ssl_ecdh_unittest.cc \ @@ -31,11 +33,12 @@ CPPSRCS = \ ssl_gather_unittest.cc \ ssl_gtest.cc \ ssl_hrr_unittest.cc \ - ssl_keylog_unittest.cc \ ssl_keyupdate_unittest.cc \ ssl_loopback_unittest.cc \ ssl_misc_unittest.cc \ + ssl_primitive_unittest.cc \ ssl_record_unittest.cc \ + ssl_recordsep_unittest.cc \ ssl_recordsize_unittest.cc \ ssl_resumption_unittest.cc \ ssl_renegotiation_unittest.cc \ @@ -52,7 +55,9 @@ CPPSRCS = \ tls_hkdf_unittest.cc \ tls_filter.cc \ tls_protect.cc \ + tls_subcerts_unittest.cc \ tls_esni_unittest.cc \ + $(SSLKEYLOGFILE_FILES) \ $(NULL) INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \ diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc index 07eadfbd1..928515067 100644 --- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -45,11 +45,40 @@ TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) { SendReceive(); } +TEST_P(TlsConnectTls13, ZeroRttApplicationReject) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + auto reject_0rtt = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_reject_0rtt; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + reject_0rtt, &cb_run)); + ZeroRttSendReceive(true, false); + Handshake(); + EXPECT_TRUE(cb_run); + CheckConnected(); + SendReceive(); +} + TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) { - // The test fixtures call SSL_SetupAntiReplay() in SetUp(). This results in - // 0-RTT being rejected until at least one window passes. SetupFor0Rtt() - // forces a rollover of the anti-replay filters, which clears this state. - // Here, we do the setup manually here without that forced rollover. + // The test fixtures enable anti-replay in SetUp(). This results in 0-RTT + // being rejected until at least one window passes. SetupFor0Rtt() forces a + // rollover of the anti-replay filters, which clears that state and allows + // 0-RTT to work. Make the first connection manually to avoid that rollover + // and cause 0-RTT to be rejected. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); @@ -106,7 +135,7 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 { SendReceive(); if (rollover) { - SSLInt_RolloverAntiReplay(); + RolloverAntiReplay(); } // Now replay that packet against the server. @@ -184,20 +213,21 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnly) { CheckKeys(); } -// A small sleep after sending the ClientHello means that the ticket age that -// arrives at the server is too low. With a small tolerance for variation in -// ticket age (which is determined by the |window| parameter that is passed to -// SSL_SetupAntiReplay()), the server then rejects early data. +// Advancing time after sending the ClientHello means that the ticket age that +// arrives at the server is too low. The server then rejects early data if this +// delay exceeds half the anti-replay window. TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) { + static const PRTime kWindow = 10 * PR_USEC_PER_SEC; + ResetAntiReplay(kWindow); SetupForZeroRtt(); + + Reset(); + StartConnect(); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); - EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); - SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. - SSLInt_RolloverAntiReplay(); ExpectResumption(RESUME_TICKET); - ZeroRttSendReceive(true, false, []() { - PR_Sleep(PR_MillisecondsToInterval(10)); + ZeroRttSendReceive(true, false, [this]() { + AdvanceTime(1 + kWindow / 2); return true; }); Handshake(); @@ -212,13 +242,15 @@ TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) { // small tolerance for variation in ticket age and the ticket will appear to // arrive prematurely, causing the server to reject early data. TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) { + static const PRTime kWindow = 10 * PR_USEC_PER_SEC; + ResetAntiReplay(kWindow); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello - PR_Sleep(PR_MillisecondsToInterval(10)); + AdvanceTime(1 + kWindow / 2); Handshake(); // Remainder of handshake CheckConnected(); SendReceive(); @@ -227,9 +259,6 @@ TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) { Reset(); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); - EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); - SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. - SSLInt_RolloverAntiReplay(); ExpectResumption(RESUME_TICKET); ExpectEarlyDataAccepted(false); StartConnect(); @@ -649,6 +678,351 @@ TEST_P(TlsConnectTls13, ZeroRttOrdering) { EXPECT_EQ(2U, step); } +// Early data remains available after the handshake completes for TLS. +TEST_F(TlsConnectStreamTls13, ZeroRttLateReadTls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. + const uint8_t data[] = {1, 2, 3, 4, 5, 6, 7, 8}; + PRInt32 rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast(sizeof(data)), rv); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read some of the data. + std::vector small_buffer(1 + sizeof(data) / 2); + rv = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + EXPECT_EQ(static_cast(small_buffer.size()), rv); + EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size())); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // After the handshake, it should be possible to read the remainder. + uint8_t big_buf[100]; + rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf)); + EXPECT_EQ(static_cast(sizeof(data) - small_buffer.size()), rv); + EXPECT_EQ(0, memcmp(&data[small_buffer.size()], big_buf, + sizeof(data) - small_buffer.size())); + + // And that's all there is to read. + rv = PR_Read(server_->ssl_fd(), big_buf, sizeof(big_buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +// Early data that arrives before the handshake can be read after the handshake +// is complete. +TEST_F(TlsConnectDatagram13, ZeroRttLateReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. + const uint8_t data[] = {1, 2, 3}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast(sizeof(data)), written); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + // Reading at the server should return the early data, which was buffered. + uint8_t buf[sizeof(data) + 1] = {0}; + PRInt32 read = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(static_cast(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buf, sizeof(data))); +} + +class PacketHolder : public PacketFilter { + public: + PacketHolder() = default; + + virtual Action Filter(const DataBuffer& input, DataBuffer* output) { + packet_ = input; + Disable(); + return DROP; + } + + const DataBuffer& packet() const { return packet_; } + + private: + DataBuffer packet_; +}; + +// Early data that arrives late is discarded for DTLS. +TEST_F(TlsConnectDatagram13, ZeroRttLateArrivalDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. Twice, so that we can read bits of it. + const uint8_t data[] = {1, 2, 3}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast(sizeof(data)), written); + + // Block and capture the next packet. + auto holder = std::make_shared(); + client_->SetFilter(holder); + written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast(sizeof(data)), written); + EXPECT_FALSE(holder->enabled()) << "the filter should disable itself"; + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read some of the data. + std::vector small_buffer(sizeof(data)); + PRInt32 read = + PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + + EXPECT_EQ(static_cast(small_buffer.size()), read); + EXPECT_EQ(0, memcmp(data, small_buffer.data(), small_buffer.size())); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); + + server_->SendDirect(holder->packet()); + + // Reading now should return nothing, even though a valid packet was + // delivered. + read = PR_Read(server_->ssl_fd(), small_buffer.data(), small_buffer.size()); + EXPECT_GT(0, read); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +// Early data reads in TLS should be coalesced. +TEST_F(TlsConnectStreamTls13, ZeroRttCoalesceReadTls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, 1); + EXPECT_EQ(1, written); + + written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1); + EXPECT_EQ(static_cast(sizeof(data) - 1), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Read all of the data. + std::vector buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data))); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// Early data reads in DTLS should not be coalesced. +TEST_F(TlsConnectDatagram13, ZeroRttNoCoalesceReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, 1); + EXPECT_EQ(1, written); + + written = PR_Write(client_->ssl_fd(), data + 1, sizeof(data) - 1); + EXPECT_EQ(static_cast(sizeof(data) - 1), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Try to read all of the data. + std::vector buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(1, read); + EXPECT_EQ(0, memcmp(data, buffer.data(), 1)); + + // Read the remainder. + read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast(sizeof(data) - 1), read); + EXPECT_EQ(0, memcmp(data + 1, buffer.data(), sizeof(data) - 1)); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// Early data reads in DTLS should fail if the buffer is too small. +TEST_F(TlsConnectDatagram13, ZeroRttShortReadDtls) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + client_->Handshake(); // ClientHello + + // Write some early data. In two writes. + const uint8_t data[] = {1, 2, 3, 4, 5, 6}; + PRInt32 written = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_EQ(static_cast(sizeof(data)), written); + + // Consume the ClientHello and generate ServerHello..Finished. + server_->Handshake(); + + // Try to read all of the data into a small buffer. + std::vector buffer(sizeof(data)); + PRInt32 read = PR_Read(server_->ssl_fd(), buffer.data(), 1); + EXPECT_GT(0, read); + EXPECT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, PORT_GetError()); + + // Read again with more space. + read = PR_Read(server_->ssl_fd(), buffer.data(), buffer.size()); + EXPECT_EQ(static_cast(sizeof(data)), read); + EXPECT_EQ(0, memcmp(data, buffer.data(), sizeof(data))); + + Handshake(); // Complete the handshake. + ExpectEarlyDataAccepted(true); + CheckConnected(); +} + +// There are few ways in which TLS uses the clock and most of those operate on +// timescales that would be ridiculous to wait for in a test. This is the one +// test we have that uses the real clock. It tests that time passes by checking +// that a small sleep results in rejection of early data. 0-RTT has a +// configurable timer, which makes it ideal for this. +TEST_F(TlsConnectStreamTls13, TimePassesByDefault) { + // Calling EnsureTlsSetup() replaces the time function on client and server, + // and sets up anti-replay, which we don't want, so initialize each directly. + client_->EnsureTlsSetup(); + server_->EnsureTlsSetup(); + // StartConnect() calls EnsureTlsSetup(), so avoid that too. + client_->StartConnect(); + server_->StartConnect(); + + // Set a tiny anti-replay window. This has to be at least 2 milliseconds to + // have any chance of being relevant as that is the smallest window that we + // can detect. Anything smaller rounds to zero. + static const unsigned int kTinyWindowMs = 5; + ResetAntiReplay(static_cast(kTinyWindowMs * PR_USEC_PER_MSEC)); + server_->SetAntiReplayContext(anti_replay_); + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); + Handshake(); + CheckConnected(); + SendReceive(); // Absorb a session ticket. + CheckKeys(); + + // Clear the first window. + PR_Sleep(PR_MillisecondsToInterval(kTinyWindowMs)); + + Reset(); + client_->EnsureTlsSetup(); + server_->EnsureTlsSetup(); + client_->StartConnect(); + server_->StartConnect(); + + // Early data is rejected by the server only if time passes for it as well. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, []() { + // Sleep long enough that we minimize the risk of our RTT estimation being + // duped by stutters in test execution. This is very long to allow for + // flaky and low-end hardware, especially what our CI runs on. + PR_Sleep(PR_MillisecondsToInterval(1000)); + return true; + }); + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); +} + +// Test that SSL_CreateAntiReplayContext doesn't pass bad inputs. +TEST_F(TlsConnectStreamTls13, BadAntiReplayArgs) { + SSLAntiReplayContext* p; + // Zero or negative window. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, -1, 1, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 0, 1, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Zero k. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 0, 1, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Zero bits. + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 0, &p)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_CreateAntiReplayContext(0, 1, 1, 1, nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Prove that these parameters do work, even if they are useless.. + EXPECT_EQ(SECSuccess, SSL_CreateAntiReplayContext(0, 1, 1, 1, &p)); + ASSERT_NE(nullptr, p); + ScopedSSLAntiReplayContext ctx(p); + + // The socket isn't a client or server until later, so configuring a client + // should work OK. + client_->EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), ctx.get())); + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(client_->ssl_fd(), nullptr)); +} + +// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher +TEST_P(TlsConnectTls13, ZeroRttDifferentCompatibleCipher) { + EnsureTlsSetup(); + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + // Change the ciphersuite. Resumption is OK because the hash is the same, but + // early data will be rejected. + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ZeroRttSendReceive(true, false); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// See also TlsConnectGenericResumption.ResumeServerIncompatibleCipher +TEST_P(TlsConnectTls13, ZeroRttDifferentIncompatibleCipher) { + EnsureTlsSetup(); + server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + // Resumption is rejected because the hash is different. + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + ExpectResumption(RESUME_NONE); + + StartConnect(); + ZeroRttSendReceive(true, false); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_TEST_CASE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest, TlsConnectTestBase::kTlsVariantsAll); diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc index 3a52ac20c..c1a810d04 100644 --- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -176,14 +176,434 @@ TEST_P(TlsConnectGeneric, ClientAuth) { CheckKeys(); } -// In TLS 1.3, the client sends its cert rejection on the -// second flight, and since it has already received the -// server's Finished, it transitions to complete and -// then gets an alert from the server. The test harness -// doesn't handle this right yet. -TEST_P(TlsConnectStream, DISABLED_ClientAuthRequiredRejected) { +class TlsCertificateRequestContextRecorder : public TlsHandshakeFilter { + public: + TlsCertificateRequestContextRecorder(const std::shared_ptr& a, + uint8_t handshake_type) + : TlsHandshakeFilter(a, {handshake_type}), buffer_(), filtered_(false) { + EnableDecryption(); + } + + bool filtered() const { return filtered_; } + const DataBuffer& buffer() const { return buffer_; } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + assert(1 < input.len()); + size_t len = input.data()[0]; + assert(len + 1 < input.len()); + buffer_.Assign(input.data() + 1, len); + filtered_ = true; + return KEEP; + } + + private: + DataBuffer buffer_; + bool filtered_; +}; + +// All stream only tests; DTLS isn't supported yet. + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuth) { + EnsureTlsSetup(); + auto capture_cert_req = MakeTlsFilter( + server_, kTlsHandshakeCertificateRequest); + auto capture_certificate = + MakeTlsFilter( + client_, kTlsHandshakeCertificate); + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + EXPECT_FALSE(capture_cert_req->filtered()); + EXPECT_FALSE(capture_certificate->filtered()); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Need to do a round-trip so that the post-handshake message is + // handled on both client and server. + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(1U, called); + EXPECT_TRUE(capture_cert_req->filtered()); + EXPECT_TRUE(capture_certificate->filtered()); + // Check if a non-empty request context is generated and it is + // properly sent back. + EXPECT_LT(0U, capture_cert_req->buffer().len()); + EXPECT_EQ(capture_cert_req->buffer().len(), + capture_certificate->buffer().len()); + EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(), + capture_certificate->buffer().data(), + capture_cert_req->buffer().len())); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + // use a different certificate than TlsAgent::kClient + if (!TlsAgent::LoadCertificate(TlsAgent::kRsa2048, &cert, &priv)) { + return SECFailure; + } + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMultiple) { + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + EXPECT_EQ(nullptr, SSL_PeerCertificate(server_->ssl_fd())); + // Send 1st CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(1U, called); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); + // Send 2nd CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( + client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(2U, called); + ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert3.get()); + ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert4.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert)); + EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert)); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthConcurrent) { + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send 1st CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send 2nd CertificateRequest. + EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd())); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBeforeKeyUpdate) { + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send KeyUpdate. + EXPECT_EQ(SECFailure, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDuringClientKeyUpdate) { + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + CheckEpochs(3, 3); + // Send CertificateRequest from server. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + // Send KeyUpdate from client. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + server_->SendData(50); // server sends CertificateRequest + client_->SendData(50); // client sends KeyUpdate + server_->ReadBytes(50); // server receives KeyUpdate and defers response + CheckEpochs(4, 3); + client_->ReadBytes(50); // client receives CertificateRequest + client_->SendData( + 50); // client sends Certificate, CertificateVerify, Finished + server_->ReadBytes( + 50); // server receives Certificate, CertificateVerify, Finished + client_->CheckEpochs(3, 4); + server_->CheckEpochs(4, 4); + server_->SendData(50); // server sends KeyUpdate + client_->ReadBytes(50); // client receives KeyUpdate + client_->CheckEpochs(4, 4); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthMissingExtension) { + client_->SetupClientAuth(); + Connect(); + // Send CertificateRequest, should fail due to missing + // post_handshake_auth extension. + EXPECT_EQ(SECFailure, SSL_SendCertificateRequest(server_->ssl_fd())); + EXPECT_EQ(SSL_ERROR_MISSING_POST_HANDSHAKE_AUTH_EXTENSION, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthAfterClientAuth) { + client_->SetupClientAuth(); server_->RequestClientAuth(true); - ConnectExpectFail(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(1U, called); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( + client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(2U, called); + ScopedCERTCertificate cert3(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert3.get()); + ScopedCERTCertificate cert4(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert4.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert4->derCert)); + EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert3->derCert, &cert1->derCert)); +} + +// Damages the request context in a CertificateRequest message. +// We don't modify a Certificate message instead, so that the client +// can compute CertificateVerify correctly. +class TlsDamageCertificateRequestContextFilter : public TlsHandshakeFilter { + public: + TlsDamageCertificateRequestContextFilter(const std::shared_ptr& a) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateRequest}) { + EnableDecryption(); + } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + assert(1 < output->len()); + // The request context has a 1 octet length. + output->data()[1] ^= 73; + return CHANGE; + } +}; + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthContextMismatch) { + EnsureTlsSetup(); + MakeTlsFilter(server_); + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERTIFICATE, PORT_GetError()); + server_->ExpectReadWriteError(); + server_->SendData(50); + client_->ExpectReceiveAlert(kTlsAlertIllegalParameter); + client_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, PORT_GetError()); +} + +// Replaces signature in a CertificateVerify message. +class TlsDamageSignatureFilter : public TlsHandshakeFilter { + public: + TlsDamageSignatureFilter(const std::shared_ptr& a) + : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}) { + EnableDecryption(); + } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + *output = input; + assert(2 < output->len()); + // The signature follows a 2-octet signature scheme. + output->data()[2] ^= 73; + return CHANGE; + } +}; + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthBadSignature) { + EnsureTlsSetup(); + MakeTlsFilter(client_); + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + Connect(); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ExpectSendAlert(kTlsAlertDecodeError); + server_->ReadBytes(50); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CERT_VERIFY, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthDecline) { + EnsureTlsSetup(); + auto capture_cert_req = MakeTlsFilter( + server_, kTlsHandshakeCertificateRequest); + auto capture_certificate = + MakeTlsFilter( + client_, kTlsHandshakeCertificate); + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_CERTIFICATE, + SSL_REQUIRE_ALWAYS)); + // Client to decline the certificate request. + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook( + client_->ssl_fd(), + [](void*, PRFileDesc*, CERTDistNames*, CERTCertificate**, + SECKEYPrivateKey**) -> SECStatus { return SECFailure; }, + nullptr)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); // send Certificate Request + client_->ReadBytes(50); // read Certificate Request + client_->SendData(50); // send empty Certificate+Finished + server_->ExpectSendAlert(kTlsAlertCertificateRequired); + server_->ReadBytes(50); // read empty Certificate+Finished + server_->ExpectReadWriteError(); + server_->SendData(50); // send alert + // AuthCertificateCallback is not called, because the client sends + // an empty certificate_list. + EXPECT_EQ(0U, called); + EXPECT_TRUE(capture_cert_req->filtered()); + EXPECT_TRUE(capture_certificate->filtered()); + // Check if a non-empty request context is generated and it is + // properly sent back. + EXPECT_LT(0U, capture_cert_req->buffer().len()); + EXPECT_EQ(capture_cert_req->buffer().len(), + capture_certificate->buffer().len()); + EXPECT_EQ(0, memcmp(capture_cert_req->buffer().data(), + capture_certificate->buffer().data(), + capture_cert_req->buffer().len())); +} + +// Check if post-handshake auth still works when session tickets are enabled: +// https://bugzilla.mozilla.org/show_bug.cgi?id=1553443 +TEST_F(TlsConnectStreamTls13, PostHandshakeAuthWithSessionTicketsEnabled) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_POST_HANDSHAKE_AUTH, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_TRUE)); + size_t called = 0; + server_->SetAuthCertificateCallback( + [&called](TlsAgent*, PRBool, PRBool) -> SECStatus { + called++; + return SECSuccess; + }); + Connect(); + EXPECT_EQ(0U, called); + // Send CertificateRequest. + EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook( + client_->ssl_fd(), GetClientAuthDataHook, nullptr)); + EXPECT_EQ(SECSuccess, SSL_SendCertificateRequest(server_->ssl_fd())) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); + server_->SendData(50); + client_->ReadBytes(50); + client_->SendData(50); + server_->ReadBytes(50); + EXPECT_EQ(1U, called); + ScopedCERTCertificate cert1(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); + ScopedCERTCertificate cert2(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectGenericPre13, ClientAuthRequiredRejected) { + server_->RequestClientAuth(true); + ConnectExpectAlert(server_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); +} + +// In TLS 1.3, the client will claim that the connection is done and then +// receive the alert afterwards. So drive the handshake manually. +TEST_P(TlsConnectTls13, ClientAuthRequiredRejected) { + server_->RequestClientAuth(true); + StartConnect(); + client_->Handshake(); // CH + server_->Handshake(); // SH.. (no resumption) + client_->Handshake(); // Next message + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + ExpectAlert(server_, kTlsAlertCertificateRequired); + server_->Handshake(); // Alert + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); } TEST_P(TlsConnectGeneric, ClientAuthRequestedRejected) { @@ -219,7 +639,9 @@ static void CheckSigScheme(std::shared_ptr& capture, EXPECT_EQ(expected_scheme, static_cast(scheme)); ScopedCERTCertificate remote_cert(SSL_PeerCertificate(peer->ssl_fd())); + ASSERT_NE(nullptr, remote_cert.get()); ScopedSECKEYPublicKey remote_key(CERT_ExtractPublicKey(remote_cert.get())); + ASSERT_NE(nullptr, remote_key.get()); EXPECT_EQ(expected_size, SECKEY_PublicKeyStrengthInBits(remote_key.get())); } @@ -273,9 +695,7 @@ class TlsReplaceSignatureSchemeFilter : public TlsHandshakeFilter { TlsReplaceSignatureSchemeFilter(const std::shared_ptr& a, SSLSignatureScheme scheme) : TlsHandshakeFilter(a, {kTlsHandshakeCertificateVerify}), - scheme_(scheme) { - EnableDecryption(); - } + scheme_(scheme) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -342,6 +762,59 @@ TEST_P(TlsConnectTls12, ClientAuthInconsistentPssSignatureScheme) { ConnectExpectAlert(server_, kTlsAlertIllegalParameter); } +TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureScheme) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pss_rsae_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + server_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + + auto capture_cert_verify = MakeTlsFilter( + client_, kTlsHandshakeCertificateVerify); + capture_cert_verify->EnableDecryption(); + + Connect(); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256, + 1024); +} + +// Client should refuse to connect without a usable signature scheme. +TEST_P(TlsConnectTls13, ClientAuthPkcs1SignatureSchemeOnly) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +// Though the client has a usable signature scheme, when a certificate is +// requested, it can't produce one. +TEST_P(TlsConnectTls13, ClientAuthPkcs1AndEcdsaScheme) { + static const SSLSignatureScheme kSignatureScheme[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_ecdsa_secp256r1_sha256}; + + Reset(TlsAgent::kServerRsa, "rsa"); + client_->SetSignatureSchemes(kSignatureScheme, + PR_ARRAY_SIZE(kSignatureScheme)); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { public: TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr& a) @@ -552,7 +1025,9 @@ TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { TEST_P(TlsConnectTls13, UnsupportedSignatureSchemeAlert) { EnsureTlsSetup(); - MakeTlsFilter(server_, ssl_sig_none); + auto filter = + MakeTlsFilter(server_, ssl_sig_none); + filter->EnableDecryption(); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -563,15 +1038,16 @@ TEST_P(TlsConnectTls13, InconsistentSignatureSchemeAlert) { EnsureTlsSetup(); // This won't work because we use an RSA cert by default. - MakeTlsFilter( + auto filter = MakeTlsFilter( server_, ssl_sig_ecdsa_secp256r1_sha256); + filter->EnableDecryption(); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM); } -TEST_P(TlsConnectTls12Plus, RequestClientAuthWithSha384) { +TEST_P(TlsConnectTls12, RequestClientAuthWithSha384) { server_->SetSignatureSchemes(kSignatureSchemeRsaSha384, PR_ARRAY_SIZE(kSignatureSchemeRsaSha384)); server_->RequestClientAuth(false); @@ -888,11 +1364,11 @@ TEST_P(TlsConnectGeneric, AuthFailImmediate) { } static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = { - ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr}; + ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr, nullptr, nullptr}; static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = { - ssl_auth_rsa_sign, nullptr, nullptr, nullptr}; + ssl_auth_rsa_sign, nullptr, nullptr, nullptr, nullptr, nullptr}; static const SSLExtraServerCertData ServerCertDataRsaPss = { - ssl_auth_rsa_pss, nullptr, nullptr, nullptr}; + ssl_auth_rsa_pss, nullptr, nullptr, nullptr, nullptr, nullptr}; // Test RSA cert with usage=[signature, encipherment]. TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1SignAndKEX) { @@ -972,6 +1448,109 @@ TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) { &ServerCertDataRsaPss)); } +// A server should refuse to even start a handshake with +// misconfigured certificate and signature scheme. +TEST_P(TlsConnectTls12Plus, MisconfiguredCertScheme) { + Reset(TlsAgent::kServerDsa); + static const SSLSignatureScheme kScheme[] = {ssl_sig_ecdsa_secp256r1_sha256}; + server_->SetSignatureSchemes(kScheme, PR_ARRAY_SIZE(kScheme)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + // TLS 1.2 disables cipher suites, which leads to a different error. + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + } else { + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + } + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// In TLS 1.2, disabling an EC group causes ECDSA to be invalid. +TEST_P(TlsConnectTls12, Tls12CertDisabledGroup) { + Reset(TlsAgent::kServerEcdsa256); + static const std::vector k25519 = {ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(k25519); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// In TLS 1.3, ECDSA configuration only depends on the signature scheme. +TEST_P(TlsConnectTls13, Tls13CertDisabledGroup) { + Reset(TlsAgent::kServerEcdsa256); + static const std::vector k25519 = {ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(k25519); + Connect(); +} + +// A client should refuse to even start a handshake with only DSA. +TEST_P(TlsConnectTls13, Tls13DsaOnlyClient) { + static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256}; + client_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa)); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls13, Tls13DsaOnlyServer) { + Reset(TlsAgent::kServerDsa); + static const SSLSignatureScheme kDsa[] = {ssl_sig_dsa_sha256}; + server_->SetSignatureSchemes(kDsa, PR_ARRAY_SIZE(kDsa)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyClient) { + static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256}; + client_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1)); + client_->StartConnect(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); +} + +TEST_P(TlsConnectTls13, Tls13Pkcs1OnlyServer) { + static const SSLSignatureScheme kPkcs1[] = {ssl_sig_rsa_pkcs1_sha256}; + server_->SetSignatureSchemes(kPkcs1, PR_ARRAY_SIZE(kPkcs1)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_SUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedClient) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + +TEST_P(TlsConnectTls13, Tls13DsaIsNotAdvertisedServer) { + EnsureTlsSetup(); + static const SSLSignatureScheme kSchemes[] = {ssl_sig_dsa_sha256, + ssl_sig_rsa_pss_rsae_sha256}; + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + auto capture = MakeTlsFilter( + server_, ssl_signature_algorithms_xtn, true); + capture->SetHandshakeTypes({kTlsHandshakeCertificateRequest}); + capture->EnableDecryption(); + server_->RequestClientAuth(false); // So we get a CertificateRequest. + Connect(); + // We should only have the one signature algorithm advertised. + static const uint8_t kExpectedExt[] = {0, 2, ssl_sig_rsa_pss_rsae_sha256 >> 8, + ssl_sig_rsa_pss_rsae_sha256 & 0xff}; + ASSERT_EQ(DataBuffer(kExpectedExt, sizeof(kExpectedExt)), + capture->extension()); +} + // variant, version, certificate, auth type, signature scheme typedef std::tuple @@ -1033,12 +1612,21 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) { INSTANTIATE_TEST_CASE_P( SignatureSchemeRsa, TlsSignatureSchemeConfiguration, ::testing::Combine( - TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus, + TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12, ::testing::Values(TlsAgent::kServerRsaSign), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_rsae_sha384))); +// RSASSA-PKCS1-v1_5 is not allowed to be used in TLS 1.3 +INSTANTIATE_TEST_CASE_P( + SignatureSchemeRsaTls13, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV13, + ::testing::Values(TlsAgent::kServerRsaSign), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384))); // PSS with SHA-512 needs a bigger key to work. INSTANTIATE_TEST_CASE_P( SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc index 573c69c75..26e5fb502 100644 --- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -43,10 +43,10 @@ class SignedCertificateTimestampsExtractor { } void assertTimestamps(const DataBuffer& timestamps) { - EXPECT_TRUE(auth_timestamps_); + ASSERT_NE(nullptr, auth_timestamps_); EXPECT_EQ(timestamps, *auth_timestamps_); - EXPECT_TRUE(handshake_timestamps_); + ASSERT_NE(nullptr, handshake_timestamps_); EXPECT_EQ(timestamps, *handshake_timestamps_); const SECItem* current = @@ -64,8 +64,8 @@ static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89}; static const SECItem kSctItem = {siBuffer, const_cast(kSctValue), sizeof(kSctValue)}; static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue)); -static const SSLExtraServerCertData kExtraSctData = {ssl_auth_null, nullptr, - nullptr, &kSctItem}; +static const SSLExtraServerCertData kExtraSctData = { + ssl_auth_null, nullptr, nullptr, &kSctItem, nullptr, nullptr}; // Test timestamps extraction during a successful handshake. TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) { @@ -147,8 +147,8 @@ static const SECItem kOcspItems[] = { {siBuffer, const_cast(kOcspValue2), sizeof(kOcspValue2)}}; static const SECItemArray kOcspResponses = {const_cast(kOcspItems), PR_ARRAY_SIZE(kOcspItems)}; -const static SSLExtraServerCertData kOcspExtraData = {ssl_auth_null, nullptr, - &kOcspResponses, nullptr}; +const static SSLExtraServerCertData kOcspExtraData = { + ssl_auth_null, nullptr, &kOcspResponses, nullptr, nullptr, nullptr}; TEST_P(TlsConnectGeneric, NoOcsp) { EnsureTlsSetup(); @@ -224,7 +224,7 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) { const SECItemArray hugeOcspResponses = {const_cast(hugeOcspItems), PR_ARRAY_SIZE(hugeOcspItems)}; const SSLExtraServerCertData hugeOcspExtraData = { - ssl_auth_null, nullptr, &hugeOcspResponses, nullptr}; + ssl_auth_null, nullptr, &hugeOcspResponses, nullptr, nullptr, nullptr}; // The value should be available during the AuthCertificateCallback client_->SetAuthCertificateCallback([&](TlsAgent* agent, bool checksig, diff --git a/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc new file mode 100644 index 000000000..1e4f817e9 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_cipherorder_unittest.cc @@ -0,0 +1,241 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include + +#include "tls_connect.h" +#include "tls_filter.h" + +namespace nss_test { + +class TlsCipherOrderTest : public TlsConnectTestBase { + protected: + virtual void ConfigureTLS() { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + } + + virtual SECStatus BuildTestLists(std::vector &cs_initial_list, + std::vector &cs_new_list) { + // This is the current CipherSuites order of enabled CipherSuites as defined + // in ssl3con.c + const PRUint16 *kCipherSuites = SSL_GetImplementedCiphers(); + + for (unsigned int i = 0; i < kNumImplementedCiphers; i++) { + PRBool pref = PR_FALSE, policy = PR_FALSE; + SECStatus rv; + rv = SSL_CipherPolicyGet(kCipherSuites[i], &policy); + if (rv != SECSuccess) { + return SECFailure; + } + rv = SSL_CipherPrefGetDefault(kCipherSuites[i], &pref); + if (rv != SECSuccess) { + return SECFailure; + } + if (pref && policy) { + cs_initial_list.push_back(kCipherSuites[i]); + } + } + + // We will test set function with the first 15 enabled ciphers. + const PRUint16 kNumCiphersToSet = 15; + for (unsigned int i = 0; i < kNumCiphersToSet; i++) { + cs_new_list.push_back(cs_initial_list[i]); + } + cs_new_list[0] = cs_initial_list[1]; + cs_new_list[1] = cs_initial_list[0]; + return SECSuccess; + } + + public: + TlsCipherOrderTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} + const unsigned int kNumImplementedCiphers = SSL_GetNumImplementedCiphers(); +}; + +const PRUint16 kCSUnsupported[] = {20196, 10101}; +const PRUint16 kNumCSUnsupported = PR_ARRAY_SIZE(kCSUnsupported); +const PRUint16 kCSEmpty[] = {0}; + +// Get the active CipherSuites odered as they were compiled +TEST_F(TlsCipherOrderTest, CipherOrderGet) { + std::vector initial_cs_order; + std::vector new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + std::vector current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < initial_cs_order.size(); i++) { + EXPECT_EQ(initial_cs_order[i], current_cs_order[i]); + } + // Get the chosen CipherSuite during the Handshake without any modification. + Connect(); + SSLChannelInfo channel; + result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel); + ASSERT_EQ(result, SECSuccess); + EXPECT_EQ(channel.cipherSuite, initial_cs_order[0]); +} + +// The "server" used for gtests honor only its ciphersuites order. +// So, we apply the new set for the server instead of client. +// This is enough to test the effect of SSL_CipherSuiteOrderSet function. +TEST_F(TlsCipherOrderTest, CipherOrderSet) { + std::vector initial_cs_order; + std::vector new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // change the server_ ciphersuites order. + result = SSL_CipherSuiteOrderSet(server_->ssl_fd(), new_cs_order.data(), + new_cs_order.size()); + ASSERT_EQ(result, SECSuccess); + + // The function expect an array. We are using vector for VStudio + // compatibility. + std::vector current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size()); + for (unsigned int i = 0; i < new_cs_order.size(); i++) { + ASSERT_EQ(new_cs_order[i], current_cs_order[i]); + } + + Connect(); + SSLChannelInfo channel; + // changes in server_ order reflect in client chosen ciphersuite. + result = SSL_GetChannelInfo(client_->ssl_fd(), &channel, sizeof channel); + ASSERT_EQ(result, SECSuccess); + EXPECT_EQ(channel.cipherSuite, new_cs_order[0]); +} + +// Duplicate socket configuration from a model. +TEST_F(TlsCipherOrderTest, CipherOrderCopySocket) { + std::vector initial_cs_order; + std::vector new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // Use the existing sockets for this test. + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + new_cs_order.size()); + ASSERT_EQ(result, SECSuccess); + + std::vector current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < current_num_active_cs; i++) { + ASSERT_EQ(initial_cs_order[i], current_cs_order[i]); + } + + // Import/Duplicate configurations from client_ to server_ + PRFileDesc *rv = SSL_ImportFD(client_->ssl_fd(), server_->ssl_fd()); + EXPECT_NE(nullptr, rv); + + result = SSL_CipherSuiteOrderGet(server_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size()); + for (unsigned int i = 0; i < new_cs_order.size(); i++) { + EXPECT_EQ(new_cs_order.data()[i], current_cs_order[i]); + } +} + +// If the infomed num of elements is lower than the actual list size, only the +// first "informed num" elements will be considered. The rest is ignored. +TEST_F(TlsCipherOrderTest, CipherOrderSetLower) { + std::vector initial_cs_order; + std::vector new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + new_cs_order.size() - 1); + ASSERT_EQ(result, SECSuccess); + + std::vector current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, new_cs_order.size() - 1); + for (unsigned int i = 0; i < new_cs_order.size() - 1; i++) { + ASSERT_EQ(new_cs_order.data()[i], current_cs_order[i]); + } +} + +// Testing Errors Controls +TEST_F(TlsCipherOrderTest, CipherOrderSetControls) { + std::vector initial_cs_order; + std::vector new_cs_order; + SECStatus result = BuildTestLists(initial_cs_order, new_cs_order); + ASSERT_EQ(result, SECSuccess); + ConfigureTLS(); + + // Create a new vector with diplicated entries + std::vector repeated_cs_order(SSL_GetNumImplementedCiphers() + 1); + std::copy(initial_cs_order.begin(), initial_cs_order.end(), + repeated_cs_order.begin()); + repeated_cs_order[0] = repeated_cs_order[1]; + + // Repeated ciphersuites in the list + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), repeated_cs_order.data(), + initial_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Zero size for the sent list + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), 0); + EXPECT_EQ(result, SECFailure); + + // Wrong size, greater than actual + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), new_cs_order.data(), + SSL_GetNumImplementedCiphers() + 1); + EXPECT_EQ(result, SECFailure); + + // Wrong ciphersuites, not implemented + result = SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSUnsupported, + kNumCSUnsupported); + EXPECT_EQ(result, SECFailure); + + // Null list + result = + SSL_CipherSuiteOrderSet(client_->ssl_fd(), nullptr, new_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Empty list + result = + SSL_CipherSuiteOrderSet(client_->ssl_fd(), kCSEmpty, new_cs_order.size()); + EXPECT_EQ(result, SECFailure); + + // Confirm that the controls are working, as the current ciphersuites + // remained untouched + std::vector current_cs_order(SSL_GetNumImplementedCiphers() + 1); + unsigned int current_num_active_cs = 0; + result = SSL_CipherSuiteOrderGet(client_->ssl_fd(), current_cs_order.data(), + ¤t_num_active_cs); + ASSERT_EQ(result, SECSuccess); + ASSERT_EQ(current_num_active_cs, initial_cs_order.size()); + for (unsigned int i = 0; i < initial_cs_order.size(); i++) { + ASSERT_EQ(initial_cs_order[i], current_cs_order[i]); + } +} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 194cbab47..7739fe76f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -56,6 +56,9 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { std::vector groups = {group_}; + if (cert_group_ != ssl_grp_none) { + groups.push_back(cert_group_); + } client_->ConfigNamedGroups(groups); server_->ConfigNamedGroups(groups); kea_type_ = SSLInt_GetKEAType(group_); @@ -68,41 +71,48 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { virtual void SetupCertificate() { if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { switch (sig_scheme_) { - case ssl_sig_rsa_pkcs1_sha256: - case ssl_sig_rsa_pkcs1_sha384: - case ssl_sig_rsa_pkcs1_sha512: + case ssl_sig_rsa_pss_rsae_sha256: + std::cerr << "Signature scheme: rsa_pss_rsae_sha256" << std::endl; Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; - case ssl_sig_rsa_pss_rsae_sha256: case ssl_sig_rsa_pss_rsae_sha384: + std::cerr << "Signature scheme: rsa_pss_rsae_sha384" << std::endl; Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; case ssl_sig_rsa_pss_rsae_sha512: // You can't fit SHA-512 PSS in a 1024-bit key. + std::cerr << "Signature scheme: rsa_pss_rsae_sha512" << std::endl; Reset(TlsAgent::kRsa2048); auth_type_ = ssl_auth_rsa_sign; break; case ssl_sig_rsa_pss_pss_sha256: + std::cerr << "Signature scheme: rsa_pss_pss_sha256" << std::endl; Reset(TlsAgent::kServerRsaPss); auth_type_ = ssl_auth_rsa_pss; break; case ssl_sig_rsa_pss_pss_sha384: + std::cerr << "Signature scheme: rsa_pss_pss_sha384" << std::endl; Reset("rsa_pss384"); auth_type_ = ssl_auth_rsa_pss; break; case ssl_sig_rsa_pss_pss_sha512: + std::cerr << "Signature scheme: rsa_pss_pss_sha512" << std::endl; Reset("rsa_pss512"); auth_type_ = ssl_auth_rsa_pss; break; case ssl_sig_ecdsa_secp256r1_sha256: + std::cerr << "Signature scheme: ecdsa_secp256r1_sha256" << std::endl; Reset(TlsAgent::kServerEcdsa256); auth_type_ = ssl_auth_ecdsa; + cert_group_ = ssl_grp_ec_secp256r1; break; case ssl_sig_ecdsa_secp384r1_sha384: + std::cerr << "Signature scheme: ecdsa_secp384r1_sha384" << std::endl; Reset(TlsAgent::kServerEcdsa384); auth_type_ = ssl_auth_ecdsa; + cert_group_ = ssl_grp_ec_secp384r1; break; default: ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_; @@ -118,9 +128,11 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { break; case ssl_auth_ecdsa: Reset(TlsAgent::kServerEcdsa256); + cert_group_ = ssl_grp_ec_secp256r1; break; case ssl_auth_ecdh_ecdsa: Reset(TlsAgent::kServerEcdhEcdsa); + cert_group_ = ssl_grp_ec_secp256r1; break; case ssl_auth_ecdh_rsa: Reset(TlsAgent::kServerEcdhRsa); @@ -198,6 +210,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { SSLAuthType auth_type_; SSLKEAType kea_type_; SSLNamedGroup group_; + SSLNamedGroup cert_group_ = ssl_grp_none; SSLSignatureScheme sig_scheme_; SSLCipherSuiteInfo csinfo_; }; @@ -330,6 +343,12 @@ static SSLSignatureScheme kSignatureSchemesParamsArr[] = { ssl_sig_rsa_pss_pss_sha256, ssl_sig_rsa_pss_pss_sha384, ssl_sig_rsa_pss_pss_sha512}; +static SSLSignatureScheme kSignatureSchemesParamsArrTls13[] = { + ssl_sig_ecdsa_secp256r1_sha256, ssl_sig_ecdsa_secp384r1_sha384, + ssl_sig_rsa_pss_rsae_sha256, ssl_sig_rsa_pss_rsae_sha384, + ssl_sig_rsa_pss_rsae_sha512, ssl_sig_rsa_pss_pss_sha256, + ssl_sig_rsa_pss_pss_sha384, ssl_sig_rsa_pss_pss_sha512}; + INSTANTIATE_CIPHER_TEST_P(RC4, Stream, V10ToV12, kDummyNamedGroupParams, kDummySignatureSchemesParams, TLS_RSA_WITH_RC4_128_SHA, @@ -394,7 +413,7 @@ INSTANTIATE_CIPHER_TEST_P( #ifndef NSS_DISABLE_TLS_1_3 INSTANTIATE_CIPHER_TEST_P(TLS13, All, V13, ::testing::ValuesIn(kFasterDHEGroups), - ::testing::ValuesIn(kSignatureSchemesParamsArr), + ::testing::ValuesIn(kSignatureSchemesParamsArrTls13), TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256, TLS_AES_256_GCM_SHA384); INSTANTIATE_CIPHER_TEST_P(TLS13AllGroups, All, V13, diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc index 0723c9bee..9cbe9566f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -62,7 +62,6 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) { EnsureTlsSetup(); auto filter = MakeTlsFilter( server_, kTlsHandshakeServerKeyExchange); - filter->EnableDecryption(); ExpectAlert(client_, kTlsAlertDecryptError); ConnectExpectFail(); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); @@ -84,7 +83,9 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { server_->RequestClientAuth(true); auto filter = MakeTlsFilter( client_, kTlsHandshakeCertificateVerify); - filter->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + filter->EnableDecryption(); + } server_->ExpectSendAlert(kTlsAlertDecryptError); // Do these handshakes by hand to avoid race condition on // the client processing the server's alert. diff --git a/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc new file mode 100644 index 000000000..59ec3d393 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_debug_env_unittest.cc @@ -0,0 +1,53 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include +#include +#include + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +extern "C" { +extern FILE* ssl_trace_iob; + +#ifdef NSS_ALLOW_SSLKEYLOGFILE +extern FILE* ssl_keylog_iob; +#endif +} + +// These tests ensure that when the associated environment variables are unset +// that the lazily-initialized defaults are what they are supposed to be. + +#ifdef DEBUG +TEST_P(TlsConnectGeneric, DebugEnvTraceFileNotSet) { + char* ev = PR_GetEnvSecure("SSLDEBUGFILE"); + if (ev && ev[0]) { + // note: should use GTEST_SKIP when GTest gets updated to support it + return; + } + + Connect(); + EXPECT_EQ(stderr, ssl_trace_iob); +} +#endif + +#ifdef NSS_ALLOW_SSLKEYLOGFILE +TEST_P(TlsConnectGeneric, DebugEnvKeylogFileNotSet) { + char* ev = PR_GetEnvSecure("SSLKEYLOGFILE"); + if (ev && ev[0]) { + // note: should use GTEST_SKIP when GTest gets updated to support it + return; + } + + Connect(); + EXPECT_EQ(nullptr, ssl_keylog_iob); +} +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index f1ccc2864..0fe88ea88 100644 --- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -682,4 +682,100 @@ TEST_P(TlsConnectTls12, ConnectInconsistentSigAlgDHE) { ConnectExpectAlert(client_, kTlsAlertIllegalParameter); } +static void CheckSkeSigScheme( + std::shared_ptr& capture_ske, + uint16_t expected_scheme) { + TlsParser parser(capture_ske->buffer()); + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_p"; + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_q"; + EXPECT_TRUE(parser.SkipVariable(2)) << " read dh_Ys"; + + uint32_t tmp; + EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme"; + EXPECT_EQ(expected_scheme, static_cast(tmp)); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicyDhe) { + EnableOnlyDheCiphers(); + + const std::vector schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + EnsureTlsSetup(); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Enable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicyDhe) { + EnableOnlyDheCiphers(); + + const std::vector schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + EnsureTlsSetup(); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy after sending ClientHello so that CH + // includes SHA-1 signature scheme. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384); +} + +TEST_P(TlsConnectPre12, ConnectSigAlgDisabledByPolicyDhePre12) { + EnableOnlyDheCiphers(); + + EnsureTlsSetup(); + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy. This will cause the connection fail as + // TLS 1.1 or earlier uses combined SHA-1 + MD5 signature. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + server_->ExpectSendAlert(kTlsAlertHandshakeFailure); + client_->ExpectReceiveAlert(kTlsAlertHandshakeFailure); + + // Remainder of handshake + Handshake(); + + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_HASH_ALGORITHM); +} + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc index f25efc77a..b441b5c10 100644 --- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -66,6 +66,38 @@ TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { Connect(); } +static void CheckAcks(const std::shared_ptr& acks, + size_t index, std::vector expected) { + ASSERT_LT(index, acks->count()); + const DataBuffer& buf = acks->record(index).buffer; + size_t offset = 2; + uint64_t len; + + EXPECT_EQ(2 + expected.size() * 8, buf.len()); + ASSERT_TRUE(buf.Read(0, 2, &len)); + ASSERT_EQ(static_cast(len + 2), buf.len()); + if ((2 + expected.size() * 8) != buf.len()) { + while (offset < buf.len()) { + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl; + } + return; + } + + for (size_t i = 0; i < expected.size(); ++i) { + uint64_t a = expected[i]; + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + if (a != ack) { + ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a + << " got=0x" << ack << std::dec; + } + } +} + class TlsDropDatagram13 : public TlsConnectDatagram13, public ::testing::WithParamInterface { public: @@ -139,37 +171,6 @@ class TlsDropDatagram13 : public TlsConnectDatagram13, std::shared_ptr chain_; }; - void CheckAcks(const DropAckChain& chain, size_t index, - std::vector acks) { - const DataBuffer& buf = chain.ack_->record(index).buffer; - size_t offset = 2; - uint64_t len; - - EXPECT_EQ(2 + acks.size() * 8, buf.len()); - ASSERT_TRUE(buf.Read(0, 2, &len)); - ASSERT_EQ(static_cast(len + 2), buf.len()); - if ((2 + acks.size() * 8) != buf.len()) { - while (offset < buf.len()) { - uint64_t ack; - ASSERT_TRUE(buf.Read(offset, 8, &ack)); - offset += 8; - std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl; - } - return; - } - - for (size_t i = 0; i < acks.size(); ++i) { - uint64_t a = acks[i]; - uint64_t ack; - ASSERT_TRUE(buf.Read(offset, 8, &ack)); - offset += 8; - if (a != ack) { - ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a - << " got=0x" << ack << std::dec; - } - } - } - void CheckedHandshakeSendReceive() { Handshake(); CheckPostHandshake(); @@ -199,7 +200,7 @@ TEST_P(TlsDropDatagram13, DropClientFirstFlightOnce) { client_->Handshake(); server_->Handshake(); CheckedHandshakeSendReceive(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) { @@ -210,7 +211,7 @@ TEST_P(TlsDropDatagram13, DropServerFirstFlightOnce) { server_->Handshake(); server_filters_.drop_->Disable(); CheckedHandshakeSendReceive(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // Dropping the server's first record also does not produce @@ -223,7 +224,7 @@ TEST_P(TlsDropDatagram13, DropServerFirstRecordOnce) { server_->Handshake(); Handshake(); CheckedHandshakeSendReceive(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // Dropping the second packet of the server's flight should @@ -236,8 +237,8 @@ TEST_P(TlsDropDatagram13, DropServerSecondRecordOnce) { HandshakeAndAck(client_); expected_client_acks_ = 1; CheckedHandshakeSendReceive(); - CheckAcks(client_filters_, 0, {0}); // ServerHello - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(client_filters_.ack_, 0, {0}); // ServerHello + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // Drop the server ACK and verify that the client retransmits @@ -265,8 +266,8 @@ TEST_P(TlsDropDatagram13, DropServerAckOnce) { EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); CheckPostHandshake(); // There should be two copies of the finished ACK - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); - CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 1, {0x0002000000000000ULL}); } // Drop the client certificate verify. @@ -281,10 +282,10 @@ TEST_P(TlsDropDatagram13, DropClientCertVerify) { expected_server_acks_ = 2; CheckedHandshakeSendReceive(); // Ack of the Cert. - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); // Ack of the whole client handshake. CheckAcks( - server_filters_, 1, + server_filters_.ack_, 1, {0x0002000000000000ULL, // CH (we drop everything after this on client) 0x0002000000000003ULL, // CT (2) 0x0002000000000004ULL}); // FIN (2) @@ -310,11 +311,11 @@ TEST_P(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { // as the previous CT1). EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); CheckedHandshakeSendReceive(); - CheckAcks(client_filters_, 0, + CheckAcks(client_filters_.ack_, 0, {0, // SH 0x0002000000000000ULL, // EE 0x0002000000000002ULL}); // CT2 - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // Shrink the MTU down so that certs get split and drop the second piece. @@ -336,13 +337,13 @@ TEST_P(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { // Check that the first record is CT1 EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); CheckedHandshakeSendReceive(); - CheckAcks(client_filters_, 0, + CheckAcks(client_filters_.ack_, 0, { 0, // SH 0x0002000000000000ULL, // EE 0x0002000000000001ULL, // CT1 }); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // In this test, the Certificate message is sent four times, we drop all or part @@ -392,18 +393,18 @@ class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 { 0, // SH 0x0002000000000000ULL // EE }; - CheckAcks(client_filters_, 0, client_acks); + CheckAcks(client_filters_.ack_, 0, client_acks); // And from the second attempt for the half was kept (we delayed this ACK). client_acks.push_back(0x0002000000000000ULL + second_flight_count + ~dropped_half % 2); - CheckAcks(client_filters_, 1, client_acks); + CheckAcks(client_filters_.ack_, 1, client_acks); // And the third attempt where the first and last thirds got through. client_acks.push_back(0x0002000000000000ULL + second_flight_count + third_flight_count - 1); client_acks.push_back(0x0002000000000000ULL + second_flight_count + third_flight_count + 1); - CheckAcks(client_filters_, 2, client_acks); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(client_filters_.ack_, 2, client_acks); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } private: @@ -548,7 +549,7 @@ TEST_P(TlsDropDatagram13, NoDropsDuringZeroRtt) { CheckConnected(); SendReceive(); EXPECT_EQ(0U, client_filters_.ack_->count()); - CheckAcks(server_filters_, 0, + CheckAcks(server_filters_.ack_, 0, {0x0001000000000001ULL, // EOED 0x0002000000000000ULL}); // Finished } @@ -567,8 +568,8 @@ TEST_P(TlsDropDatagram13, DropEEDuringZeroRtt) { ExpectEarlyDataAccepted(true); CheckConnected(); SendReceive(); - CheckAcks(client_filters_, 0, {0}); - CheckAcks(server_filters_, 0, + CheckAcks(client_filters_.ack_, 0, {0}); + CheckAcks(server_filters_.ack_, 0, {0x0001000000000002ULL, // EOED 0x0002000000000000ULL}); // Finished } @@ -608,22 +609,22 @@ TEST_P(TlsDropDatagram13, ReorderServerEE) { expected_client_acks_ = 1; HandshakeAndAck(client_); CheckedHandshakeSendReceive(); - CheckAcks(client_filters_, 0, + CheckAcks(client_filters_.ack_, 0, { 0, // SH 0x0002000000000000, // EE }); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } // The client sends an out of order non-handshake message // but with the handshake key. class TlsSendCipherSpecCapturer { public: - TlsSendCipherSpecCapturer(std::shared_ptr& agent) - : send_cipher_specs_() { - SSLInt_SetCipherSpecChangeFunc(agent->ssl_fd(), CipherSpecChanged, - (void*)this); + TlsSendCipherSpecCapturer(const std::shared_ptr& agent) + : agent_(agent), send_cipher_specs_() { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent_->ssl_fd(), SecretCallback, this)); } std::shared_ptr spec(size_t i) { @@ -634,28 +635,42 @@ class TlsSendCipherSpecCapturer { } private: - static void CipherSpecChanged(void* arg, PRBool sending, - ssl3CipherSpec* newSpec) { - if (!sending) { + static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { + auto self = static_cast(arg); + std::cerr << self->agent_->role_str() << ": capture " << dir + << " secret for epoch " << epoch << std::endl; + + if (dir == ssl_secret_read) { return; } - auto self = static_cast(arg); - - auto spec = std::make_shared(); - bool ret = spec->Init(SSLInt_CipherSpecToEpoch(newSpec), - SSLInt_CipherSpecToAlgorithm(newSpec), - SSLInt_CipherSpecToKey(newSpec), - SSLInt_CipherSpecToIv(newSpec)); - EXPECT_EQ(true, ret); + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent_->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(preinfo.cipherSuite, &cipherinfo, + sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + auto spec = std::make_shared(true, epoch); + EXPECT_TRUE(spec->SetKeys(&cipherinfo, secret)); self->send_cipher_specs_.push_back(spec); } + std::shared_ptr agent_; std::vector> send_cipher_specs_; }; -TEST_P(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { +TEST_F(TlsConnectDatagram13, SendOutOfOrderAppWithHandshakeKey) { StartConnect(); + // Capturing secrets means that we can't use decrypting filters on the client. TlsSendCipherSpecCapturer capturer(client_); client_->Handshake(); server_->Handshake(); @@ -680,9 +695,12 @@ TEST_P(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError()); } -TEST_P(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { +TEST_F(TlsConnectDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { StartConnect(); TlsSendCipherSpecCapturer capturer(client_); + auto acks = MakeTlsFilter(server_, ssl_ct_ack); + acks->EnableDecryption(); + client_->Handshake(); server_->Handshake(); client_->Handshake(); @@ -699,10 +717,10 @@ TEST_P(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { ssl_ct_handshake, DataBuffer(buf, sizeof(buf)))); server_->Handshake(); - EXPECT_EQ(2UL, server_filters_.ack_->count()); + EXPECT_EQ(2UL, acks->count()); // The server acknowledges client Finished twice. - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); - CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); + CheckAcks(acks, 0, {0x0002000000000000ULL}); + CheckAcks(acks, 1, {0x0002000000000000ULL}); } // Shrink the MTU down so that certs get split and then swap the first and @@ -726,7 +744,7 @@ TEST_P(TlsReorderDatagram13, ReorderServerCertificate) { ShiftDtlsTimers(); CheckedHandshakeSendReceive(); EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, {0x0002000000000000ULL}); } TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { @@ -761,7 +779,8 @@ TEST_P(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { CheckConnected(); EXPECT_EQ(0U, client_filters_.ack_->count()); // Acknowledgements for EOED and Finished. - CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000002ULL, 0x0002000000000000ULL}); uint8_t buf[8]; rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); EXPECT_EQ(-1, rv); @@ -800,7 +819,8 @@ TEST_P(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { CheckConnected(); EXPECT_EQ(0U, client_filters_.ack_->count()); // Acknowledgements for EOED and Finished. - CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); + CheckAcks(server_filters_.ack_, 0, + {0x0001000000000002ULL, 0x0002000000000000ULL}); uint8_t buf[8]; rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); EXPECT_EQ(-1, rv); diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc index f1cf1fabc..e62e002f3 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -666,6 +666,80 @@ TEST_P(TlsConnectTls12, ConnectIncorrectSigAlg) { client_->CheckErrorCode(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM); } +static void CheckSkeSigScheme( + std::shared_ptr &capture_ske, + uint16_t expected_scheme) { + TlsParser parser(capture_ske->buffer()); + uint32_t tmp = 0; + EXPECT_TRUE(parser.Read(&tmp, 1)) << " read curve_type"; + EXPECT_EQ(3U, tmp) << "curve type has to be 3"; + EXPECT_TRUE(parser.Skip(2)) << " read namedcurve"; + EXPECT_TRUE(parser.SkipVariable(1)) << " read public"; + + EXPECT_TRUE(parser.Read(&tmp, 2)) << " read sig_scheme"; + EXPECT_EQ(expected_scheme, static_cast(tmp)); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgEnabledByPolicy) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Enable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha1); +} + +TEST_P(TlsConnectTls12, ConnectSigAlgDisabledByPolicy) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector schemes = {ssl_sig_rsa_pkcs1_sha1, + ssl_sig_rsa_pkcs1_sha384}; + + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + server_->SetSignatureSchemes(schemes.data(), schemes.size()); + auto capture_ske = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); + + StartConnect(); + client_->Handshake(); // Send ClientHello + + // Disable SHA-1 by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_SHA1, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Handshake(); // Remainder of handshake + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + CheckSkeSigScheme(capture_ske, ssl_sig_rsa_pkcs1_sha384); +} + INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11Plus)); diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 5819af746..d7f350c8c 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -436,14 +436,14 @@ TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { } TEST_F(TlsExtensionTest13Stream, SignatureAlgorithmsPrecedingGarbage) { - // 31 unknown signature algorithms followed by sha-256, rsa + // 31 unknown signature algorithms followed by sha-256, rsa-pss const uint8_t val[] = { 0x00, 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x04, 0x01}; + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x08, 0x04}; DataBuffer extension(val, sizeof(val)); MakeTlsFilter(client_, ssl_signature_algorithms_xtn, extension); @@ -482,6 +482,73 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { client_, ssl_elliptic_curves_xtn, extension)); } +TEST_P(TlsExtensionTest12, SupportedCurvesDisableX25519) { + // Disable session resumption. + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Ensure that we can enable its use in the key exchange. + SECStatus rv = + NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + auto capture1 = + MakeTlsFilter(client_, ssl_elliptic_curves_xtn); + Connect(); + + EXPECT_TRUE(capture1->captured()); + const DataBuffer& ext1 = capture1->extension(); + + uint32_t count; + ASSERT_TRUE(ext1.Read(0, 2, &count)); + + // Whether or not we've seen x25519 offered in this handshake. + bool seen1_x25519 = false; + for (size_t offset = 2; offset <= count; offset++) { + uint32_t val; + ASSERT_TRUE(ext1.Read(offset, 2, &val)); + if (val == ssl_grp_ec_curve25519) { + seen1_x25519 = true; + break; + } + } + ASSERT_TRUE(seen1_x25519); + + // Ensure that we can disable its use in the key exchange. + rv = NSS_SetAlgorithmPolicy(SEC_OID_CURVE25519, 0, NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + // Clean up after the last run. + Reset(); + auto capture2 = + MakeTlsFilter(client_, ssl_elliptic_curves_xtn); + Connect(); + + EXPECT_TRUE(capture2->captured()); + const DataBuffer& ext2 = capture2->extension(); + + ASSERT_TRUE(ext2.Read(0, 2, &count)); + + // Whether or not we've seen x25519 offered in this handshake. + bool seen2_x25519 = false; + for (size_t offset = 2; offset <= count; offset++) { + uint32_t val; + ASSERT_TRUE(ext2.Read(offset, 2, &val)); + + if (val == ssl_grp_ec_curve25519) { + seen2_x25519 = true; + break; + } + } + + ASSERT_FALSE(seen2_x25519); +} + TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); @@ -547,6 +614,56 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { } } +// This only works on TLS 1.2, since it relies on DSA. +TEST_P(TlsExtensionTest12, SignatureAlgorithmDisableDSA) { + const std::vector schemes = { + ssl_sig_dsa_sha1, ssl_sig_dsa_sha256, ssl_sig_dsa_sha384, + ssl_sig_dsa_sha512, ssl_sig_rsa_pss_rsae_sha256}; + + // Connect with DSA enabled by policy. + SECStatus rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, + NSS_USE_ALG_IN_SSL_KX, 0); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Reset(TlsAgent::kServerDsa); + auto capture1 = + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + Connect(); + + // Check if all the signature algorithms are advertised. + EXPECT_TRUE(capture1->captured()); + const DataBuffer& ext1 = capture1->extension(); + EXPECT_EQ(2U + 2U * schemes.size(), ext1.len()); + + // Connect with DSA disabled by policy. + rv = NSS_SetAlgorithmPolicy(SEC_OID_ANSIX9_DSA_SIGNATURE, 0, + NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, + 0); + ASSERT_EQ(SECSuccess, rv); + + Reset(TlsAgent::kServerDsa); + auto capture2 = + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes.data(), schemes.size()); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + + // Check if no DSA algorithms are advertised. + EXPECT_TRUE(capture2->captured()); + const DataBuffer& ext2 = capture2->extension(); + EXPECT_EQ(2U + 2U, ext2.len()); + uint32_t v = 0; + EXPECT_TRUE(ext2.Read(2, 2, &v)); + EXPECT_EQ(ssl_sig_rsa_pss_rsae_sha256, v); +} + // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { @@ -1121,6 +1238,10 @@ INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls, TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_CASE_P(ExtensionTls12, TlsExtensionTest12, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, + TlsConnectTestBase::kTlsV12)); + INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc index f033b7843..b222f15cb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -22,7 +22,7 @@ namespace nss_test { const uint8_t kShortEmptyFinished[8] = {0}; const uint8_t kLongEmptyFinished[128] = {0}; -class TlsFuzzTest : public ::testing::Test {}; +class TlsFuzzTest : public TlsConnectGeneric {}; // Record the application data stream. class TlsApplicationDataRecorder : public TlsRecordFilter { @@ -46,16 +46,9 @@ class TlsApplicationDataRecorder : public TlsRecordFilter { DataBuffer buffer_; }; -// Ensure that ssl_Time() returns a constant value. -FUZZ_F(TlsFuzzTest, SSL_Time_Constant) { - PRUint32 now = ssl_TimeSec(); - PR_Sleep(PR_SecondsToInterval(2)); - EXPECT_EQ(ssl_TimeSec(), now); -} - // Check that due to the deterministic PRNG we derive // the same master secret in two consecutive TLS sessions. -FUZZ_P(TlsConnectGeneric, DeterministicExporter) { +FUZZ_P(TlsFuzzTest, DeterministicExporter) { const char kLabel[] = "label"; std::vector out1(32), out2(32); @@ -95,7 +88,7 @@ FUZZ_P(TlsConnectGeneric, DeterministicExporter) { // Check that due to the deterministic RNG two consecutive // TLS sessions will have the exact same transcript. -FUZZ_P(TlsConnectGeneric, DeterministicTranscript) { +FUZZ_P(TlsFuzzTest, DeterministicTranscript) { // Make sure we have RSA blinding params. Connect(); @@ -130,9 +123,7 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) { // with all supported TLS versions, STREAM and DGRAM. // Check that records are NOT encrypted. // Check that records don't have a MAC. -FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { - EnsureTlsSetup(); - +FUZZ_P(TlsFuzzTest, ConnectSendReceive_NullCipher) { // Set up app data filters. auto client_recorder = MakeTlsFilter(client_); auto server_recorder = MakeTlsFilter(server_); @@ -157,7 +148,7 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { } // Check that an invalid Finished message doesn't abort the connection. -FUZZ_P(TlsConnectGeneric, BogusClientFinished) { +FUZZ_P(TlsFuzzTest, BogusClientFinished) { EnsureTlsSetup(); MakeTlsFilter( @@ -168,7 +159,7 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) { } // Check that an invalid Finished message doesn't abort the connection. -FUZZ_P(TlsConnectGeneric, BogusServerFinished) { +FUZZ_P(TlsFuzzTest, BogusServerFinished) { EnsureTlsSetup(); MakeTlsFilter( @@ -179,7 +170,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerFinished) { } // Check that an invalid server auth signature doesn't abort the connection. -FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) { +FUZZ_P(TlsFuzzTest, BogusServerAuthSignature) { EnsureTlsSetup(); uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsHandshakeCertificateVerify @@ -190,7 +181,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) { } // Check that an invalid client auth signature doesn't abort the connection. -FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) { +FUZZ_P(TlsFuzzTest, BogusClientAuthSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -199,7 +190,7 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) { } // Check that session ticket resumption works. -FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { +FUZZ_P(TlsFuzzTest, SessionTicketResumption) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); @@ -212,7 +203,7 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { } // Check that session tickets are not encrypted. -FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { +FUZZ_P(TlsFuzzTest, UnencryptedSessionTickets) { ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); auto filter = MakeTlsFilter( @@ -220,23 +211,45 @@ FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { Connect(); std::cerr << "ticket" << filter->buffer() << std::endl; - size_t offset = 4; /* lifetime */ + size_t offset = 4; // Skip lifetime. + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { - offset += 4; /* ticket_age_add */ + offset += 4; // Skip ticket_age_add. uint32_t nonce_len = 0; EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len)); offset += 1 + nonce_len; } - offset += 2 + /* ticket length */ - 2; /* TLS_EX_SESS_TICKET_VERSION */ + + offset += 2; // Skip the ticket length. + + // This bit parses the contents of the ticket, which would ordinarily be + // encrypted. Start by checking that we have the right version. This needs + // to be updated every time that TLS_EX_SESS_TICKET_VERSION is changed. But + // we don't use the #define. That way, any time that code is updated, this + // test will fail unless it is manually checked. + uint32_t ticket_version; + EXPECT_TRUE(filter->buffer().Read(offset, 2, &ticket_version)); + EXPECT_EQ(0x010aU, ticket_version); + offset += 2; + // Check the protocol version number. uint32_t tls_version = 0; EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version)); EXPECT_EQ(version_, static_cast(tls_version)); + offset += sizeof(version_); // Check the cipher suite. uint32_t suite = 0; - EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite)); + EXPECT_TRUE(filter->buffer().Read(offset, 2, &suite)); client_->CheckCipherSuite(static_cast(suite)); } -} + +INSTANTIATE_TEST_CASE_P( + FuzzStream, TlsFuzzTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + FuzzDatagram, TlsFuzzTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp index be1c4ea32..6cff0fc9d 100644 --- a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -18,9 +18,11 @@ 'ssl_agent_unittest.cc', 'ssl_auth_unittest.cc', 'ssl_cert_ext_unittest.cc', + 'ssl_cipherorder_unittest.cc', 'ssl_ciphersuite_unittest.cc', 'ssl_custext_unittest.cc', 'ssl_damage_unittest.cc', + 'ssl_debug_env_unittest.cc', 'ssl_dhe_unittest.cc', 'ssl_drop_unittest.cc', 'ssl_ecdh_unittest.cc', @@ -32,11 +34,12 @@ 'ssl_gather_unittest.cc', 'ssl_gtest.cc', 'ssl_hrr_unittest.cc', - 'ssl_keylog_unittest.cc', 'ssl_keyupdate_unittest.cc', 'ssl_loopback_unittest.cc', 'ssl_misc_unittest.cc', + 'ssl_primitive_unittest.cc', 'ssl_record_unittest.cc', + 'ssl_recordsep_unittest.cc', 'ssl_recordsize_unittest.cc', 'ssl_resumption_unittest.cc', 'ssl_renegotiation_unittest.cc', @@ -52,7 +55,8 @@ 'tls_filter.cc', 'tls_hkdf_unittest.cc', 'tls_esni_unittest.cc', - 'tls_protect.cc' + 'tls_protect.cc', + 'tls_subcerts_unittest.cc' ], 'dependencies': [ '<(DEPTH)/exports.gyp:nss_exports', @@ -74,7 +78,7 @@ '<(DEPTH)/lib/libpkix/libpkix.gyp:libpkix', ], 'conditions': [ - [ 'test_build==1', { + [ 'static_libs==1', { 'dependencies': [ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap_static', ], @@ -91,6 +95,14 @@ '<(DEPTH)/lib/dbm/src/src.gyp:dbm', ], }], + [ 'enable_sslkeylogfile==1 and sanitizer_flags==0', { + 'sources': [ + 'ssl_keylog_unittest.cc', + ], + 'defines': [ + 'NSS_ALLOW_SSLKEYLOGFILE', + ], + }], ], } ], diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc index 322b64837..4713e52a2 100644 --- a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc @@ -4,8 +4,6 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ -#ifdef NSS_ALLOW_SSLKEYLOGFILE - #include #include #include @@ -15,20 +13,59 @@ namespace nss_test { -static const std::string keylog_file_path = "keylog.txt"; -static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path; +static const std::string kKeylogFilePath = "keylog.txt"; +static const std::string kKeylogBlankEnv = "SSLKEYLOGFILE="; +static const std::string kKeylogSetEnv = kKeylogBlankEnv + kKeylogFilePath; + +extern "C" { +extern FILE* ssl_keylog_iob; +} + +class KeyLogFileTestBase : public TlsConnectGeneric { + private: + std::string env_to_set_; -class KeyLogFileTest : public TlsConnectGeneric { public: + virtual void CheckKeyLog() = 0; + + KeyLogFileTestBase(std::string env) : env_to_set_(env) {} + void SetUp() override { TlsConnectGeneric::SetUp(); // Remove previous results (if any). - (void)remove(keylog_file_path.c_str()); - PR_SetEnv(keylog_env.c_str()); + (void)remove(kKeylogFilePath.c_str()); + PR_SetEnv(env_to_set_.c_str()); } - void CheckKeyLog() { - std::ifstream f(keylog_file_path); + void ConnectAndCheck() { + // This is a child process, ensure that error messages immediately + // propagate or else it will not be visible. + ::testing::GTEST_FLAG(throw_on_failure) = true; + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + } else { + Connect(); + } + CheckKeyLog(); + _exit(0); + } +}; + +class KeyLogFileTest : public KeyLogFileTestBase { + public: + KeyLogFileTest() : KeyLogFileTestBase(kKeylogSetEnv) {} + + void CheckKeyLog() override { + std::ifstream f(kKeylogFilePath); std::map labels; std::set client_randoms; for (std::string line; std::getline(f, line);) { @@ -65,28 +102,6 @@ class KeyLogFileTest : public TlsConnectGeneric { ASSERT_EQ(4U, labels["EXPORTER_SECRET"]); } } - - void ConnectAndCheck() { - // This is a child process, ensure that error messages immediately - // propagate or else it will not be visible. - ::testing::GTEST_FLAG(throw_on_failure) = true; - - if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { - SetupForZeroRtt(); - client_->Set0RttEnabled(true); - server_->Set0RttEnabled(true); - ExpectResumption(RESUME_TICKET); - ZeroRttSendReceive(true, true); - Handshake(); - ExpectEarlyDataAccepted(true); - CheckConnected(); - SendReceive(); - } else { - Connect(); - } - CheckKeyLog(); - _exit(0); - } }; // Tests are run in a separate process to ensure that NSS is not initialized yet @@ -113,6 +128,37 @@ INSTANTIATE_TEST_CASE_P( TlsConnectTestBase::kTlsV13)); #endif -} // namespace nss_test +class KeyLogFileUnsetTest : public KeyLogFileTestBase { + public: + KeyLogFileUnsetTest() : KeyLogFileTestBase(kKeylogBlankEnv) {} + + void CheckKeyLog() override { + std::ifstream f(kKeylogFilePath); + EXPECT_FALSE(f.good()); + + EXPECT_EQ(nullptr, ssl_keylog_iob); + } +}; + +TEST_P(KeyLogFileUnsetTest, KeyLogFile) { + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + + ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), ""); +} -#endif // NSS_ALLOW_SSLKEYLOGFILE +INSTANTIATE_TEST_CASE_P( + KeyLogFileDTLS12, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS12, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS13, KeyLogFileUnsetTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc index d6ac99a58..b921d2c1e 100644 --- a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc @@ -33,6 +33,37 @@ TEST_F(TlsConnectTest, KeyUpdateClient) { CheckEpochs(4, 3); } +TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Client) { + StartConnect(); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate); + filter->EnableDecryption(); + + client_->Handshake(); + server_->Handshake(); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectStreamTls13, KeyUpdateTooEarly_Server) { + StartConnect(); + auto filter = MakeTlsFilter( + client_, kTlsHandshakeFinished, kTlsHandshakeKeyUpdate); + filter->EnableDecryption(); + + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->Handshake(); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_KEY_UPDATE); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); Connect(); diff --git a/security/nss/gtests/ssl_gtest/ssl_primitive_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_primitive_unittest.cc new file mode 100644 index 000000000..66ecdeb12 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_primitive_unittest.cc @@ -0,0 +1,218 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +#include "keyhi.h" +#include "pk11pub.h" +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslexp.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "scoped_ptrs_ssl.h" +#include "tls_connect.h" + +namespace nss_test { + +// From tls_hkdf_unittest.cc: +extern size_t GetHashLength(SSLHashType ht); + +class AeadTest : public ::testing::Test { + public: + AeadTest() : slot_(PK11_GetInternalSlot()) {} + + void InitSecret(SSLHashType hash_type) { + static const uint8_t kData[64] = {'s', 'e', 'c', 'r', 'e', 't'}; + SECItem key_item = {siBuffer, const_cast(kData), + static_cast(GetHashLength(hash_type))}; + PK11SymKey *s = + PK11_ImportSymKey(slot_.get(), CKM_SSL3_MASTER_KEY_DERIVE, + PK11_OriginUnwrap, CKA_DERIVE, &key_item, NULL); + ASSERT_NE(nullptr, s); + secret_.reset(s); + } + + void SetUp() override { + InitSecret(ssl_hash_sha256); + PORT_SetError(0); + } + + protected: + static void EncryptDecrypt(const ScopedSSLAeadContext &ctx, + const uint8_t *ciphertext, size_t ciphertext_len) { + static const uint8_t kAad[] = {'a', 'a', 'd'}; + static const uint8_t kPlaintext[] = {'t', 'e', 'x', 't'}; + static const size_t kMaxSize = 32; + + ASSERT_GE(kMaxSize, ciphertext_len); + ASSERT_LT(0U, ciphertext_len); + + uint8_t output[kMaxSize]; + unsigned int output_len = 0; + EXPECT_EQ(SECSuccess, SSL_AeadEncrypt(ctx.get(), 0, kAad, sizeof(kAad), + kPlaintext, sizeof(kPlaintext), + output, &output_len, sizeof(output))); + ASSERT_EQ(ciphertext_len, static_cast(output_len)); + EXPECT_EQ(0, memcmp(ciphertext, output, ciphertext_len)); + + memset(output, 0, sizeof(output)); + EXPECT_EQ(SECSuccess, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + ciphertext, ciphertext_len, output, + &output_len, sizeof(output))); + ASSERT_EQ(sizeof(kPlaintext), static_cast(output_len)); + EXPECT_EQ(0, memcmp(kPlaintext, output, sizeof(kPlaintext))); + + // Now for some tests of decryption failure. + // Truncate the input. + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + ciphertext, ciphertext_len - 1, + output, &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Skip the first byte of the AAD. + EXPECT_EQ( + SECFailure, + SSL_AeadDecrypt(ctx.get(), 0, kAad + 1, sizeof(kAad) - 1, ciphertext, + ciphertext_len, output, &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + uint8_t input[kMaxSize] = {0}; + // Toggle a byte of the input. + memcpy(input, ciphertext, ciphertext_len); + input[0] ^= 9; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + input, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Toggle the last byte (the auth tag). + memcpy(input, ciphertext, ciphertext_len); + input[ciphertext_len - 1] ^= 77; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, kAad, sizeof(kAad), + input, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + + // Toggle some of the AAD. + memcpy(input, kAad, sizeof(kAad)); + input[1] ^= 23; + EXPECT_EQ(SECFailure, SSL_AeadDecrypt(ctx.get(), 0, input, sizeof(kAad), + ciphertext, ciphertext_len, output, + &output_len, sizeof(output))); + EXPECT_EQ(SEC_ERROR_BAD_DATA, PORT_GetError()); + } + + protected: + ScopedPK11SymKey secret_; + + private: + ScopedPK11SlotInfo slot_; +}; + +// These tests all use fixed inputs: a fixed secret, a fixed label, and fixed +// inputs. So they have fixed outputs. +static const char *kLabel = "test "; +static const uint8_t kCiphertextAes128Gcm[] = { + 0x11, 0x14, 0xfc, 0x58, 0x4f, 0x44, 0xff, 0x8c, 0xb6, 0xd8, + 0x20, 0xb3, 0xfb, 0x50, 0xd9, 0x3b, 0xd4, 0xc6, 0xe1, 0x14}; +static const uint8_t kCiphertextAes256Gcm[] = { + 0xf7, 0x27, 0x35, 0x80, 0x88, 0xaf, 0x99, 0x85, 0xf2, 0x83, + 0xca, 0xbb, 0x95, 0x42, 0x09, 0x3f, 0x9c, 0xf3, 0x29, 0xf0}; +static const uint8_t kCiphertextChaCha20Poly1305[] = { + 0x4e, 0x89, 0x2c, 0xfa, 0xfc, 0x8c, 0x40, 0x55, 0x6d, 0x7e, + 0x99, 0xac, 0x8e, 0x54, 0x58, 0xb1, 0x18, 0xd2, 0x66, 0x22}; + +TEST_F(AeadTest, AeadBadVersion) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadUnsupportedCipher) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadOlderCipher) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ( + SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_AES_128_CBC_SHA, + secret_.get(), kLabel, strlen(kLabel), &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadNoLabel) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), nullptr, 12, &ctx)); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadLongLabel) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), "", 254, &ctx)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadNoPointer) { + SSLAeadContext *ctx = nullptr; + ASSERT_EQ(SECFailure, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + EXPECT_EQ(nullptr, ctx); +} + +TEST_F(AeadTest, AeadAes128Gcm) { + SSLAeadContext *ctxInit; + ASSERT_EQ(SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_128_GCM_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextAes128Gcm, sizeof(kCiphertextAes128Gcm)); +} + +TEST_F(AeadTest, AeadAes256Gcm) { + SSLAeadContext *ctxInit = nullptr; + ASSERT_EQ(SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_AES_256_GCM_SHA384, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextAes256Gcm, sizeof(kCiphertextAes256Gcm)); +} + +TEST_F(AeadTest, AeadChaCha20Poly1305) { + SSLAeadContext *ctxInit; + ASSERT_EQ( + SECSuccess, + SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, TLS_CHACHA20_POLY1305_SHA256, + secret_.get(), kLabel, strlen(kLabel), &ctxInit)); + ScopedSSLAeadContext ctx(ctxInit); + EXPECT_NE(nullptr, ctx); + + EncryptDecrypt(ctx, kCiphertextChaCha20Poly1305, + sizeof(kCiphertextChaCha20Poly1305)); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc index f1e85e898..86783b86e 100644 --- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -205,6 +205,42 @@ TEST_F(TlsConnectDatagram13, ShortHeadersServer) { SendReceive(); } +TEST_F(TlsConnectStreamTls13, UnencryptedFinishedMessage) { + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send first server flight + + // Record and drop the first record, which is the Finished. + auto recorder = std::make_shared(client_); + recorder->EnableDecryption(); + auto dropper = std::make_shared(1); + client_->SetFilter(std::make_shared( + ChainedPacketFilterInit({recorder, dropper}))); + client_->Handshake(); // Save and drop CFIN. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + ASSERT_EQ(1U, recorder->count()); + auto& finished = recorder->record(0); + + DataBuffer d; + size_t offset = d.Write(0, ssl_ct_handshake, 1); + offset = d.Write(offset, SSL_LIBRARY_VERSION_TLS_1_2, 2); + offset = d.Write(offset, finished.buffer.len(), 2); + d.Append(finished.buffer); + client_->SendDirect(d); + + // Now process the message. + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + // The server should generate an alert. + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_RECORD_TYPE); + // Have the client consume the alert. + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + const static size_t kContentSizesArr[] = { 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc new file mode 100644 index 000000000..393b50ffd --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_recordsep_unittest.cc @@ -0,0 +1,577 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include +#include "gtest_utils.h" +#include "nss_scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class HandshakeSecretTracker { + public: + HandshakeSecretTracker(const std::shared_ptr& agent, + uint16_t first_read_epoch, uint16_t first_write_epoch) + : agent_(agent), + next_read_epoch_(first_read_epoch), + next_write_epoch_(first_write_epoch) { + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent_->ssl_fd(), + HandshakeSecretTracker::SecretCb, this)); + } + + void CheckComplete() const { + EXPECT_EQ(0, next_read_epoch_); + EXPECT_EQ(0, next_write_epoch_); + } + + private: + static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret, void* arg) { + HandshakeSecretTracker* t = reinterpret_cast(arg); + t->SecretUpdated(epoch, dir, secret); + } + + void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret) { + if (g_ssl_gtest_verbose) { + std::cerr << agent_->role_str() << ": secret callback for " << dir + << " epoch " << epoch << std::endl; + } + + EXPECT_TRUE(secret); + uint16_t* p; + if (dir == ssl_secret_read) { + p = &next_read_epoch_; + } else { + ASSERT_EQ(ssl_secret_write, dir); + p = &next_write_epoch_; + } + EXPECT_EQ(*p, epoch); + switch (*p) { + case 1: // 1 == 0-RTT, next should be handshake. + case 2: // 2 == handshake, next should be application data. + (*p)++; + break; + + case 3: // 3 == application data, there should be no more. + // Use 0 as a sentinel value. + *p = 0; + break; + + default: + ADD_FAILURE() << "Unexpected next epoch: " << *p; + } + } + + std::shared_ptr agent_; + uint16_t next_read_epoch_; + uint16_t next_write_epoch_; +}; + +TEST_F(TlsConnectTest, HandshakeSecrets) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + EnsureTlsSetup(); + + HandshakeSecretTracker c(client_, 2, 2); + HandshakeSecretTracker s(server_, 2, 2); + + Connect(); + SendReceive(); + + c.CheckComplete(); + s.CheckComplete(); +} + +TEST_F(TlsConnectTest, ZeroRttSecrets) { + SetupForZeroRtt(); + + HandshakeSecretTracker c(client_, 2, 1); + HandshakeSecretTracker s(server_, 1, 2); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + + c.CheckComplete(); + s.CheckComplete(); +} + +class KeyUpdateTracker { + public: + KeyUpdateTracker(const std::shared_ptr& agent, + bool expect_read_secret) + : agent_(agent), expect_read_secret_(expect_read_secret), called_(false) { + EXPECT_EQ(SECSuccess, SSL_SecretCallback(agent_->ssl_fd(), + KeyUpdateTracker::SecretCb, this)); + } + + void CheckCalled() const { EXPECT_TRUE(called_); } + + private: + static void SecretCb(PRFileDesc* fd, PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret, void* arg) { + KeyUpdateTracker* t = reinterpret_cast(arg); + t->SecretUpdated(epoch, dir, secret); + } + + void SecretUpdated(PRUint16 epoch, SSLSecretDirection dir, + PK11SymKey* secret) { + EXPECT_EQ(4U, epoch); + EXPECT_EQ(expect_read_secret_, dir == ssl_secret_read); + EXPECT_TRUE(secret); + called_ = true; + } + + std::shared_ptr agent_; + bool expect_read_secret_; + bool called_; +}; + +TEST_F(TlsConnectTest, KeyUpdateSecrets) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // The update is to the client write secret; the server read secret. + KeyUpdateTracker c(client_, false); + KeyUpdateTracker s(server_, true); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 3); + c.CheckCalled(); + s.CheckCalled(); +} + +// BadPrSocket is an instance of a PR IO layer that crashes the test if it is +// ever used for reading or writing. It does that by failing to overwrite any +// of the DummyIOLayerMethods, which all crash when invoked. +class BadPrSocket : public DummyIOLayerMethods { + public: + BadPrSocket(std::shared_ptr& agent) : DummyIOLayerMethods() { + static PRDescIdentity bad_identity = PR_GetUniqueIdentity("bad NSPR id"); + fd_ = DummyIOLayerMethods::CreateFD(bad_identity, this); + + // This is terrible, but NSPR doesn't provide an easy way to replace the + // bottom layer of an IO stack. Take the DummyPrSocket and replace its + // NSPR method vtable with the ones from this object. + dummy_layer_ = + PR_GetIdentitiesLayer(agent->ssl_fd(), DummyPrSocket::LayerId()); + EXPECT_TRUE(dummy_layer_); + original_methods_ = dummy_layer_->methods; + original_secret_ = dummy_layer_->secret; + dummy_layer_->methods = fd_->methods; + dummy_layer_->secret = reinterpret_cast(this); + } + + // This will be destroyed before the agent, so we need to restore the state + // before we tampered with it. + virtual ~BadPrSocket() { + dummy_layer_->methods = original_methods_; + dummy_layer_->secret = original_secret_; + } + + private: + ScopedPRFileDesc fd_; + PRFileDesc* dummy_layer_; + const PRIOMethods* original_methods_; + PRFilePrivate* original_secret_; +}; + +class StagedRecords { + public: + StagedRecords(std::shared_ptr& agent) : agent_(agent), records_() { + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback( + agent_->ssl_fd(), StagedRecords::StageRecordData, this)); + } + + virtual ~StagedRecords() { + // Uninstall so that the callback doesn't fire during cleanup. + EXPECT_EQ(SECSuccess, + SSL_RecordLayerWriteCallback(agent_->ssl_fd(), nullptr, nullptr)); + } + + bool empty() const { return records_.empty(); } + + void ForwardAll(std::shared_ptr& peer) { + EXPECT_NE(agent_, peer) << "can't forward to self"; + for (auto r : records_) { + r.Forward(peer); + } + records_.clear(); + } + + // This forwards all saved data and checks the resulting state. + void ForwardAll(std::shared_ptr& peer, + TlsAgent::State expected_state) { + ForwardAll(peer); + switch (expected_state) { + case TlsAgent::STATE_CONNECTED: + // The handshake callback should have been called, so check that before + // checking that SSL_ForceHandshake succeeds. + EXPECT_EQ(expected_state, peer->state()); + EXPECT_EQ(SECSuccess, SSL_ForceHandshake(peer->ssl_fd())); + break; + + case TlsAgent::STATE_CONNECTING: + // Check that SSL_ForceHandshake() blocks. + EXPECT_EQ(SECFailure, SSL_ForceHandshake(peer->ssl_fd())); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + // Update and check the state. + peer->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); + break; + + default: + ADD_FAILURE() << "No idea how to handle this state"; + } + } + + void ForwardPartial(std::shared_ptr& peer) { + if (records_.empty()) { + ADD_FAILURE() << "No records to slice"; + return; + } + auto& last = records_.back(); + auto tail = last.SliceTail(); + ForwardAll(peer, TlsAgent::STATE_CONNECTING); + records_.push_back(tail); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, peer->state()); + } + + private: + // A single record. + class StagedRecord { + public: + StagedRecord(const std::string role, uint16_t epoch, SSLContentType ct, + const uint8_t* data, size_t len) + : role_(role), epoch_(epoch), content_type_(ct), data_(data, len) { + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": staged epoch " << epoch_ << " " + << content_type_ << ": " << data_ << std::endl; + } + } + + // This forwards staged data to the identified agent. + void Forward(std::shared_ptr& peer) { + // Now there should be staged data. + EXPECT_FALSE(data_.empty()); + if (g_ssl_gtest_verbose) { + std::cerr << role_ << ": forward " << data_ << std::endl; + } + EXPECT_EQ(SECSuccess, + SSL_RecordLayerData(peer->ssl_fd(), epoch_, content_type_, + data_.data(), + static_cast(data_.len()))); + } + + // Slices the tail off this record and returns it. + StagedRecord SliceTail() { + size_t slice = 1; + if (data_.len() <= slice) { + ADD_FAILURE() << "record too small to slice in two"; + slice = 0; + } + size_t keep = data_.len() - slice; + StagedRecord tail(role_, epoch_, content_type_, data_.data() + keep, + slice); + data_.Truncate(keep); + return tail; + } + + private: + std::string role_; + uint16_t epoch_; + SSLContentType content_type_; + DataBuffer data_; + }; + + // This is an SSLRecordWriteCallback that stages data. + static SECStatus StageRecordData(PRFileDesc* fd, PRUint16 epoch, + SSLContentType content_type, + const PRUint8* data, unsigned int len, + void* arg) { + auto stage = reinterpret_cast(arg); + stage->records_.push_back(StagedRecord(stage->agent_->role_str(), epoch, + content_type, data, + static_cast(len))); + return SECSuccess; + } + + std::shared_ptr& agent_; + std::deque records_; +}; + +// Attempting to feed application data in before the handshake is complete +// should be caught. +static void RefuseApplicationData(std::shared_ptr& peer, + uint16_t epoch) { + static const uint8_t d[] = {1, 2, 3}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(peer->ssl_fd(), epoch, ssl_ct_application_data, + d, static_cast(sizeof(d)))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +static void SendForwardReceive(std::shared_ptr& sender, + StagedRecords& sender_stage, + std::shared_ptr& receiver) { + const size_t count = 10; + sender->SendData(count, count); + sender_stage.ForwardAll(receiver); + receiver->ReadBytes(count); +} + +TEST_P(TlsConnectStream, ReplaceRecordLayer) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + // BadPrSocket installs an IO layer that crashes when the SSL layer attempts + // to read or write. + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + + // StagedRecords installs a handler for unprotected data from the socket, and + // captures that data. + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + // Both peers should refuse application data from epoch 0. + RefuseApplicationData(client_, 0); + RefuseApplicationData(server_, 0); + + // This first call forwards nothing, but it causes the client to handshake, + // which starts things off. This stages the ClientHello as a result. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + // This processes the ClientHello and stages the first server flight. + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + RefuseApplicationData(server_, 1); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // Process the server flight and the client is done. + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + RefuseApplicationData(client_, 1); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + SendForwardReceive(server_, server_stage, client_); +} + +static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { + return SECWouldBlock; +} + +TEST_P(TlsConnectStream, ReplaceRecordLayerAsyncLateAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + // Prior to TLS 1.3, the client sends its second flight immediately. But in + // TLS 1.3, a client won't send a Finished until it is happy with the server + // certificate. So blocking certificate validation causes the client to send + // nothing. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_TRUE(client_stage.empty()); + + // Client should have stopped reading when it saw the Certificate message, + // so it will be reading handshake epoch, and writing cleartext. + client_->CheckEpochs(2, 0); + // Server should be reading handshake, and writing application data. + server_->CheckEpochs(2, 3); + + // Handshake again and the client will read the remainder of the server's + // flight, but it will remain blocked. + client_->Handshake(); + ASSERT_TRUE(client_stage.empty()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + } else { + // In prior versions, the client's second flight is always sent. + ASSERT_FALSE(client_stage.empty()); + } + + // Now declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + ASSERT_FALSE(client_stage.empty()); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); +} + +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncPostHandshake) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + + ASSERT_TRUE(client_stage.empty()); + client_->Handshake(); + ASSERT_TRUE(client_stage.empty()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Now declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + ASSERT_FALSE(client_stage.empty()); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + } else { + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + } + CheckKeys(); + + // Reading and writing application data should work. + SendForwardReceive(client_, client_stage, server_); + + // Post-handshake messages should work here. + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), nullptr, 0)); + SendForwardReceive(server_, server_stage, client_); +} + +// This test ensures that data is correctly forwarded when the handshake is +// resumed after asynchronous server certificate authentication, when +// SSL_AuthCertificateComplete() is called. The logic for resuming the +// handshake involves a different code path than the usual one, so this test +// exercises that code fully. +TEST_F(TlsConnectStreamTls13, ReplaceRecordLayerAsyncEarlyAuth) { + StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + + BadPrSocket bad_layer_client(client_); + BadPrSocket bad_layer_server(server_); + StagedRecords client_stage(client_); + StagedRecords server_stage(server_); + + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTING); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTING); + + // Send a partial flight on to the client. + // This includes enough to trigger the certificate callback. + server_stage.ForwardPartial(client_); + EXPECT_TRUE(client_stage.empty()); + + // Declare the certificate good. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); + EXPECT_TRUE(client_stage.empty()); + + // Send the remainder of the server flight. + PRBool pending = PR_FALSE; + EXPECT_EQ(SECSuccess, + SSLInt_HasPendingHandshakeData(client_->ssl_fd(), &pending)); + EXPECT_EQ(PR_TRUE, pending); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + server_stage.ForwardAll(client_, TlsAgent::STATE_CONNECTED); + client_stage.ForwardAll(server_, TlsAgent::STATE_CONNECTED); + CheckKeys(); + + SendForwardReceive(server_, server_stage, client_); +} + +TEST_P(TlsConnectStream, ForwardDataFromWrongEpoch) { + const uint8_t data[] = {1}; + Connect(); + uint16_t next_epoch; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 2, ssl_ct_application_data, + data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()) + << "Passing data from an old epoch is rejected"; + next_epoch = 4; + } else { + // Prior to TLS 1.3, the epoch is only updated once during the handshake. + next_epoch = 2; + } + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), next_epoch, + ssl_ct_application_data, data, sizeof(data))); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Passing data from a future epoch blocks"; +} + +TEST_F(TlsConnectStreamTls13, ForwardInvalidData) { + const uint8_t data[1] = {0}; + + EnsureTlsSetup(); + // Zero-length data. + EXPECT_EQ(SECFailure, SSL_RecordLayerData(client_->ssl_fd(), 0, + ssl_ct_application_data, data, 0)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // NULL data. + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, + nullptr, 1)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, ForwardDataDtls) { + EnsureTlsSetup(); + const uint8_t data[1] = {0}; + EXPECT_EQ(SECFailure, + SSL_RecordLayerData(client_->ssl_fd(), 0, ssl_ct_application_data, + data, sizeof(data))); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc index 0a54ae1a8..f2003a358 100644 --- a/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_recordsize_unittest.cc @@ -123,9 +123,11 @@ TEST_P(TlsConnectGeneric, RecordSizeMaximum) { EnsureTlsSetup(); auto client_max = MakeTlsFilter(client_); - client_max->EnableDecryption(); auto server_max = MakeTlsFilter(server_); - server_max->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + server_max->EnableDecryption(); + } Connect(); client_->SendData(send_size, send_size); @@ -140,7 +142,9 @@ TEST_P(TlsConnectGeneric, RecordSizeMaximum) { TEST_P(TlsConnectGeneric, RecordSizeMinimumClient) { EnsureTlsSetup(); auto server_max = MakeTlsFilter(server_); - server_max->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_max->EnableDecryption(); + } client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); Connect(); @@ -152,7 +156,9 @@ TEST_P(TlsConnectGeneric, RecordSizeMinimumClient) { TEST_P(TlsConnectGeneric, RecordSizeMinimumServer) { EnsureTlsSetup(); auto client_max = MakeTlsFilter(client_); - client_max->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + } server_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); Connect(); @@ -164,9 +170,11 @@ TEST_P(TlsConnectGeneric, RecordSizeMinimumServer) { TEST_P(TlsConnectGeneric, RecordSizeAsymmetric) { EnsureTlsSetup(); auto client_max = MakeTlsFilter(client_); - client_max->EnableDecryption(); auto server_max = MakeTlsFilter(server_); - server_max->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_max->EnableDecryption(); + server_max->EnableDecryption(); + } client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); server_->SetOption(SSL_RECORD_SIZE_LIMIT, 100); @@ -222,14 +230,15 @@ TEST_P(TlsConnectTls13, RecordSizePlaintextExceed) { // Tweak the ciphertext of server records so that they greatly exceed the limit. // This requires a much larger expansion than for plaintext to trigger the -// guard, which runs before decryption (current allowance is 304 octets). +// guard, which runs before decryption (current allowance is 320 octets, +// see MAX_EXPANSION in ssl3con.c). TEST_P(TlsConnectTls13, RecordSizeCiphertextExceed) { EnsureTlsSetup(); client_->SetOption(SSL_RECORD_SIZE_LIMIT, 64); Connect(); - auto server_expand = MakeTlsFilter(server_, 320); + auto server_expand = MakeTlsFilter(server_, 336); server_->SendData(100); client_->ExpectReadWriteError(); @@ -256,9 +265,11 @@ class TlsRecordPadder : public TlsRecordFilter { return KEEP; } + uint16_t protection_epoch; uint8_t inner_content_type; DataBuffer plaintext; - if (!Unprotect(header, record, &inner_content_type, &plaintext)) { + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext)) { return KEEP; } @@ -267,8 +278,8 @@ class TlsRecordPadder : public TlsRecordFilter { } DataBuffer ciphertext; - bool ok = - Protect(header, inner_content_type, plaintext, &ciphertext, padding_); + bool ok = Protect(spec(protection_epoch), header, inner_content_type, + plaintext, &ciphertext, padding_); EXPECT_TRUE(ok); if (!ok) { return KEEP; @@ -334,7 +345,9 @@ TEST_P(TlsConnectGeneric, RecordSizeCapExtensionClient) { client_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385); auto capture = MakeTlsFilter(client_, ssl_record_size_limit_xtn); - capture->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + capture->EnableDecryption(); + } Connect(); uint64_t val = 0; @@ -352,7 +365,9 @@ TEST_P(TlsConnectGeneric, RecordSizeCapExtensionServer) { server_->SetOption(SSL_RECORD_SIZE_LIMIT, 16385); auto capture = MakeTlsFilter(server_, ssl_record_size_limit_xtn); - capture->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + capture->EnableDecryption(); + } Connect(); uint64_t val = 0; @@ -393,10 +408,24 @@ TEST_P(TlsConnectGeneric, RecordSizeServerExtensionInvalid) { static const uint8_t v[] = {0xf4, 0x1f}; auto replace = MakeTlsFilter( server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v))); - replace->EnableDecryption(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + replace->EnableDecryption(); + } ConnectExpectAlert(client_, kTlsAlertIllegalParameter); } +TEST_P(TlsConnectGeneric, RecordSizeServerExtensionExtra) { + EnsureTlsSetup(); + server_->SetOption(SSL_RECORD_SIZE_LIMIT, 1000); + static const uint8_t v[] = {0x01, 0x00, 0x00}; + auto replace = MakeTlsFilter( + server_, ssl_record_size_limit_xtn, DataBuffer(v, sizeof(v))); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + replace->EnableDecryption(); + } + ConnectExpectAlert(client_, kTlsAlertDecodeError); +} + class RecordSizeDefaultsTest : public ::testing::Test { public: void SetUp() { diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc index a902a5f7f..072a1836c 100644 --- a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc @@ -11,6 +11,11 @@ #include "sslerr.h" #include "sslproto.h" +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + #include "gtest_utils.h" #include "tls_connect.h" @@ -34,6 +39,24 @@ TEST_P(TlsConnectStreamPre13, RenegotiateServer) { CheckConnected(); } +TEST_P(TlsConnectStreamPre13, RenegotiateRandoms) { + SSL3Random crand1, crand2, srand1, srand2; + Connect(); + EXPECT_EQ(SECSuccess, + SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand1, srand1)); + + // Renegotiate and check that both randoms have changed. + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); + EXPECT_EQ(SECSuccess, + SSLInt_GetHandshakeRandoms(client_->ssl_fd(), crand2, srand2)); + + EXPECT_NE(0, memcmp(crand1, crand2, sizeof(SSL3Random))); + EXPECT_NE(0, memcmp(srand1, srand2, sizeof(SSL3Random))); +} + // The renegotiation options shouldn't cause an error if TLS 1.3 is chosen. TEST_F(TlsConnectTest, RenegotiationConfigTls13) { EnsureTlsSetup(); diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc index 264bde67f..bfc3ccfeb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -325,14 +325,17 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { SendReceive(); } +// Tickets last two days maximum; this is a time longer than that. +static const PRTime kLongerThanTicketLifetime = + 3LL * 24 * 60 * 60 * PR_USEC_PER_SEC; + TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) { - SSLInt_SetTicketLifetime(1); // one second // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); - WAIT_(false, 1000); + AdvanceTime(kLongerThanTicketLifetime); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -354,7 +357,6 @@ TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) { } TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { - SSLInt_SetTicketLifetime(1); // one second // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); @@ -373,7 +375,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { EXPECT_TRUE(capture->captured()); EXPECT_LT(0U, capture->extension().len()); - WAIT_(false, 1000); // Let the ticket expire on the server. + AdvanceTime(kLongerThanTicketLifetime); Handshake(); CheckConnected(); @@ -421,6 +423,7 @@ static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr, TEST_P(TlsConnectGeneric, ServerSNICertSwitch) { Connect(); ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); Reset(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); @@ -429,6 +432,7 @@ TEST_P(TlsConnectGeneric, ServerSNICertSwitch) { Connect(); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); CheckKeys(); EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } @@ -437,6 +441,7 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { Reset(TlsAgent::kServerEcdsa256); Connect(); ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); Reset(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); @@ -447,6 +452,7 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { Connect(); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } @@ -531,6 +537,7 @@ TEST_P(TlsConnectTls13, TestTls13ResumeNoCertificateRequest) { Connect(); SendReceive(); // Need to read so that we absorb the session ticket. ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -546,6 +553,7 @@ TEST_P(TlsConnectTls13, TestTls13ResumeNoCertificateRequest) { // Sanity check whether the client certificate matches the one // decrypted from ticket. ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } @@ -561,6 +569,7 @@ TEST_P(TlsConnectTls13, WriteBeforeHandshakeCompleteOnResumption) { Connect(); SendReceive(); // Absorb the session ticket. ScopedCERTCertificate cert1(SSL_LocalCertificate(client_->ssl_fd())); + ASSERT_NE(nullptr, cert1.get()); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -577,6 +586,7 @@ TEST_P(TlsConnectTls13, WriteBeforeHandshakeCompleteOnResumption) { // Check whether the client certificate matches the one from the ticket. ScopedCERTCertificate cert2(SSL_PeerCertificate(server_->ssl_fd())); + ASSERT_NE(nullptr, cert2.get()); EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); } @@ -589,15 +599,17 @@ static uint16_t ChooseOneCipher(uint16_t version) { return TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA; } -static uint16_t ChooseAnotherCipher(uint16_t version) { +static uint16_t ChooseIncompatibleCipher(uint16_t version) { if (version >= SSL_LIBRARY_VERSION_TLS_1_3) { return TLS_AES_256_GCM_SHA384; } return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA; } -// Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) { +// Test that we don't resume when we can't negotiate the same cipher. Note that +// for TLS 1.3, resumption is allowed between compatible ciphers, that is those +// with the same KDF hash, but we choose an incompatible one here. +TEST_P(TlsConnectGenericResumption, ResumeClientIncompatibleCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); client_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -607,7 +619,7 @@ TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ExpectResumption(RESUME_NONE); - client_->EnableSingleCipher(ChooseAnotherCipher(version_)); + client_->EnableSingleCipher(ChooseIncompatibleCipher(version_)); uint16_t ticket_extension; if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ticket_extension = ssl_tls13_pre_shared_key_xtn; @@ -622,24 +634,24 @@ TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) { } // Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGenericResumption, TestResumeServerDifferentCipher) { +TEST_P(TlsConnectGenericResumption, ResumeServerIncompatibleCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); - SendReceive(); // Need to read so that we absorb the session ticket. + SendReceive(); // Absorb the session ticket. CheckKeys(); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ExpectResumption(RESUME_NONE); - server_->EnableSingleCipher(ChooseAnotherCipher(version_)); + server_->EnableSingleCipher(ChooseIncompatibleCipher(version_)); Connect(); CheckKeys(); } // Test that the client doesn't tolerate the server picking a different cipher // suite for resumption. -TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { +TEST_P(TlsConnectStream, ResumptionOverrideCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -648,8 +660,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - MakeTlsFilter(server_, - ChooseAnotherCipher(version_)); + MakeTlsFilter( + server_, ChooseIncompatibleCipher(version_)); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { client_->ExpectSendAlert(kTlsAlertIllegalParameter); @@ -668,6 +680,38 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { } } +// In TLS 1.3, it is possible to resume with a different cipher if it has the +// same hash. +TEST_P(TlsConnectTls13, ResumeClientCompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + Connect(); + SendReceive(); // Absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + client_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectTls13, ResumeServerCompatibleCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + Connect(); + SendReceive(); // Absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + server_->EnableSingleCipher(TLS_CHACHA20_POLY1305_SHA256); + Connect(); + CheckKeys(); +} + class SelectedVersionReplacer : public TlsHandshakeFilter { public: SelectedVersionReplacer(const std::shared_ptr& a, uint16_t version) @@ -757,7 +801,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ASSERT_LT(0U, initialTicket.len()); ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); - ASSERT_TRUE(!!cert1.get()); + ASSERT_NE(nullptr, cert1.get()); Reset(); ClearStats(); @@ -773,7 +817,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ASSERT_LT(0U, c2->extension().len()); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); - ASSERT_TRUE(!!cert2.get()); + ASSERT_NE(nullptr, cert2.get()); // Check that the cipher suite is reported the same on both sides, though in // TLS 1.3 resumption actually negotiates a different cipher suite. @@ -1109,7 +1153,7 @@ TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) { ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } -void CheckGetInfoResult(uint32_t alpnSize, uint32_t earlyDataSize, +void CheckGetInfoResult(PRTime now, uint32_t alpnSize, uint32_t earlyDataSize, ScopedCERTCertificate& cert, ScopedSSLResumptionTokenInfo& token) { ASSERT_TRUE(cert); @@ -1125,7 +1169,7 @@ void CheckGetInfoResult(uint32_t alpnSize, uint32_t earlyDataSize, ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize); - ASSERT_LT(ssl_TimeUsec(), token->expirationTime); + ASSERT_LT(now, token->expirationTime); } // The client should generate a new, randomized session_id @@ -1174,8 +1218,9 @@ TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) { client_->GetTokenInfo(token); ScopedCERTCertificate cert( PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + ASSERT_NE(nullptr, cert.get()); - CheckGetInfoResult(0, 0, cert, token); + CheckGetInfoResult(now(), 0, 0, cert, token); Handshake(); CheckConnected(); @@ -1183,6 +1228,56 @@ TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) { SendReceive(); } +TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketClient) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + // Move the clock to the expiration time of the ticket. + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + AdvanceTime(token->expirationTime - now()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_EQ(SECFailure, + SSL_SetResumptionToken(client_->ssl_fd(), + client_->GetResumptionToken().data(), + client_->GetResumptionToken().size())); + EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); +} + +TEST_P(TlsConnectGenericResumptionToken, RefuseExpiredTicketServer) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + + // Start the handshake and send the ClientHello. + StartConnect(); + ASSERT_EQ(SECSuccess, + SSL_SetResumptionToken(client_->ssl_fd(), + client_->GetResumptionToken().data(), + client_->GetResumptionToken().size())); + client_->Handshake(); + + // Move the clock to the expiration time of the ticket. + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + AdvanceTime(token->expirationTime - now()); + + Handshake(); + CheckConnected(); +} + TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { EnableAlpn(); ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); @@ -1204,8 +1299,9 @@ TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { client_->GetTokenInfo(token); ScopedCERTCertificate cert( PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + ASSERT_NE(nullptr, cert.get()); - CheckGetInfoResult(1, 0, cert, token); + CheckGetInfoResult(now(), 1, 0, cert, token); Handshake(); CheckConnected(); @@ -1216,7 +1312,7 @@ TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) { EnableAlpn(); - SSLInt_RolloverAntiReplay(); + RolloverAntiReplay(); ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); server_->Set0RttEnabled(true); Connect(); @@ -1239,8 +1335,8 @@ TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) { client_->GetTokenInfo(token); ScopedCERTCertificate cert( PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); - - CheckGetInfoResult(1, 1024, cert, token); + ASSERT_NE(nullptr, cert.get()); + CheckGetInfoResult(now(), 1, 1024, cert, token); ZeroRttSendReceive(true, true); Handshake(); @@ -1272,6 +1368,54 @@ TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) { SendReceive(); } +// Check that resumption is blocked if the server requires client auth. +TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->RequestClientAuth(false); + Connect(); + SendReceive(); + + Reset(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +// Check that resumption is blocked if the server requires client auth and +// the client fails to provide a certificate. +TEST_P(TlsConnectGenericResumption, ClientAuthRequiredOnResumptionNoCert) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->RequestClientAuth(false); + Connect(); + SendReceive(); + + Reset(); + server_->RequestClientAuth(true); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + // Drive handshake manually because TLS 1.3 needs it. + StartConnect(); + client_->Handshake(); // CH + server_->Handshake(); // SH.. (no resumption) + client_->Handshake(); // ... + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the client thinks that everything is OK here. + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + ExpectAlert(server_, kTlsAlertCertificateRequired); + server_->Handshake(); // Alert + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_RX_CERTIFICATE_REQUIRED_ALERT); + } else { + ExpectAlert(server_, kTlsAlertBadCertificate); + server_->Handshake(); // Alert + client_->Handshake(); // Receive Alert + client_->CheckErrorCode(SSL_ERROR_BAD_CERT_ALERT); + } + server_->CheckErrorCode(SSL_ERROR_NO_CERTIFICATE); +} + TEST_F(TlsConnectStreamTls13, ExternalTokenAfterHrr) { ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); Connect(); diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc index ffc0893e9..3255bd512 100644 --- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -55,6 +55,10 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { // two validate that we can also detect fallback using the // SSL_SetDowngradeCheckVersion() API. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); MakeTlsFilter(client_, SSL_LIBRARY_VERSION_TLS_1_1); @@ -116,11 +120,11 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { TEST_F(TlsConnectTest, TestFallbackFromTls12) { client_->SetOption(SSL_ENABLE_HELLO_DOWNGRADE_CHECK, PR_TRUE); - client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_1); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -269,4 +273,11 @@ TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { ASSERT_LT(static_cast(SSL_LIBRARY_VERSION_TLS_1_2), version); } +// Offer 1.3 but with ClientHello.legacy_version == SSL 3.0. This +// causes a protocol version alert. See RFC 8446 Appendix D.5. +TEST_F(TlsConnectStreamTls13, Ssl30ClientHelloWithSupportedVersions) { + MakeTlsFilter(client_, SSL_LIBRARY_VERSION_3_0); + ConnectExpectAlert(server_, kTlsAlertProtocolVersion); +} + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc index a75dbb7aa..44e685414 100644 --- a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -214,12 +214,6 @@ class TestPolicyVersionRange ASSERT_EQ(SECSuccess, rv); rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_); ASSERT_EQ(SECSuccess, rv); - // If it wasn't set initially, clear the bit that we set. - if (!(saved_algorithm_policy_ & NSS_USE_POLICY_IN_SSL)) { - rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, 0, - NSS_USE_POLICY_IN_SSL); - ASSERT_EQ(SECSuccess, rv); - } } private: @@ -233,16 +227,12 @@ class TestPolicyVersionRange ASSERT_EQ(SECSuccess, rv); rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_); ASSERT_EQ(SECSuccess, rv); - rv = NSS_GetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, - &saved_algorithm_policy_); - ASSERT_EQ(SECSuccess, rv); } int32_t saved_min_tls_; int32_t saved_max_tls_; int32_t saved_min_dtls_; int32_t saved_max_dtls_; - uint32_t saved_algorithm_policy_; }; VersionPolicy saved_version_policy_; diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc index 6d792c520..4a7f91459 100644 --- a/security/nss/gtests/ssl_gtest/test_io.cc +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -25,10 +25,13 @@ namespace nss_test { if (g_ssl_gtest_verbose) LOG(a); \ } while (false) +PRDescIdentity DummyPrSocket::LayerId() { + static PRDescIdentity id = PR_GetUniqueIdentity("dummysocket"); + return id; +} + ScopedPRFileDesc DummyPrSocket::CreateFD() { - static PRDescIdentity test_fd_identity = - PR_GetUniqueIdentity("testtransportadapter"); - return DummyIOLayerMethods::CreateFD(test_fd_identity, this); + return DummyIOLayerMethods::CreateFD(DummyPrSocket::LayerId(), this); } void DummyPrSocket::Reset() { @@ -136,19 +139,18 @@ int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { DataBuffer filtered; PacketFilter::Action action = PacketFilter::KEEP; if (filter_) { + LOGV("Original packet: " << packet); action = filter_->Process(packet, &filtered); } switch (action) { case PacketFilter::CHANGE: - LOG("Original packet: " << packet); LOG("Filtered packet: " << filtered); dst->PacketReceived(filtered); break; case PacketFilter::DROP: - LOG("Droppped packet: " << packet); + LOG("Drop packet"); break; case PacketFilter::KEEP: - LOGV("Packet: " << packet); dst->PacketReceived(packet); break; } diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h index 062ae86c8..e262fb123 100644 --- a/security/nss/gtests/ssl_gtest/test_io.h +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -33,9 +33,11 @@ class PacketFilter { CHANGE, // change the packet to a different value DROP // drop the packet }; - PacketFilter(bool enabled = true) : enabled_(enabled) {} + explicit PacketFilter(bool on = true) : enabled_(on) {} virtual ~PacketFilter() {} + bool enabled() const { return enabled_; } + virtual Action Process(const DataBuffer& input, DataBuffer* output) { if (!enabled_) { return KEEP; @@ -68,6 +70,8 @@ class DummyPrSocket : public DummyIOLayerMethods { write_error_(0) {} virtual ~DummyPrSocket() {} + static PRDescIdentity LayerId(); + // Create a file descriptor that will reference this object. The fd must not // live longer than this adapter; call PR_Close() before. ScopedPRFileDesc CreateFD(); diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index fb66196b5..88640481e 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -47,6 +47,8 @@ const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; +const std::string TlsAgent::kDelegatorEcdsa256 = "delegator_ecdsa256"; +const std::string TlsAgent::kDelegatorRsae2048 = "delegator_rsae2048"; static const uint8_t kCannedTls13ServerHello[] = { 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, @@ -127,16 +129,76 @@ void TlsAgent::SetState(State s) { ScopedCERTCertificate* cert, ScopedSECKEYPrivateKey* priv) { cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); + EXPECT_NE(nullptr, cert); + if (!cert) return false; EXPECT_NE(nullptr, cert->get()); if (!cert->get()) return false; priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); + EXPECT_NE(nullptr, priv); + if (!priv) return false; EXPECT_NE(nullptr, priv->get()); if (!priv->get()) return false; return true; } +// Loads a key pair from the certificate identified by |id|. +/*static*/ bool TlsAgent::LoadKeyPairFromCert(const std::string& name, + ScopedSECKEYPublicKey* pub, + ScopedSECKEYPrivateKey* priv) { + ScopedCERTCertificate cert; + if (!TlsAgent::LoadCertificate(name, &cert, priv)) { + return false; + } + + pub->reset(SECKEY_ExtractPublicKey(&cert->subjectPublicKeyInfo)); + if (!pub->get()) { + return false; + } + + return true; +} + +void TlsAgent::DelegateCredential(const std::string& name, + const ScopedSECKEYPublicKey& dc_pub, + SSLSignatureScheme dc_cert_verify_alg, + PRUint32 dc_valid_for, PRTime now, + SECItem* dc) { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey cert_priv; + EXPECT_TRUE(TlsAgent::LoadCertificate(name, &cert, &cert_priv)) + << "Could not load delegate certificate: " << name + << "; test db corrupt?"; + + EXPECT_EQ(SECSuccess, + SSL_DelegateCredential(cert.get(), cert_priv.get(), dc_pub.get(), + dc_cert_verify_alg, dc_valid_for, now, dc)); +} + +void TlsAgent::EnableDelegatedCredentials() { + ASSERT_TRUE(EnsureTlsSetup()); + SetOption(SSL_ENABLE_DELEGATED_CREDENTIALS, PR_TRUE); +} + +void TlsAgent::AddDelegatedCredential(const std::string& dc_name, + SSLSignatureScheme dc_cert_verify_alg, + PRUint32 dc_valid_for, PRTime now) { + ASSERT_TRUE(EnsureTlsSetup()); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(dc_name, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(name_, pub, dc_cert_verify_alg, dc_valid_for, + now, &dc); + + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, priv.get()}; + EXPECT_TRUE(ConfigServerCert(name_, true, &extra_data)); +} + bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { ScopedCERTCertificate cert; @@ -224,6 +286,9 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; + // All these tests depend on having this disabled to start with. + SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_FALSE); + return true; } @@ -251,6 +316,10 @@ bool TlsAgent::MaybeSetResumptionToken() { return true; } +void TlsAgent::SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx) { + EXPECT_EQ(SECSuccess, SSL_SetAntiReplayContext(ssl_fd_.get(), ctx.get())); +} + void TlsAgent::SetupClientAuth() { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); @@ -279,7 +348,7 @@ SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; - // See bug 1457716 + // See bug 1573945 // CheckCertReqAgainstDefaultCAs(caNames); ScopedCERTCertificate cert; @@ -640,6 +709,16 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, } } +void TlsAgent::CheckEpochs(uint16_t expected_read, + uint16_t expected_write) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + EXPECT_EQ(SECSuccess, + SSL_GetCurrentEpoch(ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(expected_read, read_epoch) << role_str() << " read epoch"; + EXPECT_EQ(expected_write, write_epoch) << role_str() << " write epoch"; +} + void TlsAgent::EnableSrtp() { EXPECT_TRUE(EnsureTlsSetup()); const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, @@ -715,26 +794,26 @@ void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { } void TlsAgent::CheckPreliminaryInfo() { - SSLPreliminaryChannelInfo info; + SSLPreliminaryChannelInfo preinfo; EXPECT_EQ(SECSuccess, - SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info))); - EXPECT_EQ(sizeof(info), info.length); - EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); - EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); + SSL_GetPreliminaryChannelInfo(ssl_fd(), &preinfo, sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_version); + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); // A version of 0 is invalid and indicates no expectation. This value is // initialized to 0 so that tests that don't explicitly set an expected // version can negotiate a version. if (!expected_version_) { - expected_version_ = info.protocolVersion; + expected_version_ = preinfo.protocolVersion; } - EXPECT_EQ(expected_version_, info.protocolVersion); + EXPECT_EQ(expected_version_, preinfo.protocolVersion); // As with the version; 0 is the null cipher suite (and also invalid). if (!expected_cipher_suite_) { - expected_cipher_suite_ = info.cipherSuite; + expected_cipher_suite_ = preinfo.cipherSuite; } - EXPECT_EQ(expected_cipher_suite_, info.cipherSuite); + EXPECT_EQ(expected_cipher_suite_, preinfo.cipherSuite); } // Check that all the expected callbacks have been called. @@ -766,6 +845,13 @@ void TlsAgent::ResetPreliminaryInfo() { expected_cipher_suite_ = 0; } +void TlsAgent::UpdatePreliminaryChannelInfo() { + SECStatus rv = SSL_GetPreliminaryChannelInfo(ssl_fd_.get(), &pre_info_, + sizeof(pre_info_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(pre_info_), pre_info_.length); +} + void TlsAgent::ValidateCipherSpecs() { PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); // We use one ciphersuite in each direction. @@ -828,6 +914,7 @@ void TlsAgent::Connected() { // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. + UpdatePreliminaryChannelInfo(); EXPECT_EQ(expected_version_, info_.protocolVersion); EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index 020221868..4b6cce8e0 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -76,6 +76,8 @@ class TlsAgent : public PollTarget { static const std::string kServerEcdhEcdsa; static const std::string kServerEcdhRsa; static const std::string kServerDsa; + static const std::string kDelegatorEcdsa256; // draft-ietf-tls-subcerts + static const std::string kDelegatorRsae2048; // draft-ietf-tls-subcerts TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); virtual ~TlsAgent(); @@ -108,9 +110,32 @@ class TlsAgent : public PollTarget { void PrepareForRenegotiate(); // Prepares for renegotiation, then actually triggers it. void StartRenegotiate(); + void SetAntiReplayContext(ScopedSSLAntiReplayContext& ctx); + static bool LoadCertificate(const std::string& name, ScopedCERTCertificate* cert, ScopedSECKEYPrivateKey* priv); + static bool LoadKeyPairFromCert(const std::string& name, + ScopedSECKEYPublicKey* pub, + ScopedSECKEYPrivateKey* priv); + + // Delegated credentials. + // + // Generate a delegated credential and sign it using the certificate + // associated with |name|. + static void DelegateCredential(const std::string& name, + const ScopedSECKEYPublicKey& dcPub, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now, SECItem* dc); + // Indicate support for the delegated credentials extension. + void EnableDelegatedCredentials(); + // Generate and configure a delegated credential to use in the handshake with + // clients that support this extension.. + void AddDelegatedCredential(const std::string& dc_name, + SSLSignatureScheme dcCertVerifyAlg, + PRUint32 dcValidFor, PRTime now); + void UpdatePreliminaryChannelInfo(); + bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, const SSLExtraServerCertData* serverCertData = nullptr); bool ConfigServerCertWithChain(const std::string& name); @@ -139,6 +164,7 @@ class TlsAgent : public PollTarget { const std::string& expected = "") const; void EnableSrtp(); void CheckSrtp() const; + void CheckEpochs(uint16_t expected_read, uint16_t expected_write) const; void CheckErrorCode(int32_t expected) const; void WaitForErrorCode(int32_t expected, uint32_t delay) const; // Send data on the socket, encrypting it. @@ -199,16 +225,20 @@ class TlsAgent : public PollTarget { PRFileDesc* ssl_fd() const { return ssl_fd_.get(); } std::shared_ptr& adapter() { return adapter_; } + const SSLChannelInfo& info() const { + EXPECT_EQ(STATE_CONNECTED, state_); + return info_; + } + + const SSLPreliminaryChannelInfo& pre_info() const { return pre_info_; } + bool is_compressed() const { - return info_.compressionMethod != ssl_compression_null; + return info().compressionMethod != ssl_compression_null; } uint16_t server_key_bits() const { return server_key_bits_; } uint16_t min_version() const { return vrange_.min; } uint16_t max_version() const { return vrange_.max; } - uint16_t version() const { - EXPECT_EQ(STATE_CONNECTED, state_); - return info_.protocolVersion; - } + uint16_t version() const { return info().protocolVersion; } bool cipher_suite(uint16_t* suite) const { if (state_ != STATE_CONNECTED) return false; @@ -399,6 +429,7 @@ class TlsAgent : public PollTarget { bool handshake_callback_called_; bool resumption_callback_called_; SSLChannelInfo info_; + SSLPreliminaryChannelInfo pre_info_; SSLCipherSuiteInfo csinfo_; SSLVersionRange vrange_; PRErrorCode error_code_; diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index c48ae38ec..28165cf7f 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -106,6 +106,10 @@ std::string VersionString(uint16_t version) { } } +// The default anti-replay window for tests. Tests that rely on a different +// value call SSL_InitAntiReplay directly. +static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC; + TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version) : variant_(variant), @@ -167,18 +171,8 @@ void TlsConnectTestBase::CheckShares( void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const { - uint16_t read_epoch = 0; - uint16_t write_epoch = 0; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; - EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; - EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; + client_->CheckEpochs(server_epoch, client_epoch); + server_->CheckEpochs(client_epoch, server_epoch); } void TlsConnectTestBase::ClearStats() { @@ -193,12 +187,37 @@ void TlsConnectTestBase::ClearServerCache() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); } +void TlsConnectTestBase::SaveAlgorithmPolicy() { + saved_policies_.clear(); + for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) { + uint32_t policy; + SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy); + ASSERT_EQ(SECSuccess, rv); + saved_policies_.push_back(std::make_tuple(*it, policy)); + } +} + +void TlsConnectTestBase::RestoreAlgorithmPolicy() { + for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) { + auto algorithm = std::get<0>(*it); + auto policy = std::get<1>(*it); + SECStatus rv = NSS_SetAlgorithmPolicy( + algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX); + ASSERT_EQ(SECSuccess, rv); + } +} + +PRTime TlsConnectTestBase::TimeFunc(void* arg) { + return *reinterpret_cast(arg); +} + void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSelfEncryptKey(); - SSLInt_SetTicketLifetime(30); - SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); + now_ = PR_Now(); + ResetAntiReplay(kAntiReplayWindow); ClearStats(); + SaveAlgorithmPolicy(); Init(); } @@ -209,6 +228,7 @@ void TlsConnectTestBase::TearDown() { SSL_ClearSessionCache(); SSLInt_ClearSelfEncryptKey(); SSL_ShutdownServerSessionIDCache(); + RestoreAlgorithmPolicy(); } void TlsConnectTestBase::Init() { @@ -220,6 +240,14 @@ void TlsConnectTestBase::Init() { } } +void TlsConnectTestBase::ResetAntiReplay(PRTime window) { + SSLAntiReplayContext* p_anti_replay = nullptr; + EXPECT_EQ(SECSuccess, + SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay)); + EXPECT_NE(nullptr, p_anti_replay); + anti_replay_.reset(p_anti_replay); +} + void TlsConnectTestBase::Reset() { // Take a copy of the names because they are about to disappear. std::string server_name = server_->name(); @@ -238,6 +266,8 @@ void TlsConnectTestBase::Reset(const std::string& server_name, server_->SkipVersionChecks(); } + std::cerr << "Reset server:" << server_name << ", client:" << client_name + << std::endl; Init(); } @@ -269,10 +299,14 @@ void TlsConnectTestBase::EnsureTlsSetup() { : nullptr)); EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd() : nullptr)); + server_->SetAntiReplayContext(anti_replay_); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); + EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(), + TlsConnectTestBase::TimeFunc, &now_)); } void TlsConnectTestBase::Handshake() { - EnsureTlsSetup(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -289,16 +323,16 @@ void TlsConnectTestBase::EnableExtendedMasterSecret() { } void TlsConnectTestBase::Connect() { - server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); - client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + StartConnect(); client_->MaybeSetResumptionToken(); Handshake(); CheckConnected(); } void TlsConnectTestBase::StartConnect() { - server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); - client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + EnsureTlsSetup(); + server_->StartConnect(); + client_->StartConnect(); } void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { @@ -666,8 +700,9 @@ void TlsConnectTestBase::SendReceive(size_t total) { // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { + // Force rollover of the anti-replay window. // If we don't do this, then all 0-RTT attempts will be rejected. - SSLInt_RolloverAntiReplay(); + RolloverAntiReplay(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); @@ -779,12 +814,20 @@ void TlsConnectTestBase::ShiftDtlsTimers() { time_shift = time; } - if (time_shift == PR_INTERVAL_NO_TIMEOUT) { - return; + if (time_shift != PR_INTERVAL_NO_TIMEOUT) { + AdvanceTime(PR_IntervalToMicroseconds(time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, + SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); } +} + +void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; } - EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); - EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); +// Advance time by a full anti-replay window. +void TlsConnectTestBase::RolloverAntiReplay() { + AdvanceTime(kAntiReplayWindow); } TlsConnectGeneric::TlsConnectGeneric() diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index 000494501..23c60bf4f 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -48,6 +48,8 @@ class TlsConnectTestBase : public ::testing::Test { virtual void SetUp(); virtual void TearDown(); + PRTime now() const { return now_; } + // Initialize client and server. void Init(); // Clear the statistics. @@ -131,6 +133,13 @@ class TlsConnectTestBase : public ::testing::Test { // Move the DTLS timers for both endpoints to pop the next timer. void ShiftDtlsTimers(); + void AdvanceTime(PRTime time_shift); + + void ResetAntiReplay(PRTime window); + void RolloverAntiReplay(); + + void SaveAlgorithmPolicy(); + void RestoreAlgorithmPolicy(); protected: SSLProtocolVariant variant_; @@ -142,6 +151,7 @@ class TlsConnectTestBase : public ::testing::Test { SessionResumptionMode expected_resumption_mode_; uint8_t expected_resumptions_; std::vector> session_ids_; + ScopedSSLAntiReplayContext anti_replay_; // A simple value of "a", "b". Note that the preferred value of "a" is placed // at the end, because the NSS API follows the now defunct NPN specification, @@ -149,14 +159,24 @@ class TlsConnectTestBase : public ::testing::Test { // NSS will move this final entry to the front when used with ALPN. const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; + // A list of algorithm IDs whose policies need to be preserved + // around test cases. In particular, DSA is checked in + // ssl_extension_unittest.cc. + const std::vector algorithms_ = {SEC_OID_APPLY_SSL_POLICY, + SEC_OID_ANSIX9_DSA_SIGNATURE, + SEC_OID_CURVE25519, SEC_OID_SHA1}; + std::vector> saved_policies_; + private: void CheckResumption(SessionResumptionMode expected); void CheckExtendedMasterSecret(); void CheckEarlyDataAccepted(); + static PRTime TimeFunc(void* arg); bool expect_extended_master_secret_; bool expect_early_data_accepted_; bool skip_version_checks_; + PRTime now_; // Track groups and make sure that there are no duplicates. class DuplicateGroupChecker { diff --git a/security/nss/gtests/ssl_gtest/tls_esni_unittest.cc b/security/nss/gtests/ssl_gtest/tls_esni_unittest.cc index 3c860a0b2..26275e0bc 100644 --- a/security/nss/gtests/ssl_gtest/tls_esni_unittest.cc +++ b/security/nss/gtests/ssl_gtest/tls_esni_unittest.cc @@ -4,8 +4,6 @@ * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ -#include - #include "secerr.h" #include "ssl.h" @@ -57,7 +55,7 @@ static void UpdateEsniKeysChecksum(DataBuffer* buf) { buf->Write(2, sha256, 4); } -static void GenerateEsniKey(time_t windowStart, SSLNamedGroup group, +static void GenerateEsniKey(PRTime now, SSLNamedGroup group, std::vector& cipher_suites, DataBuffer* record, ScopedSECKEYPublicKey* pubKey = nullptr, @@ -70,11 +68,12 @@ static void GenerateEsniKey(time_t windowStart, SSLNamedGroup group, ASSERT_NE(nullptr, priv); SECITEM_FreeItem(&ecParams, PR_FALSE); PRUint8 encoded[1024]; - unsigned int encoded_len; + unsigned int encoded_len = 0; SECStatus rv = SSL_EncodeESNIKeys( - &cipher_suites[0], cipher_suites.size(), group, pub, 100, windowStart, - windowStart + 10, encoded, &encoded_len, sizeof(encoded)); + &cipher_suites[0], cipher_suites.size(), group, pub, 100, + (now / PR_USEC_PER_SEC) - 1, (now / PR_USEC_PER_SEC) + 10, encoded, + &encoded_len, sizeof(encoded)); ASSERT_EQ(SECSuccess, rv); ASSERT_GT(encoded_len, 0U); @@ -92,15 +91,15 @@ static void GenerateEsniKey(time_t windowStart, SSLNamedGroup group, record->Write(0, encoded, encoded_len); } -static void SetupEsni(const std::shared_ptr& client, +static void SetupEsni(PRTime now, const std::shared_ptr& client, const std::shared_ptr& server, SSLNamedGroup group = ssl_grp_ec_curve25519) { ScopedSECKEYPublicKey pub; ScopedSECKEYPrivateKey priv; DataBuffer record; - GenerateEsniKey(time(nullptr), ssl_grp_ec_curve25519, kDefaultSuites, &record, - &pub, &priv); + GenerateEsniKey(now, ssl_grp_ec_curve25519, kDefaultSuites, &record, &pub, + &priv); SECStatus rv = SSL_SetESNIKeyPair(server->ssl_fd(), priv.get(), record.data(), record.len()); ASSERT_EQ(SECSuccess, rv); @@ -124,77 +123,87 @@ static void CheckSniExtension(const DataBuffer& data) { ASSERT_EQ(expected, name); } -static void ClientInstallEsni(std::shared_ptr& agent, - const DataBuffer& record, PRErrorCode err = 0) { - SECStatus rv = - SSL_EnableESNI(agent->ssl_fd(), record.data(), record.len(), kDummySni); - if (err == 0) { - ASSERT_EQ(SECSuccess, rv); - } else { - ASSERT_EQ(SECFailure, rv); - ASSERT_EQ(err, PORT_GetError()); +class TlsAgentEsniTest : public TlsAgentTestClient13 { + public: + void SetUp() override { now_ = PR_Now(); } + + protected: + PRTime now() const { return now_; } + + void InstallEsni(const DataBuffer& record, PRErrorCode err = 0) { + SECStatus rv = SSL_EnableESNI(agent_->ssl_fd(), record.data(), record.len(), + kDummySni); + if (err == 0) { + ASSERT_EQ(SECSuccess, rv); + } else { + ASSERT_EQ(SECFailure, rv); + ASSERT_EQ(err, PORT_GetError()); + } } -} -TEST_P(TlsAgentTestClient13, EsniInstall) { + private: + PRTime now_ = 0; +}; + +TEST_P(TlsAgentEsniTest, EsniInstall) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); - ClientInstallEsni(agent_, record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); + InstallEsni(record); } // The next set of tests fail at setup time. -TEST_P(TlsAgentTestClient13, EsniInvalidHash) { +TEST_P(TlsAgentEsniTest, EsniInvalidHash) { EnsureInit(); DataBuffer record; GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.data()[2]++; - ClientInstallEsni(agent_, record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); + InstallEsni(record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); } -TEST_P(TlsAgentTestClient13, EsniInvalidVersion) { +TEST_P(TlsAgentEsniTest, EsniInvalidVersion) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.Write(0, 0xffff, 2); - ClientInstallEsni(agent_, record, SSL_ERROR_UNSUPPORTED_VERSION); + InstallEsni(record, SSL_ERROR_UNSUPPORTED_VERSION); } -TEST_P(TlsAgentTestClient13, EsniShort) { +TEST_P(TlsAgentEsniTest, EsniShort) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.Truncate(record.len() - 1); UpdateEsniKeysChecksum(&record); - ClientInstallEsni(agent_, record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); + InstallEsni(record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); } -TEST_P(TlsAgentTestClient13, EsniLong) { +TEST_P(TlsAgentEsniTest, EsniLong) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.Write(record.len(), 1, 1); UpdateEsniKeysChecksum(&record); - ClientInstallEsni(agent_, record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); + InstallEsni(record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); } -TEST_P(TlsAgentTestClient13, EsniExtensionMismatch) { +TEST_P(TlsAgentEsniTest, EsniExtensionMismatch) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.Write(record.len() - 1, 1, 1); UpdateEsniKeysChecksum(&record); - ClientInstallEsni(agent_, record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); + InstallEsni(record, SSL_ERROR_RX_MALFORMED_ESNI_KEYS); } // The following tests fail by ignoring the Esni block. -TEST_P(TlsAgentTestClient13, EsniUnknownGroup) { +TEST_P(TlsAgentEsniTest, EsniUnknownGroup) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); record.Write(8, 0xffff, 2); // Fake group UpdateEsniKeysChecksum(&record); - ClientInstallEsni(agent_, record, 0); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); @@ -202,11 +211,11 @@ TEST_P(TlsAgentTestClient13, EsniUnknownGroup) { ASSERT_TRUE(!filter->captured()); } -TEST_P(TlsAgentTestClient13, EsniUnknownCS) { +TEST_P(TlsAgentEsniTest, EsniUnknownCS) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kBogusSuites, &record); - ClientInstallEsni(agent_, record, 0); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kBogusSuites, &record); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); @@ -214,12 +223,12 @@ TEST_P(TlsAgentTestClient13, EsniUnknownCS) { ASSERT_TRUE(!filter->captured()); } -TEST_P(TlsAgentTestClient13, EsniInvalidCS) { +TEST_P(TlsAgentEsniTest, EsniInvalidCS) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kTls12Suites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kTls12Suites, &record); UpdateEsniKeysChecksum(&record); - ClientInstallEsni(agent_, record, 0); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); @@ -227,36 +236,34 @@ TEST_P(TlsAgentTestClient13, EsniInvalidCS) { ASSERT_TRUE(!filter->captured()); } -TEST_P(TlsAgentTestClient13, EsniNotReady) { +TEST_P(TlsAgentEsniTest, EsniNotReady) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0) + 1000, ssl_grp_ec_curve25519, kDefaultSuites, - &record); - ClientInstallEsni(agent_, record, 0); + GenerateEsniKey(now() + 1000, ssl_grp_ec_curve25519, kDefaultSuites, &record); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); ASSERT_TRUE(!filter->captured()); } -TEST_P(TlsAgentTestClient13, EsniExpired) { +TEST_P(TlsAgentEsniTest, EsniExpired) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0) - 1000, ssl_grp_ec_curve25519, kDefaultSuites, - &record); - ClientInstallEsni(agent_, record, 0); + GenerateEsniKey(now() - 1000, ssl_grp_ec_curve25519, kDefaultSuites, &record); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); ASSERT_TRUE(!filter->captured()); } -TEST_P(TlsAgentTestClient13, NoSniSoNoEsni) { +TEST_P(TlsAgentEsniTest, NoSniSoNoEsni) { EnsureInit(); DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); SSL_SetURL(agent_->ssl_fd(), ""); - ClientInstallEsni(agent_, record, 0); + InstallEsni(record, 0); auto filter = MakeTlsFilter(agent_, ssl_tls13_encrypted_sni_xtn); agent_->Handshake(); @@ -275,7 +282,7 @@ static int32_t SniCallback(TlsAgent* agent, const SECItem* srvNameAddr, TEST_P(TlsConnectTls13, ConnectEsni) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); auto cFilterSni = MakeTlsFilter(client_, ssl_server_name_xtn); auto cFilterEsni = @@ -300,16 +307,19 @@ TEST_P(TlsConnectTls13, ConnectEsniHrr) { EnsureTlsSetup(); const std::vector groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); auto hrr_capture = MakeTlsFilter( server_, kTlsHandshakeHelloRetryRequest); auto filter = MakeTlsFilter(client_, ssl_server_name_xtn); - auto cfilter = - MakeTlsFilter(client_, ssl_server_name_xtn); + auto filter2 = + MakeTlsFilter(client_, ssl_server_name_xtn, true); + client_->SetFilter(std::make_shared( + ChainedPacketFilterInit({filter, filter2}))); server_->SetSniCallback(SniCallback); Connect(); - CheckSniExtension(cfilter->extension()); + CheckSniExtension(filter->extension()); + CheckSniExtension(filter2->extension()); EXPECT_NE(0UL, hrr_capture->buffer().len()); } @@ -319,8 +329,8 @@ TEST_P(TlsConnectTls13, ConnectEsniNoDummy) { ScopedSECKEYPrivateKey priv; DataBuffer record; - GenerateEsniKey(time(nullptr), ssl_grp_ec_curve25519, kDefaultSuites, &record, - &pub, &priv); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record, &pub, + &priv); SECStatus rv = SSL_SetESNIKeyPair(server_->ssl_fd(), priv.get(), record.data(), record.len()); ASSERT_EQ(SECSuccess, rv); @@ -343,8 +353,8 @@ TEST_P(TlsConnectTls13, ConnectEsniNullDummy) { ScopedSECKEYPrivateKey priv; DataBuffer record; - GenerateEsniKey(time(nullptr), ssl_grp_ec_curve25519, kDefaultSuites, &record, - &pub, &priv); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record, &pub, + &priv); SECStatus rv = SSL_SetESNIKeyPair(server_->ssl_fd(), priv.get(), record.data(), record.len()); ASSERT_EQ(SECSuccess, rv); @@ -369,14 +379,17 @@ TEST_P(TlsConnectTls13, ConnectEsniCSMismatch) { ScopedSECKEYPrivateKey priv; DataBuffer record; - GenerateEsniKey(time(nullptr), ssl_grp_ec_curve25519, kDefaultSuites, &record, - &pub, &priv); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record, &pub, + &priv); PRUint8 encoded[1024]; - unsigned int encoded_len; + unsigned int encoded_len = 0; SECStatus rv = SSL_EncodeESNIKeys( &kChaChaSuite[0], kChaChaSuite.size(), ssl_grp_ec_curve25519, pub.get(), - 100, time(0), time(0) + 10, encoded, &encoded_len, sizeof(encoded)); + 100, (now() / PR_USEC_PER_SEC) - 1, (now() / PR_USEC_PER_SEC) + 10, + encoded, &encoded_len, sizeof(encoded)); + ASSERT_EQ(SECSuccess, rv); + ASSERT_LT(0U, encoded_len); rv = SSL_SetESNIKeyPair(server_->ssl_fd(), priv.get(), encoded, encoded_len); ASSERT_EQ(SECSuccess, rv); rv = SSL_EnableESNI(client_->ssl_fd(), record.data(), record.len(), ""); @@ -387,7 +400,7 @@ TEST_P(TlsConnectTls13, ConnectEsniCSMismatch) { TEST_P(TlsConnectTls13, ConnectEsniP256) { EnsureTlsSetup(); - SetupEsni(client_, server_, ssl_grp_ec_secp256r1); + SetupEsni(now(), client_, server_, ssl_grp_ec_secp256r1); auto cfilter = MakeTlsFilter(client_, ssl_server_name_xtn); auto sfilter = @@ -400,18 +413,21 @@ TEST_P(TlsConnectTls13, ConnectEsniP256) { TEST_P(TlsConnectTls13, ConnectMismatchedEsniKeys) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); // Now install a new set of keys on the client, so we have a mismatch. DataBuffer record; - GenerateEsniKey(time(0), ssl_grp_ec_curve25519, kDefaultSuites, &record); - ClientInstallEsni(client_, record, 0); + GenerateEsniKey(now(), ssl_grp_ec_curve25519, kDefaultSuites, &record); + + SECStatus rv = + SSL_EnableESNI(client_->ssl_fd(), record.data(), record.len(), kDummySni); + ASSERT_EQ(SECSuccess, rv); ConnectExpectAlert(server_, illegal_parameter); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); } TEST_P(TlsConnectTls13, ConnectDamagedEsniExtensionCH) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); auto filter = MakeTlsFilter( client_, ssl_tls13_encrypted_sni_xtn, 50); // in the ciphertext ConnectExpectAlert(server_, illegal_parameter); @@ -420,7 +436,7 @@ TEST_P(TlsConnectTls13, ConnectDamagedEsniExtensionCH) { TEST_P(TlsConnectTls13, ConnectRemoveEsniExtensionEE) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); auto filter = MakeTlsFilter(server_, ssl_tls13_encrypted_sni_xtn); filter->EnableDecryption(); @@ -430,7 +446,7 @@ TEST_P(TlsConnectTls13, ConnectRemoveEsniExtensionEE) { TEST_P(TlsConnectTls13, ConnectShortEsniExtensionEE) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); DataBuffer shortNonce; auto filter = MakeTlsFilter( server_, ssl_tls13_encrypted_sni_xtn, shortNonce); @@ -441,7 +457,7 @@ TEST_P(TlsConnectTls13, ConnectShortEsniExtensionEE) { TEST_P(TlsConnectTls13, ConnectBogusEsniExtensionEE) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); const uint8_t bogusNonceBuf[16] = {0}; DataBuffer bogusNonce(bogusNonceBuf, sizeof(bogusNonceBuf)); auto filter = MakeTlsFilter( @@ -456,7 +472,7 @@ TEST_P(TlsConnectTls13, ConnectBogusEsniExtensionEE) { // The client then aborts when it sees the server did TLS 1.2. TEST_P(TlsConnectTls13, EsniButTLS12Server) { EnsureTlsSetup(); - SetupEsni(client_, server_); + SetupEsni(now(), client_, server_); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -467,4 +483,4 @@ TEST_P(TlsConnectTls13, EsniButTLS12Server) { ASSERT_FALSE(SSLInt_ExtensionNegotiated(server_->ssl_fd(), ssl_tls13_encrypted_sni_xtn)); } -} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 25ad606fc..b2917274b 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -45,40 +45,65 @@ void TlsVersioned::WriteStream(std::ostream& stream) const { } } +TlsRecordFilter::TlsRecordFilter(const std::shared_ptr& a) + : agent_(a) { + cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0); +} + void TlsRecordFilter::EnableDecryption() { - SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged, - (void*)this); + EXPECT_EQ(SECSuccess, + SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this)); + decrypting_ = true; } -void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, - ssl3CipherSpec* newSpec) { +void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg) { TlsRecordFilter* self = static_cast(arg); - PRBool isServer = self->agent()->role() == TlsAgent::SERVER; - if (g_ssl_gtest_verbose) { - std::cerr << (isServer ? "server" : "client") << ": " - << (sending ? "send" : "receive") - << " cipher spec changed: " << newSpec->epoch << " (" - << newSpec->phase << ")" << std::endl; + std::cerr << self->agent()->role_str() << ": " << dir + << " secret changed for epoch " << epoch << std::endl; } - if (!sending) { + + if (dir == ssl_secret_read) { return; } - uint64_t seq_no; - if (self->agent()->variant() == ssl_variant_datagram) { - seq_no = static_cast(SSLInt_CipherSpecToEpoch(newSpec)) << 48; + for (auto& spec : self->cipher_specs_) { + ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch; + } + + SSLPreliminaryChannelInfo preinfo; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo, + sizeof(preinfo))); + EXPECT_EQ(sizeof(preinfo), preinfo.length); + + // Check the version. + if (preinfo.valuesSet & ssl_preinfo_version) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion); + } else { + EXPECT_EQ(1U, epoch); + } + + uint16_t suite; + if (epoch == 1) { + // 0-RTT + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite); + suite = preinfo.zeroRttCipherSuite; } else { - seq_no = 0; + EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite); + suite = preinfo.cipherSuite; } - self->in_sequence_number_ = seq_no; - self->out_sequence_number_ = seq_no; - self->dropped_record_ = false; - self->cipher_spec_.reset(new TlsCipherSpec()); - bool ret = self->cipher_spec_->Init( - SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec), - SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec)); - EXPECT_EQ(true, ret); + + SSLCipherSuiteInfo cipherinfo; + EXPECT_EQ(SECSuccess, + SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo))); + EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length); + + bool is_dtls = self->agent()->variant() == ssl_variant_datagram; + self->cipher_specs_.emplace_back(is_dtls, epoch); + EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret)); } bool TlsRecordFilter::is_dtls13() const { @@ -95,6 +120,23 @@ bool TlsRecordFilter::is_dtls13() const { info.canSendEarlyData; } +// Gets the cipher spec that matches the specified epoch. +TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) { + for (auto& sp : cipher_specs_) { + if (sp.epoch() == write_epoch) { + return sp; + } + } + + // If we aren't decrypting, provide a cipher spec that does nothing other than + // count sequence numbers. + EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch; + ; + bool is_dtls = agent()->variant() == ssl_variant_datagram; + cipher_specs_.emplace_back(is_dtls, write_epoch); + return cipher_specs_.back(); +} + PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { // Disable during shutdown. @@ -108,34 +150,28 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, output->Allocate(input.len()); TlsParser parser(input); + // This uses the current write spec for the purposes of parsing the epoch and + // sequence number from the header. This might be wrong because we can + // receive records from older specs, but guessing is good enough: + // - In DTLS, parsing the sequence number corrects any errors. + // - In TLS, we don't use the sequence number unless decrypting, where we use + // trial decryption to get the right epoch. + uint16_t write_epoch = 0; + SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch); + if (rv != SECSuccess) { + ADD_FAILURE() << "unable to read epoch"; + return KEEP; + } + uint64_t guess_seqno = static_cast(write_epoch) << 48; + while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - - if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) { + if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } - // Track the sequence number, which is necessary for stream mode when - // decrypting and for TLS 1.3 datagram to recover the sequence number. - // - // We reset the counter when the cipher spec changes, but that notification - // appears before a record is sent. If multiple records are sent with - // different cipher specs, this would fail. This filters out cleartext - // records, so we don't get confused by handshake messages that are sent at - // the same time as encrypted records. Sequence numbers are therefore - // likely to be incorrect for cleartext records. - // - // This isn't perfectly robust: if there is a change from an active cipher - // spec to another active cipher spec (KeyUpdate for instance) AND writes - // are consolidated across that change, this code could use the wrong - // sequence numbers when re-encrypting records with the old keys. - if (header.content_type() == ssl_ct_application_data) { - in_sequence_number_ = - (std::max)(in_sequence_number_, header.sequence_number() + 1); - } - if (FilterRecord(header, record, &offset, output) != KEEP) { changed = true; } else { @@ -159,14 +195,16 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( DataBuffer filtered; uint8_t inner_content_type; DataBuffer plaintext; + uint16_t protection_epoch = 0; - if (!Unprotect(header, record, &inner_content_type, &plaintext)) { - if (g_ssl_gtest_verbose) { - std::cerr << "unprotect failed: " << header << ":" << record << std::endl; - } + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext)) { + std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":" + << record << std::endl; return KEEP; } + auto& protection_spec = spec(protection_epoch); TlsRecordHeader real_header(header.variant(), header.version(), inner_content_type, header.sequence_number()); @@ -174,7 +212,9 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( // In stream mode, even if something doesn't change we need to re-encrypt if // previous packets were dropped. if (action == KEEP) { - if (header.is_dtls() || !dropped_record_) { + if (header.is_dtls() || !protection_spec.record_dropped()) { + // Count every outgoing packet. + protection_spec.RecordProtected(); return KEEP; } filtered = plaintext; @@ -182,7 +222,7 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( if (action == DROP) { std::cerr << "record drop: " << header << ":" << record << std::endl; - dropped_record_ = true; + protection_spec.RecordDropped(); return DROP; } @@ -192,19 +232,18 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( std::cerr << "record new: " << filtered << std::endl; } - uint64_t seq_num; - if (header.is_dtls() || !cipher_spec_ || - header.content_type() != ssl_ct_application_data) { - seq_num = header.sequence_number(); - } else { - seq_num = out_sequence_number_++; + uint64_t seq_num = protection_spec.next_out_seqno(); + if (!decrypting_ && header.is_dtls()) { + // Copy over the epoch, which isn't tracked when not decrypting. + seq_num |= header.sequence_number() & (0xffffULL << 48); } + TlsRecordHeader out_header(header.variant(), header.version(), header.content_type(), seq_num); DataBuffer ciphertext; - bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); - EXPECT_TRUE(rv); + bool rv = Protect(protection_spec, out_header, inner_content_type, filtered, + &ciphertext); if (!rv) { return KEEP; } @@ -227,15 +266,20 @@ uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected, uint32_t partial, size_t partial_bits) { EXPECT_GE(32U, partial_bits); - uint64_t mask = (1 << partial_bits) - 1; + uint64_t mask = (1ULL << partial_bits) - 1; // First we determine the highest possible value. This is half the - // expressible range above the expected value. - uint64_t cap = expected + (1ULL << (partial_bits - 1)); + // expressible range above the expected value, less 1. + // + // We subtract the extra 1 from the cap so that when given a choice between + // the equidistant expected+N and expected-N we want to chose the lower. With + // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch + // of 3 and with 2 partial bits, the alternative result of 5 is wrong. + uint64_t cap = expected + (1ULL << (partial_bits - 1)) - 1; // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. uint64_t seq_no = (cap & ~mask) | partial; // If the partial value is higher than the same partial piece from the cap, // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. - if (partial > (cap & mask)) { + if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) { seq_no -= 1ULL << partial_bits; } return seq_no; @@ -375,16 +419,41 @@ size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, + uint16_t* protection_epoch, uint8_t* inner_content_type, DataBuffer* plaintext) { - if (!cipher_spec_ || header.content_type() != ssl_ct_application_data) { + if (!decrypting_ || header.content_type() != ssl_ct_application_data) { + // Maintain the epoch and sequence number for plaintext records. + uint16_t ep = 0; + if (agent()->variant() == ssl_variant_datagram) { + ep = static_cast(header.sequence_number() >> 48); + } + spec(ep).RecordUnprotected(header.sequence_number()); + *protection_epoch = ep; *inner_content_type = header.content_type(); *plaintext = ciphertext; return true; } - if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) { - return false; + uint16_t ep = 0; + if (agent()->variant() == ssl_variant_datagram) { + ep = static_cast(header.sequence_number() >> 48); + if (!spec(ep).Unprotect(header, ciphertext, plaintext)) { + return false; + } + } else { + // In TLS, records aren't clearly labelled with their epoch, and we + // can't just use the newest keys because the same flight of messages can + // contain multiple epochs. So... trial decrypt! + for (size_t i = cipher_specs_.size() - 1; i > 0; --i) { + if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext)) { + ep = cipher_specs_[i].epoch(); + break; + } + } + if (!ep) { + return false; + } } size_t len = plaintext->len(); @@ -396,33 +465,45 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, return false; } + *protection_epoch = ep; *inner_content_type = plaintext->data()[len - 1]; plaintext->Truncate(len - 1); if (g_ssl_gtest_verbose) { - std::cerr << "unprotect: " << std::hex << header.sequence_number() - << std::dec << " type=" << static_cast(*inner_content_type) + std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep + << " seq=" << std::hex << header.sequence_number() << std::dec << " " << *plaintext << std::endl; } return true; } -bool TlsRecordFilter::Protect(const TlsRecordHeader& header, +bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec, + const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, DataBuffer* ciphertext, size_t padding) { - if (!cipher_spec_ || header.content_type() != ssl_ct_application_data) { + if (!protection_spec.is_protected()) { + // Not protected, just keep the sequence numbers updated. + protection_spec.RecordProtected(); *ciphertext = plaintext; return true; } - if (g_ssl_gtest_verbose) { - std::cerr << "protect: " << header.sequence_number() << std::endl; - } + DataBuffer padded; padded.Allocate(plaintext.len() + 1 + padding); size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); padded.Write(offset, inner_content_type, 1); - return cipher_spec_->Protect(header, padded, ciphertext); + + bool ok = protection_spec.Protect(header, padded, ciphertext); + if (!ok) { + ADD_FAILURE() << "protect fail"; + } else if (g_ssl_gtest_verbose) { + std::cerr << agent()->role_str() + << ": protect: epoch=" << protection_spec.epoch() + << " seq=" << std::hex << header.sequence_number() << std::dec + << " " << *ciphertext << std::endl; + } + return ok; } bool IsHelloRetry(const DataBuffer& body) { diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index 2b6e88645..64ee71c89 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -97,13 +97,7 @@ inline std::shared_ptr MakeTlsFilter(const std::shared_ptr& agent, // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter(const std::shared_ptr& a) - : agent_(a), - count_(0), - cipher_spec_(), - dropped_record_(false), - in_sequence_number_(0), - out_sequence_number_(0) {} + TlsRecordFilter(const std::shared_ptr& a); std::shared_ptr agent() const { return agent_.lock(); } @@ -118,10 +112,11 @@ class TlsRecordFilter : public PacketFilter { // behavior. void EnableDecryption(); bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, - uint8_t* inner_content_type, DataBuffer* plaintext); - bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type, - const DataBuffer& plaintext, DataBuffer* ciphertext, - size_t padding = 0); + uint16_t* protection_epoch, uint8_t* inner_content_type, + DataBuffer* plaintext); + bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header, + uint8_t inner_content_type, const DataBuffer& plaintext, + DataBuffer* ciphertext, size_t padding = 0); protected: // There are two filter functions which can be overriden. Both are @@ -146,20 +141,17 @@ class TlsRecordFilter : public PacketFilter { } bool is_dtls13() const; + TlsCipherSpec& spec(uint16_t epoch); private: - static void CipherSpecChanged(void* arg, PRBool sending, - ssl3CipherSpec* newSpec); + static void SecretCallback(PRFileDesc* fd, PRUint16 epoch, + SSLSecretDirection dir, PK11SymKey* secret, + void* arg); std::weak_ptr agent_; - size_t count_; - std::unique_ptr cipher_spec_; - // Whether we dropped a record since the cipher spec changed. - bool dropped_record_; - // The sequence number we use for reading records as they are written. - uint64_t in_sequence_number_; - // The sequence number we use for writing modified records. - uint64_t out_sequence_number_; + size_t count_ = 0; + std::vector cipher_specs_; + bool decrypting_ = false; }; inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { @@ -449,6 +441,80 @@ class TlsExtensionDropper : public TlsExtensionFilter { uint16_t extension_; }; +class TlsHandshakeDropper : public TlsHandshakeFilter { + public: + TlsHandshakeDropper(const std::shared_ptr& a) + : TlsHandshakeFilter(a) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + return DROP; + } +}; + +class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter { + public: + TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr& a, + uint8_t old_ct, uint8_t new_ct) + : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& record, size_t* offset, + DataBuffer* output) override { + if (header.content_type() != ssl_ct_application_data) { + return KEEP; + } + + uint16_t protection_epoch = 0; + uint8_t inner_content_type; + DataBuffer plaintext; + if (!Unprotect(header, record, &protection_epoch, &inner_content_type, + &plaintext) || + !plaintext.len()) { + return KEEP; + } + + if (inner_content_type != ssl_ct_handshake) { + return KEEP; + } + + size_t off = 0; + uint32_t msg_len = 0; + uint32_t msg_type = 255; // Not a real message + do { + if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) { + break; + } + + // Increment and check next messages + if (!plaintext.Read(++off, 3, &msg_len)) { + break; + } + off += 3 + msg_len; + } while (msg_type != old_ct_); + + if (msg_type == old_ct_) { + plaintext.Write(off, new_ct_, 1); + } + + DataBuffer ciphertext; + bool ok = Protect(spec(protection_epoch), header, inner_content_type, + plaintext, &ciphertext, 0); + if (!ok) { + return KEEP; + } + *offset = header.Write(output, *offset, ciphertext); + return CHANGE; + } + + private: + uint8_t old_ct_; + uint8_t new_ct_; +}; + class TlsExtensionInjector : public TlsHandshakeFilter { public: TlsExtensionInjector(const std::shared_ptr& a, uint16_t ext, @@ -557,9 +623,9 @@ class SelectiveDropFilter : public PacketFilter { class SelectiveRecordDropFilter : public TlsRecordFilter { public: SelectiveRecordDropFilter(const std::shared_ptr& a, - uint32_t pattern, bool enabled = true) + uint32_t pattern, bool on = true) : TlsRecordFilter(a), pattern_(pattern), counter_(0) { - if (!enabled) { + if (!on) { Disable(); } } diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc index 004da3b1c..e1ad9e9f0 100644 --- a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc +++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -7,6 +7,9 @@ #include #include "nss.h" #include "pk11pub.h" +#include "secerr.h" +#include "sslproto.h" +#include "sslexp.h" #include "tls13hkdf.h" #include "databuffer.h" @@ -56,6 +59,39 @@ const size_t kHashLength[] = { 64, /* ssl_hash_sha512 */ }; +size_t GetHashLength(SSLHashType hash) { + size_t i = static_cast(hash); + if (i < PR_ARRAY_SIZE(kHashLength)) { + return kHashLength[i]; + } + ADD_FAILURE() << "Unknown hash: " << hash; + return 0; +} + +CK_MECHANISM_TYPE GetHkdfMech(SSLHashType hash) { + switch (hash) { + case ssl_hash_sha256: + return CKM_NSS_HKDF_SHA256; + case ssl_hash_sha384: + return CKM_NSS_HKDF_SHA384; + default: + ADD_FAILURE() << "Unknown hash: " << hash; + } + return CKM_INVALID_MECHANISM; +} + +PRUint16 GetSomeCipherSuiteForHash(SSLHashType hash) { + switch (hash) { + case ssl_hash_sha256: + return TLS_AES_128_GCM_SHA256; + case ssl_hash_sha384: + return TLS_AES_256_GCM_SHA384; + default: + ADD_FAILURE() << "Unknown hash: " << hash; + } + return 0; +} + const std::string kHashName[] = {"None", "MD5", "SHA-1", "SHA-224", "SHA-256", "SHA-384", "SHA-512"}; @@ -64,7 +100,7 @@ static void ImportKey(ScopedPK11SymKey* to, const DataBuffer& key, ASSERT_LT(hash_type, sizeof(kHashLength)); ASSERT_LE(kHashLength[hash_type], key.len()); SECItem key_item = {siBuffer, const_cast(key.data()), - static_cast(kHashLength[hash_type])}; + static_cast(GetHashLength(hash_type))}; PK11SymKey* inner = PK11_ImportSymKey(slot, CKM_SSL3_MASTER_KEY_DERIVE, PK11_OriginUnwrap, @@ -112,15 +148,19 @@ class TlsHkdfTest : public ::testing::Test, ImportKey(&k2_, kKey2, hash_type_, slot_.get()); } - void VerifyKey(const ScopedPK11SymKey& key, const DataBuffer& expected) { + void VerifyKey(const ScopedPK11SymKey& key, CK_MECHANISM_TYPE expected_mech, + const DataBuffer& expected_value) { + EXPECT_EQ(expected_mech, PK11_GetMechanism(key.get())); + SECStatus rv = PK11_ExtractKeyValue(key.get()); ASSERT_EQ(SECSuccess, rv); SECItem* key_data = PK11_GetKeyData(key.get()); ASSERT_NE(nullptr, key_data); - EXPECT_EQ(expected.len(), key_data->len); - EXPECT_EQ(0, memcmp(expected.data(), key_data->data, expected.len())); + EXPECT_EQ(expected_value.len(), key_data->len); + EXPECT_EQ( + 0, memcmp(expected_value.data(), key_data->data, expected_value.len())); } void HkdfExtract(const ScopedPK11SymKey& ikmk1, const ScopedPK11SymKey& ikmk2, @@ -133,7 +173,15 @@ class TlsHkdfTest : public ::testing::Test, ScopedPK11SymKey prkk(prk); DumpKey("Output", prkk); - VerifyKey(prkk, expected); + VerifyKey(prkk, GetHkdfMech(base_hash), expected); + + // Now test the public wrapper. + PRUint16 cs = GetSomeCipherSuiteForHash(base_hash); + rv = SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, cs, ikmk1.get(), + ikmk2.get(), &prk); + ASSERT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, prk); + VerifyKey(ScopedPK11SymKey(prk), GetHkdfMech(base_hash), expected); } void HkdfExpandLabel(ScopedPK11SymKey* prk, SSLHashType base_hash, @@ -150,6 +198,32 @@ class TlsHkdfTest : public ::testing::Test, ASSERT_EQ(SECSuccess, rv); DumpData("Output", &output[0], output.size()); EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len())); + + // Verify that the public API produces the same result. + PRUint16 cs = GetSomeCipherSuiteForHash(base_hash); + PK11SymKey* secret; + rv = SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(), + session_hash, session_hash_len, label, label_len, + &secret); + EXPECT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, prk); + VerifyKey(ScopedPK11SymKey(secret), GetHkdfMech(base_hash), expected); + + // Verify that a key can be created with a different key type and size. + rv = SSL_HkdfExpandLabelWithMech( + SSL_LIBRARY_VERSION_TLS_1_3, cs, prk->get(), session_hash, + session_hash_len, label, label_len, CKM_DES3_CBC_PAD, 24, &secret); + EXPECT_EQ(SECSuccess, rv); + ASSERT_NE(nullptr, prk); + ScopedPK11SymKey with_mech(secret); + EXPECT_EQ(static_cast(CKM_DES3_CBC_PAD), + PK11_GetMechanism(with_mech.get())); + // Just verify that the key is the right size. + rv = PK11_ExtractKeyValue(with_mech.get()); + ASSERT_EQ(SECSuccess, rv); + SECItem* key_data = PK11_GetKeyData(with_mech.get()); + ASSERT_NE(nullptr, key_data); + EXPECT_EQ(24U, key_data->len); } protected: @@ -175,7 +249,7 @@ TEST_P(TlsHkdfTest, HkdfNullNull) { 0x10, 0xba, 0x18, 0xe2, 0x35, 0x7e, 0x71, 0x69, 0x71, 0xf9, 0x36, 0x2f, 0x2c, 0x2f, 0xe2, 0xa7, 0x6b, 0xfd, 0x78, 0xdf, 0xec, 0x4e, 0xa9, 0xb5}}; - const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); HkdfExtract(nullptr, nullptr, hash_type_, expected_data); } @@ -193,7 +267,7 @@ TEST_P(TlsHkdfTest, HkdfKey1Only) { 0x57, 0xc2, 0x76, 0x9f, 0x3f, 0x83, 0x45, 0x2f, 0xf6, 0xf3, 0x56, 0x1f, 0x58, 0x63, 0xdb, 0x88, 0xda, 0x40, 0xce, 0x63, 0x7d, 0x24, 0x37, 0xf3}}; - const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); HkdfExtract(k1_, nullptr, hash_type_, expected_data); } @@ -211,7 +285,7 @@ TEST_P(TlsHkdfTest, HkdfKey2Only) { 0xd4, 0x6a, 0xf6, 0xe5, 0xec, 0xea, 0xf8, 0x7d, 0x91, 0x71, 0x81, 0xf1, 0xdb, 0x3b, 0xaf, 0xbf, 0xde, 0x71, 0x61, 0x15, 0xeb, 0xb5, 0x5f, 0x68}}; - const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); HkdfExtract(nullptr, k2_, hash_type_, expected_data); } @@ -229,7 +303,7 @@ TEST_P(TlsHkdfTest, HkdfKey1Key2) { 0x1c, 0x5b, 0x98, 0x0b, 0x02, 0x92, 0x3f, 0xfd, 0x73, 0x5a, 0x6f, 0x2a, 0x95, 0xa3, 0xee, 0xf6, 0xd6, 0x8e, 0x6f, 0x86, 0xea, 0x63, 0xf8, 0x33}}; - const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); HkdfExtract(k1_, k2_, hash_type_, expected_data); } @@ -247,12 +321,122 @@ TEST_P(TlsHkdfTest, HkdfExpandLabel) { 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7, 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}}; - const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); - HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_], + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExpandLabel(&k1_, hash_type_, kSessionHash, GetHashLength(hash_type_), kLabelMasterSecret, strlen(kLabelMasterSecret), expected_data); } +TEST_P(TlsHkdfTest, HkdfExpandLabelNoHash) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0xb7, 0x08, 0x00, 0xe3, 0x8e, 0x48, 0x68, 0x91, 0xb1, 0x0f, 0x5e, + 0x6f, 0x22, 0x53, 0x6b, 0x84, 0x69, 0x75, 0xaa, 0xa3, 0x2a, 0xe7, + 0xde, 0xaa, 0xc3, 0xd1, 0xb4, 0x05, 0x22, 0x5c, 0x68, 0xf5}, + {0x13, 0xd3, 0x36, 0x9f, 0x3c, 0x78, 0xa0, 0x32, 0x40, 0xee, 0x16, 0xe9, + 0x11, 0x12, 0x66, 0xc7, 0x51, 0xad, 0xd8, 0x3c, 0xa1, 0xa3, 0x97, 0x74, + 0xd7, 0x45, 0xff, 0xa7, 0x88, 0x9e, 0x52, 0x17, 0x2e, 0xaa, 0x3a, 0xd2, + 0x35, 0xd8, 0xd5, 0x35, 0xfd, 0x65, 0x70, 0x9f, 0xa9, 0xf9, 0xfa, 0x23}}; + + const DataBuffer expected_data(tv[hash_type_], GetHashLength(hash_type_)); + HkdfExpandLabel(&k1_, hash_type_, nullptr, 0, kLabelMasterSecret, + strlen(kLabelMasterSecret), expected_data); +} + +TEST_P(TlsHkdfTest, BadExtractWrapperInput) { + PK11SymKey* key = nullptr; + + // Bad version. + EXPECT_EQ(SECFailure, + SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Bad ciphersuite. + EXPECT_EQ(SECFailure, + SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_SHA, + k1_.get(), k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Old ciphersuite. + EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + k2_.get(), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // NULL outparam.. + EXPECT_EQ(SECFailure, SSL_HkdfExtract(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + k2_.get(), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(nullptr, key); +} + +TEST_P(TlsHkdfTest, BadExpandLabelWrapperInput) { + PK11SymKey* key = nullptr; + static const char* kLabel = "label"; + + // Bad version. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Bad ciphersuite. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, TLS_RSA_WITH_NULL_MD5, + k1_.get(), nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Old ciphersuite. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_RSA_WITH_AES_128_CBC_SHA, k1_.get(), + nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null PRK. + EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel( + SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + nullptr, nullptr, 0, kLabel, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null, non-zero-length handshake hash. + EXPECT_EQ( + SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_2, TLS_AES_128_GCM_SHA256, + k1_.get(), nullptr, 2, kLabel, strlen(kLabel), &key)); + + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + // Null, non-zero-length label. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0, + nullptr, strlen(kLabel), &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null, empty label. + EXPECT_EQ(SECFailure, SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), + nullptr, 0, nullptr, 0, &key)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + // Null key pointer.. + EXPECT_EQ(SECFailure, + SSL_HkdfExpandLabel(SSL_LIBRARY_VERSION_TLS_1_3, + TLS_AES_128_GCM_SHA256, k1_.get(), nullptr, 0, + kLabel, strlen(kLabel), nullptr)); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + EXPECT_EQ(nullptr, key); +} + static const SSLHashType kHashTypes[] = {ssl_hash_sha256, ssl_hash_sha384}; INSTANTIATE_TEST_CASE_P(AllHashFuncs, TlsHkdfTest, ::testing::ValuesIn(kHashTypes)); diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc index c715a36a6..de91982f7 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.cc +++ b/security/nss/gtests/ssl_gtest/tls_protect.cc @@ -5,145 +5,98 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_protect.h" +#include "sslproto.h" #include "tls_filter.h" namespace nss_test { -AeadCipher::~AeadCipher() { - if (key_) { - PK11_FreeSymKey(key_); +static uint64_t FirstSeqno(bool dtls, uint16_t epoc) { + if (dtls) { + return static_cast(epoc) << 48; } + return 0; } -bool AeadCipher::Init(PK11SymKey *key, const uint8_t *iv) { - key_ = PK11_ReferenceSymKey(key); - if (!key_) return false; - - memcpy(iv_, iv, sizeof(iv_)); - return true; -} - -void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) { - memcpy(nonce, iv_, 12); - - for (size_t i = 0; i < 8; ++i) { - nonce[12 - (i + 1)] ^= seq & 0xff; - seq >>= 8; +TlsCipherSpec::TlsCipherSpec(bool dtls, uint16_t epoc) + : dtls_(dtls), + epoch_(epoc), + in_seqno_(FirstSeqno(dtls, epoc)), + out_seqno_(FirstSeqno(dtls, epoc)) {} + +bool TlsCipherSpec::SetKeys(SSLCipherSuiteInfo* cipherinfo, + PK11SymKey* secret) { + SSLAeadContext* ctx; + SECStatus rv = SSL_MakeAead(SSL_LIBRARY_VERSION_TLS_1_3, + cipherinfo->cipherSuite, secret, "", + 0, // Use the default labels. + &ctx); + if (rv != SECSuccess) { + return false; } - - DataBuffer d(nonce, 12); -} - -bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length, - const uint8_t *in, size_t inlen, uint8_t *out, - size_t *outlen, size_t maxlen) { - SECStatus rv; - unsigned int uoutlen = 0; - SECItem param = { - siBuffer, static_cast(params), - static_cast(param_length), - }; - - if (decrypt) { - rv = PK11_Decrypt(key_, mech_, ¶m, out, &uoutlen, maxlen, in, inlen); - } else { - rv = PK11_Encrypt(key_, mech_, ¶m, out, &uoutlen, maxlen, in, inlen); - } - *outlen = (int)uoutlen; - - return rv == SECSuccess; -} - -bool AeadCipherAesGcm::Aead(bool decrypt, const uint8_t *hdr, size_t hdr_len, - uint64_t seq, const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen, size_t maxlen) { - CK_GCM_PARAMS aeadParams; - unsigned char nonce[12]; - - memset(&aeadParams, 0, sizeof(aeadParams)); - aeadParams.pIv = nonce; - aeadParams.ulIvLen = sizeof(nonce); - aeadParams.pAAD = const_cast(hdr); - aeadParams.ulAADLen = hdr_len; - aeadParams.ulTagBits = 128; - - FormatNonce(seq, nonce); - return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams), - in, inlen, out, outlen, maxlen); -} - -bool AeadCipherChacha20Poly1305::Aead(bool decrypt, const uint8_t *hdr, - size_t hdr_len, uint64_t seq, - const uint8_t *in, size_t inlen, - uint8_t *out, size_t *outlen, - size_t maxlen) { - CK_NSS_AEAD_PARAMS aeadParams; - unsigned char nonce[12]; - - memset(&aeadParams, 0, sizeof(aeadParams)); - aeadParams.pNonce = nonce; - aeadParams.ulNonceLen = sizeof(nonce); - aeadParams.pAAD = const_cast(hdr); - aeadParams.ulAADLen = hdr_len; - aeadParams.ulTagLen = 16; - - FormatNonce(seq, nonce); - return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams), - in, inlen, out, outlen, maxlen); + aead_.reset(ctx); + return true; } -bool TlsCipherSpec::Init(uint16_t epoc, SSLCipherAlgorithm cipher, - PK11SymKey *key, const uint8_t *iv) { - epoch_ = epoc; - switch (cipher) { - case ssl_calg_aes_gcm: - aead_.reset(new AeadCipherAesGcm()); - break; - case ssl_calg_chacha20: - aead_.reset(new AeadCipherChacha20Poly1305()); - break; - default: - return false; +bool TlsCipherSpec::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + DataBuffer* plaintext) { + if (aead_ == nullptr) { + return false; } - - return aead_->Init(key, iv); -} - -bool TlsCipherSpec::Unprotect(const TlsRecordHeader &header, - const DataBuffer &ciphertext, - DataBuffer *plaintext) { // Make space. plaintext->Allocate(ciphertext.len()); auto header_bytes = header.header(); - size_t len; - bool ret = - aead_->Aead(true, header_bytes.data(), header_bytes.len(), - header.sequence_number(), ciphertext.data(), ciphertext.len(), - plaintext->data(), &len, plaintext->len()); - if (!ret) return false; + unsigned int len; + uint64_t seqno; + if (dtls_) { + seqno = header.sequence_number(); + } else { + seqno = in_seqno_; + } + SECStatus rv = + SSL_AeadDecrypt(aead_.get(), seqno, header_bytes.data(), + header_bytes.len(), ciphertext.data(), ciphertext.len(), + plaintext->data(), &len, plaintext->len()); + if (rv != SECSuccess) { + return false; + } - plaintext->Truncate(len); + RecordUnprotected(seqno); + plaintext->Truncate(static_cast(len)); return true; } -bool TlsCipherSpec::Protect(const TlsRecordHeader &header, - const DataBuffer &plaintext, - DataBuffer *ciphertext) { +bool TlsCipherSpec::Protect(const TlsRecordHeader& header, + const DataBuffer& plaintext, + DataBuffer* ciphertext) { + if (aead_ == nullptr) { + return false; + } // Make a padded buffer. - ciphertext->Allocate(plaintext.len() + 32); // Room for any plausible auth tag - size_t len; + unsigned int len; DataBuffer header_bytes; (void)header.WriteHeader(&header_bytes, 0, plaintext.len() + 16); - bool ret = - aead_->Aead(false, header_bytes.data(), header_bytes.len(), - header.sequence_number(), plaintext.data(), plaintext.len(), - ciphertext->data(), &len, ciphertext->len()); - if (!ret) return false; + uint64_t seqno; + if (dtls_) { + seqno = header.sequence_number(); + } else { + seqno = out_seqno_; + } + + SECStatus rv = + SSL_AeadEncrypt(aead_.get(), seqno, header_bytes.data(), + header_bytes.len(), plaintext.data(), plaintext.len(), + ciphertext->data(), &len, ciphertext->len()); + if (rv != SECSuccess) { + return false; + } + + RecordProtected(); ciphertext->Truncate(len); return true; diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h index 6f129a4eb..b1febf887 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.h +++ b/security/nss/gtests/ssl_gtest/tls_protect.h @@ -10,71 +10,48 @@ #include #include -#include "databuffer.h" #include "pk11pub.h" #include "sslt.h" +#include "sslexp.h" + +#include "databuffer.h" +#include "scoped_ptrs_ssl.h" namespace nss_test { class TlsRecordHeader; -class AeadCipher { - public: - AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {} - virtual ~AeadCipher(); - - bool Init(PK11SymKey *key, const uint8_t *iv); - virtual bool Aead(bool decrypt, const uint8_t *hdr, size_t hdr_len, - uint64_t seq, const uint8_t *in, size_t inlen, uint8_t *out, - size_t *outlen, size_t maxlen) = 0; - - protected: - void FormatNonce(uint64_t seq, uint8_t *nonce); - bool AeadInner(bool decrypt, void *params, size_t param_length, - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen, - size_t maxlen); - - CK_MECHANISM_TYPE mech_; - PK11SymKey *key_; - uint8_t iv_[12]; -}; - -class AeadCipherChacha20Poly1305 : public AeadCipher { - public: - AeadCipherChacha20Poly1305() : AeadCipher(CKM_NSS_CHACHA20_POLY1305) {} - - protected: - bool Aead(bool decrypt, const uint8_t *hdr, size_t hdr_len, uint64_t seq, - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen, - size_t maxlen); -}; - -class AeadCipherAesGcm : public AeadCipher { - public: - AeadCipherAesGcm() : AeadCipher(CKM_AES_GCM) {} - - protected: - bool Aead(bool decrypt, const uint8_t *hdr, size_t hdr_len, uint64_t seq, - const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen, - size_t maxlen); -}; - // Our analog of ssl3CipherSpec class TlsCipherSpec { public: - TlsCipherSpec() : epoch_(0), aead_() {} + TlsCipherSpec(bool dtls, uint16_t epoc); + bool SetKeys(SSLCipherSuiteInfo* cipherinfo, PK11SymKey* secret); - bool Init(uint16_t epoch, SSLCipherAlgorithm cipher, PK11SymKey *key, - const uint8_t *iv); + bool Protect(const TlsRecordHeader& header, const DataBuffer& plaintext, + DataBuffer* ciphertext); + bool Unprotect(const TlsRecordHeader& header, const DataBuffer& ciphertext, + DataBuffer* plaintext); - bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext, - DataBuffer *ciphertext); - bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext, - DataBuffer *plaintext); uint16_t epoch() const { return epoch_; } + uint64_t next_in_seqno() const { return in_seqno_; } + void RecordUnprotected(uint64_t seqno) { + // Reordering happens, so don't let this go backwards. + in_seqno_ = (std::max)(in_seqno_, seqno + 1); + } + uint64_t next_out_seqno() { return out_seqno_; } + void RecordProtected() { out_seqno_++; } + + void RecordDropped() { record_dropped_ = true; } + bool record_dropped() const { return record_dropped_; } + + bool is_protected() const { return aead_ != nullptr; } private: + bool dtls_; uint16_t epoch_; - std::unique_ptr aead_; + uint64_t in_seqno_; + uint64_t out_seqno_; + bool record_dropped_ = false; + ScopedSSLAeadContext aead_; }; } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc new file mode 100644 index 000000000..0882ef7ef --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_subcerts_unittest.cc @@ -0,0 +1,568 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include + +#include "prtime.h" +#include "secerr.h" +#include "ssl.h" + +#include "gtest_utils.h" +#include "tls_agent.h" +#include "tls_connect.h" + +namespace nss_test { + +const std::string kEcdsaDelegatorId = TlsAgent::kDelegatorEcdsa256; +const std::string kRsaeDelegatorId = TlsAgent::kDelegatorRsae2048; +const std::string kDCId = TlsAgent::kServerEcdsa256; +const SSLSignatureScheme kDCScheme = ssl_sig_ecdsa_secp256r1_sha256; +const PRUint32 kDCValidFor = 60 * 60 * 24 * 7 /* 1 week (seconds */; + +static void CheckPreliminaryPeerDelegCred( + const std::shared_ptr& client, bool expected, + PRUint32 key_bits = 0, SSLSignatureScheme sig_scheme = ssl_sig_none) { + EXPECT_NE(0U, (client->pre_info().valuesSet & ssl_preinfo_peer_auth)); + EXPECT_EQ(expected, client->pre_info().peerDelegCred); + if (expected) { + EXPECT_EQ(key_bits, client->pre_info().authKeyBits); + EXPECT_EQ(sig_scheme, client->pre_info().signatureScheme); + } +} + +static void CheckPeerDelegCred(const std::shared_ptr& client, + bool expected, PRUint32 key_bits = 0) { + EXPECT_EQ(expected, client->info().peerDelegCred); + EXPECT_EQ(expected, client->pre_info().peerDelegCred); + if (expected) { + EXPECT_EQ(key_bits, client->info().authKeyBits); + EXPECT_EQ(key_bits, client->pre_info().authKeyBits); + EXPECT_EQ(client->info().signatureScheme, + client->pre_info().signatureScheme); + } +} + +// AuthCertificate callbacks to simulate DC validation +static SECStatus CheckPreliminaryDC(TlsAgent* agent, bool checksig, + bool isServer) { + agent->UpdatePreliminaryChannelInfo(); + EXPECT_EQ(PR_TRUE, agent->pre_info().peerDelegCred); + EXPECT_EQ(256U, agent->pre_info().authKeyBits); + EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, agent->pre_info().signatureScheme); + return SECSuccess; +} + +static SECStatus CheckPreliminaryNoDC(TlsAgent* agent, bool checksig, + bool isServer) { + agent->UpdatePreliminaryChannelInfo(); + EXPECT_EQ(PR_FALSE, agent->pre_info().peerDelegCred); + return SECSuccess; +} + +// AuthCertificate callbacks for modifying DC attributes. +// This allows testing tls13_CertificateVerify for rejection +// of DC attributes that have changed since AuthCertificateHook +// may have handled them. +static SECStatus ModifyDCAuthKeyBits(TlsAgent* agent, bool checksig, + bool isServer) { + return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(), + PR_TRUE, // Change authKeyBits + PR_FALSE); // Change scheme +} + +static SECStatus ModifyDCScheme(TlsAgent* agent, bool checksig, bool isServer) { + return SSLInt_TweakChannelInfoForDC(agent->ssl_fd(), + PR_FALSE, // Change authKeyBits + PR_TRUE); // Change scheme +} + +// Attempt to configure a DC when either the DC or DC private key is missing. +TEST_P(TlsConnectTls13, DCNotConfigured) { + // Load and delegate the credential. + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor, + now(), &dc); + + // Attempt to install the certificate and DC with a missing DC private key. + EnsureTlsSetup(); + SSLExtraServerCertData extra_data_missing_dc_priv_key = { + ssl_auth_null, nullptr, nullptr, nullptr, &dc, nullptr}; + EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true, + &extra_data_missing_dc_priv_key)); + + // Attempt to install the certificate and with only the DC private key. + EnsureTlsSetup(); + SSLExtraServerCertData extra_data_missing_dc = { + ssl_auth_null, nullptr, nullptr, nullptr, nullptr, priv.get()}; + EXPECT_FALSE(server_->ConfigServerCert(kEcdsaDelegatorId, true, + &extra_data_missing_dc)); +} + +// Connected with ECDSA-P256. +TEST_P(TlsConnectTls13, DCConnectEcdsaP256) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 256); + EXPECT_EQ(ssl_sig_ecdsa_secp256r1_sha256, client_->info().signatureScheme); +} + +// Connected with ECDSA-P521. +TEST_P(TlsConnectTls13, DCConnectEcdsaP521) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 521); + EXPECT_EQ(ssl_sig_ecdsa_secp521r1_sha512, client_->info().signatureScheme); +} + +// Connected with RSA-PSS, using an RSAE DC SPKI. +TEST_P(TlsConnectTls13, DCConnectRsaPssRsae) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_rsae_sha256, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 1024); + EXPECT_EQ(ssl_sig_rsa_pss_rsae_sha256, client_->info().signatureScheme); +} + +// Connected with RSA-PSS, using a RSAE Delegator SPKI. +TEST_P(TlsConnectTls13, DCConnectRsaeDelegator) { + Reset(kRsaeDelegatorId); + + static const SSLSignatureScheme kSchemes[] = {ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 1024); + EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme); +} + +// Connected with RSA-PSS, using a PSS SPKI. +TEST_P(TlsConnectTls13, DCConnectRsaPssPss) { + Reset(kEcdsaDelegatorId); + + // Need to enable PSS-PSS, which is not on by default. + static const SSLSignatureScheme kSchemes[] = {ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_rsa_pss_pss_sha256}; + client_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + server_->SetSignatureSchemes(kSchemes, PR_ARRAY_SIZE(kSchemes)); + + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential( + TlsAgent::kServerRsaPss, ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, true, 1024); + EXPECT_EQ(ssl_sig_rsa_pss_pss_sha256, client_->info().signatureScheme); +} + +// Generate a weak key. We can't do this in the fixture because certutil +// won't sign with such a tiny key. That's OK, because this is fast(ish). +static void GenerateWeakRsaKey(ScopedSECKEYPrivateKey& priv, + ScopedSECKEYPublicKey& pub) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(slot); + PK11RSAGenParams rsaparams; + // The absolute minimum size of RSA key that we can use with SHA-256 is + // 256bit (hash) + 256bit (salt) + 8 (start byte) + 8 (end byte) = 528. + rsaparams.keySizeInBits = 528; + rsaparams.pe = 65537; + + // Bug 1012786: PK11_GenerateKeyPair can fail if there is insufficient + // entropy to generate a random key. We can fake some. + for (int retry = 0; retry < 10; ++retry) { + SECKEYPublicKey* p_pub = nullptr; + priv.reset(PK11_GenerateKeyPair(slot.get(), CKM_RSA_PKCS_KEY_PAIR_GEN, + &rsaparams, &p_pub, false, false, nullptr)); + pub.reset(p_pub); + if (priv) { + return; + } + + ASSERT_FALSE(pub); + if (PORT_GetError() != SEC_ERROR_PKCS11_FUNCTION_FAILED) { + break; + } + + // https://xkcd.com/221/ + static const uint8_t FRESH_ENTROPY[16] = {4}; + ASSERT_EQ( + SECSuccess, + PK11_RandomUpdate( + const_cast(reinterpret_cast(FRESH_ENTROPY)), + sizeof(FRESH_ENTROPY))); + break; + } + ADD_FAILURE() << "Unable to generate an RSA key: " + << PORT_ErrorToName(PORT_GetError()); +} + +// Fail to connect with a weak RSA key. +TEST_P(TlsConnectTls13, DCWeakKey) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + + ScopedSECKEYPrivateKey dc_priv; + ScopedSECKEYPublicKey dc_pub; + GenerateWeakRsaKey(dc_priv, dc_pub); + ASSERT_TRUE(dc_priv); + + // Construct a DC. + StackSECItem dc; + TlsAgent::DelegateCredential(kEcdsaDelegatorId, dc_pub, + ssl_sig_rsa_pss_rsae_sha256, kDCValidFor, now(), + &dc); + + // Configure the DC on the server. + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, dc_priv.get()}; + EXPECT_TRUE(server_->ConfigServerCert(kEcdsaDelegatorId, true, &extra_data)); + + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + ConnectExpectAlert(client_, kTlsAlertInsufficientSecurity); +} + +class ReplaceDCSigScheme : public TlsHandshakeFilter { + public: + ReplaceDCSigScheme(const std::shared_ptr& a) + : TlsHandshakeFilter(a, {ssl_hs_certificate_verify}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + *output = input; + output->Write(0, ssl_sig_ecdsa_secp384r1_sha384, 2); + return CHANGE; + } +}; + +// Aborted because of incorrect DC signature algorithm indication. +TEST_P(TlsConnectTls13, DCAbortBadExpectedCertVerifyAlg) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + auto filter = MakeTlsFilter(server_); + filter->EnableDecryption(); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of invalid DC signature. +TEST_P(TlsConnectTls13, DCAbortBadSignature) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + + ScopedSECKEYPublicKey pub; + ScopedSECKEYPrivateKey priv; + EXPECT_TRUE(TlsAgent::LoadKeyPairFromCert(kDCId, &pub, &priv)); + + StackSECItem dc; + TlsAgent::DelegateCredential(kEcdsaDelegatorId, pub, kDCScheme, kDCValidFor, + now(), &dc); + ASSERT_TRUE(dc.data != nullptr); + + // Flip the first bit of the DC so that the signature is invalid. + dc.data[0] ^= 0x01; + + SSLExtraServerCertData extra_data = {ssl_auth_null, nullptr, nullptr, + nullptr, &dc, priv.get()}; + EXPECT_TRUE(server_->ConfigServerCert(kEcdsaDelegatorId, true, &extra_data)); + + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_BAD_SIGNATURE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of expired DC. +TEST_P(TlsConnectTls13, DCAbortExpired) { + Reset(kEcdsaDelegatorId); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + client_->EnableDelegatedCredentials(); + // When the client checks the time, it will be at least one second after the + // DC expired. + AdvanceTime((static_cast(kDCValidFor) + 1) * PR_USEC_PER_SEC); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + client_->CheckErrorCode(SSL_ERROR_DC_EXPIRED); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Aborted because of invalid key usage. +TEST_P(TlsConnectTls13, DCAbortBadKeyUsage) { + // The sever does not have the delegationUsage extension. + Reset(TlsAgent::kServerEcdsa256); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); +} + +// Connected without DC because of no client indication. +TEST_P(TlsConnectTls13, DCConnectNoClientSupport) { + Reset(kEcdsaDelegatorId); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_FALSE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because of no server DC. +TEST_P(TlsConnectTls13, DCConnectNoServerSupport) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because client doesn't support TLS 1.3. +TEST_P(TlsConnectTls13, DCConnectClientNoTls13) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Should fallback to TLS 1.2 and not negotiate a DC. + EXPECT_FALSE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because server doesn't support TLS 1.3. +TEST_P(TlsConnectTls13, DCConnectServerNoTls13) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(kDCId, kDCScheme, kDCValidFor, now()); + + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Should fallback to TLS 1.2 and not negotiate a DC. The client will still + // send the indication because it supports 1.3. + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Connected without DC because client doesn't support the signature scheme. +TEST_P(TlsConnectTls13, DCConnectExpectedCertVerifyAlgNotSupported) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + static const SSLSignatureScheme kClientSchemes[] = { + ssl_sig_ecdsa_secp256r1_sha256, + }; + client_->SetSignatureSchemes(kClientSchemes, PR_ARRAY_SIZE(kClientSchemes)); + + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + + auto cfilter = MakeTlsFilter( + client_, ssl_delegated_credentials_xtn); + Connect(); + + // Client sends indication, but the server doesn't send a DC. + EXPECT_TRUE(cfilter->captured()); + CheckPeerDelegCred(client_, false); +} + +// Check that preliminary channel info properly reflects the DC. +TEST_P(TlsConnectTls13, DCCheckPreliminaryInfo) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa256, + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now()); + + auto filter = MakeTlsFilter(server_); + filter->SetHandshakeTypes( + {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished}); + filter->EnableDecryption(); + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + + client_->SetAuthCertificateCallback(CheckPreliminaryDC); + client_->Handshake(); // Process response + + client_->UpdatePreliminaryChannelInfo(); + CheckPreliminaryPeerDelegCred(client_, true, 256, + ssl_sig_ecdsa_secp256r1_sha256); +} + +// Check that preliminary channel info properly reflects a lack of DC. +TEST_P(TlsConnectTls13, DCCheckPreliminaryInfoNoDC) { + Reset(kEcdsaDelegatorId); + EnsureTlsSetup(); + client_->EnableDelegatedCredentials(); + auto filter = MakeTlsFilter(server_); + filter->SetHandshakeTypes( + {kTlsHandshakeCertificateVerify, kTlsHandshakeFinished}); + filter->EnableDecryption(); + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + + client_->SetAuthCertificateCallback(CheckPreliminaryNoDC); + client_->Handshake(); // Process response + + client_->UpdatePreliminaryChannelInfo(); + CheckPreliminaryPeerDelegCred(client_, false); +} + +// Tweak the scheme in between |Cert| and |CertVerify|. +TEST_P(TlsConnectTls13, DCRejectModifiedDCScheme) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + client_->SetAuthCertificateCallback(ModifyDCScheme); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); +} + +// Tweak the authKeyBits in between |Cert| and |CertVerify|. +TEST_P(TlsConnectTls13, DCRejectModifiedDCAuthKeyBits) { + Reset(kEcdsaDelegatorId); + client_->EnableDelegatedCredentials(); + client_->SetAuthCertificateCallback(ModifyDCAuthKeyBits); + server_->AddDelegatedCredential(TlsAgent::kServerEcdsa521, + ssl_sig_ecdsa_secp521r1_sha512, kDCValidFor, + now()); + ConnectExpectAlert(client_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + client_->CheckErrorCode(SSL_ERROR_DC_CERT_VERIFY_ALG_MISMATCH); +} + +class DCDelegation : public ::testing::Test {}; + +TEST_F(DCDelegation, DCDelegations) { + PRTime now = PR_Now(); + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + ASSERT_TRUE(TlsAgent::LoadCertificate(kEcdsaDelegatorId, &cert, &priv)); + + ScopedSECKEYPublicKey pub_rsa; + ScopedSECKEYPrivateKey priv_rsa; + ASSERT_TRUE( + TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerRsa, &pub_rsa, &priv_rsa)); + + StackSECItem dc; + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_ecdsa_secp256r1_sha256, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + + // Using different PSS hashes should be OK. + EXPECT_EQ(SECSuccess, + SSL_DelegateCredential(cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_rsae_sha256, kDCValidFor, + now, &dc)); + // Make sure to reset |dc| after each success. + dc.Reset(); + EXPECT_EQ(SECSuccess, SSL_DelegateCredential( + cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc)); + dc.Reset(); + EXPECT_EQ(SECSuccess, SSL_DelegateCredential( + cert.get(), priv.get(), pub_rsa.get(), + ssl_sig_rsa_pss_pss_sha384, kDCValidFor, now, &dc)); + dc.Reset(); + + ScopedSECKEYPublicKey pub_ecdsa; + ScopedSECKEYPrivateKey priv_ecdsa; + ASSERT_TRUE(TlsAgent::LoadKeyPairFromCert(TlsAgent::kServerEcdsa256, + &pub_ecdsa, &priv_ecdsa)); + + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_rsa_pss_rsae_sha256, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + EXPECT_EQ(SECFailure, SSL_DelegateCredential( + cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_rsa_pss_pss_sha256, kDCValidFor, now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); + EXPECT_EQ(SECFailure, + SSL_DelegateCredential(cert.get(), priv.get(), pub_ecdsa.get(), + ssl_sig_ecdsa_secp384r1_sha384, kDCValidFor, + now, &dc)); + EXPECT_EQ(SSL_ERROR_INCORRECT_SIGNATURE_ALGORITHM, PORT_GetError()); +} + +} // namespace nss_test -- cgit v1.2.3