diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 190 |
1 files changed, 140 insertions, 50 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index c8de5a1fe..0af5123e9 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -5,6 +5,7 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" +#include "sslexp.h" extern "C" { #include "libssl_internals.h" } @@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: @@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), @@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares( EXPECT_EQ(shares.len(), i); } +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; + EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; + EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; +} + void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); @@ -178,6 +198,7 @@ void TlsConnectTestBase::SetUp() { SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); SSLInt_SetMaxEarlyDataSize(1024); + SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); } @@ -219,12 +240,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name, Init(); } -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { @@ -258,6 +294,11 @@ void TlsConnectTestBase::Connect() { CheckConnected(); } +void TlsConnectTestBase::StartConnect() { + server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); + client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); +} + void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); @@ -274,6 +315,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected @@ -314,10 +368,12 @@ void TlsConnectTestBase::CheckConnected() { void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { - client_->CheckKEA(kea_type, kea_group); - server_->CheckKEA(kea_type, kea_group); - client_->CheckAuthType(auth_type, sig_scheme); + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, @@ -372,9 +428,19 @@ void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + void TlsConnectTestBase::ConnectExpectFail() { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); @@ -395,8 +461,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -455,29 +520,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() { } } +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey privKey; - ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, - &privKey)); - - ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); - ASSERT_TRUE(pubKey); - - EXPECT_EQ(SECSuccess, - SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); + ConfigureSelfEncrypt(); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); - int resume_count = expected ? 1 : 0; - int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); @@ -490,7 +559,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check that the last two session ids match. - ASSERT_EQ(2U, session_ids_.size()); + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { @@ -540,31 +609,28 @@ void TlsConnectTestBase::CheckSrtp() const { server_->CheckSrtp(); } -void TlsConnectTestBase::SendReceive() { - client_->SendData(50); - server_->SendData(50); - Receive(50); +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { + // If we don't do this, then all 0-RTT attempts will be rejected. + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); } // Do a first connection so we can do resumption @@ -584,10 +650,6 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); - if (expect_writable && expect_readable) { - ExpectAlert(client_, kTlsAlertEndOfEarlyData); - } - client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,7 +661,7 @@ void TlsConnectTestBase::ZeroRttSendReceive( } else { EXPECT_EQ(SECFailure, rv); } - server_->Handshake(); // Consume ClientHello, EE, Finished. + server_->Handshake(); // Consume ClientHello std::vector<uint8_t> buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read @@ -653,6 +715,30 @@ void TlsConnectTestBase::SkipVersionChecks() { server_->SkipVersionChecks(); } +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift == PR_INTERVAL_NO_TIMEOUT) { + return; + } + + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -691,11 +777,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); + std::vector<SSLNamedGroup> groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); @@ -705,10 +795,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + std::vector<SSLNamedGroup> shares; size_t i = 2; while (i < ext.len()) { @@ -724,17 +818,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( void TlsKeyExchangeTest::CheckKEXDetails( const std::vector<SSLNamedGroup>& expected_groups, const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { - std::vector<SSLNamedGroup> groups = - GetGroupDetails(groups_capture_->extension()); + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture_->extension()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); EXPECT_EQ(expected_shares, shares); } else { - EXPECT_EQ(0U, shares_capture_->extension().len()); + EXPECT_FALSE(shares_capture_->captured()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); @@ -756,8 +848,6 @@ void TlsKeyExchangeTest::CheckKEXDetails( EXPECT_NE(expected_share2, it); } std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture2_->extension()); - EXPECT_EQ(expected_shares2, shares); + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); } } // namespace nss_test |