diff options
author | Moonchild <mcwerewolf@gmail.com> | 2018-06-12 00:58:35 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-12 00:58:35 +0200 |
commit | b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387 (patch) | |
tree | 40d946c5ff23b3c0c09558f478cc68e87cc71448 /security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc | |
parent | b1d82a62259c6888ea6f3f71f3e0973ea4b4e85e (diff) | |
parent | 505a561549b5226fd3c7905eaa61fe787dfad243 (diff) | |
download | UXP-b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387.tar UXP-b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387.tar.gz UXP-b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387.tar.lz UXP-b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387.tar.xz UXP-b0f5f9bc6bb3c8b5ab7b5120dbf7ec48f8445387.zip |
Merge pull request #477 from JustOff/PR_nss-3.36
Update NSS/NSPR to 3.36.4/4.19
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 222 |
1 files changed, 118 insertions, 104 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 4142ab07a..0453dabdb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -19,8 +19,9 @@ namespace nss_test { class TlsExtensionTruncator : public TlsExtensionFilter { public: - TlsExtensionTruncator(uint16_t extension, size_t length) - : extension_(extension), length_(length) {} + TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t length) + : TlsExtensionFilter(agent), extension_(extension), length_(length) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter { class TlsExtensionDamager : public TlsExtensionFilter { public: - TlsExtensionDamager(uint16_t extension, size_t index) - : extension_(extension), index_(index) {} + TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t index) + : TlsExtensionFilter(agent), extension_(extension), index_(index) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -63,8 +65,11 @@ class TlsExtensionDamager : public TlsExtensionFilter { class TlsExtensionAppender : public TlsHandshakeFilter { public: - TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {} + TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : TlsHandshakeFilter(agent, {handshake_type}), + extension_(ext), + data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -124,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase { void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - client_->SetPacketFilter(filter); + client_->SetFilter(filter); ConnectExpectAlert(server_, desc); } void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, desc); } @@ -156,7 +161,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase { StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type)); + MakeTlsFilter<TlsExtensionDropper>(client_, type); Handshake(); client_->CheckErrorCode(client_error); server_->CheckErrorCode(server_error); @@ -197,8 +202,8 @@ class TlsExtensionTest13 void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { DataBuffer versions_buf(buf, len); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -209,8 +214,8 @@ class TlsExtensionTest13 size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectFail(); } }; @@ -241,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase, TEST_P(TlsExtensionTestGeneric, DamageSniLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4)); } TEST_P(TlsExtensionTestGeneric, TruncateSni) { ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7)); } // A valid extension that appears twice will be reported as unsupported. TEST_P(TlsExtensionTestGeneric, RepeatSni) { DataBuffer extension; InitSimpleSni(&extension); - ClientHelloErrorTest( - std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>( + client_, ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); } // An SNI entry with zero length is considered invalid (strangely, not if it is @@ -272,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) { extension.Allocate(simple.len() + 3); extension.Write(0, static_cast<uint32_t>(0), 3); extension.Write(3, simple); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptySni) { DataBuffer extension; extension.Allocate(2); extension.Write(0, static_cast<uint32_t>(0), 2); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { EnableAlpn(); DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } @@ -299,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertNoApplicationProtocol); } TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { EnableAlpn(); - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { EnableAlpn(); // This will leave the length of the second entry, but no value. - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 5)); } TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { @@ -321,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { const uint8_t val[] = {0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + client_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { @@ -340,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { @@ -348,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { @@ -356,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { @@ -364,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { @@ -372,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { @@ -380,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { @@ -388,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + server_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } TEST_P(TlsExtensionTestDtls, SrtpShort) { EnableSrtp(); ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3)); } TEST_P(TlsExtensionTestDtls, SrtpOdd) { EnableSrtp(); const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_use_srtp_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) { + const uint8_t val[] = {0x00, 0x02, 0xff, 0xff}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { const uint8_t val[] = {0x00, 0x01, 0x04}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn), + std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn), version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError : kTlsAlertMissingExtension); } @@ -445,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { const uint8_t val[] = {0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { const uint8_t val[] = {0x01, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { const uint8_t val[] = {0x99}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { const uint8_t val[] = {0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // The extension has to contain a length. TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // This only works on TLS 1.2, since it relies on static RSA; otherwise libssl // picks the wrong cipher suite. TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { - const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, - ssl_sig_rsa_pss_sha384}; + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_rsae_sha384}; auto capture = - std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); - client_->SetPacketFilter(capture); EnableOnlyStaticRsaCiphers(); Connect(); @@ -531,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2), - kTlsAlertHandshakeFailure); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); } // These tests only work in stream mode because the client sends a @@ -542,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { // packet gets dropped. TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn)); + MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -563,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -585,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -597,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { SetupForResume(); DataBuffer empty; - server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>( - ssl_signature_algorithms_xtn, empty)); + MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn, + empty); client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -618,8 +628,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)> class TlsPreSharedKeyReplacer : public TlsExtensionFilter { public: - TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function) - : identities_(), binders_(), function_(function) {} + TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent, + TlsPreSharedKeyReplacerFunc function) + : TlsExtensionFilter(agent), + identities_(), + binders_(), + function_(function) {} static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, const std::unique_ptr<DataBuffer>& replace, @@ -733,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter { TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_[0].identity.Truncate(0); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -744,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -757,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -770,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>( - [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -782,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); r->binders_.push_back(r->binders_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -797,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -809,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_.push_back(r->binders_[0]); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -822,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { const uint8_t empty_buf[] = {0}; DataBuffer empty(empty_buf, 0); // Inject an unused extension after the PSK extension. - client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>( - kTlsHandshakeClientHello, 0xffff, empty)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff, + empty); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -833,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { SetupForResume(); DataBuffer empty; - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>( - ssl_tls13_psk_key_exchange_modes_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_psk_key_exchange_modes_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); @@ -849,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { kTls13PskKe}; DataBuffer modes(ke_modes, sizeof(ke_modes)); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_psk_key_exchange_modes_xtn, modes)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_psk_key_exchange_modes_xtn, modes); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -860,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - auto capture = std::make_shared<TlsExtensionCapture>( - ssl_tls13_psk_key_exchange_modes_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_psk_key_exchange_modes_xtn); Connect(); EXPECT_FALSE(capture->captured()); } @@ -958,11 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, static uint8_t empty_buf[1] = {0}; DataBuffer empty(empty_buf, 0); auto filter = - std::make_shared<TlsExtensionAppender>(message, extension, empty); + MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - server_->SetTlsRecordFilter(filter); - } else { - server_->SetPacketFilter(filter); + filter->EnableDecryption(); } } @@ -1078,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) { SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); EXPECT_EQ(SECSuccess, rv); - auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn); client_->StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); |