/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" #include "sslexp.h" extern "C" { #include "libssl_internals.h" } #include #include "databuffer.h" #include "gtest_utils.h" #include "nss_scoped_ptrs.h" #include "sslproto.h" extern std::string g_working_dir_path; namespace nss_test { static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsVariantsStream = ::testing::ValuesIn(kTlsVariantsStreamArr); static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { ssl_variant_datagram}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsVariantsDatagram = ::testing::ValuesIn(kTlsVariantsDatagramArr); static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, ssl_variant_datagram}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsVariantsAll = ::testing::ValuesIn(kTlsVariantsAllArr); static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10 = ::testing::ValuesIn(kTlsV10Arr); static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV11 = ::testing::ValuesIn(kTlsV11Arr); static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV12 = ::testing::ValuesIn(kTlsV12Arr); static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, SSL_LIBRARY_VERSION_TLS_1_1}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10V11 = ::testing::ValuesIn(kTlsV10V11Arr); static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10ToV12 = ::testing::ValuesIn(kTlsV10ToV12Arr); static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV11V12 = ::testing::ValuesIn(kTlsV11V12Arr); static const uint16_t kTlsV11PlusArr[] = { #ifndef NSS_DISABLE_TLS_1_3 SSL_LIBRARY_VERSION_TLS_1_3, #endif SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV11Plus = ::testing::ValuesIn(kTlsV11PlusArr); static const uint16_t kTlsV12PlusArr[] = { #ifndef NSS_DISABLE_TLS_1_3 SSL_LIBRARY_VERSION_TLS_1_3, #endif SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV12Plus = ::testing::ValuesIn(kTlsV12PlusArr); static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV13 = ::testing::ValuesIn(kTlsV13Arr); static const uint16_t kTlsVAllArr[] = { #ifndef NSS_DISABLE_TLS_1_3 SSL_LIBRARY_VERSION_TLS_1_3, #endif SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsVAll = ::testing::ValuesIn(kTlsVAllArr); std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; case SSL_LIBRARY_VERSION_3_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: return "1.1"; case SSL_LIBRARY_VERSION_TLS_1_2: return "1.2"; case SSL_LIBRARY_VERSION_TLS_1_3: return "1.3"; default: std::cerr << "Invalid version: " << version << std::endl; EXPECT_TRUE(false); return ""; } } // The default anti-replay window for tests. Tests that rely on a different // value call ResetAntiReplay directly. static PRTime kAntiReplayWindow = 100 * PR_USEC_PER_SEC; TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version) : variant_(variant), client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), client_model_(nullptr), server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), skip_version_checks_(false) { std::string v; if (variant_ == ssl_variant_datagram && version_ == SSL_LIBRARY_VERSION_TLS_1_1) { v = "1.0"; } else { v = VersionString(version_); } std::cerr << "Version: " << variant_ << " " << v << std::endl; } TlsConnectTestBase::~TlsConnectTestBase() {} // Check the group of each of the supported groups void TlsConnectTestBase::CheckGroups( const DataBuffer& groups, std::function check_group) { DuplicateGroupChecker group_set; uint32_t tmp = 0; EXPECT_TRUE(groups.Read(0, 2, &tmp)); EXPECT_EQ(groups.len() - 2, static_cast(tmp)); for (size_t i = 2; i < groups.len(); i += 2) { EXPECT_TRUE(groups.Read(i, 2, &tmp)); SSLNamedGroup group = static_cast(tmp); group_set.AddAndCheckGroup(group); check_group(group); } } // Check the group of each of the shares void TlsConnectTestBase::CheckShares( const DataBuffer& shares, std::function check_group) { DuplicateGroupChecker group_set; uint32_t tmp = 0; EXPECT_TRUE(shares.Read(0, 2, &tmp)); EXPECT_EQ(shares.len() - 2, static_cast(tmp)); size_t i; for (i = 2; i < shares.len(); i += 4 + tmp) { ASSERT_TRUE(shares.Read(i, 2, &tmp)); SSLNamedGroup group = static_cast(tmp); group_set.AddAndCheckGroup(group); check_group(group); ASSERT_TRUE(shares.Read(i + 2, 2, &tmp)); } EXPECT_EQ(shares.len(), i); } void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const { client_->CheckEpochs(server_epoch, client_epoch); server_->CheckEpochs(client_epoch, server_epoch); } void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); memset(stats, 0, sizeof(*stats)); } void TlsConnectTestBase::ClearServerCache() { SSL_ShutdownServerSessionIDCache(); SSLInt_ClearSelfEncryptKey(); SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); } void TlsConnectTestBase::SaveAlgorithmPolicy() { saved_policies_.clear(); for (auto it = algorithms_.begin(); it != algorithms_.end(); ++it) { uint32_t policy; SECStatus rv = NSS_GetAlgorithmPolicy(*it, &policy); ASSERT_EQ(SECSuccess, rv); saved_policies_.push_back(std::make_tuple(*it, policy)); } } void TlsConnectTestBase::RestoreAlgorithmPolicy() { for (auto it = saved_policies_.begin(); it != saved_policies_.end(); ++it) { auto algorithm = std::get<0>(*it); auto policy = std::get<1>(*it); SECStatus rv = NSS_SetAlgorithmPolicy( algorithm, policy, NSS_USE_POLICY_IN_SSL | NSS_USE_ALG_IN_SSL_KX); ASSERT_EQ(SECSuccess, rv); } } PRTime TlsConnectTestBase::TimeFunc(void* arg) { return *reinterpret_cast(arg); } void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSelfEncryptKey(); now_ = PR_Now(); ResetAntiReplay(kAntiReplayWindow); ClearStats(); SaveAlgorithmPolicy(); Init(); } void TlsConnectTestBase::TearDown() { client_ = nullptr; server_ = nullptr; SSL_ClearSessionCache(); SSLInt_ClearSelfEncryptKey(); SSL_ShutdownServerSessionIDCache(); RestoreAlgorithmPolicy(); } void TlsConnectTestBase::Init() { client_->SetPeer(server_); server_->SetPeer(client_); if (version_) { ConfigureVersion(version_); } } void TlsConnectTestBase::ResetAntiReplay(PRTime window) { SSLAntiReplayContext* p_anti_replay = nullptr; EXPECT_EQ(SECSuccess, SSL_CreateAntiReplayContext(now_, window, 1, 3, &p_anti_replay)); EXPECT_NE(nullptr, p_anti_replay); anti_replay_.reset(p_anti_replay); } void TlsConnectTestBase::Reset() { // Take a copy of the names because they are about to disappear. std::string server_name = server_->name(); std::string client_name = client_->name(); Reset(server_name, client_name); } void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { auto token = client_->GetResumptionToken(); client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); client_->SetResumptionToken(token); server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_->SkipVersionChecks(); server_->SkipVersionChecks(); } std::cerr << "Reset server:" << server_name << ", client:" << client_name << std::endl; Init(); } void TlsConnectTestBase::MakeNewServer() { auto replacement = std::make_shared( server_->name(), TlsAgent::SERVER, server_->variant()); server_ = replacement; if (version_) { server_->SetVersionRange(version_, version_); } client_->SetPeer(server_); server_->SetPeer(client_); server_->StartConnect(); } void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, uint8_t num_resumptions) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); expected_resumptions_ = num_resumptions; } EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd() : nullptr)); EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd() : nullptr)); server_->SetAntiReplayContext(anti_replay_); EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(client_->ssl_fd(), TlsConnectTestBase::TimeFunc, &now_)); EXPECT_EQ(SECSuccess, SSL_SetTimeFunc(server_->ssl_fd(), TlsConnectTestBase::TimeFunc, &now_)); } void TlsConnectTestBase::Handshake() { client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && (server_->state() != TlsAgent::STATE_CONNECTING), 5000); } void TlsConnectTestBase::EnableExtendedMasterSecret() { client_->EnableExtendedMasterSecret(); server_->EnableExtendedMasterSecret(); ExpectExtendedMasterSecret(true); } void TlsConnectTestBase::Connect() { StartConnect(); client_->MaybeSetResumptionToken(); Handshake(); CheckConnected(); } void TlsConnectTestBase::StartConnect() { EnsureTlsSetup(); server_->StartConnect(); client_->StartConnect(); } void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); Connect(); SendReceive(); // Check that we used the right cipher suite. uint16_t actual; EXPECT_TRUE(client_->cipher_suite(&actual)); EXPECT_EQ(cipher_suite, actual); EXPECT_TRUE(server_->cipher_suite(&actual)); EXPECT_EQ(cipher_suite, actual); } void TlsConnectTestBase::CheckConnected() { // Have the client read handshake twice to make sure we get the // NST and the ACK. if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && variant_ == ssl_variant_datagram) { client_->Handshake(); client_->Handshake(); auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); // Verify that we dropped the client's retransmission cipher suites. EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; if (suites != 2) { SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); } } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), client_->version()); } EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); uint16_t cipher_suite1, cipher_suite2; bool ret = client_->cipher_suite(&cipher_suite1); EXPECT_TRUE(ret); ret = server_->cipher_suite(&cipher_suite2); EXPECT_TRUE(ret); EXPECT_EQ(cipher_suite1, cipher_suite2); std::cerr << "Connected with version " << client_->version() << " cipher suite " << client_->cipher_suite_name() << std::endl; if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check and store session ids. std::vector sid_c1 = client_->session_id(); EXPECT_EQ(32U, sid_c1.size()); std::vector sid_s1 = server_->session_id(); EXPECT_EQ(32U, sid_s1.size()); EXPECT_EQ(sid_c1, sid_s1); session_ids_.push_back(sid_c1); } CheckExtendedMasterSecret(); CheckEarlyDataAccepted(); CheckResumption(expected_resumption_mode_); client_->CheckSecretsDestroyed(); server_->CheckSecretsDestroyed(); } void TlsConnectTestBase::CheckEarlyDataLimit( const std::shared_ptr& agent, size_t expected_size) { SSLPreliminaryChannelInfo preinfo; SECStatus rv = SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(expected_size, static_cast(preinfo.maxEarlyDataSize)); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { if (kea_group != ssl_grp_none) { client_->CheckKEA(kea_type, kea_group); server_->CheckKEA(kea_type, kea_group); } server_->CheckAuthType(auth_type, sig_scheme); client_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const { SSLNamedGroup group; switch (kea_type) { case ssl_kea_ecdh: group = ssl_grp_ec_curve25519; break; case ssl_kea_dh: group = ssl_grp_ffdhe_2048; break; case ssl_kea_rsa: group = ssl_grp_none; break; default: EXPECT_TRUE(false) << "unexpected KEA"; group = ssl_grp_none; break; } SSLSignatureScheme scheme; switch (auth_type) { case ssl_auth_rsa_decrypt: scheme = ssl_sig_none; break; case ssl_auth_rsa_sign: if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { scheme = ssl_sig_rsa_pss_rsae_sha256; } else { scheme = ssl_sig_rsa_pkcs1_sha256; } break; case ssl_auth_rsa_pss: scheme = ssl_sig_rsa_pss_rsae_sha256; break; case ssl_auth_ecdsa: scheme = ssl_sig_ecdsa_secp256r1_sha256; break; case ssl_auth_dsa: scheme = ssl_sig_dsa_sha1; break; default: EXPECT_TRUE(false) << "unexpected auth type"; scheme = static_cast(0x0100); break; } CheckKeys(kea_type, group, auth_type, scheme); } void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLNamedGroup original_kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) { CheckKeys(kea_type, kea_group, auth_type, sig_scheme); EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); client_->CheckOriginalKEA(original_kea_group); server_->CheckOriginalKEA(original_kea_group); } void TlsConnectTestBase::ConnectExpectFail() { StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); } void TlsConnectTestBase::ExpectAlert(std::shared_ptr& sender, uint8_t alert) { EnsureTlsSetup(); auto receiver = (sender == client_) ? server_ : client_; sender->ExpectSendAlert(alert); receiver->ExpectReceiveAlert(alert); } void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr& sender, uint8_t alert) { ExpectAlert(sender, alert); ConnectExpectFail(); } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); auto failing_agent = server_; if (failing_side == TlsAgent::CLIENT) { failing_agent = client_; } ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); } void TlsConnectTestBase::ConfigureVersion(uint16_t version) { version_ = version; client_->SetVersionRange(version, version); server_->SetVersionRange(version, version); } void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { client_->SetExpectedVersion(version); server_->SetExpectedVersion(version); } void TlsConnectTestBase::AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, uint16_t zeroRttSuite) { client_->AddPsk(psk, label, hash, zeroRttSuite); server_->AddPsk(psk, label, hash, zeroRttSuite); client_->ExpectPsk(); server_->ExpectPsk(); } void TlsConnectTestBase::DisableAllCiphers() { EnsureTlsSetup(); client_->DisableAllCiphers(); server_->DisableAllCiphers(); } void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() { DisableAllCiphers(); client_->EnableCiphersByKeyExchange(ssl_kea_rsa); server_->EnableCiphersByKeyExchange(ssl_kea_rsa); } void TlsConnectTestBase::EnableOnlyDheCiphers() { if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { DisableAllCiphers(); client_->EnableCiphersByKeyExchange(ssl_kea_dh); server_->EnableCiphersByKeyExchange(ssl_kea_dh); } else { client_->ConfigNamedGroups(kFFDHEGroups); server_->ConfigNamedGroups(kFFDHEGroups); } } void TlsConnectTestBase::EnableSomeEcdhCiphers() { if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); } else { client_->ConfigNamedGroups(kECDHEGroups); server_->ConfigNamedGroups(kECDHEGroups); } } void TlsConnectTestBase::ConfigureSelfEncrypt() { ScopedCERTCertificate cert; ScopedSECKEYPrivateKey privKey; ASSERT_TRUE( TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); ASSERT_TRUE(pubKey); EXPECT_EQ(SECSuccess, SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); } void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { ConfigureSelfEncrypt(); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); int resume_count = expected ? expected_resumptions_ : 0; int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); EXPECT_EQ(resume_count, stats->hch_sid_cache_hits); EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits); EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes); EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes); if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && client_->GetResumptionToken().size() == 0) { // Check that the last two session ids match. ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { // We've either chosen TLS 1.3 or are using an external resumption token, // both of which only use tickets. EXPECT_TRUE(expected & RESUME_TICKET); } } } static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd, const unsigned char* protos, unsigned int protos_len, unsigned char* protoOut, unsigned int* protoOutLen, unsigned int protoMaxLen) { EXPECT_EQ(protoMaxLen, 255U); TlsAgent* agent = reinterpret_cast(arg); // Check that agent->alpn_value_to_use_ is in protos. if (protos_len < 1) { return SECFailure; } for (size_t i = 0; i < protos_len;) { size_t l = protos[i]; EXPECT_LT(i + l, protos_len); if (i + l >= protos_len) { return SECFailure; } std::string protos_s(reinterpret_cast(protos + i + 1), l); if (protos_s == agent->alpn_value_to_use_) { size_t s_len = agent->alpn_value_to_use_.size(); EXPECT_LE(s_len, 255U); memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len); *protoOutLen = s_len; return SECSuccess; } i += l + 1; } return SECFailure; } void TlsConnectTestBase::EnableAlpn() { client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); } void TlsConnectTestBase::EnableAlpnWithCallback( const std::vector& client_vals, std::string server_choice) { EnsureTlsSetup(); server_->alpn_value_to_use_ = server_choice; EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(), client_vals.size())); SECStatus rv = SSL_SetNextProtoCallback( server_->ssl_fd(), NextProtoCallbackServer, server_.get()); EXPECT_EQ(SECSuccess, rv); } void TlsConnectTestBase::EnableAlpn(const std::vector& vals) { client_->EnableAlpn(vals.data(), vals.size()); server_->EnableAlpn(vals.data(), vals.size()); } void TlsConnectTestBase::EnsureModelSockets() { // Make sure models agents are available. if (!client_model_) { ASSERT_EQ(server_model_, nullptr); client_model_.reset( new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); server_model_.reset( new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_model_->SkipVersionChecks(); server_model_->SkipVersionChecks(); } } } void TlsConnectTestBase::CheckAlpn(const std::string& val) { client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val); server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val); } void TlsConnectTestBase::EnableSrtp() { client_->EnableSrtp(); server_->EnableSrtp(); } void TlsConnectTestBase::CheckSrtp() const { client_->CheckSrtp(); server_->CheckSrtp(); } void TlsConnectTestBase::SendReceive(size_t total) { ASSERT_GT(total, client_->received_bytes()); ASSERT_GT(total, server_->received_bytes()); client_->SendData(total - server_->received_bytes()); server_->SendData(total - client_->received_bytes()); Receive(total); // Receive() is cumulative } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { // Force rollover of the anti-replay window. // If we don't do this, then all 0-RTT attempts will be rejected. RolloverAntiReplay(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); StartConnect(); } // Do a first connection so we can do resumption void TlsConnectTestBase::SetupForResume() { EnsureTlsSetup(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); } void TlsConnectTestBase::ZeroRttSendReceive( bool expect_writable, bool expect_readable, std::function post_clienthello_check) { const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast(strlen(k0RttData)); client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; } PRInt32 rv = PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. if (expect_writable) { EXPECT_EQ(k0RttDataLen, rv); } else { EXPECT_EQ(SECFailure, rv); } server_->Handshake(); // Consume ClientHello std::vector buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read if (expect_readable) { std::cerr << "0-RTT read " << rv << " bytes\n"; EXPECT_EQ(k0RttDataLen, rv); } else { EXPECT_EQ(SECFailure, rv); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); } // Do a second read. This should fail. rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); EXPECT_EQ(SECFailure, rv); EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } void TlsConnectTestBase::Receive(size_t amount) { WAIT_(client_->received_bytes() == amount && server_->received_bytes() == amount, 2000); ASSERT_EQ(amount, client_->received_bytes()); ASSERT_EQ(amount, server_->received_bytes()); } void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { expect_extended_master_secret_ = expected; } void TlsConnectTestBase::CheckExtendedMasterSecret() { client_->CheckExtendedMasterSecret(expect_extended_master_secret_); server_->CheckExtendedMasterSecret(expect_extended_master_secret_); } void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) { expect_early_data_accepted_ = expected; } void TlsConnectTestBase::CheckEarlyDataAccepted() { client_->CheckEarlyDataAccepted(expect_early_data_accepted_); server_->CheckEarlyDataAccepted(expect_early_data_accepted_); } void TlsConnectTestBase::DisableECDHEServerKeyReuse() { server_->DisableECDHEServerKeyReuse(); } void TlsConnectTestBase::SkipVersionChecks() { skip_version_checks_ = true; client_->SkipVersionChecks(); server_->SkipVersionChecks(); } // Shift the DTLS timers, to the minimum time necessary to let the next timer // run on either client or server. This allows tests to skip waiting without // having timers run out of order. void TlsConnectTestBase::ShiftDtlsTimers() { PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; PRIntervalTime time; SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); if (rv == SECSuccess) { time_shift = time; } rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); if (rv == SECSuccess && (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { time_shift = time; } if (time_shift != PR_INTERVAL_NO_TIMEOUT) { AdvanceTime(PR_IntervalToMicroseconds(time_shift)); EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); } } void TlsConnectTestBase::AdvanceTime(PRTime time_shift) { now_ += time_shift; } // Advance time by a full anti-replay window. void TlsConnectTestBase::RolloverAntiReplay() { AdvanceTime(kAntiReplayWindow); } TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} TlsConnectPre12::TlsConnectPre12() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} TlsConnectTls12::TlsConnectTls12() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {} TlsConnectTls12Plus::TlsConnectTls12Plus() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} TlsConnectTls13::TlsConnectTls13() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} TlsConnectGenericResumption::TlsConnectGenericResumption() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), external_cache_(std::get<2>(GetParam())) {} TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); groups_capture_ = std::make_shared(client_, ssl_supported_groups_xtn); shares_capture_ = std::make_shared(client_, ssl_tls13_key_share_xtn); shares_capture2_ = std::make_shared( client_, ssl_tls13_key_share_xtn, true); std::vector> captures = { groups_capture_, shares_capture_, shares_capture2_}; client_->SetFilter(std::make_shared(captures)); capture_hrr_ = MakeTlsFilter( server_, kTlsHandshakeHelloRetryRequest); } void TlsKeyExchangeTest::ConfigNamedGroups( const std::vector& groups) { client_->ConfigNamedGroups(groups); server_->ConfigNamedGroups(groups); } std::vector TlsKeyExchangeTest::GetGroupDetails( const std::shared_ptr& capture) { EXPECT_TRUE(capture->captured()); const DataBuffer& ext = capture->extension(); uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); std::vector groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); groups.push_back(static_cast(tmp)); } return groups; } std::vector TlsKeyExchangeTest::GetShareDetails( const std::shared_ptr& capture) { EXPECT_TRUE(capture->captured()); const DataBuffer& ext = capture->extension(); uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast(tmp)); std::vector shares; size_t i = 2; while (i < ext.len()) { EXPECT_TRUE(ext.Read(i, 2, &tmp)); shares.push_back(static_cast(tmp)); EXPECT_TRUE(ext.Read(i + 2, 2, &tmp)); i += 4 + tmp; } EXPECT_EQ(ext.len(), i); return shares; } void TlsKeyExchangeTest::CheckKEXDetails( const std::vector& expected_groups, const std::vector& expected_shares, bool expect_hrr) { std::vector groups = GetGroupDetails(groups_capture_); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); std::vector shares = GetShareDetails(shares_capture_); EXPECT_EQ(expected_shares, shares); } else { EXPECT_FALSE(shares_capture_->captured()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); } void TlsKeyExchangeTest::CheckKEXDetails( const std::vector& expected_groups, const std::vector& expected_shares) { CheckKEXDetails(expected_groups, expected_shares, false); } void TlsKeyExchangeTest::CheckKEXDetails( const std::vector& expected_groups, const std::vector& expected_shares, SSLNamedGroup expected_share2) { CheckKEXDetails(expected_groups, expected_shares, true); for (auto it : expected_shares) { EXPECT_NE(expected_share2, it); } std::vector expected_shares2 = {expected_share2}; EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); } } // namespace nss_test