diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | 77 |
1 files changed, 51 insertions, 26 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 4bc6e60ab..f1b78f52f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -56,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { class TlsAlertRecorder : public TlsRecordFilter { public: - TlsAlertRecorder() : level_(255), description_(255) {} + TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), level_(255), description_(255) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -86,9 +87,9 @@ class TlsAlertRecorder : public TlsRecordFilter { class HelloTruncator : public TlsHandshakeFilter { public: - HelloTruncator() + HelloTruncator(const std::shared_ptr<TlsAgent>& agent) : TlsHandshakeFilter( - {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} + agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -99,9 +100,8 @@ class HelloTruncator : public TlsHandshakeFilter { // Verify that when NSS reports that an alert is sent, it is actually sent. TEST_P(TlsConnectGeneric, CaptureAlertServer) { - client_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - server_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(client_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_); ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -109,9 +109,8 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) { } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); ConnectExpectAlert(client_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -120,9 +119,8 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { // In TLS 1.3, the server can't read the client alert. TEST_P(TlsConnectTls13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); StartConnect(); @@ -173,7 +171,8 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) { class SaveTlsRecord : public TlsRecordFilter { public: - SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {} + SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0), contents_() {} const DataBuffer& contents() const { return contents_; } @@ -198,8 +197,8 @@ class SaveTlsRecord : public TlsRecordFilter { TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer - auto saved = std::make_shared<SaveTlsRecord>(3); - client_->SetTlsRecordFilter(saved); + auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3); + saved->EnableDecryption(); Connect(); SendReceive(); @@ -215,8 +214,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer - auto saved = std::make_shared<SaveTlsRecord>(3); - server_->SetTlsRecordFilter(saved); + auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3); + saved->EnableDecryption(); Connect(); SendReceive(); @@ -228,7 +227,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { class DropTlsRecord : public TlsRecordFilter { public: - DropTlsRecord(size_t index) : index_(index), count_(0) {} + DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0) {} protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, @@ -253,7 +253,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) { SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); // 0 = ServerHello, 1 = other handshake, 2 = first write - server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2)); + auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2); + filter->EnableDecryption(); Connect(); server_->SendData(23, 23); // This should be dropped, so it won't be counted. server_->ResetSentBytes(); @@ -263,7 +264,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) { TEST_F(TlsConnectStreamTls13, DropRecordClient) { EnsureTlsSetup(); // 0 = ClientHello, 1 = Finished, 2 = first write - client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2)); + auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2); + filter->EnableDecryption(); Connect(); client_->SendData(26, 26); // This should be dropped, so it won't be counted. client_->ResetSentBytes(); @@ -371,7 +373,8 @@ TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: - TlsPreCCSHeaderInjector() {} + TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, size_t* offset, DataBuffer* output) override { @@ -388,14 +391,14 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { - client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); + MakeTlsFilter<TlsPreCCSHeaderInjector>(client_); ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { - server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); + MakeTlsFilter<TlsPreCCSHeaderInjector>(server_); StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); @@ -476,8 +479,7 @@ TEST_F(TlsConnectTest, OneNRecordSplitting) { ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); EnsureTlsSetup(); ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); - auto records = std::make_shared<TlsRecordRecorder>(); - server_->SetPacketFilter(records); + auto records = MakeTlsFilter<TlsRecordRecorder>(server_); // This should be split into 1, 16384 and 20. DataBuffer big_buffer; big_buffer.Allocate(1 + 16384 + 20); @@ -535,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -} // namespace nspr_test +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test |