diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc | 134 |
1 files changed, 58 insertions, 76 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index 97943303a..cdafa7a84 100644 --- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -24,7 +24,7 @@ TEST_P(TlsConnectGeneric, ConnectDhe) { EnableOnlyDheCiphers(); Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { @@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { client_->ConfigNamedGroups(kAllDHEGroups); auto groups_capture = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); Connect(); @@ -59,15 +59,14 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { TEST_P(TlsConnectGeneric, ConnectFfdheClient) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); auto groups_capture = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); Connect(); @@ -90,8 +89,7 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) { // because the client automatically sends the supported groups extension. TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { Connect(); @@ -105,14 +103,11 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { public: - TlsDheServerKeyExchangeDamager() {} + TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Damage the first octet of dh_p. Anything other than the known prime will // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled. *output = input; @@ -126,9 +121,8 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { // the signature until everything else has been checked. TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); - server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>()); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + MakeTlsFilter<TlsDheServerKeyExchangeDamager>(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -147,7 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { kYZeroPad }; - TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {} + TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, ChangeYTo change) + : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {} protected: void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, @@ -212,8 +208,11 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { public: - TlsDheSkeChangeYServer(ChangeYTo change, bool modify) - : TlsDheSkeChangeY(change), modify_(modify), p_() {} + TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& agent, + ChangeYTo change, bool modify) + : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change), + modify_(modify), + p_() {} const DataBuffer& prime() const { return p_; } @@ -221,10 +220,6 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - size_t offset = 2; // Read dh_p uint32_t dh_len = 0; @@ -252,18 +247,15 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { public: TlsDheSkeChangeYClient( - ChangeYTo change, + const std::shared_ptr<TlsAgent>& agent, ChangeYTo change, std::shared_ptr<const TlsDheSkeChangeYServer> server_filter) - : TlsDheSkeChangeY(change), server_filter_(server_filter) {} + : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change), + server_filter_(server_filter) {} protected: virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { - return KEEP; - } - ChangeY(input, output, 0, server_filter_->prime()); return CHANGE; } @@ -289,12 +281,10 @@ class TlsDamageDHYTest TEST_P(TlsDamageDHYTest, DamageServerY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - server_->SetPacketFilter( - std::make_shared<TlsDheSkeChangeYServer>(change, true)); + MakeTlsFilter<TlsDheSkeChangeYServer>(server_, change, true); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(client_, kTlsAlertDecryptError); @@ -320,18 +310,15 @@ TEST_P(TlsDamageDHYTest, DamageServerY) { TEST_P(TlsDamageDHYTest, DamageClientY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } // The filter on the server is required to capture the prime. - auto server_filter = - std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false); - server_->SetPacketFilter(server_filter); + auto server_filter = MakeTlsFilter<TlsDheSkeChangeYServer>( + server_, TlsDheSkeChangeY::kYZero, false); // The client filter does the damage. TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - client_->SetPacketFilter( - std::make_shared<TlsDheSkeChangeYClient>(change, server_filter)); + MakeTlsFilter<TlsDheSkeChangeYClient>(client_, change, server_filter); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(server_, kTlsAlertDecryptError); @@ -370,13 +357,12 @@ INSTANTIATE_TEST_CASE_P( class TlsDheSkeMakePEven : public TlsHandshakeFilter { public: + TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Find the end of dh_p uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -394,7 +380,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter { // Even without requiring named groups, an even value for p is bad news. TEST_P(TlsConnectGenericPre13, MakeDhePEven) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>()); + MakeTlsFilter<TlsDheSkeMakePEven>(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -404,13 +390,12 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) { class TlsDheSkeZeroPadP : public TlsHandshakeFilter { public: + TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - *output = input; uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -425,7 +410,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter { // Zero padding only causes signature failure. TEST_P(TlsConnectGenericPre13, PadDheP) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>()); + MakeTlsFilter<TlsDheSkeZeroPadP>(server_); ConnectExpectAlert(client_, kTlsAlertDecryptError); @@ -445,8 +430,7 @@ TEST_P(TlsConnectGenericPre13, PadDheP) { // Note: This test case can take ages to generate the weak DH key. TEST_P(TlsConnectGenericPre13, WeakDHGroup) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); @@ -474,7 +458,7 @@ TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) { Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // Same test but for TLS 1.3. This has to fail. @@ -496,8 +480,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) { // custom group in contrast to the previous test. TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048}; @@ -519,14 +502,13 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { Connect(); client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); - client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); - server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, MismatchDHE) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group}; EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups, PR_ARRAY_SIZE(serverGroups))); @@ -544,37 +526,37 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); // Need to read so that we absorb the session ticket. - CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); EnableOnlyDheCiphers(); auto clientCapture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(clientCapture); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); auto serverCapture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - server_->SetPacketFilter(serverCapture); + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_pre_shared_key_xtn); ExpectResumption(RESUME_TICKET); Connect(); - CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); ASSERT_LT(0UL, clientCapture->extension().len()); ASSERT_LT(0UL, serverCapture->extension().len()); } class TlsDheSkeChangeSignature : public TlsHandshakeFilter { public: - TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len) - : version_(version), data_(data), len_(len) {} + TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& agent, + uint16_t version, const uint8_t* data, size_t len) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}), + version_(version), + data_(data), + len_(len) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - TlsParser parser(input); EXPECT_TRUE(parser.SkipVariable(2)); // dh_p EXPECT_TRUE(parser.SkipVariable(2)); // dh_g @@ -615,8 +597,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048}; client_->ConfigNamedGroups(client_groups); - server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>( - version_, kBogusDheSignature, sizeof(kBogusDheSignature))); + MakeTlsFilter<TlsDheSkeChangeSignature>(server_, version_, kBogusDheSignature, + sizeof(kBogusDheSignature)); ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); |