diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 181 |
1 files changed, 118 insertions, 63 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index d02549954..c8de5a1fe 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -13,23 +13,27 @@ extern "C" { #include "databuffer.h" #include "gtest_utils.h" +#include "scoped_ptrs.h" #include "sslproto.h" extern std::string g_working_dir_path; namespace nss_test { -static const std::string kTlsModesStreamArr[] = {"TLS"}; -::testing::internal::ParamGenerator<std::string> - TlsConnectTestBase::kTlsModesStream = - ::testing::ValuesIn(kTlsModesStreamArr); -static const std::string kTlsModesDatagramArr[] = {"DTLS"}; -::testing::internal::ParamGenerator<std::string> - TlsConnectTestBase::kTlsModesDatagram = - ::testing::ValuesIn(kTlsModesDatagramArr); -static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; -::testing::internal::ParamGenerator<std::string> - TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); +static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsStream = + ::testing::ValuesIn(kTlsVariantsStreamArr); +static const SSLProtocolVariant kTlsVariantsDatagramArr[] = { + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsDatagram = + ::testing::ValuesIn(kTlsVariantsDatagramArr); +static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream, + ssl_variant_datagram}; +::testing::internal::ParamGenerator<SSLProtocolVariant> + TlsConnectTestBase::kTlsVariantsAll = + ::testing::ValuesIn(kTlsVariantsAllArr); static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 = @@ -99,30 +103,29 @@ std::string VersionString(uint16_t version) { } } -TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) - : mode_(mode), - client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)), - server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)), +TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, + uint16_t version) + : variant_(variant), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)), client_model_(nullptr), server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), session_ids_(), expect_extended_master_secret_(false), - expect_early_data_accepted_(false) { + expect_early_data_accepted_(false), + skip_version_checks_(false) { std::string v; - if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + if (variant_ == ssl_variant_datagram && + version_ == SSL_LIBRARY_VERSION_TLS_1_1) { v = "1.0"; } else { v = VersionString(version_); } - std::cerr << "Version: " << mode_ << " " << v << std::endl; + std::cerr << "Version: " << variant_ << " " << v << std::endl; } -TlsConnectTestBase::TlsConnectTestBase(const std::string& mode, - uint16_t version) - : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {} - TlsConnectTestBase::~TlsConnectTestBase() {} // Check the group of each of the supported groups @@ -166,35 +169,29 @@ void TlsConnectTestBase::ClearStats() { void TlsConnectTestBase::ClearServerCache() { SSL_ShutdownServerSessionIDCache(); - SSLInt_ClearSessionTicketKey(); + SSLInt_ClearSelfEncryptKey(); SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); } void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); - SSLInt_ClearSessionTicketKey(); + SSLInt_ClearSelfEncryptKey(); + SSLInt_SetTicketLifetime(30); + SSLInt_SetMaxEarlyDataSize(1024); ClearStats(); Init(); } void TlsConnectTestBase::TearDown() { - delete client_; - delete server_; - if (client_model_) { - ASSERT_NE(server_model_, nullptr); - delete client_model_; - delete server_model_; - } + client_ = nullptr; + server_ = nullptr; SSL_ClearSessionCache(); - SSLInt_ClearSessionTicketKey(); + SSLInt_ClearSelfEncryptKey(); SSL_ShutdownServerSessionIDCache(); } void TlsConnectTestBase::Init() { - EXPECT_TRUE(client_->Init()); - EXPECT_TRUE(server_->Init()); - client_->SetPeer(server_); server_->SetPeer(client_); @@ -212,11 +209,12 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { - delete client_; - delete server_; - - client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_); - server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_); + client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); + } Init(); } @@ -276,10 +274,12 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { - // Check the version is as expected EXPECT_EQ(client_->version(), server_->version()); - EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), - client_->version()); + if (!skip_version_checks_) { + // Check the version is as expected + EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), + client_->version()); + } EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); @@ -345,6 +345,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, scheme = ssl_sig_none; break; case ssl_auth_rsa_sign: + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { + scheme = ssl_sig_rsa_pss_sha256; + } else { + scheme = ssl_sig_rsa_pkcs1_sha256; + } + break; + case ssl_auth_rsa_pss: scheme = ssl_sig_rsa_pss_sha256; break; case ssl_auth_ecdsa: @@ -373,7 +380,36 @@ void TlsConnectTestBase::ConnectExpectFail() { ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); } +void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + EnsureTlsSetup(); + auto receiver = (sender == client_) ? server_ : client_; + sender->ExpectSendAlert(alert); + receiver->ExpectReceiveAlert(alert); +} + +void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, + uint8_t alert) { + ExpectAlert(sender, alert); + ConnectExpectFail(); +} + +void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { + server_->StartConnect(); + client_->StartConnect(); + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + auto failing_agent = server_; + if (failing_side == TlsAgent::CLIENT) { + failing_agent = client_; + } + ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000); +} + void TlsConnectTestBase::ConfigureVersion(uint16_t version) { + version_ = version; client_->SetVersionRange(version, version); server_->SetVersionRange(version, version); } @@ -424,10 +460,16 @@ void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - // This is an abomination. NSS encrypts session tickets with the server's - // RSA public key. That means we need the server to have an RSA certificate - // even if it won't be used for the connection. - server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt); + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, + &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); } } @@ -472,13 +514,15 @@ void TlsConnectTestBase::EnsureModelSockets() { // Make sure models agents are available. if (!client_model_) { ASSERT_EQ(server_model_, nullptr); - client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_); - server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_); + client_model_.reset( + new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)); + server_model_.reset( + new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)); + if (skip_version_checks_) { + client_model_->SkipVersionChecks(); + server_model_->SkipVersionChecks(); + } } - - // Initialise agents. - ASSERT_TRUE(client_model_->Init()); - ASSERT_TRUE(server_model_->Init()); } void TlsConnectTestBase::CheckAlpn(const std::string& val) { @@ -540,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + if (expect_writable && expect_readable) { + ExpectAlert(client_, kTlsAlertEndOfEarlyData); + } + client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,6 +647,12 @@ void TlsConnectTestBase::DisableECDHEServerKeyReuse() { server_->DisableECDHEServerKeyReuse(); } +void TlsConnectTestBase::SkipVersionChecks() { + skip_version_checks_ = true; + client_->SkipVersionChecks(); + server_->SkipVersionChecks(); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -616,16 +670,17 @@ TlsConnectTls13::TlsConnectTls13() void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); - groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn); - shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn); - shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true); - std::vector<PacketFilter*> captures; - captures.push_back(groups_capture_); - captures.push_back(shares_capture_); - captures.push_back(shares_capture2_); - client_->SetPacketFilter(new ChainedPacketFilter(captures)); - capture_hrr_ = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + groups_capture_ = + std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + shares_capture_ = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + shares_capture2_ = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true); + std::vector<std::shared_ptr<PacketFilter>> captures = { + groups_capture_, shares_capture_, shares_capture2_}; + client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); + capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>( + kTlsHandshakeHelloRetryRequest); server_->SetPacketFilter(capture_hrr_); } |