summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc77
1 files changed, 51 insertions, 26 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
index 4bc6e60ab..f1b78f52f 100644
--- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -56,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
class TlsAlertRecorder : public TlsRecordFilter {
public:
- TlsAlertRecorder() : level_(255), description_(255) {}
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), level_(255), description_(255) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -86,9 +87,9 @@ class TlsAlertRecorder : public TlsRecordFilter {
class HelloTruncator : public TlsHandshakeFilter {
public:
- HelloTruncator()
+ HelloTruncator(const std::shared_ptr<TlsAgent>& agent)
: TlsHandshakeFilter(
- {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
+ agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
@@ -99,9 +100,8 @@ class HelloTruncator : public TlsHandshakeFilter {
// Verify that when NSS reports that an alert is sent, it is actually sent.
TEST_P(TlsConnectGeneric, CaptureAlertServer) {
- client_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(client_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_);
ConnectExpectAlert(server_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -109,9 +109,8 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) {
}
TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
ConnectExpectAlert(client_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -120,9 +119,8 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
// In TLS 1.3, the server can't read the client alert.
TEST_P(TlsConnectTls13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
StartConnect();
@@ -173,7 +171,8 @@ TEST_P(TlsConnectGeneric, ConnectSendReceive) {
class SaveTlsRecord : public TlsRecordFilter {
public:
- SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {}
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0), contents_() {}
const DataBuffer& contents() const { return contents_; }
@@ -198,8 +197,8 @@ class SaveTlsRecord : public TlsRecordFilter {
TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- client_->SetTlsRecordFilter(saved);
+ auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
Connect();
SendReceive();
@@ -215,8 +214,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
- auto saved = std::make_shared<SaveTlsRecord>(3);
- server_->SetTlsRecordFilter(saved);
+ auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
Connect();
SendReceive();
@@ -228,7 +227,8 @@ TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
class DropTlsRecord : public TlsRecordFilter {
public:
- DropTlsRecord(size_t index) : index_(index), count_(0) {}
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0) {}
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -253,7 +253,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
// 0 = ServerHello, 1 = other handshake, 2 = first write
- server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
Connect();
server_->SendData(23, 23); // This should be dropped, so it won't be counted.
server_->ResetSentBytes();
@@ -263,7 +264,8 @@ TEST_F(TlsConnectStreamTls13, DropRecordServer) {
TEST_F(TlsConnectStreamTls13, DropRecordClient) {
EnsureTlsSetup();
// 0 = ClientHello, 1 = Finished, 2 = first write
- client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
Connect();
client_->SendData(26, 26); // This should be dropped, so it won't be counted.
client_->ResetSentBytes();
@@ -371,7 +373,8 @@ TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
- TlsPreCCSHeaderInjector() {}
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
size_t* offset, DataBuffer* output) override {
@@ -388,14 +391,14 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
};
TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
- client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(client_);
ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
}
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
- server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(server_);
StartConnect();
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
@@ -476,8 +479,7 @@ TEST_F(TlsConnectTest, OneNRecordSplitting) {
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
EnsureTlsSetup();
ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
- auto records = std::make_shared<TlsRecordRecorder>();
- server_->SetPacketFilter(records);
+ auto records = MakeTlsFilter<TlsRecordRecorder>(server_);
// This should be split into 1, 16384 and 20.
DataBuffer big_buffer;
big_buffer.Allocate(1 + 16384 + 20);
@@ -535,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus,
::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus));
-} // namespace nspr_test
+INSTANTIATE_TEST_CASE_P(
+ GenericStream, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll,
+ ::testing::Values(true, false)));
+INSTANTIATE_TEST_CASE_P(
+ GenericDatagram, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus,
+ ::testing::Values(true, false)));
+
+INSTANTIATE_TEST_CASE_P(
+ GenericStream, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(
+ GenericDatagram, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+} // namespace nss_test