diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc | 143 |
1 files changed, 112 insertions, 31 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index 523a37499..a130ef77f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -28,9 +28,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { protected: // Takes a record; if it is a handshake record, it removes the first handshake // message that is of handshake_type_ type. - virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header, - const DataBuffer& input, - DataBuffer* output) { + virtual PacketFilter::Action FilterRecord( + const TlsRecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { if (record_header.content_type() != kTlsHandshakeType) { return KEEP; } @@ -78,81 +78,162 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { bool skipped_; }; -class TlsSkipTest - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { +class TlsSkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { protected: TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} - void ServerSkipTest(PacketFilter* filter, + void ServerSkipTest(std::shared_ptr<PacketFilter> filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - auto alert_recorder = new TlsAlertRecorder(); - client_->SetPacketFilter(alert_recorder); - if (filter) { - server_->SetPacketFilter(filter); + server_->SetPacketFilter(filter); + ConnectExpectAlert(client_, alert); + } +}; + +class Tls13SkipTest : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + protected: + Tls13SkipTest() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + + void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + EnsureTlsSetup(); + server_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + } else { + ConnectExpectFailOneSide(TlsAgent::CLIENT); + } + client_->CheckErrorCode(error); + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); + } + + void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + EnsureTlsSetup(); + client_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + ConnectExpectFailOneSide(TlsAgent::SERVER); + + server_->CheckErrorCode(error); + ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + + client_->Handshake(); // Make sure to consume the alert the server sends. } }; TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared<ChainedPacketFilter>(); + chain->Add( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + chain->Add( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared<ChainedPacketFilter>(); + chain->Add( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + chain->Add( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } -INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); +TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + kTlsHandshakeEncryptedExtensions), + SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); +} + +TEST_P(Tls13SkipTest, SkipServerCertificate) { + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { + ServerSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +TEST_P(Tls13SkipTest, SkipClientCertificate) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); + ClientSkipTest( + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +INSTANTIATE_TEST_CASE_P( + SkipTls10, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10)); INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest, - ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV11V12)); - +INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest, + TlsConnectTestBase::kTlsVariantsAll); } // namespace nss_test |