diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 80 |
1 files changed, 67 insertions, 13 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 4142ab07a..d15139419 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -61,14 +61,60 @@ class TlsExtensionDamager : public TlsExtensionFilter { size_t index_; }; +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(uint16_t ext, DataBuffer& data) + : extension_(ext), data_(data) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; + } + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + class TlsExtensionAppender : public TlsHandshakeFilter { public: TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {} + : handshake_type_(handshake_type), extension_(ext), data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != handshake_type_) { + return KEEP; + } + TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; @@ -113,6 +159,7 @@ class TlsExtensionAppender : public TlsHandshakeFilter { return true; } + const uint8_t handshake_type_; const uint16_t extension_; const DataBuffer data_; }; @@ -153,7 +200,8 @@ class TlsExtensionTestBase : public TlsConnectTestBase { client_->ConfigNamedGroups(client_groups); server_->ConfigNamedGroups(server_groups); EnsureTlsSetup(); - StartConnect(); + client_->StartConnect(); + server_->StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type)); @@ -961,6 +1009,7 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, std::make_shared<TlsExtensionAppender>(message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { server_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); } else { server_->SetPacketFilter(filter); } @@ -983,20 +1032,17 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { protected: void ConnectAndFail(uint8_t message) override { - if (message != kTlsHandshakeServerHello) { + if (message == kTlsHandshakeHelloRetryRequest) { ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); return; } - FailWithAlert(kTlsAlertUnsupportedExtension); - } - - void FailWithAlert(uint8_t alert) { - StartConnect(); + client_->StartConnect(); + server_->StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello - client_->ExpectSendAlert(alert); + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); client_->Handshake(); if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -1021,12 +1067,9 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { Run(kTlsHandshakeCertificate); } -// It's perfectly valid to set unknown extensions in CertificateRequest. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { server_->RequestClientAuth(false); - AddFilter(kTlsHandshakeCertificateRequest, 0xff); - ConnectExpectAlert(client_, kTlsAlertDecryptError); - client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + Run(kTlsHandshakeCertificateRequest); } TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { @@ -1036,6 +1079,10 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { Run(kTlsHandshakeHelloRetryRequest); } +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) { + Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn); +} + TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) { Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); } @@ -1049,6 +1096,13 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) { Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); } +TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) { + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn); +} + // NewSessionTicket allows unknown extensions AND it isn't protected by the // Finished. So adding an unknown extension doesn't cause an error. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { |