summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc298
1 files changed, 129 insertions, 169 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
index d15139419..0453dabdb 100644
--- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -19,8 +19,9 @@ namespace nss_test {
class TlsExtensionTruncator : public TlsExtensionFilter {
public:
- TlsExtensionTruncator(uint16_t extension, size_t length)
- : extension_(extension), length_(length) {}
+ TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t length)
+ : TlsExtensionFilter(agent), extension_(extension), length_(length) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter {
class TlsExtensionDamager : public TlsExtensionFilter {
public:
- TlsExtensionDamager(uint16_t extension, size_t index)
- : extension_(extension), index_(index) {}
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t index)
+ : TlsExtensionFilter(agent), extension_(extension), index_(index) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -61,60 +63,17 @@ class TlsExtensionDamager : public TlsExtensionFilter {
size_t index_;
};
-class TlsExtensionInjector : public TlsHandshakeFilter {
- public:
- TlsExtensionInjector(uint16_t ext, DataBuffer& data)
- : extension_(ext), data_(data) {}
-
- virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) {
- TlsParser parser(input);
- if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
- return KEEP;
- }
- size_t offset = parser.consumed();
-
- *output = input;
-
- // Increase the size of the extensions.
- uint16_t ext_len;
- memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
- ext_len = htons(ntohs(ext_len) + data_.len() + 4);
- memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
-
- // Insert the extension type and length.
- DataBuffer type_length;
- type_length.Allocate(4);
- type_length.Write(0, extension_, 2);
- type_length.Write(2, data_.len(), 2);
- output->Splice(type_length, offset + 2);
-
- // Insert the payload.
- if (data_.len() > 0) {
- output->Splice(data_, offset + 6);
- }
-
- return CHANGE;
- }
-
- private:
- const uint16_t extension_;
- const DataBuffer data_;
-};
-
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
- TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
- : handshake_type_(handshake_type), extension_(ext), data_(data) {}
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(agent, {handshake_type}),
+ extension_(ext),
+ data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() != handshake_type_) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
return KEEP;
@@ -159,7 +118,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter {
return true;
}
- const uint8_t handshake_type_;
const uint16_t extension_;
const DataBuffer data_;
};
@@ -171,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- client_->SetPacketFilter(filter);
+ client_->SetFilter(filter);
ConnectExpectAlert(server_, desc);
}
void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, desc);
}
@@ -200,11 +158,10 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
client_->ConfigNamedGroups(client_groups);
server_->ConfigNamedGroups(server_groups);
EnsureTlsSetup();
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
+ MakeTlsFilter<TlsExtensionDropper>(client_, type);
Handshake();
client_->CheckErrorCode(client_error);
server_->CheckErrorCode(server_error);
@@ -245,8 +202,8 @@ class TlsExtensionTest13
void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
DataBuffer versions_buf(buf, len);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -257,8 +214,8 @@ class TlsExtensionTest13
size_t index = versions_buf.Write(0, 2, 1);
versions_buf.Write(index, version, 2);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
ConnectExpectFail();
}
};
@@ -289,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase,
TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4));
}
TEST_P(TlsExtensionTestGeneric, TruncateSni) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7));
}
// A valid extension that appears twice will be reported as unsupported.
TEST_P(TlsExtensionTestGeneric, RepeatSni) {
DataBuffer extension;
InitSimpleSni(&extension);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>(
+ client_, ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An SNI entry with zero length is considered invalid (strangely, not if it is
@@ -320,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) {
extension.Allocate(simple.len() + 3);
extension.Write(0, static_cast<uint32_t>(0), 3);
extension.Write(3, simple);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptySni) {
DataBuffer extension;
extension.Allocate(2);
extension.Write(0, static_cast<uint32_t>(0), 2);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
EnableAlpn();
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
@@ -347,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertNoApplicationProtocol);
}
TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
EnableAlpn();
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
EnableAlpn();
// This will leave the length of the second entry, but no value.
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 5));
}
TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
@@ -369,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
const uint8_t val[] = {0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ client_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
@@ -388,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
@@ -396,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
@@ -404,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
@@ -412,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
@@ -420,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
@@ -428,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
@@ -436,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ server_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
TEST_P(TlsExtensionTestDtls, SrtpShort) {
EnableSrtp();
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3));
}
TEST_P(TlsExtensionTestDtls, SrtpOdd) {
EnableSrtp();
const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_use_srtp_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) {
+ const uint8_t val[] = {0x00, 0x02, 0xff, 0xff};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
const uint8_t val[] = {0x00, 0x01, 0x04};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn),
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn),
version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
: kTlsAlertMissingExtension);
}
@@ -493,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
const uint8_t val[] = {0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
const uint8_t val[] = {0x01, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
const uint8_t val[] = {0x99};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
const uint8_t val[] = {0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// The extension has to contain a length.
TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
// picks the wrong cipher suite.
TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
- const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512,
- ssl_sig_rsa_pss_sha384};
+ const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512,
+ ssl_sig_rsa_pss_rsae_sha384};
auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
- client_->SetPacketFilter(capture);
EnableOnlyStaticRsaCiphers();
Connect();
@@ -579,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
// Temporary test to verify that we choke on an empty ClientKeyShare.
// This test will fail when we implement HelloRetryRequest.
TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2),
- kTlsAlertHandshakeFailure);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
}
// These tests only work in stream mode because the client sends a
@@ -590,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
// packet gets dropped.
TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn);
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -611,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -633,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -645,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
SetupForResume();
DataBuffer empty;
- server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>(
- ssl_signature_algorithms_xtn, empty));
+ MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn,
+ empty);
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -666,8 +628,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)>
class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
public:
- TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
- : identities_(), binders_(), function_(function) {}
+ TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent,
+ TlsPreSharedKeyReplacerFunc function)
+ : TlsExtensionFilter(agent),
+ identities_(),
+ binders_(),
+ function_(function) {}
static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
const std::unique_ptr<DataBuffer>& replace,
@@ -781,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_[0].identity.Truncate(0);
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -792,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -805,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -818,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>(
- [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -830,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
r->binders_.push_back(r->binders_[0]);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -845,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -857,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_.push_back(r->binders_[0]);
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -870,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
const uint8_t empty_buf[] = {0};
DataBuffer empty(empty_buf, 0);
// Inject an unused extension after the PSK extension.
- client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
- kTlsHandshakeClientHello, 0xffff, empty));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff,
+ empty);
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -881,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
SetupForResume();
DataBuffer empty;
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(
- ssl_tls13_psk_key_exchange_modes_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_psk_key_exchange_modes_xtn);
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
@@ -897,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
kTls13PskKe};
DataBuffer modes(ke_modes, sizeof(ke_modes));
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn, modes);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -908,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
- auto capture = std::make_shared<TlsExtensionCapture>(
- ssl_tls13_psk_key_exchange_modes_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
Connect();
EXPECT_FALSE(capture->captured());
}
@@ -1006,12 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
static uint8_t empty_buf[1] = {0};
DataBuffer empty(empty_buf, 0);
auto filter =
- std::make_shared<TlsExtensionAppender>(message, extension, empty);
+ MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- server_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
- } else {
- server_->SetPacketFilter(filter);
}
}
@@ -1032,17 +998,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
protected:
void ConnectAndFail(uint8_t message) override {
- if (message == kTlsHandshakeHelloRetryRequest) {
+ if (message != kTlsHandshakeServerHello) {
ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
return;
}
- client_->StartConnect();
- server_->StartConnect();
+ FailWithAlert(kTlsAlertUnsupportedExtension);
+ }
+
+ void FailWithAlert(uint8_t alert) {
+ StartConnect();
client_->Handshake(); // ClientHello
server_->Handshake(); // ServerHello
- client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ client_->ExpectSendAlert(alert);
client_->Handshake();
if (variant_ == ssl_variant_stream) {
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -1067,9 +1036,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
Run(kTlsHandshakeCertificate);
}
+// It's perfectly valid to set unknown extensions in CertificateRequest.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
server_->RequestClientAuth(false);
- Run(kTlsHandshakeCertificateRequest);
+ AddFilter(kTlsHandshakeCertificateRequest, 0xff);
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
@@ -1079,10 +1051,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
Run(kTlsHandshakeHelloRetryRequest);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) {
- Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn);
-}
-
TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) {
Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
}
@@ -1096,13 +1064,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) {
Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) {
- static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
- server_->ConfigNamedGroups(groups);
-
- Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn);
-}
-
// NewSessionTicket allows unknown extensions AND it isn't protected by the
// Finished. So adding an unknown extension doesn't cause an error.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
@@ -1132,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) {
SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
EXPECT_EQ(SECSuccess, rv);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn);
client_->StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());