diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 190 |
1 files changed, 50 insertions, 140 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index 0af5123e9..c8de5a1fe 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -5,7 +5,6 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" -#include "sslexp.h" extern "C" { #include "libssl_internals.h" } @@ -89,8 +88,6 @@ std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; - case SSL_LIBRARY_VERSION_3_0: - return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: @@ -115,7 +112,6 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), - expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), @@ -165,22 +161,6 @@ void TlsConnectTestBase::CheckShares( EXPECT_EQ(shares.len(), i); } -void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, - uint16_t server_epoch) const { - uint16_t read_epoch = 0; - uint16_t write_epoch = 0; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; - EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; - - EXPECT_EQ(SECSuccess, - SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); - EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; - EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; -} - void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); @@ -198,7 +178,6 @@ void TlsConnectTestBase::SetUp() { SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); SSLInt_SetMaxEarlyDataSize(1024); - SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); } @@ -240,27 +219,12 @@ void TlsConnectTestBase::Reset(const std::string& server_name, Init(); } -void TlsConnectTestBase::MakeNewServer() { - auto replacement = std::make_shared<TlsAgent>( - server_->name(), TlsAgent::SERVER, server_->variant()); - server_ = replacement; - if (version_) { - server_->SetVersionRange(version_, version_); - } - client_->SetPeer(server_); - server_->SetPeer(client_); - server_->StartConnect(); -} - -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, - uint8_t num_resumptions) { +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); - expected_resumptions_ = num_resumptions; } - EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { @@ -294,11 +258,6 @@ void TlsConnectTestBase::Connect() { CheckConnected(); } -void TlsConnectTestBase::StartConnect() { - server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); - client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); -} - void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); @@ -315,19 +274,6 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { - // Have the client read handshake twice to make sure we get the - // NST and the ACK. - if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && - variant_ == ssl_variant_datagram) { - client_->Handshake(); - client_->Handshake(); - auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); - // Verify that we dropped the client's retransmission cipher suites. - EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; - if (suites != 2) { - SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); - } - } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected @@ -368,12 +314,10 @@ void TlsConnectTestBase::CheckConnected() { void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { - if (kea_group != ssl_grp_none) { - client_->CheckKEA(kea_type, kea_group); - server_->CheckKEA(kea_type, kea_group); - } - server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); client_->CheckAuthType(auth_type, sig_scheme); + server_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, @@ -428,19 +372,9 @@ void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } -void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, - SSLNamedGroup kea_group, - SSLNamedGroup original_kea_group, - SSLAuthType auth_type, - SSLSignatureScheme sig_scheme) { - CheckKeys(kea_type, kea_group, auth_type, sig_scheme); - EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); - client_->CheckOriginalKEA(original_kea_group); - server_->CheckOriginalKEA(original_kea_group); -} - void TlsConnectTestBase::ConnectExpectFail() { - StartConnect(); + server_->StartConnect(); + client_->StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); @@ -461,7 +395,8 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { - StartConnect(); + server_->StartConnect(); + client_->StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -520,33 +455,29 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() { } } -void TlsConnectTestBase::ConfigureSelfEncrypt() { - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey privKey; - ASSERT_TRUE( - TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); - - ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); - ASSERT_TRUE(pubKey); - - EXPECT_EQ(SECSuccess, - SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); -} - void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - ConfigureSelfEncrypt(); + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, + &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); - int resume_count = expected ? expected_resumptions_ : 0; - int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; + int resume_count = expected ? 1 : 0; + int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); @@ -559,7 +490,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check that the last two session ids match. - ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); + ASSERT_EQ(2U, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { @@ -609,28 +540,31 @@ void TlsConnectTestBase::CheckSrtp() const { server_->CheckSrtp(); } -void TlsConnectTestBase::SendReceive(size_t total) { - ASSERT_GT(total, client_->received_bytes()); - ASSERT_GT(total, server_->received_bytes()); - client_->SendData(total - server_->received_bytes()); - server_->SendData(total - client_->received_bytes()); - Receive(total); // Receive() is cumulative +void TlsConnectTestBase::SendReceive() { + client_->SendData(50); + server_->SendData(50); + Receive(50); } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { - // If we don't do this, then all 0-RTT attempts will be rejected. - SSLInt_RolloverAntiReplay(); - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); - StartConnect(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->StartConnect(); + client_->StartConnect(); } // Do a first connection so we can do resumption @@ -650,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + if (expect_writable && expect_readable) { + ExpectAlert(client_, kTlsAlertEndOfEarlyData); + } + client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -661,7 +599,7 @@ void TlsConnectTestBase::ZeroRttSendReceive( } else { EXPECT_EQ(SECFailure, rv); } - server_->Handshake(); // Consume ClientHello + server_->Handshake(); // Consume ClientHello, EE, Finished. std::vector<uint8_t> buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read @@ -715,30 +653,6 @@ void TlsConnectTestBase::SkipVersionChecks() { server_->SkipVersionChecks(); } -// Shift the DTLS timers, to the minimum time necessary to let the next timer -// run on either client or server. This allows tests to skip waiting without -// having timers run out of order. -void TlsConnectTestBase::ShiftDtlsTimers() { - PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; - PRIntervalTime time; - SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); - if (rv == SECSuccess) { - time_shift = time; - } - rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); - if (rv == SECSuccess && - (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { - time_shift = time; - } - - if (time_shift == PR_INTERVAL_NO_TIMEOUT) { - return; - } - - EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); - EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); -} - TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -777,15 +691,11 @@ void TlsKeyExchangeTest::ConfigNamedGroups( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( - const std::shared_ptr<TlsExtensionCapture>& capture) { - EXPECT_TRUE(capture->captured()); - const DataBuffer& ext = capture->extension(); - + const DataBuffer& ext) { uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); - std::vector<SSLNamedGroup> groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); @@ -795,14 +705,10 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( - const std::shared_ptr<TlsExtensionCapture>& capture) { - EXPECT_TRUE(capture->captured()); - const DataBuffer& ext = capture->extension(); - + const DataBuffer& ext) { uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); - std::vector<SSLNamedGroup> shares; size_t i = 2; while (i < ext.len()) { @@ -818,15 +724,17 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( void TlsKeyExchangeTest::CheckKEXDetails( const std::vector<SSLNamedGroup>& expected_groups, const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { - std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); + std::vector<SSLNamedGroup> groups = + GetGroupDetails(groups_capture_->extension()); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); - std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); + std::vector<SSLNamedGroup> shares = + GetShareDetails(shares_capture_->extension()); EXPECT_EQ(expected_shares, shares); } else { - EXPECT_FALSE(shares_capture_->captured()); + EXPECT_EQ(0U, shares_capture_->extension().len()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); @@ -848,6 +756,8 @@ void TlsKeyExchangeTest::CheckKEXDetails( EXPECT_NE(expected_share2, it); } std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; - EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); + std::vector<SSLNamedGroup> shares = + GetShareDetails(shares_capture2_->extension()); + EXPECT_EQ(expected_shares2, shares); } } // namespace nss_test |