diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc | 94 |
1 files changed, 54 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index 335bfecfa..e4a9e5aed 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -22,8 +22,11 @@ namespace nss_test { class TlsHandshakeSkipFilter : public TlsRecordFilter { public: // A TLS record filter that skips handshake messages of the identified type. - TlsHandshakeSkipFilter(uint8_t handshake_type) - : handshake_type_(handshake_type), skipped_(false) {} + TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type) + : TlsRecordFilter(agent), + handshake_type_(handshake_type), + skipped_(false) {} protected: // Takes a record; if it is a handshake record, it removes the first handshake @@ -92,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase, TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + void ServerSkipTest(std::shared_ptr<PacketFilter> filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, alert); } }; @@ -105,9 +113,14 @@ class Tls13SkipTest : public TlsConnectTestBase, Tls13SkipTest() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} - void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + void SetUp() override { + TlsConnectTestBase::SetUp(); EnsureTlsSetup(); - server_->SetTlsRecordFilter(filter); + } + + void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + filter->EnableDecryption(); + server_->SetFilter(filter); ExpectAlert(client_, kTlsAlertUnexpectedMessage); ConnectExpectFail(); client_->CheckErrorCode(error); @@ -115,8 +128,8 @@ class Tls13SkipTest : public TlsConnectTestBase, } void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { - EnsureTlsSetup(); - client_->SetTlsRecordFilter(filter); + filter->EnableDecryption(); + client_->SetFilter(filter); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); ConnectExpectFailOneSide(TlsAgent::SERVER); @@ -129,48 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase, TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{ - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - std::make_shared<TlsHandshakeSkipFilter>( - kTlsHandshakeServerKeyExchange)}); + auto chain = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)}); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } @@ -178,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) { TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( - kTlsHandshakeEncryptedExtensions), + server_, kTlsHandshakeEncryptedExtensions), SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); } TEST_P(Tls13SkipTest, SkipServerCertificate) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } TEST_P(Tls13SkipTest, SkipClientCertificate) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } INSTANTIATE_TEST_CASE_P( |