diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 298 |
1 files changed, 129 insertions, 169 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index d15139419..0453dabdb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -19,8 +19,9 @@ namespace nss_test { class TlsExtensionTruncator : public TlsExtensionFilter { public: - TlsExtensionTruncator(uint16_t extension, size_t length) - : extension_(extension), length_(length) {} + TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t length) + : TlsExtensionFilter(agent), extension_(extension), length_(length) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter { class TlsExtensionDamager : public TlsExtensionFilter { public: - TlsExtensionDamager(uint16_t extension, size_t index) - : extension_(extension), index_(index) {} + TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t index) + : TlsExtensionFilter(agent), extension_(extension), index_(index) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -61,60 +63,17 @@ class TlsExtensionDamager : public TlsExtensionFilter { size_t index_; }; -class TlsExtensionInjector : public TlsHandshakeFilter { - public: - TlsExtensionInjector(uint16_t ext, DataBuffer& data) - : extension_(ext), data_(data) {} - - virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindExtensions(&parser, header)) { - return KEEP; - } - size_t offset = parser.consumed(); - - *output = input; - - // Increase the size of the extensions. - uint16_t ext_len; - memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); - ext_len = htons(ntohs(ext_len) + data_.len() + 4); - memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); - - // Insert the extension type and length. - DataBuffer type_length; - type_length.Allocate(4); - type_length.Write(0, extension_, 2); - type_length.Write(2, data_.len(), 2); - output->Splice(type_length, offset + 2); - - // Insert the payload. - if (data_.len() > 0) { - output->Splice(data_, offset + 6); - } - - return CHANGE; - } - - private: - const uint16_t extension_; - const DataBuffer data_; -}; - class TlsExtensionAppender : public TlsHandshakeFilter { public: - TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : handshake_type_(handshake_type), extension_(ext), data_(data) {} + TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : TlsHandshakeFilter(agent, {handshake_type}), + extension_(ext), + data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != handshake_type_) { - return KEEP; - } - TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; @@ -159,7 +118,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter { return true; } - const uint8_t handshake_type_; const uint16_t extension_; const DataBuffer data_; }; @@ -171,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase { void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - client_->SetPacketFilter(filter); + client_->SetFilter(filter); ConnectExpectAlert(server_, desc); } void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, desc); } @@ -200,11 +158,10 @@ class TlsExtensionTestBase : public TlsConnectTestBase { client_->ConfigNamedGroups(client_groups); server_->ConfigNamedGroups(server_groups); EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type)); + MakeTlsFilter<TlsExtensionDropper>(client_, type); Handshake(); client_->CheckErrorCode(client_error); server_->CheckErrorCode(server_error); @@ -245,8 +202,8 @@ class TlsExtensionTest13 void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { DataBuffer versions_buf(buf, len); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -257,8 +214,8 @@ class TlsExtensionTest13 size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectFail(); } }; @@ -289,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase, TEST_P(TlsExtensionTestGeneric, DamageSniLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4)); } TEST_P(TlsExtensionTestGeneric, TruncateSni) { ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7)); } // A valid extension that appears twice will be reported as unsupported. TEST_P(TlsExtensionTestGeneric, RepeatSni) { DataBuffer extension; InitSimpleSni(&extension); - ClientHelloErrorTest( - std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>( + client_, ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); } // An SNI entry with zero length is considered invalid (strangely, not if it is @@ -320,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) { extension.Allocate(simple.len() + 3); extension.Write(0, static_cast<uint32_t>(0), 3); extension.Write(3, simple); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptySni) { DataBuffer extension; extension.Allocate(2); extension.Write(0, static_cast<uint32_t>(0), 2); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { EnableAlpn(); DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } @@ -347,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertNoApplicationProtocol); } TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { EnableAlpn(); - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { EnableAlpn(); // This will leave the length of the second entry, but no value. - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 5)); } TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { @@ -369,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { const uint8_t val[] = {0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + client_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { @@ -388,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { @@ -396,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { @@ -404,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { @@ -412,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { @@ -420,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { @@ -428,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { @@ -436,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + server_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } TEST_P(TlsExtensionTestDtls, SrtpShort) { EnableSrtp(); ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3)); } TEST_P(TlsExtensionTestDtls, SrtpOdd) { EnableSrtp(); const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_use_srtp_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) { + const uint8_t val[] = {0x00, 0x02, 0xff, 0xff}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { const uint8_t val[] = {0x00, 0x01, 0x04}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn), + std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn), version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError : kTlsAlertMissingExtension); } @@ -493,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { const uint8_t val[] = {0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { const uint8_t val[] = {0x01, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { const uint8_t val[] = {0x99}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { const uint8_t val[] = {0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // The extension has to contain a length. TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // This only works on TLS 1.2, since it relies on static RSA; otherwise libssl // picks the wrong cipher suite. TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { - const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, - ssl_sig_rsa_pss_sha384}; + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_rsae_sha384}; auto capture = - std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); - client_->SetPacketFilter(capture); EnableOnlyStaticRsaCiphers(); Connect(); @@ -579,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2), - kTlsAlertHandshakeFailure); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); } // These tests only work in stream mode because the client sends a @@ -590,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { // packet gets dropped. TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn)); + MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -611,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -633,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -645,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { SetupForResume(); DataBuffer empty; - server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>( - ssl_signature_algorithms_xtn, empty)); + MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn, + empty); client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -666,8 +628,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)> class TlsPreSharedKeyReplacer : public TlsExtensionFilter { public: - TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function) - : identities_(), binders_(), function_(function) {} + TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent, + TlsPreSharedKeyReplacerFunc function) + : TlsExtensionFilter(agent), + identities_(), + binders_(), + function_(function) {} static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, const std::unique_ptr<DataBuffer>& replace, @@ -781,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter { TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_[0].identity.Truncate(0); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -792,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -805,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -818,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>( - [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -830,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); r->binders_.push_back(r->binders_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -845,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -857,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_.push_back(r->binders_[0]); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -870,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { const uint8_t empty_buf[] = {0}; DataBuffer empty(empty_buf, 0); // Inject an unused extension after the PSK extension. - client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>( - kTlsHandshakeClientHello, 0xffff, empty)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff, + empty); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -881,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { SetupForResume(); DataBuffer empty; - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>( - ssl_tls13_psk_key_exchange_modes_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_psk_key_exchange_modes_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); @@ -897,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { kTls13PskKe}; DataBuffer modes(ke_modes, sizeof(ke_modes)); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_psk_key_exchange_modes_xtn, modes)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_psk_key_exchange_modes_xtn, modes); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -908,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - auto capture = std::make_shared<TlsExtensionCapture>( - ssl_tls13_psk_key_exchange_modes_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_psk_key_exchange_modes_xtn); Connect(); EXPECT_FALSE(capture->captured()); } @@ -1006,12 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, static uint8_t empty_buf[1] = {0}; DataBuffer empty(empty_buf, 0); auto filter = - std::make_shared<TlsExtensionAppender>(message, extension, empty); + MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - server_->SetTlsRecordFilter(filter); filter->EnableDecryption(); - } else { - server_->SetPacketFilter(filter); } } @@ -1032,17 +998,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { protected: void ConnectAndFail(uint8_t message) override { - if (message == kTlsHandshakeHelloRetryRequest) { + if (message != kTlsHandshakeServerHello) { ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); return; } - client_->StartConnect(); - server_->StartConnect(); + FailWithAlert(kTlsAlertUnsupportedExtension); + } + + void FailWithAlert(uint8_t alert) { + StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello - client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + client_->ExpectSendAlert(alert); client_->Handshake(); if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -1067,9 +1036,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { Run(kTlsHandshakeCertificate); } +// It's perfectly valid to set unknown extensions in CertificateRequest. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { server_->RequestClientAuth(false); - Run(kTlsHandshakeCertificateRequest); + AddFilter(kTlsHandshakeCertificateRequest, 0xff); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { @@ -1079,10 +1051,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { Run(kTlsHandshakeHelloRetryRequest); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) { - Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn); -} - TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) { Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); } @@ -1096,13 +1064,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) { Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) { - static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; - server_->ConfigNamedGroups(groups); - - Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn); -} - // NewSessionTicket allows unknown extensions AND it isn't protected by the // Finished. So adding an unknown extension doesn't cause an error. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { @@ -1132,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) { SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); EXPECT_EQ(SECSuccess, rv); - auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn); client_->StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); |