From f83f62e1bff0c2aedc32e67fe369ba923c5b104a Mon Sep 17 00:00:00 2001 From: JustOff Date: Sat, 9 Jun 2018 15:11:22 +0300 Subject: Update NSS to 3.36.4-RTM --- .../nss/gtests/ssl_gtest/bloomfilter_unittest.cc | 2 +- security/nss/gtests/ssl_gtest/libssl_internals.c | 4 - security/nss/gtests/ssl_gtest/libssl_internals.h | 1 - security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc | 80 +++--- .../nss/gtests/ssl_gtest/ssl_agent_unittest.cc | 5 +- security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc | 176 ++++++++----- .../nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc | 10 +- .../gtests/ssl_gtest/ssl_ciphersuite_unittest.cc | 12 +- .../nss/gtests/ssl_gtest/ssl_custext_unittest.cc | 25 +- .../nss/gtests/ssl_gtest/ssl_damage_unittest.cc | 24 +- security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc | 82 +++--- security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc | 78 +++--- security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc | 64 +++-- .../nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 222 +++++++++-------- .../nss/gtests/ssl_gtest/ssl_fragment_unittest.cc | 4 +- security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc | 45 ++-- security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc | 91 ++++--- .../nss/gtests/ssl_gtest/ssl_keylog_unittest.cc | 4 +- .../nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | 77 ++++-- .../nss/gtests/ssl_gtest/ssl_record_unittest.cc | 14 +- .../gtests/ssl_gtest/ssl_resumption_unittest.cc | 276 +++++++++++++++------ security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc | 94 ++++--- .../nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc | 24 +- .../gtests/ssl_gtest/ssl_tls13compat_unittest.cc | 68 +++-- .../ssl_gtest/ssl_v2_client_hello_unittest.cc | 8 +- .../nss/gtests/ssl_gtest/ssl_version_unittest.cc | 39 ++- .../gtests/ssl_gtest/ssl_versionpolicy_unittest.cc | 6 +- security/nss/gtests/ssl_gtest/test_io.cc | 4 - security/nss/gtests/ssl_gtest/test_io.h | 6 +- security/nss/gtests/ssl_gtest/tls_agent.cc | 53 +++- security/nss/gtests/ssl_gtest/tls_agent.h | 38 ++- security/nss/gtests/ssl_gtest/tls_connect.cc | 36 ++- security/nss/gtests/ssl_gtest/tls_connect.h | 54 +++- security/nss/gtests/ssl_gtest/tls_filter.cc | 8 +- security/nss/gtests/ssl_gtest/tls_filter.h | 141 +++++++---- 35 files changed, 1142 insertions(+), 733 deletions(-) (limited to 'security/nss/gtests/ssl_gtest') diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc index 110cfa13a..6efe06ec7 100644 --- a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc +++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc @@ -105,4 +105,4 @@ static const BloomFilterConfig kBloomFilterConfigurations[] = { INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest, ::testing::ValuesIn(kBloomFilterConfigurations)); -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index 887d85278..17b4ffe49 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -332,10 +332,6 @@ void SSLInt_SetTicketLifetime(uint32_t lifetime) { ssl_ticket_lifetime = lifetime; } -void SSLInt_SetMaxEarlyDataSize(uint32_t size) { - ssl_max_early_data_size = size; -} - SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { sslSocket *ss; diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h index 95d4afdaf..3efb362c2 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.h +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -50,7 +50,6 @@ PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec); SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec); const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec); void SSLInt_SetTicketLifetime(uint32_t lifetime); -void SSLInt_SetMaxEarlyDataSize(uint32_t size); SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); void SSLInt_RolloverAntiReplay(void); diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc index a60295490..08781af71 100644 --- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -94,7 +94,7 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 { // Now run a true 0-RTT handshake, but capture the first packet. auto first_packet = std::make_shared(); - client_->SetPacketFilter(first_packet); + client_->SetFilter(first_packet); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); @@ -116,8 +116,7 @@ class TlsZeroRttReplayTest : public TlsConnectTls13 { // Capture the early_data extension, which should not appear. auto early_data_ext = - std::make_shared(ssl_tls13_early_data_xtn); - server_->SetPacketFilter(early_data_ext); + MakeTlsFilter(server_, ssl_tls13_early_data_xtn); // Finally, replay the ClientHello and force the server to consume it. Stop // after the server sends its first flight; the client will not be able to @@ -405,6 +404,9 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { // The client should abort the connection when sending a 0-rtt handshake but // the servers responds with a TLS 1.2 ServerHello. (with app data) TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast(strlen(k0RttData)); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->Set0RttEnabled(true); // set ticket_allow_early_data Connect(); @@ -422,27 +424,28 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { // Send the early data xtn in the CH, followed by early app data. The server // will fail right after sending its flight, when receiving the early data. client_->Set0RttEnabled(true); - ZeroRttSendReceive(true, false, [this]() { - client_->ExpectSendAlert(kTlsAlertIllegalParameter); - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - } - return true; - }); - - client_->Handshake(); - server_->Handshake(); - ASSERT_TRUE_WAIT( - (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000); + client_->Handshake(); // Send ClientHello. + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); - // DTLS will timeout as we bump the epoch when installing the early app data - // cipher suite. Thus the encrypted alert will be ignored. if (variant_ == ssl_variant_stream) { - // The server sends an alert when receiving the early app data record. - ASSERT_TRUE_WAIT( - (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), - 2000); + // When the server receives the early data, it will fail. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); + } else { + // If it's datagram, we just discard the early data. + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); } + + // The client now reads the ServerHello and fails. + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA); } static void CheckEarlyDataLimit(const std::shared_ptr& agent, @@ -455,10 +458,13 @@ static void CheckEarlyDataLimit(const std::shared_ptr& agent, } TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { + EnsureTlsSetup(); const char* big_message = "0123456789abcdef"; const size_t short_size = strlen(big_message) - 1; const PRInt32 short_length = static_cast(short_size); - SSLInt_SetMaxEarlyDataSize(static_cast(short_size)); + EXPECT_EQ(SECSuccess, + SSL_SetMaxEarlyDataSize(server_->ssl_fd(), + static_cast(short_size))); SetupForZeroRtt(); client_->Set0RttEnabled(true); @@ -510,8 +516,10 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { } TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { + EnsureTlsSetup(); + const size_t limit = 5; - SSLInt_SetMaxEarlyDataSize(limit); + EXPECT_EQ(SECSuccess, SSL_SetMaxEarlyDataSize(server_->ssl_fd(), limit)); SetupForZeroRtt(); client_->Set0RttEnabled(true); @@ -521,6 +529,8 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { client_->Handshake(); // Send ClientHello CheckEarlyDataLimit(client_, limit); + server_->Handshake(); // Process ClientHello, send server flight. + // Lift the limit on the client. EXPECT_EQ(SECSuccess, SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000)); @@ -534,21 +544,31 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { // This error isn't fatal for DTLS. ExpectAlert(server_, kTlsAlertUnexpectedMessage); } - server_->Handshake(); // Process ClientHello, send server flight. - server_->Handshake(); // Just to make sure that we don't read ahead. + + server_->Handshake(); // This reads the early data and maybe throws an error. + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + } else { + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + } CheckEarlyDataLimit(server_, limit); - // Attempt to read early data. + // Attempt to read early data. This will get an error. std::vector buf(strlen(message) + 1); EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity())); if (variant_ == ssl_variant_stream) { - server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PORT_GetError()); + } else { + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } - client_->Handshake(); // Process the handshake. - client_->Handshake(); // Process the alert. + client_->Handshake(); // Process the server's first flight. if (variant_ == ssl_variant_stream) { + client_->Handshake(); // Process the alert. client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + } else { + server_->Handshake(); // Finish connecting. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); } } @@ -583,7 +603,7 @@ TEST_P(TlsConnectTls13, ZeroRttOrdering) { // Now, coalesce the next three things from the client: early data, second // flight and 1-RTT data. auto coalesce = std::make_shared(); - client_->SetPacketFilter(coalesce); + client_->SetFilter(coalesce); // Send (and hold) early data. static const std::vector early_data = {3, 2, 1}; diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc index 0aa9a4c78..f0c57e8b1 100644 --- a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -160,9 +160,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) { SSL_LIBRARY_VERSION_TLS_1_3); agent_->StartConnect(); agent_->Set0RttEnabled(true); - auto filter = std::make_shared( - kTlsHandshakeClientHello); - agent_->SetPacketFilter(filter); + auto filter = + MakeTlsFilter(agent_, kTlsHandshakeClientHello); PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData)); EXPECT_EQ(-1, rv); int32_t err = PORT_GetError(); diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc index dbcdd92ea..7f2b2840d 100644 --- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -95,10 +95,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) { } // Offset is the position in the captured buffer where the signature sits. -static void CheckSigScheme( - std::shared_ptr& capture, size_t offset, - std::shared_ptr& peer, uint16_t expected_scheme, - size_t expected_size) { +static void CheckSigScheme(std::shared_ptr& capture, + size_t offset, std::shared_ptr& peer, + uint16_t expected_scheme, size_t expected_size) { EXPECT_LT(offset + 2U, capture->buffer().len()); uint32_t scheme = 0; @@ -114,9 +113,8 @@ static void CheckSigScheme( // in the default certificate. TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { EnsureTlsSetup(); - auto capture_ske = std::make_shared( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(capture_ske); + auto capture_ske = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); @@ -127,15 +125,14 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve"; EXPECT_EQ(ssl_grp_ec_curve25519, tmp); EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint"; - CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_sha256, 1024); + CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_rsae_sha256, + 1024); } TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { EnsureTlsSetup(); - auto capture_cert_verify = - std::make_shared( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + auto capture_cert_verify = MakeTlsFilter( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); Connect(); @@ -146,21 +143,20 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); - auto capture_cert_verify = - std::make_shared( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + auto capture_cert_verify = MakeTlsFilter( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); Connect(); CheckKeys(); - CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256, + 2048); } class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { public: - TlsZeroCertificateRequestSigAlgsFilter() - : TlsHandshakeFilter({kTlsHandshakeCertificateRequest}) {} + TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeCertificateRequest}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { @@ -205,12 +201,9 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { // supported_signature_algorithms in the CertificateRequest message. TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) { EnsureTlsSetup(); - auto filter = std::make_shared(); - server_->SetPacketFilter(filter); - auto capture_cert_verify = - std::make_shared( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + MakeTlsFilter(server_); + auto capture_cert_verify = MakeTlsFilter( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -358,8 +351,7 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { // The signature_algorithms extension is mandatory in TLS 1.3. TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { - client_->SetPacketFilter( - std::make_shared(ssl_signature_algorithms_xtn)); + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); @@ -368,8 +360,7 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { // TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and // only fails when the Finished is checked. TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { - client_->SetPacketFilter( - std::make_shared(ssl_signature_algorithms_xtn)); + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -387,11 +378,11 @@ class BeforeFinished : public TlsRecordFilter { enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE }; public: - BeforeFinished(std::shared_ptr& client, - std::shared_ptr& server, VoidFunction before_ccs, - VoidFunction before_finished) - : client_(client), - server_(server), + BeforeFinished(const std::shared_ptr& server, + const std::shared_ptr& client, + VoidFunction before_ccs, VoidFunction before_finished) + : TlsRecordFilter(server), + client_(client), before_ccs_(before_ccs), before_finished_(before_finished), state_(BEFORE_CCS) {} @@ -411,7 +402,7 @@ class BeforeFinished : public TlsRecordFilter { // but that means that they both get processed together. DataBuffer ccs; header.Write(&ccs, 0, body); - server_.lock()->SendDirect(ccs); + agent()->SendDirect(ccs); client_.lock()->Handshake(); state_ = AFTER_CCS; // Request that the original record be dropped by the filter. @@ -436,7 +427,6 @@ class BeforeFinished : public TlsRecordFilter { private: std::weak_ptr client_; - std::weak_ptr server_; VoidFunction before_ccs_; VoidFunction before_finished_; HandshakeState state_; @@ -461,11 +451,11 @@ class BeforeFinished13 : public PacketFilter { }; public: - BeforeFinished13(std::shared_ptr& client, - std::shared_ptr& server, + BeforeFinished13(const std::shared_ptr& server, + const std::shared_ptr& client, VoidFunction before_finished) - : client_(client), - server_(server), + : server_(server), + client_(client), before_finished_(before_finished), records_(0) {} @@ -497,8 +487,8 @@ class BeforeFinished13 : public PacketFilter { } private: - std::weak_ptr client_; std::weak_ptr server_; + std::weak_ptr client_; VoidFunction before_finished_; size_t records_; }; @@ -512,11 +502,9 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { // processed by the client, SSL_AuthCertificateComplete() is called. TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter( - std::make_shared(client_, server_, [this]() { - EXPECT_EQ(SECSuccess, - SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); - })); + MakeTlsFilter(server_, client_, [this]() { + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + }); Connect(); } @@ -544,13 +532,13 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) { TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { client_->EnableFalseStart(); - server_->SetPacketFilter(std::make_shared( - client_, server_, + MakeTlsFilter( + server_, client_, [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); }, [this]() { // Write something, which used to fail: bug 1235366. client_->SendData(10); - })); + }); Connect(); server_->SendData(10); @@ -560,8 +548,8 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { client_->EnableFalseStart(); client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter(std::make_shared( - client_, server_, + MakeTlsFilter( + server_, client_, []() { // Do nothing before CCS }, @@ -572,7 +560,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); EXPECT_TRUE(client_->can_falsestart_hook_called()); client_->SendData(10); - })); + }); Connect(); server_->SendData(10); @@ -606,7 +594,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); // The client should send nothing from here on. - client_->SetPacketFilter(std::make_shared()); + client_->SetFilter(std::make_shared()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); @@ -616,8 +604,33 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); - // Remove this before closing or the close_notify alert will trigger it. - client_->DeletePacketFilter(); + // Remove filter before closing or the close_notify alert will trigger it. + client_->ClearFilter(); +} + +TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + client_->Handshake(); // Send ClientKeyExchange and Finished + server_->Handshake(); // Send Finished + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // The client should send nothing from here on. + client_->SetFilter(std::make_shared()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + client_->ExpectSendAlert(kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // Fail + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); } // TLS 1.3 handles a delayed AuthComplete callback differently since the @@ -632,12 +645,12 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); // The client will send nothing until AuthCertificateComplete is called. - client_->SetPacketFilter(std::make_shared()); + client_->SetFilter(std::make_shared()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); // This should allow the handshake to complete now. - client_->DeletePacketFilter(); + client_->ClearFilter(); EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); // Send Finished server_->Handshake(); // Transition to connected and send NewSessionTicket @@ -645,6 +658,44 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); } +TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + + // The client will send nothing until AuthCertificateComplete is called. + client_->SetFilter(std::make_shared()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + ExpectAlert(client_, kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // This should now fail. + server_->Handshake(); // Get the error. + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +static SECStatus AuthCompleteFail(TlsAgent*, PRBool, PRBool) { + PORT_SetError(SSL_ERROR_BAD_CERTIFICATE); + return SECFailure; +} + +TEST_P(TlsConnectGeneric, AuthFailImmediate) { + client_->SetAuthCertificateCallback(AuthCompleteFail); + + StartConnect(); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); +} + static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = { ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr}; static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = { @@ -767,8 +818,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) { TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) { Reset(certificate_); auto capture = - std::make_shared(ssl_signature_algorithms_xtn); - client_->SetPacketFilter(capture); + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); TestSignatureSchemeConfig(client_); const DataBuffer& ext = capture->extension(); @@ -796,8 +846,8 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(TlsAgent::kServerRsaSign), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, - ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_sha256, - ssl_sig_rsa_pss_sha384))); + ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384))); // PSS with SHA-512 needs a bigger key to work. INSTANTIATE_TEST_CASE_P( SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, @@ -805,7 +855,7 @@ INSTANTIATE_TEST_CASE_P( TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kRsa2048), ::testing::Values(ssl_auth_rsa_sign), - ::testing::Values(ssl_sig_rsa_pss_sha512))); + ::testing::Values(ssl_sig_rsa_pss_rsae_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, @@ -842,4 +892,4 @@ INSTANTIATE_TEST_CASE_P( TlsAgent::kServerEcdsa384), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_sha1))); -} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc index 36ee104af..573c69c75 100644 --- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -180,9 +180,8 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) { server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); static const uint8_t val[] = {1}; - auto replacer = std::make_shared( - ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); - server_->SetPacketFilter(replacer); + auto replacer = MakeTlsFilter( + server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -192,8 +191,7 @@ TEST_P(TlsConnectGeneric, OcspSuccess) { EnsureTlsSetup(); client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); auto capture_ocsp = - std::make_shared(ssl_cert_status_xtn); - server_->SetPacketFilter(capture_ocsp); + MakeTlsFilter(server_, ssl_cert_status_xtn); // The value should be available during the AuthCertificateCallback client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig, @@ -245,4 +243,4 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) { Connect(); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 810656868..fa2238be7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -74,12 +74,12 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; - case ssl_sig_rsa_pss_sha256: - case ssl_sig_rsa_pss_sha384: + case ssl_sig_rsa_pss_rsae_sha256: + case ssl_sig_rsa_pss_rsae_sha384: Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; - case ssl_sig_rsa_pss_sha512: + case ssl_sig_rsa_pss_rsae_sha512: // You can't fit SHA-512 PSS in a 1024-bit key. Reset(TlsAgent::kRsa2048); auth_type_ = ssl_auth_rsa_sign; @@ -313,8 +313,8 @@ static const auto kDummySignatureSchemesParams = static SSLSignatureScheme kSignatureSchemesParamsArr[] = { ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256, - ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_sha256, - ssl_sig_rsa_pss_sha384, ssl_sig_rsa_pss_sha512, + ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384, ssl_sig_rsa_pss_rsae_sha512, }; #endif @@ -466,4 +466,4 @@ static const SecStatusParams kSecStatusTestValuesArr[] = { INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest, ::testing::ValuesIn(kSecStatusTestValuesArr)); -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc index dad944a1f..c2f582a93 100644 --- a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc @@ -150,9 +150,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) { client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter, nullptr, NoopExtensionHandler, nullptr); EXPECT_EQ(SECSuccess, rv); - auto capture = - std::make_shared(ssl_signed_cert_timestamp_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter( + client_, ssl_signed_cert_timestamp_xtn); Connect(); // So nothing will be sent. @@ -204,9 +203,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) { EXPECT_EQ(SECSuccess, rv); // Capture it to see what we got. - auto capture = - std::make_shared(ssl_signed_cert_timestamp_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter( + client_, ssl_signed_cert_timestamp_xtn); ConnectExpectAlert(server_, kTlsAlertDecodeError); @@ -246,8 +244,7 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) { EXPECT_EQ(SECSuccess, rv); // Capture it to see what we got. - auto capture = std::make_shared(extension_code); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter(client_, extension_code); // Handle it so that the handshake completes. rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, @@ -290,9 +287,8 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) { EXPECT_EQ(SECSuccess, rv); // Capture the extension from the ServerHello only and check it. - auto capture = std::make_shared(extension_code); + auto capture = MakeTlsFilter(server_, extension_code); capture->SetHandshakeTypes({kTlsHandshakeServerHello}); - server_->SetPacketFilter(capture); Connect(); @@ -329,9 +325,9 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) { EXPECT_EQ(SECSuccess, rv); // Capture the extension from the EncryptedExtensions only and check it. - auto capture = std::make_shared(extension_code); + auto capture = MakeTlsFilter(server_, extension_code); capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions}); - server_->SetTlsRecordFilter(capture); + capture->EnableDecryption(); Connect(); @@ -350,8 +346,7 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) { EXPECT_EQ(SECSuccess, rv); // Capture it to see what we got. - auto capture = std::make_shared(extension_code); - server_->SetPacketFilter(capture); + auto capture = MakeTlsFilter(server_, extension_code); client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -500,4 +495,4 @@ TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) { client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR); } -} // namespace "nss_test" +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc index d1668b823..b8836d7fc 100644 --- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -50,19 +50,19 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetPacketFilter(std::make_shared( + MakeTlsFilter( server_, client_, 0, // ServerHello. - [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); })); + [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }); ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } TEST_P(TlsConnectGenericPre13, DamageServerSignature) { EnsureTlsSetup(); - auto filter = - std::make_shared(kTlsHandshakeServerKeyExchange); - server_->SetTlsRecordFilter(filter); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); + filter->EnableDecryption(); ExpectAlert(client_, kTlsAlertDecryptError); ConnectExpectFail(); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); @@ -71,9 +71,9 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) { TEST_P(TlsConnectTls13, DamageServerSignature) { EnsureTlsSetup(); - auto filter = - std::make_shared(kTlsHandshakeCertificateVerify); - server_->SetTlsRecordFilter(filter); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeCertificateVerify); + filter->EnableDecryption(); ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } @@ -82,9 +82,9 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); - auto filter = - std::make_shared(kTlsHandshakeCertificateVerify); - client_->SetTlsRecordFilter(filter); + auto filter = MakeTlsFilter( + client_, kTlsHandshakeCertificateVerify); + filter->EnableDecryption(); server_->ExpectSendAlert(kTlsAlertDecryptError); // Do these handshakes by hand to avoid race condition on // the client processing the server's alert. @@ -100,4 +100,4 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index 4aa3bb639..cdafa7a84 100644 --- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -24,7 +24,7 @@ TEST_P(TlsConnectGeneric, ConnectDhe) { EnableOnlyDheCiphers(); Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { @@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { client_->ConfigNamedGroups(kAllDHEGroups); auto groups_capture = - std::make_shared(ssl_supported_groups_xtn); + std::make_shared(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared(ssl_tls13_key_share_xtn); + std::make_shared(client_, ssl_tls13_key_share_xtn); std::vector> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared(captures)); + client_->SetFilter(std::make_shared(captures)); Connect(); @@ -61,12 +61,12 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) { EnableOnlyDheCiphers(); client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); auto groups_capture = - std::make_shared(ssl_supported_groups_xtn); + std::make_shared(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared(ssl_tls13_key_share_xtn); + std::make_shared(client_, ssl_tls13_key_share_xtn); std::vector> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared(captures)); + client_->SetFilter(std::make_shared(captures)); Connect(); @@ -103,8 +103,8 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { public: - TlsDheServerKeyExchangeDamager() - : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} + TlsDheServerKeyExchangeDamager(const std::shared_ptr& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { @@ -122,7 +122,7 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { EnableOnlyDheCiphers(); client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); - server_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -141,8 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { kYZeroPad }; - TlsDheSkeChangeY(uint8_t handshake_type, ChangeYTo change) - : TlsHandshakeFilter({handshake_type}), change_Y_(change) {} + TlsDheSkeChangeY(const std::shared_ptr& agent, + uint8_t handshake_type, ChangeYTo change) + : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {} protected: void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, @@ -207,8 +208,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { public: - TlsDheSkeChangeYServer(ChangeYTo change, bool modify) - : TlsDheSkeChangeY(kTlsHandshakeServerKeyExchange, change), + TlsDheSkeChangeYServer(const std::shared_ptr& agent, + ChangeYTo change, bool modify) + : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change), modify_(modify), p_() {} @@ -245,9 +247,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { public: TlsDheSkeChangeYClient( - ChangeYTo change, + const std::shared_ptr& agent, ChangeYTo change, std::shared_ptr server_filter) - : TlsDheSkeChangeY(kTlsHandshakeClientKeyExchange, change), + : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change), server_filter_(server_filter) {} protected: @@ -282,8 +284,7 @@ TEST_P(TlsDamageDHYTest, DamageServerY) { client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - server_->SetPacketFilter( - std::make_shared(change, true)); + MakeTlsFilter(server_, change, true); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(client_, kTlsAlertDecryptError); @@ -312,14 +313,12 @@ TEST_P(TlsDamageDHYTest, DamageClientY) { client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } // The filter on the server is required to capture the prime. - auto server_filter = - std::make_shared(TlsDheSkeChangeY::kYZero, false); - server_->SetPacketFilter(server_filter); + auto server_filter = MakeTlsFilter( + server_, TlsDheSkeChangeY::kYZero, false); // The client filter does the damage. TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - client_->SetPacketFilter( - std::make_shared(change, server_filter)); + MakeTlsFilter(client_, change, server_filter); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(server_, kTlsAlertDecryptError); @@ -358,7 +357,9 @@ INSTANTIATE_TEST_CASE_P( class TlsDheSkeMakePEven : public TlsHandshakeFilter { public: - TlsDheSkeMakePEven() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} + TlsDheSkeMakePEven(const std::shared_ptr& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { @@ -379,7 +380,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter { // Even without requiring named groups, an even value for p is bad news. TEST_P(TlsConnectGenericPre13, MakeDhePEven) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -389,7 +390,9 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) { class TlsDheSkeZeroPadP : public TlsHandshakeFilter { public: - TlsDheSkeZeroPadP() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} + TlsDheSkeZeroPadP(const std::shared_ptr& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { @@ -407,7 +410,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter { // Zero padding only causes signature failure. TEST_P(TlsConnectGenericPre13, PadDheP) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(server_); ConnectExpectAlert(client_, kTlsAlertDecryptError); @@ -455,7 +458,7 @@ TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) { Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // Same test but for TLS 1.3. This has to fail. @@ -499,8 +502,8 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { Connect(); client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); - client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); - server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, MismatchDHE) { @@ -524,29 +527,28 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); EnableOnlyDheCiphers(); auto clientCapture = - std::make_shared(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(clientCapture); + MakeTlsFilter(client_, ssl_tls13_pre_shared_key_xtn); auto serverCapture = - std::make_shared(ssl_tls13_pre_shared_key_xtn); - server_->SetPacketFilter(serverCapture); + MakeTlsFilter(server_, ssl_tls13_pre_shared_key_xtn); ExpectResumption(RESUME_TICKET); Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); ASSERT_LT(0UL, clientCapture->extension().len()); ASSERT_LT(0UL, serverCapture->extension().len()); } class TlsDheSkeChangeSignature : public TlsHandshakeFilter { public: - TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len) - : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}), + TlsDheSkeChangeSignature(const std::shared_ptr& agent, + uint16_t version, const uint8_t* data, size_t len) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}), version_(version), data_(data), len_(len) {} @@ -595,8 +597,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { const std::vector client_groups = {ssl_grp_ffdhe_2048}; client_->ConfigNamedGroups(client_groups); - server_->SetPacketFilter(std::make_shared( - version_, kBogusDheSignature, sizeof(kBogusDheSignature))); + MakeTlsFilter(server_, version_, kBogusDheSignature, + sizeof(kBogusDheSignature)); ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc index c059e9938..ee8906deb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -22,13 +22,13 @@ extern "C" { namespace nss_test { TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) { - client_->SetPacketFilter(std::make_shared(0x1)); + client_->SetFilter(std::make_shared(0x1)); Connect(); SendReceive(); } TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) { - server_->SetPacketFilter(std::make_shared(0x1)); + server_->SetFilter(std::make_shared(0x1)); Connect(); SendReceive(); } @@ -37,32 +37,32 @@ TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) { // flights that they send. Note: In DTLS 1.3, the shorter handshake means that // this will also drop some application data, so we can't call SendReceive(). TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) { - client_->SetPacketFilter(std::make_shared(0x15)); - server_->SetPacketFilter(std::make_shared(0x5)); + client_->SetFilter(std::make_shared(0x15)); + server_->SetFilter(std::make_shared(0x5)); Connect(); } // This drops the server's first flight three times. TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) { - server_->SetPacketFilter(std::make_shared(0x7)); + server_->SetFilter(std::make_shared(0x7)); Connect(); } // This drops the client's second flight once TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) { - client_->SetPacketFilter(std::make_shared(0x2)); + client_->SetFilter(std::make_shared(0x2)); Connect(); } // This drops the client's second flight three times. TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) { - client_->SetPacketFilter(std::make_shared(0xe)); + client_->SetFilter(std::make_shared(0xe)); Connect(); } // This drops the server's second flight three times. TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { - server_->SetPacketFilter(std::make_shared(0xe)); + server_->SetFilter(std::make_shared(0xe)); Connect(); } @@ -74,7 +74,7 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 { expected_client_acks_(0), expected_server_acks_(1) {} - void SetUp() { + void SetUp() override { TlsConnectDatagram13::SetUp(); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); SetFilters(); @@ -82,12 +82,8 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 { void SetFilters() { EnsureTlsSetup(); - client_->SetPacketFilter(client_filters_.chain_); - client_filters_.ack_->SetAgent(client_.get()); - client_filters_.ack_->EnableDecryption(); - server_->SetPacketFilter(server_filters_.chain_); - server_filters_.ack_->SetAgent(server_.get()); - server_filters_.ack_->EnableDecryption(); + client_filters_.Init(client_); + server_filters_.Init(server_); } void HandshakeAndAck(const std::shared_ptr& agent) { @@ -119,11 +115,17 @@ class TlsDropDatagram13 : public TlsConnectDatagram13 { class DropAckChain { public: DropAckChain() - : records_(std::make_shared()), - ack_(std::make_shared(content_ack)), - drop_(std::make_shared(0, false)), - chain_(std::make_shared( - ChainedPacketFilterInit({records_, ack_, drop_}))) {} + : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {} + + void Init(const std::shared_ptr& agent) { + records_ = std::make_shared(agent); + ack_ = std::make_shared(agent, content_ack); + ack_->EnableDecryption(); + drop_ = std::make_shared(agent, 0, false); + chain_ = std::make_shared( + ChainedPacketFilterInit({records_, ack_, drop_})); + agent->SetFilter(chain_); + } const TlsRecord& record(size_t i) const { return records_->record(i); } @@ -227,7 +229,7 @@ TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) { HandshakeAndAck(client_); expected_client_acks_ = 1; CheckedHandshakeSendReceive(); - CheckAcks(client_filters_, 0, {0}); + CheckAcks(client_filters_, 0, {0}); // ServerHello CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); } @@ -257,7 +259,7 @@ TEST_F(TlsDropDatagram13, DropServerAckOnce) { CheckPostHandshake(); // There should be two copies of the finished ACK CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); } // Drop the client certificate verify. @@ -276,10 +278,9 @@ TEST_F(TlsDropDatagram13, DropClientCertVerify) { // Ack of the whole client handshake. CheckAcks( server_filters_, 1, - {0x0002000000000000ULL, // CH (we drop everything after this on client) - 0x0002000000000003ULL, // CT (2) - 0x0002000000000004ULL} // FIN (2) - ); + {0x0002000000000000ULL, // CH (we drop everything after this on client) + 0x0002000000000003ULL, // CT (2) + 0x0002000000000004ULL}); // FIN (2) } // Shrink the MTU down so that certs get split and drop the first piece. @@ -303,10 +304,9 @@ TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); CheckedHandshakeSendReceive(); CheckAcks(client_filters_, 0, - {0, // SH - 0x0002000000000000ULL, // EE - 0x0002000000000002ULL} // CT2 - ); + {0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000002ULL}); // CT2 CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); } @@ -540,7 +540,10 @@ TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) { ExpectEarlyDataAccepted(true); CheckConnected(); SendReceive(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + EXPECT_EQ(0U, client_filters_.ack_->count()); + CheckAcks(server_filters_, 0, + {0x0001000000000001ULL, // EOED + 0x0002000000000000ULL}); // Finished } TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) { @@ -558,7 +561,9 @@ TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) { CheckConnected(); SendReceive(); CheckAcks(client_filters_, 0, {0}); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 0, + {0x0001000000000002ULL, // EOED + 0x0002000000000000ULL}); // Finished } class TlsReorderDatagram13 : public TlsDropDatagram13 { @@ -688,6 +693,7 @@ TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { kTlsHandshakeType, DataBuffer(buf, sizeof(buf)))); server_->Handshake(); EXPECT_EQ(2UL, server_filters_.ack_->count()); + // The server acknowledges client Finished twice. CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); } @@ -746,7 +752,9 @@ TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { ReSend(TlsAgent::CLIENT, std::vector({1, 0, 2})); server_->Handshake(); CheckConnected(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); uint8_t buf[8]; rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); EXPECT_EQ(-1, rv); @@ -783,7 +791,9 @@ TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { ReSend(TlsAgent::CLIENT, std::vector({1, 2, 0})); server_->Handshake(); CheckConnected(); - CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); uint8_t buf[8]; rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); EXPECT_EQ(-1, rv); diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc index e0f8b1f55..3c7cd2ecf 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -69,20 +69,19 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { server_->ConfigNamedGroups(groups); Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care. TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) { EnsureTlsSetup(); - auto hrr_capture = std::make_shared( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(hrr_capture); + auto hrr_capture = MakeTlsFilter( + server_, kTlsHandshakeHelloRetryRequest); const std::vector groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3, hrr_capture->buffer().len() != 0); } @@ -112,7 +111,7 @@ TEST_P(TlsKeyExchangeTest, P384Priority) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); std::vector shares = {ssl_grp_ec_secp384r1}; CheckKEXDetails(groups, shares); @@ -129,7 +128,7 @@ TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); std::vector shares = {ssl_grp_ec_secp384r1}; std::vector expectedGroups = {ssl_grp_ec_secp384r1, @@ -147,7 +146,7 @@ TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { std::vector shares = {ssl_grp_ec_secp384r1}; @@ -172,7 +171,7 @@ TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { @@ -188,13 +187,13 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { public: - TlsKeyExchangeGroupCapture() - : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}), + TlsKeyExchangeGroupCapture(const std::shared_ptr &agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}), group_(ssl_grp_none) {} SSLNamedGroup group() const { return group_; } @@ -221,10 +220,8 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { // P-256 is supported by the client (<= 1.2 only). TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared(ssl_supported_groups_xtn)); - auto group_capture = std::make_shared(); - server_->SetPacketFilter(group_capture); + MakeTlsFilter(client_, ssl_supported_groups_xtn); + auto group_capture = MakeTlsFilter(server_); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); @@ -236,8 +233,7 @@ TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { // Supported groups is mandatory in TLS 1.3. TEST_P(TlsConnectTls13, DropSupportedGroupExtension) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared(ssl_supported_groups_xtn)); + MakeTlsFilter(client_, ssl_supported_groups_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); @@ -276,7 +272,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); CheckConnected(); // The renegotiation has to use the same preferences as the original session. @@ -284,7 +280,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { client_->StartRenegotiate(); Handshake(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsKeyExchangeTest, Curve25519) { @@ -318,7 +314,7 @@ TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } #ifndef NSS_DISABLE_TLS_1_3 @@ -337,7 +333,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_secp256r1}; CheckKEXDetails(client_groups, shares); } @@ -357,7 +353,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares); } @@ -379,7 +375,7 @@ TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -401,7 +397,7 @@ TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -423,7 +419,7 @@ TEST_P(TlsKeyExchangeTest13, Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -445,7 +441,7 @@ TEST_P(TlsKeyExchangeTest13, Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -507,7 +503,7 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { // The server would accept 25519 but its preferred group (P256) has to win. CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector shares = {ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1}; CheckKEXDetails(client_groups, shares); @@ -516,7 +512,8 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { // Replace the point in the client key exchange message with an empty one class ECCClientKEXFilter : public TlsHandshakeFilter { public: - ECCClientKEXFilter() : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}) {} + ECCClientKEXFilter(const std::shared_ptr &client) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, @@ -532,7 +529,8 @@ class ECCClientKEXFilter : public TlsHandshakeFilter { // Replace the point in the server key exchange message with an empty one class ECCServerKEXFilter : public TlsHandshakeFilter { public: - ECCServerKEXFilter() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} + ECCServerKEXFilter(const std::shared_ptr &server) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, @@ -550,15 +548,13 @@ class ECCServerKEXFilter : public TlsHandshakeFilter { }; TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) { - // add packet filter - server_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH); } TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { - // add packet filter - client_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(client_); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH); } diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 4142ab07a..0453dabdb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -19,8 +19,9 @@ namespace nss_test { class TlsExtensionTruncator : public TlsExtensionFilter { public: - TlsExtensionTruncator(uint16_t extension, size_t length) - : extension_(extension), length_(length) {} + TlsExtensionTruncator(const std::shared_ptr& agent, + uint16_t extension, size_t length) + : TlsExtensionFilter(agent), extension_(extension), length_(length) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter { class TlsExtensionDamager : public TlsExtensionFilter { public: - TlsExtensionDamager(uint16_t extension, size_t index) - : extension_(extension), index_(index) {} + TlsExtensionDamager(const std::shared_ptr& agent, + uint16_t extension, size_t index) + : TlsExtensionFilter(agent), extension_(extension), index_(index) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -63,8 +65,11 @@ class TlsExtensionDamager : public TlsExtensionFilter { class TlsExtensionAppender : public TlsHandshakeFilter { public: - TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {} + TlsExtensionAppender(const std::shared_ptr& agent, + uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : TlsHandshakeFilter(agent, {handshake_type}), + extension_(ext), + data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -124,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase { void ClientHelloErrorTest(std::shared_ptr filter, uint8_t desc = kTlsAlertDecodeError) { - client_->SetPacketFilter(filter); + client_->SetFilter(filter); ConnectExpectAlert(server_, desc); } void ServerHelloErrorTest(std::shared_ptr filter, uint8_t desc = kTlsAlertDecodeError) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, desc); } @@ -156,7 +161,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase { StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. - client_->SetPacketFilter(std::make_shared(type)); + MakeTlsFilter(client_, type); Handshake(); client_->CheckErrorCode(client_error); server_->CheckErrorCode(server_error); @@ -197,8 +202,8 @@ class TlsExtensionTest13 void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { DataBuffer versions_buf(buf, len); - client_->SetPacketFilter(std::make_shared( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -209,8 +214,8 @@ class TlsExtensionTest13 size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); - client_->SetPacketFilter(std::make_shared( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectFail(); } }; @@ -241,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase, TEST_P(TlsExtensionTestGeneric, DamageSniLength) { ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, 1)); + std::make_shared(client_, ssl_server_name_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, 4)); + std::make_shared(client_, ssl_server_name_xtn, 4)); } TEST_P(TlsExtensionTestGeneric, TruncateSni) { ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, 7)); + std::make_shared(client_, ssl_server_name_xtn, 7)); } // A valid extension that appears twice will be reported as unsupported. TEST_P(TlsExtensionTestGeneric, RepeatSni) { DataBuffer extension; InitSimpleSni(&extension); - ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest(std::make_shared( + client_, ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); } // An SNI entry with zero length is considered invalid (strangely, not if it is @@ -272,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) { extension.Allocate(simple.len() + 3); extension.Write(0, static_cast(0), 3); extension.Write(3, simple); - ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptySni) { DataBuffer extension; extension.Allocate(2); extension.Write(0, static_cast(0), 2); - ClientHelloErrorTest( - std::make_shared(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { EnableAlpn(); DataBuffer extension; ClientHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } @@ -299,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertNoApplicationProtocol); } TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { EnableAlpn(); - ClientHelloErrorTest( - std::make_shared(ssl_app_layer_protocol_xtn, 1)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_app_layer_protocol_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { EnableAlpn(); // This will leave the length of the second entry, but no value. - ClientHelloErrorTest( - std::make_shared(ssl_app_layer_protocol_xtn, 5)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_app_layer_protocol_xtn, 5)); } TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { @@ -321,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { const uint8_t val[] = {0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + client_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { @@ -340,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { @@ -348,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { @@ -356,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { @@ -364,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { @@ -372,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { @@ -380,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { @@ -388,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared( - ssl_app_layer_protocol_xtn, extension), + server_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } TEST_P(TlsExtensionTestDtls, SrtpShort) { EnableSrtp(); ClientHelloErrorTest( - std::make_shared(ssl_use_srtp_xtn, 3)); + std::make_shared(client_, ssl_use_srtp_xtn, 3)); } TEST_P(TlsExtensionTestDtls, SrtpOdd) { EnableSrtp(); const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - std::make_shared(ssl_use_srtp_xtn, extension)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_use_srtp_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) { + const uint8_t val[] = {0x00, 0x02, 0xff, 0xff}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { const uint8_t val[] = {0x00, 0x01, 0x04}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { ClientHelloErrorTest( - std::make_shared(ssl_supported_groups_xtn), + std::make_shared(client_, ssl_supported_groups_xtn), version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError : kTlsAlertMissingExtension); } @@ -445,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { const uint8_t val[] = {0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { const uint8_t val[] = {0x01, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { const uint8_t val[] = {0x99}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { const uint8_t val[] = {0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // The extension has to contain a length. TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { DataBuffer extension; ClientHelloErrorTest(std::make_shared( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // This only works on TLS 1.2, since it relies on static RSA; otherwise libssl // picks the wrong cipher suite. TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { - const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, - ssl_sig_rsa_pss_sha384}; + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_rsae_sha384}; auto capture = - std::make_shared(ssl_signature_algorithms_xtn); + MakeTlsFilter(client_, ssl_signature_algorithms_xtn); client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); - client_->SetPacketFilter(capture); EnableOnlyStaticRsaCiphers(); Connect(); @@ -531,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { - ClientHelloErrorTest( - std::make_shared(ssl_tls13_key_share_xtn, 2), - kTlsAlertHandshakeFailure); + ClientHelloErrorTest(std::make_shared( + client_, ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); } // These tests only work in stream mode because the client sends a @@ -542,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { // packet gets dropped. TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared(ssl_tls13_key_share_xtn)); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -563,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -585,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -597,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { SetupForResume(); DataBuffer empty; - server_->SetPacketFilter(std::make_shared( - ssl_signature_algorithms_xtn, empty)); + MakeTlsFilter(server_, ssl_signature_algorithms_xtn, + empty); client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -618,8 +628,12 @@ typedef std::function class TlsPreSharedKeyReplacer : public TlsExtensionFilter { public: - TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function) - : identities_(), binders_(), function_(function) {} + TlsPreSharedKeyReplacer(const std::shared_ptr& agent, + TlsPreSharedKeyReplacerFunc function) + : TlsExtensionFilter(agent), + identities_(), + binders_(), + function_(function) {} static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, const std::unique_ptr& replace, @@ -733,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter { TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { SetupForResume(); - client_->SetPacketFilter(std::make_shared([]( - TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_[0].identity.Truncate(0); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -744,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -757,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -770,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { SetupForResume(); - client_->SetPacketFilter(std::make_shared( - [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -782,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); r->binders_.push_back(r->binders_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -797,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -809,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { SetupForResume(); - client_->SetPacketFilter(std::make_shared([]( - TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); + MakeTlsFilter( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_.push_back(r->binders_[0]); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -822,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { const uint8_t empty_buf[] = {0}; DataBuffer empty(empty_buf, 0); // Inject an unused extension after the PSK extension. - client_->SetPacketFilter(std::make_shared( - kTlsHandshakeClientHello, 0xffff, empty)); + MakeTlsFilter(client_, kTlsHandshakeClientHello, 0xffff, + empty); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -833,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { SetupForResume(); DataBuffer empty; - client_->SetPacketFilter(std::make_shared( - ssl_tls13_psk_key_exchange_modes_xtn)); + MakeTlsFilter(client_, + ssl_tls13_psk_key_exchange_modes_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); @@ -849,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { kTls13PskKe}; DataBuffer modes(ke_modes, sizeof(ke_modes)); - client_->SetPacketFilter(std::make_shared( - ssl_tls13_psk_key_exchange_modes_xtn, modes)); + MakeTlsFilter( + client_, ssl_tls13_psk_key_exchange_modes_xtn, modes); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -860,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - auto capture = std::make_shared( - ssl_tls13_psk_key_exchange_modes_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter( + client_, ssl_tls13_psk_key_exchange_modes_xtn); Connect(); EXPECT_FALSE(capture->captured()); } @@ -958,11 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, static uint8_t empty_buf[1] = {0}; DataBuffer empty(empty_buf, 0); auto filter = - std::make_shared(message, extension, empty); + MakeTlsFilter(server_, message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - server_->SetTlsRecordFilter(filter); - } else { - server_->SetPacketFilter(filter); + filter->EnableDecryption(); } } @@ -1078,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) { SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); EXPECT_EQ(SECSuccess, rv); - auto capture = std::make_shared(ssl_padding_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter(client_, ssl_padding_xtn); client_->StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc index 64b824786..f4940bf28 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -149,13 +149,13 @@ class RecordFragmenter : public PacketFilter { }; TEST_P(TlsConnectDatagram, FragmentClientPackets) { - client_->SetPacketFilter(std::make_shared()); + client_->SetFilter(std::make_shared()); Connect(); SendReceive(); } TEST_P(TlsConnectDatagram, FragmentServerPackets) { - server_->SetPacketFilter(std::make_shared()); + server_->SetFilter(std::make_shared()); Connect(); SendReceive(); } diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc index ab4c0eab7..99448321c 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -27,7 +27,8 @@ class TlsFuzzTest : public ::testing::Test {}; // Record the application data stream. class TlsApplicationDataRecorder : public TlsRecordFilter { public: - TlsApplicationDataRecorder() : buffer_() {} + TlsApplicationDataRecorder(const std::shared_ptr& agent) + : TlsRecordFilter(agent), buffer_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -106,16 +107,16 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) { DisableECDHEServerKeyReuse(); DataBuffer buffer; - client_->SetPacketFilter(std::make_shared(buffer)); - server_->SetPacketFilter(std::make_shared(buffer)); + MakeTlsFilter(client_, buffer); + MakeTlsFilter(server_, buffer); // Reset the RNG state. EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); Connect(); // Ensure the filters go away before |buffer| does. - client_->DeletePacketFilter(); - server_->DeletePacketFilter(); + client_->ClearFilter(); + server_->ClearFilter(); if (last.len() > 0) { EXPECT_EQ(last, buffer); @@ -133,10 +134,8 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { EnsureTlsSetup(); // Set up app data filters. - auto client_recorder = std::make_shared(); - client_->SetPacketFilter(client_recorder); - auto server_recorder = std::make_shared(); - server_->SetPacketFilter(server_recorder); + auto client_recorder = MakeTlsFilter(client_); + auto server_recorder = MakeTlsFilter(server_); Connect(); @@ -161,10 +160,9 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { FUZZ_P(TlsConnectGeneric, BogusClientFinished) { EnsureTlsSetup(); - auto i1 = std::make_shared( - kTlsHandshakeFinished, + MakeTlsFilter( + client_, kTlsHandshakeFinished, DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished))); - client_->SetPacketFilter(i1); Connect(); SendReceive(); } @@ -173,10 +171,9 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) { FUZZ_P(TlsConnectGeneric, BogusServerFinished) { EnsureTlsSetup(); - auto i1 = std::make_shared( - kTlsHandshakeFinished, + MakeTlsFilter( + server_, kTlsHandshakeFinished, DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished))); - server_->SetPacketFilter(i1); Connect(); SendReceive(); } @@ -187,7 +184,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) { uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsHandshakeCertificateVerify : kTlsHandshakeServerKeyExchange; - server_->SetPacketFilter(std::make_shared(msg_type)); + MakeTlsFilter(server_, msg_type); Connect(); SendReceive(); } @@ -197,8 +194,7 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); - client_->SetPacketFilter( - std::make_shared(kTlsHandshakeCertificateVerify)); + MakeTlsFilter(client_, kTlsHandshakeCertificateVerify); Connect(); } @@ -219,29 +215,28 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); - auto i1 = std::make_shared( - kTlsHandshakeNewSessionTicket); - server_->SetPacketFilter(i1); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeNewSessionTicket); Connect(); - std::cerr << "ticket" << i1->buffer() << std::endl; + std::cerr << "ticket" << filter->buffer() << std::endl; size_t offset = 4; /* lifetime */ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { offset += 4; /* ticket_age_add */ uint32_t nonce_len = 0; - EXPECT_TRUE(i1->buffer().Read(offset, 1, &nonce_len)); + EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len)); offset += 1 + nonce_len; } offset += 2 + /* ticket length */ 2; /* TLS_EX_SESS_TICKET_VERSION */ // Check the protocol version number. uint32_t tls_version = 0; - EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version)); + EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version)); EXPECT_EQ(version_, static_cast(tls_version)); // Check the cipher suite. uint32_t suite = 0; - EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite)); + EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite)); client_->CheckCipherSuite(static_cast(suite)); } } diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc index 93e19a720..05ae87034 100644 --- a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -35,17 +35,15 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // Send first ClientHello and send 0-RTT data auto capture_early_data = - std::make_shared(ssl_tls13_early_data_xtn); - client_->SetPacketFilter(capture_early_data); + MakeTlsFilter(client_, ssl_tls13_early_data_xtn); client_->Handshake(); EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen)); // 0-RTT write. EXPECT_TRUE(capture_early_data->captured()); // Send the HelloRetryRequest - auto hrr_capture = std::make_shared( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(hrr_capture); + auto hrr_capture = MakeTlsFilter( + server_, kTlsHandshakeHelloRetryRequest); server_->Handshake(); EXPECT_LT(0U, hrr_capture->buffer().len()); @@ -56,8 +54,7 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // Make a new capture for the early data. capture_early_data = - std::make_shared(ssl_tls13_early_data_xtn); - client_->SetPacketFilter(capture_early_data); + MakeTlsFilter(client_, ssl_tls13_early_data_xtn); // Complete the handshake successfully Handshake(); @@ -71,6 +68,10 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // packet. If the record is split into two packets, or there are multiple // handshake packets, this will break. class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter { + public: + CorrectMessageSeqAfterHrrFilter(const std::shared_ptr& agent) + : TlsRecordFilter(agent) {} + protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, @@ -131,8 +132,7 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { // Correct the DTLS message sequence number after an HRR. if (variant_ == ssl_variant_datagram) { - client_->SetPacketFilter( - std::make_shared()); + MakeTlsFilter(client_); } server_->SetPeer(client_); @@ -151,7 +151,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { class KeyShareReplayer : public TlsExtensionFilter { public: - KeyShareReplayer() {} + KeyShareReplayer(const std::shared_ptr& agent) + : TlsExtensionFilter(agent) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, @@ -178,7 +179,7 @@ class KeyShareReplayer : public TlsExtensionFilter { // server should reject this. TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { EnsureTlsSetup(); - client_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(client_); static const std::vector groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); @@ -192,7 +193,7 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { TEST_P(TlsConnectTls13, RetryWithTwoShares) { EnsureTlsSetup(); EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); - client_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(client_); static const std::vector groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; @@ -238,9 +239,9 @@ TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) { return ssl_hello_retry_accept; }; - auto capture = std::make_shared(ssl_tls13_cookie_xtn); + auto capture = + MakeTlsFilter(server_, ssl_tls13_cookie_xtn); capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(capture); static const std::vector groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); @@ -359,14 +360,14 @@ SSLHelloRetryRequestAction RetryHello(PRBool firstHello, TEST_P(TlsConnectTls13, RetryCallbackRetry) { EnsureTlsSetup(); - auto capture_hrr = std::make_shared( - ssl_hs_hello_retry_request); + auto capture_hrr = std::make_shared( + server_, ssl_hs_hello_retry_request); auto capture_key_share = - std::make_shared(ssl_tls13_key_share_xtn); + std::make_shared(server_, ssl_tls13_key_share_xtn); capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); std::vector> chain = {capture_hrr, capture_key_share}; - server_->SetPacketFilter(std::make_shared(chain)); + server_->SetFilter(std::make_shared(chain)); size_t cb_called = 0; EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), @@ -383,8 +384,7 @@ TEST_P(TlsConnectTls13, RetryCallbackRetry) { << "no key_share extension expected"; auto capture_cookie = - std::make_shared(ssl_tls13_cookie_xtn); - client_->SetPacketFilter(capture_cookie); + MakeTlsFilter(client_, ssl_tls13_cookie_xtn); Handshake(); CheckConnected(); @@ -413,9 +413,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) { EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); auto capture_server = - std::make_shared(ssl_tls13_key_share_xtn); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn); capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(capture_server); size_t cb_called = 0; EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), @@ -431,8 +430,7 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) { << "no key_share extension expected from server"; auto capture_client_2nd = - std::make_shared(ssl_tls13_key_share_xtn); - client_->SetPacketFilter(capture_client_2nd); + MakeTlsFilter(client_, ssl_tls13_key_share_xtn); Handshake(); CheckConnected(); @@ -449,12 +447,12 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) { EnsureTlsSetup(); auto capture_cookie = - std::make_shared(ssl_tls13_cookie_xtn); + std::make_shared(server_, ssl_tls13_cookie_xtn); capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); auto capture_key_share = - std::make_shared(ssl_tls13_key_share_xtn); + std::make_shared(server_, ssl_tls13_key_share_xtn); capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(std::make_shared( + server_->SetFilter(std::make_shared( ChainedPacketFilterInit{capture_cookie, capture_key_share})); static const std::vector groups = {ssl_grp_ec_secp384r1}; @@ -493,9 +491,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) { EnsureTlsSetup(); auto capture_key_share = - std::make_shared(ssl_tls13_key_share_xtn); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn); capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(capture_key_share); size_t cb_called = 0; EXPECT_EQ(SECSuccess, @@ -513,9 +510,8 @@ TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) { server_->ConfigNamedGroups(groups); auto capture_key_share = - std::make_shared(ssl_tls13_key_share_xtn); + MakeTlsFilter(server_, ssl_tls13_key_share_xtn); capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(capture_key_share); size_t cb_called = 0; EXPECT_EQ(SECSuccess, @@ -589,8 +585,7 @@ TEST_P(TlsConnectTls13, RetryStatefulDropCookie) { EnsureTlsSetup(); TriggerHelloRetryRequest(client_, server_); - client_->SetPacketFilter( - std::make_shared(ssl_tls13_cookie_xtn)); + MakeTlsFilter(client_, ssl_tls13_cookie_xtn); ExpectAlert(server_, kTlsAlertMissingExtension); Handshake(); @@ -603,8 +598,8 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) { ConfigureSelfEncrypt(); EnsureTlsSetup(); - auto damage_ch = std::make_shared(0xfff3, DataBuffer()); - client_->SetPacketFilter(damage_ch); + auto damage_ch = + MakeTlsFilter(client_, 0xfff3, DataBuffer()); TriggerHelloRetryRequest(client_, server_); MakeNewServer(); @@ -625,8 +620,8 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) { TriggerHelloRetryRequest(client_, server_); MakeNewServer(); - auto damage_ch = std::make_shared(0xfff3, DataBuffer()); - client_->SetPacketFilter(damage_ch); + auto damage_ch = + MakeTlsFilter(client_, 0xfff3, DataBuffer()); // Key exchange fails when the handshake continues because client and server // disagree about the transcript. @@ -640,7 +635,7 @@ TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) { // Read the cipher suite from the HRR and disable it on the identified agent. static void DisableSuiteFromHrr( std::shared_ptr& agent, - std::shared_ptr& capture_hrr) { + std::shared_ptr& capture_hrr) { uint32_t tmp; size_t offset = 2 + 32; // skip version + server_random ASSERT_TRUE( @@ -657,9 +652,8 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) { ConfigureSelfEncrypt(); EnsureTlsSetup(); - auto capture_hrr = std::make_shared( - ssl_hs_hello_retry_request); - server_->SetPacketFilter(capture_hrr); + auto capture_hrr = + MakeTlsFilter(server_, ssl_hs_hello_retry_request); TriggerHelloRetryRequest(client_, server_); MakeNewServer(); @@ -678,9 +672,8 @@ TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) { ConfigureSelfEncrypt(); EnsureTlsSetup(); - auto capture_hrr = std::make_shared( - ssl_hs_hello_retry_request); - server_->SetPacketFilter(capture_hrr); + auto capture_hrr = + MakeTlsFilter(server_, ssl_hs_hello_retry_request); TriggerHelloRetryRequest(client_, server_); MakeNewServer(); @@ -761,8 +754,8 @@ TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) { static const std::vector groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); // Then switch out the default suite (TLS_AES_128_GCM_SHA256). - server_->SetPacketFilter(std::make_shared( - TLS_CHACHA20_POLY1305_SHA256)); + MakeTlsFilter(server_, + TLS_CHACHA20_POLY1305_SHA256); client_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -777,7 +770,7 @@ TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { static const std::vector groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); - server_->SetPacketFilter(std::make_shared(0x2)); + server_->SetFilter(std::make_shared(0x2)); Connect(); } @@ -833,9 +826,9 @@ TEST_P(TlsKeyExchange13, EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); auto capture_server = - std::make_shared(ssl_tls13_key_share_xtn); + std::make_shared(server_, ssl_tls13_key_share_xtn); capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); - server_->SetPacketFilter(std::make_shared( + server_->SetFilter(std::make_shared( ChainedPacketFilterInit{capture_hrr_, capture_server})); size_t cb_called = 0; diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc index 8ed342305..322b64837 100644 --- a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc @@ -20,8 +20,8 @@ static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path; class KeyLogFileTest : public TlsConnectGeneric { public: - void SetUp() { - TlsConnectTestBase::SetUp(); + void SetUp() override { + TlsConnectGeneric::SetUp(); // Remove previous results (if any). (void)remove(keylog_file_path.c_str()); PR_SetEnv(keylog_env.c_str()); diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 4bc6e60ab..f1b78f52f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -56,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { class TlsAlertRecorder : public TlsRecordFilter { public: - TlsAlertRecorder() : level_(255), description_(255) {} + TlsAlertRecorder(const std::shared_ptr& agent) + : TlsRecordFilter(agent), level_(255), description_(255) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -86,9 +87,9 @@ class TlsAlertRecorder : public TlsRecordFilter { class HelloTruncator : public TlsHandshakeFilter { public: - HelloTruncator() + HelloTruncator(const std::shared_ptr& agent) : TlsHandshakeFilter( - {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} + agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -99,9 +100,8 @@ class HelloTruncator : public TlsHandshakeFilter { // Verify that when NSS reports that an alert is sent, it is actually sent. TEST_P(TlsConnectGeneric, CaptureAlertServer) { - client_->SetPacketFilter(std::make_shared()); - auto alert_recorder = std::make_shared(); - server_->SetPacketFilter(alert_recorder); + MakeTlsFilter(client_); + auto alert_recorder = MakeTlsFilter(server_); ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -109,9 +109,8 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) { } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared()); - auto alert_recorder = std::make_shared(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter(server_); + auto alert_recorder = MakeTlsFilter(client_); ConnectExpectAlert(client_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -120,9 +119,8 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { // In TLS 1.3, the server can't read the client alert. TEST_P(TlsConnectTls13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared()); - auto alert_recorder = std::make_shared(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter(server_); + auto alert_recorder = MakeTlsFilter(client_); StartConnect(); @@ -173,7 +171,8 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) { class SaveTlsRecord : public TlsRecordFilter { public: - SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {} + SaveTlsRecord(const std::shared_ptr& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0), contents_() {} const DataBuffer& contents() const { return contents_; } @@ -198,8 +197,8 @@ class SaveTlsRecord : public TlsRecordFilter { TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer - auto saved = std::make_shared(3); - client_->SetTlsRecordFilter(saved); + auto saved = MakeTlsFilter(client_, 3); + saved->EnableDecryption(); Connect(); SendReceive(); @@ -215,8 +214,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer - auto saved = std::make_shared(3); - server_->SetTlsRecordFilter(saved); + auto saved = MakeTlsFilter(server_, 3); + saved->EnableDecryption(); Connect(); SendReceive(); @@ -228,7 +227,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { class DropTlsRecord : public TlsRecordFilter { public: - DropTlsRecord(size_t index) : index_(index), count_(0) {} + DropTlsRecord(const std::shared_ptr& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0) {} protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, @@ -253,7 +253,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) { SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = first write - server_->SetTlsRecordFilter(std::make_shared(2)); + auto filter = MakeTlsFilter(server_, 2); + filter->EnableDecryption(); Connect(); server_->SendData(23, 23); // This should be dropped, so it won't be counted. server_->ResetSentBytes(); @@ -263,7 +264,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) { TEST_F(TlsConnectStreamTls13, DropRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = first write - client_->SetTlsRecordFilter(std::make_shared(2)); + auto filter = MakeTlsFilter(client_, 2); + filter->EnableDecryption(); Connect(); client_->SendData(26, 26); // This should be dropped, so it won't be counted. client_->ResetSentBytes(); @@ -371,7 +373,8 @@ TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: - TlsPreCCSHeaderInjector() {} + TlsPreCCSHeaderInjector(const std::shared_ptr& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, size_t* offset, DataBuffer* output) override { @@ -388,14 +391,14 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { - client_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(client_); ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { - server_->SetPacketFilter(std::make_shared()); + MakeTlsFilter(server_); StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); @@ -476,8 +479,7 @@ TEST_F(TlsConnectTest, OneNRecordSplitting) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); EnsureTlsSetup(); ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); - auto records = std::make_shared(); - server_->SetPacketFilter(records); + auto records = MakeTlsFilter(server_); // This should be split into 1, 16384 and 20. DataBuffer big_buffer; big_buffer.Allocate(1 + 16384 + 20); @@ -535,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -} // namespace nspr_test +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc index d1d496f49..3b8727850 100644 --- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -103,8 +103,8 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) { class RecordReplacer : public TlsRecordFilter { public: - RecordReplacer(size_t size) - : TlsRecordFilter(), enabled_(false), size_(size) {} + RecordReplacer(const std::shared_ptr& agent, size_t size) + : TlsRecordFilter(agent), enabled_(false), size_(size) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, @@ -135,8 +135,8 @@ TEST_F(TlsConnectStreamTls13, LargeRecord) { EnsureTlsSetup(); const size_t record_limit = 16384; - auto replacer = std::make_shared(record_limit); - client_->SetTlsRecordFilter(replacer); + auto replacer = MakeTlsFilter(client_, record_limit); + replacer->EnableDecryption(); Connect(); replacer->Enable(); @@ -149,8 +149,8 @@ TEST_F(TlsConnectStreamTls13, TooLargeRecord) { EnsureTlsSetup(); const size_t record_limit = 16384; - auto replacer = std::make_shared(record_limit + 1); - client_->SetTlsRecordFilter(replacer); + auto replacer = MakeTlsFilter(client_, record_limit + 1); + replacer->EnableDecryption(); Connect(); replacer->Enable(); @@ -177,4 +177,4 @@ auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest, ::testing::Combine(kContentSizes, kTrueFalse)); -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc index a413caf2c..eb78c0585 100644 --- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -60,7 +60,7 @@ TEST_P(TlsConnectGenericPre13, ConnectResumed) { Connect(); } -TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectClientCacheDisabled) { ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID); Connect(); SendReceive(); @@ -71,7 +71,7 @@ TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectServerCacheDisabled) { ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE); Connect(); SendReceive(); @@ -82,7 +82,7 @@ TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectSessionCacheDisabled) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); SendReceive(); @@ -93,7 +93,7 @@ TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeSupportBoth) { // This prefers tickets. ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); Connect(); @@ -106,7 +106,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientTicketServerBoth) { // This causes no resumption because the client needs the // session cache to resume even with tickets. ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); @@ -120,7 +120,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothTicketServerTicket) { // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); @@ -133,7 +133,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientServerTicketOnly) { // This causes no resumption because the client needs the // session cache to resume even with tickets. ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); @@ -147,7 +147,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothServerNone) { ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); Connect(); SendReceive(); @@ -159,7 +159,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientNoneServerBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientNoneServerBoth) { ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); Connect(); SendReceive(); @@ -202,7 +202,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) { +TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) { SSLInt_SetTicketLifetime(1); // one second // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -219,8 +219,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) { SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? ssl_tls13_pre_shared_key_xtn : ssl_session_ticket_xtn; - auto capture = std::make_shared(xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter(client_, xtn); Connect(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { @@ -245,8 +244,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? ssl_tls13_pre_shared_key_xtn : ssl_session_ticket_xtn; - auto capture = std::make_shared(xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter(client_, xtn); StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); @@ -327,25 +325,23 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { // Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { - auto i1 = std::make_shared( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); + EXPECT_TRUE(dhe1.Parse(filter->buffer())); // Restart Reset(); - auto i2 = std::make_shared( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); + auto filter2 = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); // Make sure they are the same. EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); @@ -356,26 +352,24 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { // This test parses the ServerKeyExchange, which isn't in 1.3 TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - auto i1 = std::make_shared( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); + auto filter = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); + EXPECT_TRUE(dhe1.Parse(filter->buffer())); // Restart Reset(); server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - auto i2 = std::make_shared( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); + auto filter2 = MakeTlsFilter( + server_, kTlsHandshakeServerKeyExchange); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); // Make sure they are different. EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && @@ -397,7 +391,7 @@ TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) { server_->ConfigNamedGroups(kFFDHEGroups); Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // We need to enable different cipher suites at different times in the following @@ -417,7 +411,7 @@ static uint16_t ChooseAnotherCipher(uint16_t version) { } // Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { +TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); client_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -434,15 +428,15 @@ TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { } else { ticket_extension = ssl_session_ticket_xtn; } - auto ticket_capture = std::make_shared(ticket_extension); - client_->SetPacketFilter(ticket_capture); + auto ticket_capture = + MakeTlsFilter(client_, ticket_extension); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); EXPECT_EQ(0U, ticket_capture->extension().len()); } // Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) { +TEST_P(TlsConnectGenericResumption, TestResumeServerDifferentCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -468,8 +462,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - server_->SetPacketFilter(std::make_shared( - ChooseAnotherCipher(version_))); + MakeTlsFilter(server_, + ChooseAnotherCipher(version_)); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { client_->ExpectSendAlert(kTlsAlertIllegalParameter); @@ -490,8 +484,10 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { class SelectedVersionReplacer : public TlsHandshakeFilter { public: - SelectedVersionReplacer(uint16_t version) - : TlsHandshakeFilter({kTlsHandshakeServerHello}), version_(version) {} + SelectedVersionReplacer(const std::shared_ptr& agent, + uint16_t version) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + version_(version) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -543,8 +539,7 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { // Enable the lower version on the client. client_->SetVersionRange(version_ - 1, version_); server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); - server_->SetPacketFilter( - std::make_shared(override_version)); + MakeTlsFilter(server_, override_version); ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); @@ -567,12 +562,12 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); ExpectResumption(RESUME_TICKET); - auto c1 = std::make_shared(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(c1); + auto c1 = + MakeTlsFilter(client_, ssl_tls13_pre_shared_key_xtn); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); // The filter will go away when we reset, so save the captured extension. DataBuffer initialTicket(c1->extension()); ASSERT_LT(0U, initialTicket.len()); @@ -584,13 +579,13 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ClearStats(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - auto c2 = std::make_shared(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(c2); + auto c2 = + MakeTlsFilter(client_, ssl_tls13_pre_shared_key_xtn); ExpectResumption(RESUME_TICKET); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); ASSERT_LT(0U, c2->extension().len()); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); @@ -656,9 +651,9 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - auto nst_capture = std::make_shared( - ssl_hs_new_session_ticket); - server_->SetTlsRecordFilter(nst_capture); + auto nst_capture = + MakeTlsFilter(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); Connect(); // Clear the session ticket keys to invalidate the old ticket. @@ -679,8 +674,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) { ExpectResumption(RESUME_TICKET); auto psk_capture = - std::make_shared(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(psk_capture); + MakeTlsFilter(client_, ssl_tls13_pre_shared_key_xtn); Connect(); SendReceive(); @@ -696,9 +690,9 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) { EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); - auto nst_capture = std::make_shared( - ssl_hs_new_session_ticket); - server_->SetTlsRecordFilter(nst_capture); + auto nst_capture = + MakeTlsFilter(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); Connect(); EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet"; @@ -715,8 +709,7 @@ TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) { ExpectResumption(RESUME_TICKET); auto psk_capture = - std::make_shared(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(psk_capture); + MakeTlsFilter(client_, ssl_tls13_pre_shared_key_xtn); Connect(); SendReceive(); @@ -819,20 +812,20 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { // We will eventually fail the (sid.version == SH.version) check. std::vector> filters; filters.push_back(std::make_shared( - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); - filters.push_back( - std::make_shared(SSL_LIBRARY_VERSION_TLS_1_2)); + server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); + filters.push_back(std::make_shared( + server_, SSL_LIBRARY_VERSION_TLS_1_2)); // Drop a bunch of extensions so that we get past the SH processing. The // version extension says TLS 1.3, which is counter to our goal, the others // are not permitted in TLS 1.2 handshakes. + filters.push_back(std::make_shared( + server_, ssl_tls13_supported_versions_xtn)); filters.push_back( - std::make_shared(ssl_tls13_supported_versions_xtn)); - filters.push_back( - std::make_shared(ssl_tls13_key_share_xtn)); - filters.push_back( - std::make_shared(ssl_tls13_pre_shared_key_xtn)); - server_->SetPacketFilter(std::make_shared(filters)); + std::make_shared(server_, ssl_tls13_key_share_xtn)); + filters.push_back(std::make_shared( + server_, ssl_tls13_pre_shared_key_xtn)); + server_->SetFilter(std::make_shared(filters)); // The client here generates an unexpected_message alert when it receives an // encrypted handshake message from the server (EncryptedExtension). The @@ -845,13 +838,13 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } -TEST_P(TlsConnectGeneric, ReConnectTicket) { +TEST_P(TlsConnectGenericResumption, ReConnectTicket) { ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); // Resume Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); @@ -859,7 +852,7 @@ TEST_P(TlsConnectGeneric, ReConnectTicket) { Connect(); // Only the client knows this. CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, - ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, ReConnectCache) { @@ -868,22 +861,22 @@ TEST_P(TlsConnectGenericPre13, ReConnectCache) { Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); // Resume Reset(); ExpectResumption(RESUME_SESSIONID); Connect(); CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, - ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } -TEST_P(TlsConnectGeneric, ReConnectAgainTicket) { +TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) { ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); // Resume Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); @@ -891,7 +884,7 @@ TEST_P(TlsConnectGeneric, ReConnectAgainTicket) { Connect(); // Only the client knows this. CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, - ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); // Resume connection again Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); @@ -899,7 +892,140 @@ TEST_P(TlsConnectGeneric, ReConnectAgainTicket) { Connect(); // Only the client knows this. CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, - ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +void CheckGetInfoResult(uint32_t alpnSize, uint32_t earlyDataSize, + ScopedCERTCertificate& cert, + ScopedSSLResumptionTokenInfo& token) { + ASSERT_TRUE(cert); + ASSERT_TRUE(token->peerCert); + + // Check that the server cert is the correct one. + ASSERT_EQ(cert->derCert.len, token->peerCert->derCert.len); + EXPECT_EQ(0, memcmp(cert->derCert.data, token->peerCert->derCert.data, + cert->derCert.len)); + + ASSERT_EQ(alpnSize, token->alpnSelectionLen); + EXPECT_EQ(0, memcmp("a", token->alpnSelection, token->alpnSelectionLen)); + + ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize); +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(0, 0, cert, token); + + Handshake(); + CheckConnected(); + + SendReceive(); +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(1, 0, cert, token); + + Handshake(); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) { + EnableAlpn(); + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->Set0RttEnabled(true); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + server_->Set0RttEnabled(true); + client_->Set0RttEnabled(true); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(1, 1024, cert, token); + + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +// Resumption on sessions with client authentication only works with internal +// caching. +TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + SendReceive(); + EXPECT_FALSE(client_->resumption_callback_called()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + if (use_external_cache()) { + ExpectResumption(RESUME_NONE); + } else { + ExpectResumption(RESUME_TICKET); + } + Connect(); + SendReceive(); } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index 335bfecfa..e4a9e5aed 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -22,8 +22,11 @@ namespace nss_test { class TlsHandshakeSkipFilter : public TlsRecordFilter { public: // A TLS record filter that skips handshake messages of the identified type. - TlsHandshakeSkipFilter(uint8_t handshake_type) - : handshake_type_(handshake_type), skipped_(false) {} + TlsHandshakeSkipFilter(const std::shared_ptr& agent, + uint8_t handshake_type) + : TlsRecordFilter(agent), + handshake_type_(handshake_type), + skipped_(false) {} protected: // Takes a record; if it is a handshake record, it removes the first handshake @@ -92,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase, TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + void ServerSkipTest(std::shared_ptr filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, alert); } }; @@ -105,9 +113,14 @@ class Tls13SkipTest : public TlsConnectTestBase, Tls13SkipTest() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} - void ServerSkipTest(std::shared_ptr filter, int32_t error) { + void SetUp() override { + TlsConnectTestBase::SetUp(); EnsureTlsSetup(); - server_->SetTlsRecordFilter(filter); + } + + void ServerSkipTest(std::shared_ptr filter, int32_t error) { + filter->EnableDecryption(); + server_->SetFilter(filter); ExpectAlert(client_, kTlsAlertUnexpectedMessage); ConnectExpectFail(); client_->CheckErrorCode(error); @@ -115,8 +128,8 @@ class Tls13SkipTest : public TlsConnectTestBase, } void ClientSkipTest(std::shared_ptr filter, int32_t error) { - EnsureTlsSetup(); - client_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + client_->SetFilter(filter); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); ConnectExpectFailOneSide(TlsAgent::SERVER); @@ -129,48 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase, TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest( - std::make_shared(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = std::make_shared(ChainedPacketFilterInit{ - std::make_shared(kTlsHandshakeCertificate), - std::make_shared( - kTlsHandshakeServerKeyExchange)}); + auto chain = std::make_shared( + ChainedPacketFilterInit{std::make_shared( + server_, kTlsHandshakeCertificate), + std::make_shared( + server_, kTlsHandshakeServerKeyExchange)}); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } @@ -178,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) { TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); auto chain = std::make_shared(); - chain->Add( - std::make_shared(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared(kTlsHandshakeServerKeyExchange)); + chain->Add(std::make_shared( + server_, kTlsHandshakeCertificate)); + chain->Add(std::make_shared( + server_, kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { ServerSkipTest(std::make_shared( - kTlsHandshakeEncryptedExtensions), + server_, kTlsHandshakeEncryptedExtensions), SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); } TEST_P(Tls13SkipTest, SkipServerCertificate) { - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { - ServerSkipTest( - std::make_shared(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ServerSkipTest(std::make_shared( + server_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } TEST_P(Tls13SkipTest, SkipClientCertificate) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ClientSkipTest(std::make_shared( + client_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ClientSkipTest(std::make_shared( + client_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } INSTANTIATE_TEST_CASE_P( diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc index e7fe44d92..e5fccc12b 100644 --- a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -48,10 +48,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) { // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); - auto i1 = std::make_shared( - kTlsHandshakeClientKeyExchange, + MakeTlsFilter( + client_, kTlsHandshakeClientKeyExchange, DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(i1); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -59,8 +58,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter( - std::make_shared(server_)); + MakeTlsFilter(client_, server_); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -69,8 +67,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { // ConnectStaticRSABogusPMSVersionDetect. TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter( - std::make_shared(server_)); + MakeTlsFilter(client_, server_); server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } @@ -79,10 +76,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - auto inspect = std::make_shared( - kTlsHandshakeClientKeyExchange, + MakeTlsFilter( + client_, kTlsHandshakeClientKeyExchange, DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(inspect); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -91,8 +87,7 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter( - std::make_shared(server_)); + MakeTlsFilter(client_, server_); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -100,10 +95,9 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter( - std::make_shared(server_)); + MakeTlsFilter(client_, server_); server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc index 75cee52fc..f5ccf096b 100644 --- a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc @@ -67,10 +67,7 @@ class Tls13CompatTest : public TlsConnectStreamTls13 { private: struct Recorders { - Recorders() - : records_(new TlsRecordRecorder()), - hello_(new TlsInspectorRecordHandshakeMessage(std::set( - {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))) {} + Recorders() : records_(nullptr), hello_(nullptr) {} uint8_t session_id_length() const { // session_id is always after version (2) and random (32). @@ -91,12 +88,22 @@ class Tls13CompatTest : public TlsConnectStreamTls13 { } void Install(std::shared_ptr& agent) { - agent->SetPacketFilter(std::make_shared( + if (records_ && records_->agent() == agent) { + // Avoid replacing the filters if they are already installed on this + // agent. This ensures that InstallFilters() can be used after + // MakeNewServer() without losing state on the client filters. + return; + } + records_.reset(new TlsRecordRecorder(agent)); + hello_.reset(new TlsHandshakeRecorder( + agent, std::set( + {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))); + agent->SetFilter(std::make_shared( ChainedPacketFilterInit({records_, hello_}))); } std::shared_ptr records_; - std::shared_ptr hello_; + std::shared_ptr hello_; }; void CheckRecordsAreTls12(const std::string& agent, @@ -171,16 +178,20 @@ TEST_F(Tls13CompatTest, EnabledStatelessHrr) { server_->StartConnect(); client_->Handshake(); server_->Handshake(); + + // The server should send CCS before HRR. CheckForCCS(false, true); - // A new server should just work, but not send another CCS. + // A new server should complete the handshake, and not send CCS. MakeNewServer(); InstallFilters(); server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); Handshake(); CheckConnected(); - CheckForCompatHandshake(); + CheckRecordVersions(); + CheckHelloVersions(); + CheckForCCS(true, false); } TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) { @@ -262,10 +273,8 @@ TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) { TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { EnsureTlsSetup(); client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); - auto client_records = std::make_shared(); - client_->SetPacketFilter(client_records); - auto server_records = std::make_shared(); - server_->SetPacketFilter(server_records); + auto client_records = MakeTlsFilter(client_); + auto server_records = MakeTlsFilter(server_); Connect(); ASSERT_EQ(2U, client_records->count()); // CH, Fin @@ -283,7 +292,8 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { class AddSessionIdFilter : public TlsHandshakeFilter { public: - AddSessionIdFilter() : TlsHandshakeFilter({ssl_hs_client_hello}) {} + AddSessionIdFilter(const std::shared_ptr& client) + : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -303,14 +313,14 @@ class AddSessionIdFilter : public TlsHandshakeFilter { // mode. It should be ignored instead. TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { EnsureTlsSetup(); - auto client_records = std::make_shared(); - client_->SetPacketFilter( + auto client_records = std::make_shared(client_); + client_->SetFilter( std::make_shared(ChainedPacketFilterInit( - {client_records, std::make_shared()}))); - auto server_hello = std::make_shared( - kTlsHandshakeServerHello); - auto server_records = std::make_shared(); - server_->SetPacketFilter(std::make_shared( + {client_records, std::make_shared(client_)}))); + auto server_hello = + std::make_shared(server_, kTlsHandshakeServerHello); + auto server_records = std::make_shared(server_); + server_->SetFilter(std::make_shared( ChainedPacketFilterInit({server_records, server_hello}))); StartConnect(); client_->Handshake(); @@ -334,4 +344,20 @@ TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { EXPECT_EQ(0U, session_id_len); } -} // nss_test +TEST_F(Tls13CompatTest, ConnectWith12ThenAttemptToResume13CompatMode) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + + Reset(); + ExpectResumption(RESUME_NONE); + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + EnableCompatMode(); + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc index 2f8ddd6fe..100595732 100644 --- a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -23,7 +23,8 @@ namespace nss_test { // Replaces the client hello with an SSLv2 version once. class SSLv2ClientHelloFilter : public PacketFilter { public: - SSLv2ClientHelloFilter(std::shared_ptr& client, uint16_t version) + SSLv2ClientHelloFilter(const std::shared_ptr& client, + uint16_t version) : replaced_(false), client_(client), version_(version), @@ -147,10 +148,9 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version) : TlsConnectTestBase(variant, version), filter_(nullptr) {} - void SetUp() { + void SetUp() override { TlsConnectTestBase::SetUp(); - filter_ = std::make_shared(client_, version_); - client_->SetPacketFilter(filter_); + filter_ = MakeTlsFilter(client_, version_); } void SetExpectedVersion(uint16_t version) { diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc index 9db293b07..4e9099561 100644 --- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -56,18 +56,15 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { // two validate that we can also detect fallback using the // SSL_SetDowngradeCheckVersion() API. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { - client_->SetPacketFilter( - std::make_shared( - SSL_LIBRARY_VERSION_TLS_1_1)); + MakeTlsFilter(client_, + SSL_LIBRARY_VERSION_TLS_1_1); ConnectExpectFail(); ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); } /* Attempt to negotiate the bogus DTLS 1.1 version. */ TEST_F(DtlsConnectTest, TestDtlsVersion11) { - client_->SetPacketFilter( - std::make_shared( - ((~0x0101) & 0xffff))); + MakeTlsFilter(client_, ((~0x0101) & 0xffff)); ConnectExpectFail(); // It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is // what is returned here, but this is deliberate in ssl3_HandleAlert(). @@ -78,9 +75,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) { // Disabled as long as we have draft version. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared( - SSL_LIBRARY_VERSION_TLS_1_2)); + MakeTlsFilter(client_, + SSL_LIBRARY_VERSION_TLS_1_2); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -92,9 +88,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { // TLS 1.1 clients do not check the random values, so we should // instead get a handshake failure alert from the server. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { - client_->SetPacketFilter( - std::make_shared( - SSL_LIBRARY_VERSION_TLS_1_0)); + MakeTlsFilter(client_, + SSL_LIBRARY_VERSION_TLS_1_0); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, SSL_LIBRARY_VERSION_TLS_1_1); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, @@ -177,12 +172,10 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 { client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version); - client_->SetPacketFilter( - std::make_shared( - overwritten_client_version)); - auto capture = std::make_shared( - kTlsHandshakeServerHello); - server_->SetPacketFilter(capture); + MakeTlsFilter(client_, + overwritten_client_version); + auto capture = + MakeTlsFilter(server_, kTlsHandshakeServerHello); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -214,12 +207,10 @@ TEST_F(Tls13NoSupportedVersions, // Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This // causes a bad MAC error when we read EncryptedExtensions. TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { - client_->SetPacketFilter( - std::make_shared( - SSL_LIBRARY_VERSION_TLS_1_3 + 1)); - auto capture = - std::make_shared(ssl_tls13_supported_versions_xtn); - server_->SetPacketFilter(capture); + MakeTlsFilter(client_, + SSL_LIBRARY_VERSION_TLS_1_3 + 1); + auto capture = MakeTlsFilter( + server_, ssl_tls13_supported_versions_xtn); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc index eda96831c..7f3c4a896 100644 --- a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -189,12 +189,12 @@ class TestPolicyVersionRange } } - void SetUp() { - SetPolicy(policy_.range()); + void SetUp() override { TlsConnectTestBase::SetUp(); + SetPolicy(policy_.range()); } - void TearDown() { + void TearDown() override { TlsConnectTestBase::TearDown(); saved_version_policy_.RestoreOriginalPolicy(); } diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc index adcdbfbaf..728217851 100644 --- a/security/nss/gtests/ssl_gtest/test_io.cc +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -25,10 +25,6 @@ namespace nss_test { if (g_ssl_gtest_verbose) LOG(a); \ } while (false) -void DummyPrSocket::SetPacketFilter(std::shared_ptr filter) { - filter_ = filter; -} - ScopedPRFileDesc DummyPrSocket::CreateFD() { static PRDescIdentity test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h index 469d90a7c..dbeb6b9d4 100644 --- a/security/nss/gtests/ssl_gtest/test_io.h +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -74,7 +74,9 @@ class DummyPrSocket : public DummyIOLayerMethods { std::weak_ptr& peer() { return peer_; } void SetPeer(const std::shared_ptr& peer) { peer_ = peer; } - void SetPacketFilter(std::shared_ptr filter); + void SetPacketFilter(const std::shared_ptr& filter) { + filter_ = filter; + } // Drops peer, packet filter and any outstanding packets. void Reset(); @@ -176,6 +178,6 @@ class Poller { timers_; }; -} // end of namespace +} // namespace nss_test #endif diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index 3b939bba8..2f71caedb 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -12,6 +12,7 @@ #include "sslerr.h" #include "sslexp.h" #include "sslproto.h" +#include "tls_filter.h" #include "tls_parser.h" extern "C" { @@ -66,6 +67,7 @@ TlsAgent::TlsAgent(const std::string& name, Role role, expected_sent_alert_(kTlsAlertCloseNotify), expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), + resumption_callback_called_(false), error_code_(0), send_ctr_(0), recv_ctr_(0), @@ -73,7 +75,8 @@ TlsAgent::TlsAgent(const std::string& name, Role role, handshake_callback_(), auth_certificate_callback_(), sni_callback_(), - skip_version_checks_(false) { + skip_version_checks_(false), + resumption_token_() { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); @@ -182,6 +185,10 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { ScopedCERTCertList anchors(CERT_NewCertList()); rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; + + rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; } else { rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); @@ -207,6 +214,29 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { return true; } +bool TlsAgent::MaybeSetResumptionToken() { + if (!resumption_token_.empty()) { + SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), + resumption_token_.size()); + + // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR + // if the resumption token was bad (expired/malformed/etc.). + if (expect_resumption_) { + // Only in case we expect resumption this has to be successful. We might + // not expect resumption due to some reason but the token is totally fine. + EXPECT_EQ(SECSuccess, rv); + } + if (rv != SECSuccess) { + EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); + resumption_token_.clear(); + EXPECT_FALSE(expect_resumption_); + if (expect_resumption_) return false; + } + } + + return true; +} + void TlsAgent::SetupClientAuth() { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); @@ -386,6 +416,27 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { } } +SECStatus ResumptionTokenCallback(PRFileDesc* fd, + const PRUint8* resumptionToken, + unsigned int len, void* ctx) { + EXPECT_NE(nullptr, resumptionToken); + if (!resumptionToken) { + return SECFailure; + } + + std::vector new_token(resumptionToken, resumptionToken + len); + reinterpret_cast(ctx)->SetResumptionToken(new_token); + reinterpret_cast(ctx)->SetResumptionCallbackCalled(); + return SECSuccess; +} + +void TlsAgent::SetResumptionTokenCallback() { + EXPECT_TRUE(EnsureTlsSetup()); + SECStatus rv = + SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); + EXPECT_EQ(SECSuccess, rv); +} + void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { *minver = vrange_.min; *maxver = vrange_.max; diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index b3fd892ae..6cd6d5073 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -14,7 +14,6 @@ #include #include "test_io.h" -#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -37,7 +36,10 @@ enum SessionResumptionMode { RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET }; +class PacketFilter; class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; const extern std::vector kAllDHEGroups; const extern std::vector kECDHEGroups; @@ -80,18 +82,10 @@ class TlsAgent : public PollTarget { adapter_->SetPeer(peer->adapter_); } - // Set a filter that can access plaintext (TLS 1.3 only). - void SetTlsRecordFilter(std::shared_ptr filter) { - filter->SetAgent(this); + void SetFilter(std::shared_ptr filter) { adapter_->SetPacketFilter(filter); - filter->EnableDecryption(); } - - void SetPacketFilter(std::shared_ptr filter) { - adapter_->SetPacketFilter(filter); - } - - void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, @@ -165,6 +159,24 @@ class TlsAgent : public PollTarget { void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } const std::string& name() const { return name_; } @@ -382,6 +394,7 @@ class TlsAgent : public PollTarget { uint8_t expected_sent_alert_; uint8_t expected_sent_alert_level_; bool handshake_callback_called_; + bool resumption_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; SSLVersionRange vrange_; @@ -393,6 +406,7 @@ class TlsAgent : public PollTarget { AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; bool skip_version_checks_; + std::vector resumption_token_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -443,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - std::unique_ptr agent_; + std::shared_ptr agent_; TlsAgent::Role role_; SSLProtocolVariant variant_; uint16_t version_; diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index 0af5123e9..8567b392f 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -197,7 +197,6 @@ void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); - SSLInt_SetMaxEarlyDataSize(1024); SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); @@ -230,7 +229,9 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { + auto token = client_->GetResumptionToken(); client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + client_->SetResumptionToken(token); server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_->SkipVersionChecks(); @@ -290,6 +291,7 @@ void TlsConnectTestBase::EnableExtendedMasterSecret() { void TlsConnectTestBase::Connect() { server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + client_->MaybeSetResumptionToken(); Handshake(); CheckConnected(); } @@ -402,13 +404,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, break; case ssl_auth_rsa_sign: if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; } else { scheme = ssl_sig_rsa_pkcs1_sha256; } break; case ssl_auth_rsa_pss: - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; break; case ssl_auth_ecdsa: scheme = ssl_sig_ecdsa_secp256r1_sha256; @@ -670,7 +672,8 @@ void TlsConnectTestBase::ZeroRttSendReceive( EXPECT_EQ(k0RttDataLen, rv); } else { EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); } // Do a second read. this should fail. @@ -754,20 +757,29 @@ TlsConnectTls12Plus::TlsConnectTls12Plus() TlsConnectTls13::TlsConnectTls13() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} +TlsConnectGenericResumption::TlsConnectGenericResumption() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + external_cache_(std::get<2>(GetParam())) {} + +TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); groups_capture_ = - std::make_shared(ssl_supported_groups_xtn); + std::make_shared(client_, ssl_supported_groups_xtn); shares_capture_ = - std::make_shared(ssl_tls13_key_share_xtn); - shares_capture2_ = - std::make_shared(ssl_tls13_key_share_xtn, true); + std::make_shared(client_, ssl_tls13_key_share_xtn); + shares_capture2_ = std::make_shared( + client_, ssl_tls13_key_share_xtn, true); std::vector> captures = { groups_capture_, shares_capture_, shares_capture2_}; - client_->SetPacketFilter(std::make_shared(captures)); - capture_hrr_ = std::make_shared( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(capture_hrr_); + client_->SetFilter(std::make_shared(captures)); + capture_hrr_ = MakeTlsFilter( + server_, kTlsHandshakeHelloRetryRequest); } void TlsKeyExchangeTest::ConfigNamedGroups( diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index c650dda1d..7dffe7f8a 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test { TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); - void SetUp(); - void TearDown(); + virtual void SetUp(); + virtual void TearDown(); // Initialize client and server. void Init(); @@ -55,7 +55,7 @@ class TlsConnectTestBase : public ::testing::Test { // Clear the server session cache. void ClearServerCache(); // Make sure TLS is configured for a connection. - void EnsureTlsSetup(); + virtual void EnsureTlsSetup(); // Reset and keep the same certificate names void Reset(); // Reset, and update the certificate names on both peers @@ -208,6 +208,52 @@ class TlsConnectGeneric : public TlsConnectTestBase, TlsConnectGeneric(); }; +class TlsConnectGenericResumption + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { + private: + bool external_cache_; + + public: + TlsConnectGenericResumption(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + // Enable external resumption token cache. + if (external_cache_) { + client_->SetResumptionTokenCallback(); + } + } + + bool use_external_cache() const { return external_cache_; } +}; + +class TlsConnectTls13ResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface { + public: + TlsConnectTls13ResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +class TlsConnectGenericResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + TlsConnectGenericResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + // A Pre TLS 1.2 generic test. class TlsConnectPre12 : public TlsConnectTestBase, public ::testing::WithParamInterface< @@ -273,7 +319,7 @@ class TlsKeyExchangeTest : public TlsConnectGeneric { std::shared_ptr groups_capture_; std::shared_ptr shares_capture_; std::shared_ptr shares_capture2_; - std::shared_ptr capture_hrr_; + std::shared_ptr capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector& groups); diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 89f201295..d34b13bcb 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -452,7 +452,7 @@ size_t TlsHandshakeFilter::HandshakeHeader::Write( return offset; } -PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( +PacketFilter::Action TlsHandshakeRecorder::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { // Only do this once. @@ -763,7 +763,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, if (counter_++ == record_) { DataBuffer buf; header.Write(&buf, 0, body); - src_.lock()->SendDirect(buf); + agent()->SendDirect(buf); dest_.lock()->Handshake(); func_(); return DROP; @@ -772,7 +772,7 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, return KEEP; } -PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( +PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { EXPECT_EQ(SECSuccess, @@ -808,7 +808,7 @@ PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( return pattern; } -PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake( +PacketFilter::Action TlsClientHelloVersionSetter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { *output = input; diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index 1db3b90f6..1bbe190ab 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -13,6 +13,7 @@ #include #include "test_io.h" +#include "tls_agent.h" #include "tls_parser.h" #include "tls_protect.h" @@ -23,7 +24,6 @@ extern "C" { namespace nss_test { class TlsCipherSpec; -class TlsAgent; class TlsVersioned { public: @@ -71,19 +71,27 @@ struct TlsRecord { const DataBuffer buffer; }; +// Make a filter and install it on a TlsAgent. +template +inline std::shared_ptr MakeTlsFilter(const std::shared_ptr& agent, + Args&&... args) { + auto filter = std::make_shared(agent, std::forward(args)...); + agent->SetFilter(filter); + return filter; +} + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() - : agent_(nullptr), + TlsRecordFilter(const std::shared_ptr& agent) + : agent_(agent), count_(0), cipher_spec_(), dropped_record_(false), in_sequence_number_(0), out_sequence_number_(0) {} - void SetAgent(const TlsAgent* agent) { agent_ = agent; } - const TlsAgent* agent() const { return agent_; } + std::shared_ptr agent() const { return agent_.lock(); } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); @@ -126,7 +134,7 @@ class TlsRecordFilter : public PacketFilter { static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); - const TlsAgent* agent_; + std::weak_ptr agent_; size_t count_; std::unique_ptr cipher_spec_; // Whether we dropped a record since the cipher spec changed. @@ -175,9 +183,13 @@ inline std::ostream& operator<<(std::ostream& stream, // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} - TlsHandshakeFilter(const std::set& types) - : handshake_types_(types), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr& agent) + : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr& agent, + const std::set& types) + : TlsRecordFilter(agent), + handshake_types_(types), + preceding_fragment_() {} // This filter can be set to be selective based on handshake message type. If // this function isn't used (or the set is empty), then all handshake messages @@ -229,12 +241,14 @@ class TlsHandshakeFilter : public TlsRecordFilter { }; // Make a copy of the first instance of a handshake message. -class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { +class TlsHandshakeRecorder : public TlsHandshakeFilter { public: - TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : TlsHandshakeFilter({handshake_type}), buffer_() {} - TlsInspectorRecordHandshakeMessage(const std::set& handshake_types) - : TlsHandshakeFilter(handshake_types), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr& agent, + uint8_t handshake_type) + : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr& agent, + const std::set& handshake_types) + : TlsHandshakeFilter(agent, handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -251,9 +265,10 @@ class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: - TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, + TlsInspectorReplaceHandshakeMessage(const std::shared_ptr& agent, + uint8_t handshake_type, const DataBuffer& replacement) - : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} + : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -266,9 +281,11 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { // Make a copy of each record of a given type. class TlsRecordRecorder : public TlsRecordFilter { public: - TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} - TlsRecordRecorder() - : filter_(false), + TlsRecordRecorder(const std::shared_ptr& agent, uint8_t ct) + : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {} + TlsRecordRecorder(const std::shared_ptr& agent) + : TlsRecordFilter(agent), + filter_(false), ct_(content_handshake), // dummy ( is C++14) records_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, @@ -289,7 +306,9 @@ class TlsRecordRecorder : public TlsRecordFilter { // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: - TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} + TlsConversationRecorder(const std::shared_ptr& agent, + DataBuffer& buffer) + : TlsRecordFilter(agent), buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -302,6 +321,8 @@ class TlsConversationRecorder : public TlsRecordFilter { // Make a copy of the records class TlsHeaderRecorder : public TlsRecordFilter { public: + TlsHeaderRecorder(const std::shared_ptr& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); @@ -338,13 +359,15 @@ typedef std::function class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() - : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, + TlsExtensionFilter(const std::shared_ptr& agent) + : TlsHandshakeFilter(agent, + {kTlsHandshakeClientHello, kTlsHandshakeServerHello, kTlsHandshakeHelloRetryRequest, kTlsHandshakeEncryptedExtensions}) {} - TlsExtensionFilter(const std::set& types) - : TlsHandshakeFilter(types) {} + TlsExtensionFilter(const std::shared_ptr& agent, + const std::set& types) + : TlsHandshakeFilter(agent, types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -365,8 +388,13 @@ class TlsExtensionFilter : public TlsHandshakeFilter { class TlsExtensionCapture : public TlsExtensionFilter { public: - TlsExtensionCapture(uint16_t ext, bool last = false) - : extension_(ext), captured_(false), last_(last), data_() {} + TlsExtensionCapture(const std::shared_ptr& agent, uint16_t ext, + bool last = false) + : TlsExtensionFilter(agent), + extension_(ext), + captured_(false), + last_(last), + data_() {} const DataBuffer& extension() const { return data_; } bool captured() const { return captured_; } @@ -385,8 +413,9 @@ class TlsExtensionCapture : public TlsExtensionFilter { class TlsExtensionReplacer : public TlsExtensionFilter { public: - TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) - : extension_(extension), data_(data) {} + TlsExtensionReplacer(const std::shared_ptr& agent, + uint16_t extension, const DataBuffer& data) + : TlsExtensionFilter(agent), extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; @@ -398,7 +427,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter { class TlsExtensionDropper : public TlsExtensionFilter { public: - TlsExtensionDropper(uint16_t extension) : extension_(extension) {} + TlsExtensionDropper(const std::shared_ptr& agent, + uint16_t extension) + : TlsExtensionFilter(agent), extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; @@ -408,8 +439,9 @@ class TlsExtensionDropper : public TlsExtensionFilter { class TlsExtensionInjector : public TlsHandshakeFilter { public: - TlsExtensionInjector(uint16_t ext, const DataBuffer& data) - : extension_(ext), data_(data) {} + TlsExtensionInjector(const std::shared_ptr& agent, uint16_t ext, + const DataBuffer& data) + : TlsHandshakeFilter(agent), extension_(ext), data_(data) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -426,16 +458,20 @@ typedef std::function VoidFunction; class AfterRecordN : public TlsRecordFilter { public: - AfterRecordN(std::shared_ptr& src, std::shared_ptr& dest, - unsigned int record, VoidFunction func) - : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} + AfterRecordN(const std::shared_ptr& src, + const std::shared_ptr& dest, unsigned int record, + VoidFunction func) + : TlsRecordFilter(src), + dest_(dest), + record_(record), + func_(func), + counter_(0) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: - std::weak_ptr src_; std::weak_ptr dest_; unsigned int record_; VoidFunction func_; @@ -444,10 +480,12 @@ class AfterRecordN : public TlsRecordFilter { // When we see the ClientKeyExchange from |client|, increment the // ClientHelloVersion on |server|. -class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { +class TlsClientHelloVersionChanger : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionChanger(std::shared_ptr& server) - : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} + TlsClientHelloVersionChanger(const std::shared_ptr& client, + const std::shared_ptr& server) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), + server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -477,14 +515,16 @@ class SelectiveDropFilter : public PacketFilter { // datagram, we just drop one. class SelectiveRecordDropFilter : public TlsRecordFilter { public: - SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) - : pattern_(pattern), counter_(0) { + SelectiveRecordDropFilter(const std::shared_ptr& agent, + uint32_t pattern, bool enabled = true) + : TlsRecordFilter(agent), pattern_(pattern), counter_(0) { if (!enabled) { Disable(); } } - SelectiveRecordDropFilter(std::initializer_list records) - : SelectiveRecordDropFilter(ToPattern(records), true) {} + SelectiveRecordDropFilter(const std::shared_ptr& agent, + std::initializer_list records) + : SelectiveRecordDropFilter(agent, ToPattern(records), true) {} void Reset(uint32_t pattern) { counter_ = 0; @@ -509,10 +549,12 @@ class SelectiveRecordDropFilter : public TlsRecordFilter { }; // Set the version number in the ClientHello. -class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { +class TlsClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) - : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} + TlsClientHelloVersionSetter(const std::shared_ptr& agent, + uint16_t version) + : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}), + version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -525,7 +567,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: - TlsLastByteDamager(uint8_t type) : type_(type) {} + TlsLastByteDamager(const std::shared_ptr& agent, uint8_t type) + : TlsHandshakeFilter(agent), type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -545,8 +588,10 @@ class TlsLastByteDamager : public TlsHandshakeFilter { class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { public: - SelectedCipherSuiteReplacer(uint16_t suite) - : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} + SelectedCipherSuiteReplacer(const std::shared_ptr& agent, + uint16_t suite) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + cipher_suite_(suite) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, -- cgit v1.2.3