diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc | 117 |
1 files changed, 63 insertions, 54 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index a130ef77f..e4a9e5aed 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -22,8 +22,11 @@ namespace nss_test { class TlsHandshakeSkipFilter : public TlsRecordFilter { public: // A TLS record filter that skips handshake messages of the identified type. - TlsHandshakeSkipFilter(uint8_t handshake_type) - : handshake_type_(handshake_type), skipped_(false) {} + TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type) + : TlsRecordFilter(agent), + handshake_type_(handshake_type), + skipped_(false) {} protected: // Takes a record; if it is a handshake record, it removes the first handshake @@ -43,7 +46,14 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { size_t start = parser.consumed(); TlsHandshakeFilter::HandshakeHeader header; DataBuffer ignored; - if (!header.Parse(&parser, record_header, &ignored)) { + bool complete = false; + if (!header.Parse(&parser, record_header, DataBuffer(), &ignored, + &complete)) { + ADD_FAILURE() << "Error parsing handshake header"; + return KEEP; + } + if (!complete) { + ADD_FAILURE() << "Don't want to deal with fragmented input"; return KEEP; } @@ -85,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase, TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + void ServerSkipTest(std::shared_ptr<PacketFilter> filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, alert); } }; @@ -98,29 +113,23 @@ class Tls13SkipTest : public TlsConnectTestBase, Tls13SkipTest() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} - void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + void SetUp() override { + TlsConnectTestBase::SetUp(); EnsureTlsSetup(); - server_->SetTlsRecordFilter(filter); + } + + void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { filter->EnableDecryption(); - client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - ConnectExpectFail(); - } else { - ConnectExpectFailOneSide(TlsAgent::CLIENT); - } + server_->SetFilter(filter); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + ConnectExpectFail(); client_->CheckErrorCode(error); - if (variant_ == ssl_variant_stream) { - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - } else { - ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); - } + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); } void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { - EnsureTlsSetup(); - client_->SetTlsRecordFilter(filter); filter->EnableDecryption(); + client_->SetFilter(filter); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); ConnectExpectFailOneSide(TlsAgent::SERVER); @@ -133,49 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase, TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)}); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } @@ -183,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) { TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( - kTlsHandshakeEncryptedExtensions), + server_, kTlsHandshakeEncryptedExtensions), SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); } TEST_P(Tls13SkipTest, SkipServerCertificate) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } TEST_P(Tls13SkipTest, SkipClientCertificate) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } INSTANTIATE_TEST_CASE_P( |