diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 226 |
1 files changed, 164 insertions, 62 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index c8de5a1fe..8567b392f 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -5,6 +5,7 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" +#include "sslexp.h" extern "C" { #include "libssl_internals.h" } @@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: @@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), @@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares( EXPECT_EQ(shares.len(), i); } +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; + EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; + EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; +} + void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); @@ -177,7 +197,7 @@ void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); - SSLInt_SetMaxEarlyDataSize(1024); + SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); } @@ -209,7 +229,9 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { + auto token = client_->GetResumptionToken(); client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + client_->SetResumptionToken(token); server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_->SkipVersionChecks(); @@ -219,12 +241,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name, Init(); } -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { @@ -254,10 +291,16 @@ void TlsConnectTestBase::EnableExtendedMasterSecret() { void TlsConnectTestBase::Connect() { server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + client_->MaybeSetResumptionToken(); Handshake(); CheckConnected(); } +void TlsConnectTestBase::StartConnect() { + server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); + client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); +} + void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); @@ -274,6 +317,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected @@ -314,10 +370,12 @@ void TlsConnectTestBase::CheckConnected() { void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { - client_->CheckKEA(kea_type, kea_group); - server_->CheckKEA(kea_type, kea_group); - client_->CheckAuthType(auth_type, sig_scheme); + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, @@ -346,13 +404,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, break; case ssl_auth_rsa_sign: if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; } else { scheme = ssl_sig_rsa_pkcs1_sha256; } break; case ssl_auth_rsa_pss: - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; break; case ssl_auth_ecdsa: scheme = ssl_sig_ecdsa_secp256r1_sha256; @@ -372,9 +430,19 @@ void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + void TlsConnectTestBase::ConnectExpectFail() { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); @@ -395,8 +463,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -455,29 +522,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() { } } +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey privKey; - ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, - &privKey)); - - ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); - ASSERT_TRUE(pubKey); - - EXPECT_EQ(SECSuccess, - SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); + ConfigureSelfEncrypt(); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); - int resume_count = expected ? 1 : 0; - int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); @@ -490,7 +561,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check that the last two session ids match. - ASSERT_EQ(2U, session_ids_.size()); + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { @@ -540,31 +611,28 @@ void TlsConnectTestBase::CheckSrtp() const { server_->CheckSrtp(); } -void TlsConnectTestBase::SendReceive() { - client_->SendData(50); - server_->SendData(50); - Receive(50); +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { + // If we don't do this, then all 0-RTT attempts will be rejected. + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); } // Do a first connection so we can do resumption @@ -584,10 +652,6 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); - if (expect_writable && expect_readable) { - ExpectAlert(client_, kTlsAlertEndOfEarlyData); - } - client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,7 +663,7 @@ void TlsConnectTestBase::ZeroRttSendReceive( } else { EXPECT_EQ(SECFailure, rv); } - server_->Handshake(); // Consume ClientHello, EE, Finished. + server_->Handshake(); // Consume ClientHello std::vector<uint8_t> buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read @@ -608,7 +672,8 @@ void TlsConnectTestBase::ZeroRttSendReceive( EXPECT_EQ(k0RttDataLen, rv); } else { EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); } // Do a second read. this should fail. @@ -653,6 +718,30 @@ void TlsConnectTestBase::SkipVersionChecks() { server_->SkipVersionChecks(); } +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift == PR_INTERVAL_NO_TIMEOUT) { + return; + } + + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -668,20 +757,29 @@ TlsConnectTls12Plus::TlsConnectTls12Plus() TlsConnectTls13::TlsConnectTls13() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} +TlsConnectGenericResumption::TlsConnectGenericResumption() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + external_cache_(std::get<2>(GetParam())) {} + +TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); groups_capture_ = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); shares_capture_ = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); - shares_capture2_ = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + shares_capture2_ = std::make_shared<TlsExtensionCapture>( + client_, ssl_tls13_key_share_xtn, true); std::vector<std::shared_ptr<PacketFilter>> captures = { groups_capture_, shares_capture_, shares_capture2_}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); - capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(capture_hrr_); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); } void TlsKeyExchangeTest::ConfigNamedGroups( @@ -691,11 +789,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); + std::vector<SSLNamedGroup> groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); @@ -705,10 +807,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + std::vector<SSLNamedGroup> shares; size_t i = 2; while (i < ext.len()) { @@ -724,17 +830,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( void TlsKeyExchangeTest::CheckKEXDetails( const std::vector<SSLNamedGroup>& expected_groups, const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { - std::vector<SSLNamedGroup> groups = - GetGroupDetails(groups_capture_->extension()); + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture_->extension()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); EXPECT_EQ(expected_shares, shares); } else { - EXPECT_EQ(0U, shares_capture_->extension().len()); + EXPECT_FALSE(shares_capture_->captured()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); @@ -756,8 +860,6 @@ void TlsKeyExchangeTest::CheckKEXDetails( EXPECT_NE(expected_share2, it); } std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture2_->extension()); - EXPECT_EQ(expected_shares2, shares); + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); } } // namespace nss_test |