summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc143
1 files changed, 112 insertions, 31 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
index 523a37499..a130ef77f 100644
--- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -28,9 +28,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
// message that is of handshake_type_ type.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
- const DataBuffer& input,
- DataBuffer* output) {
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
if (record_header.content_type() != kTlsHandshakeType) {
return KEEP;
}
@@ -78,81 +78,162 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
bool skipped_;
};
-class TlsSkipTest
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsSkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
protected:
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
- void ServerSkipTest(PacketFilter* filter,
+ void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- auto alert_recorder = new TlsAlertRecorder();
- client_->SetPacketFilter(alert_recorder);
- if (filter) {
- server_->SetPacketFilter(filter);
+ server_->SetPacketFilter(filter);
+ ConnectExpectAlert(client_, alert);
+ }
+};
+
+class Tls13SkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ protected:
+ Tls13SkipTest()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ EnsureTlsSetup();
+ server_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(error);
+ if (variant_ == ssl_variant_stream) {
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
}
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ EnsureTlsSetup();
+ client_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFailOneSide(TlsAgent::SERVER);
+
+ server_->CheckErrorCode(error);
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ client_->Handshake(); // Make sure to consume the alert the server sends.
}
};
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = new ChainedPacketFilter();
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- auto chain = new ChainedPacketFilter();
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
-INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10));
+TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ kTlsHandshakeEncryptedExtensions),
+ SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificate) {
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificate) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ SkipTls10, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV11V12));
-
+INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest,
+ TlsConnectTestBase::kTlsVariantsAll);
} // namespace nss_test