diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
46 files changed, 5531 insertions, 1450 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile index a9a9290e0..95c111aeb 100644 --- a/security/nss/gtests/ssl_gtest/Makefile +++ b/security/nss/gtests/ssl_gtest/Makefile @@ -29,10 +29,6 @@ include ../common/gtest.mk CFLAGS += -I$(CORE_DEPTH)/lib/ssl -ifdef NSS_SSL_ENABLE_ZLIB -include $(CORE_DEPTH)/coreconf/zlib.mk -endif - ifdef NSS_DISABLE_TLS_1_3 NSS_DISABLE_TLS_1_3=1 # Run parameterized tests only, for which we can easily exclude TLS 1.3 diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc new file mode 100644 index 000000000..6efe06ec7 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc @@ -0,0 +1,108 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +extern "C" { +#include "sslbloom.h" +} + +#include "gtest_utils.h" + +namespace nss_test { + +// Some random-ish inputs to test with. These don't result in collisions in any +// of the configurations that are tested below. +static const uint8_t kHashes1[] = { + 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8, + 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34, + 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48}; +static const uint8_t kHashes2[] = { + 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59, + 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc, + 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1}; + +typedef struct { + unsigned int k; + unsigned int bits; +} BloomFilterConfig; + +class BloomFilterTest + : public ::testing::Test, + public ::testing::WithParamInterface<BloomFilterConfig> { + public: + BloomFilterTest() : filter_() {} + + void SetUp() { Init(); } + + void TearDown() { sslBloom_Destroy(&filter_); } + + protected: + void Init() { + if (filter_.filter) { + sslBloom_Destroy(&filter_); + } + ASSERT_EQ(SECSuccess, + sslBloom_Init(&filter_, GetParam().k, GetParam().bits)); + } + + bool Check(const uint8_t* hashes) { + return sslBloom_Check(&filter_, hashes) ? true : false; + } + + void Add(const uint8_t* hashes, bool expect_collision = false) { + EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false); + EXPECT_TRUE(Check(hashes)); + } + + sslBloomFilter filter_; +}; + +TEST_P(BloomFilterTest, InitOnly) {} + +TEST_P(BloomFilterTest, AddToEmpty) { + EXPECT_FALSE(Check(kHashes1)); + Add(kHashes1); +} + +TEST_P(BloomFilterTest, AddTwo) { + Add(kHashes1); + Add(kHashes2); +} + +TEST_P(BloomFilterTest, AddOneTwice) { + Add(kHashes1); + Add(kHashes1, true); +} + +TEST_P(BloomFilterTest, Zero) { + Add(kHashes1); + sslBloom_Zero(&filter_); + EXPECT_FALSE(Check(kHashes1)); + EXPECT_FALSE(Check(kHashes2)); +} + +TEST_P(BloomFilterTest, Fill) { + sslBloom_Fill(&filter_); + EXPECT_TRUE(Check(kHashes1)); + EXPECT_TRUE(Check(kHashes2)); +} + +static const BloomFilterConfig kBloomFilterConfigurations[] = { + {1, 1}, // 1 hash, 1 bit input - high chance of collision. + {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size. + {1, 3}, // 1 hash, 3 bits - same as basic unit size. + {1, 4}, // 1 hash, 4 bits - 2 octets each. + {3, 10}, // 3 hashes over a reasonable number of bits. + {3, 3}, // Test that we can read multiple bits. + {4, 15}, // A credible filter. + {2, 18}, // A moderately large allocation. + {16, 16}, // Insane, use all of the bits from the hashes. + {16, 9}, // This also uses all of the bits from the hashes. +}; + +INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest, + ::testing::ValuesIn(kBloomFilterConfigurations)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index 97b8354ae..17b4ffe49 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -34,18 +34,17 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, return SECFailure; } - ssl3_InitState(ss); ssl3_RestartHandshakeHashes(ss); // Ensure we don't overrun hs.client_random. rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); - // Zero the client_random struct. - PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + // Zero the client_random. + PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); // Copy over the challenge bytes. size_t offset = SSL3_RANDOM_LENGTH - rnd_len; - PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len); + PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len); // Rehash the SSLv2 client hello message. return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); @@ -73,10 +72,11 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { return SECFailure; } ss->ssl3.mtu = mtu; + ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */ return SECSuccess; } -PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { PRCList *cur_p; PRInt32 ct = 0; @@ -92,7 +92,7 @@ PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { return ct; } -void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { PRCList *cur_p; sslSocket *ss = ssl_FindSocket(fd); @@ -100,27 +100,31 @@ void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { return; } - fprintf(stderr, "Cipher specs\n"); + fprintf(stderr, "Cipher specs for %s\n", label); for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; - fprintf(stderr, " %s\n", spec->phase); + fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec), + spec->epoch, spec->phase, spec->refCt); } } -/* Force a timer expiry by backdating when the timer was started. - * We could set the remaining time to 0 but then backoff would not - * work properly if we decide to test it. */ -void SSLInt_ForceTimerExpiry(PRFileDesc *fd) { +/* Force a timer expiry by backdating when all active timers were started. We + * could set the remaining time to 0 but then backoff would not work properly if + * we decide to test it. */ +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { + size_t i; sslSocket *ss = ssl_FindSocket(fd); if (!ss) { - return; + return SECFailure; } - if (!ss->ssl3.hs.rtTimerCb) return; - - ss->ssl3.hs.rtTimerStarted = - PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1); + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + if (ss->ssl3.hs.timers[i].cb) { + ss->ssl3.hs.timers[i].started -= shift; + } + } + return SECSuccess; } #define CHECK_SECRET(secret) \ @@ -136,7 +140,6 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { } CHECK_SECRET(currentSecret); - CHECK_SECRET(resumptionMasterSecret); CHECK_SECRET(dheSecret); CHECK_SECRET(clientEarlyTrafficSecret); CHECK_SECRET(clientHsTrafficSecret); @@ -226,28 +229,7 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { return PR_TRUE; } -PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) { - sslSocket *ss = ssl_FindSocket(fd); - if (!ss) { - return PR_FALSE; - } - - ssl_GetSSL3HandshakeLock(ss); - ssl_GetXmitBufLock(ss); - - SECStatus rv = tls13_SendNewSessionTicket(ss); - if (rv == SECSuccess) { - rv = ssl3_FlushHandshake(ss, 0); - } - - ssl_ReleaseXmitBufLock(ss); - ssl_ReleaseSSL3HandshakeLock(ss); - - return rv == SECSuccess; -} - SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ssl3CipherSpec *spec; @@ -255,43 +237,40 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); spec = ss->ssl3.crSpec; - epoch = spec->read_seq_num >> 48; - spec->read_seq_num = (epoch << 48) | to; + spec->seqNum = to; /* For DTLS, we need to fix the record sequence number. For this, we can just * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ - if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } - dtls_RecordSetRecvd(&spec->recvdRecords, to); + dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum); ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ss = ssl_FindSocket(fd); if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); - epoch = ss->ssl3.cwSpec->write_seq_num >> 48; - ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to; + ss->ssl3.cwSpec->seqNum = to; ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } @@ -305,9 +284,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { return SECFailure; } ssl_GetSpecReadLock(ss); - to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra; + to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra; ssl_ReleaseSpecReadLock(ss); - return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX); + return SSLInt_AdvanceWriteSeqNum(fd, to); } SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { @@ -333,56 +312,26 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, return SECSuccess; } -static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer, - ssl3CipherSpec *spec) { - return isServer ? &spec->server : &spec->client; +PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) { + return spec->keyMaterial.key; } -PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_key; +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) { + return spec->cipherDef->calg; } -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, - ssl3CipherSpec *spec) { - return spec->cipher_def->calg; +const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) { + return spec->keyMaterial.iv; } -unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_iv; -} - -SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - ss->opt.enableShortHeaders = PR_TRUE; - return SECSuccess; -} - -SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - *result = ss->ssl3.hs.shortHeaders; - return SECSuccess; +PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) { + return spec->epoch; } void SSLInt_SetTicketLifetime(uint32_t lifetime) { ssl_ticket_lifetime = lifetime; } -void SSLInt_SetMaxEarlyDataSize(uint32_t size) { - ssl_max_early_data_size = size; -} - SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { sslSocket *ss; @@ -405,3 +354,21 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { return SECSuccess; } + +void SSLInt_RolloverAntiReplay(void) { + tls13_AntiReplayRollover(ssl_TimeUsec()); +} + +SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, + PRUint16 *writeEpoch) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss || !readEpoch || !writeEpoch) { + return SECFailure; + } + + ssl_GetSpecReadLock(ss); + *readEpoch = ss->ssl3.crSpec->epoch; + *writeEpoch = ss->ssl3.cwSpec->epoch; + ssl_ReleaseSpecReadLock(ss); + return SECSuccess; +} diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h index 33709c4b4..3efb362c2 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.h +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -24,9 +24,9 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext); void SSLInt_ClearSelfEncryptKey(); void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key); -PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd); -void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd); -void SSLInt_ForceTimerExpiry(PRFileDesc *fd); +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd); +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd); +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift); SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu); PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd); PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd); @@ -35,23 +35,22 @@ PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd); SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len); PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType); PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type); -PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd); SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); +SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, + PRUint16 *writeEpoch); SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, sslCipherSpecChangedFunc func, void *arg); -PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec); -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, - ssl3CipherSpec *spec); -unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec); -SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd); -SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result); +PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec); +PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec); +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec); +const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec); void SSLInt_SetTicketLifetime(uint32_t lifetime); -void SSLInt_SetMaxEarlyDataSize(uint32_t size); SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); +void SSLInt_RolloverAntiReplay(void); #endif // ndef libssl_internals_h_ diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn index cc729c0f1..5d893bab3 100644 --- a/security/nss/gtests/ssl_gtest/manifest.mn +++ b/security/nss/gtests/ssl_gtest/manifest.mn @@ -12,11 +12,13 @@ CSRCS = \ $(NULL) CPPSRCS = \ + bloomfilter_unittest.cc \ ssl_0rtt_unittest.cc \ ssl_agent_unittest.cc \ ssl_auth_unittest.cc \ ssl_cert_ext_unittest.cc \ ssl_ciphersuite_unittest.cc \ + ssl_custext_unittest.cc \ ssl_damage_unittest.cc \ ssl_dhe_unittest.cc \ ssl_drop_unittest.cc \ @@ -29,11 +31,16 @@ CPPSRCS = \ ssl_gather_unittest.cc \ ssl_gtest.cc \ ssl_hrr_unittest.cc \ + ssl_keylog_unittest.cc \ + ssl_keyupdate_unittest.cc \ ssl_loopback_unittest.cc \ + ssl_misc_unittest.cc \ ssl_record_unittest.cc \ ssl_resumption_unittest.cc \ + ssl_renegotiation_unittest.cc \ ssl_skip_unittest.cc \ ssl_staticrsa_unittest.cc \ + ssl_tls13compat_unittest.cc \ ssl_v2_client_hello_unittest.cc \ ssl_version_unittest.cc \ ssl_versionpolicy_unittest.cc \ diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc index 85b7011a1..08781af71 100644 --- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -7,6 +7,7 @@ #include "secerr.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" extern "C" { @@ -44,6 +45,92 @@ TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) { SendReceive(); } +TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) { + // The test fixtures call SSL_SetupAntiReplay() in SetUp(). This results in + // 0-RTT being rejected until at least one window passes. SetupFor0Rtt() + // forces a rollover of the anti-replay filters, which clears this state. + // Here, we do the setup manually here without that forced rollover. + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +class TlsZeroRttReplayTest : public TlsConnectTls13 { + private: + class SaveFirstPacket : public PacketFilter { + public: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!packet_.len() && input.len()) { + packet_ = input; + } + return KEEP; + } + + const DataBuffer& packet() const { return packet_; } + + private: + DataBuffer packet_; + }; + + protected: + void RunTest(bool rollover) { + // Run the initial handshake + SetupForZeroRtt(); + + // Now run a true 0-RTT handshake, but capture the first packet. + auto first_packet = std::make_shared<SaveFirstPacket>(); + client_->SetFilter(first_packet); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + EXPECT_LT(0U, first_packet->packet().len()); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + + if (rollover) { + SSLInt_RolloverAntiReplay(); + } + + // Now replay that packet against the server. + Reset(); + server_->StartConnect(); + server_->Set0RttEnabled(true); + + // Capture the early_data extension, which should not appear. + auto early_data_ext = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_early_data_xtn); + + // Finally, replay the ClientHello and force the server to consume it. Stop + // after the server sends its first flight; the client will not be able to + // complete this handshake. + server_->adapter()->PacketReceived(first_packet->packet()); + server_->Handshake(); + EXPECT_FALSE(early_data_ext->captured()); + } +}; + +TEST_P(TlsZeroRttReplayTest, ZeroRttReplay) { RunTest(false); } + +TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { RunTest(true); } + // Test that we don't try to send 0-RTT data when the server sent // us a ticket without the 0-RTT flags. TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { @@ -52,8 +139,7 @@ TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); Reset(); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); // Now turn on 0-RTT but too late for the ticket. client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); @@ -80,8 +166,7 @@ TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) { TEST_P(TlsConnectTls13, ZeroRttServerOnly) { ExpectResumption(RESUME_NONE); server_->Set0RttEnabled(true); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); // Client sends ordinary ClientHello. client_->Handshake(); @@ -99,6 +184,61 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnly) { CheckKeys(); } +// A small sleep after sending the ClientHello means that the ticket age that +// arrives at the server is too low. With a small tolerance for variation in +// ticket age (which is determined by the |window| parameter that is passed to +// SSL_SetupAntiReplay()), the server then rejects early data. +TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); + SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. + SSLInt_RolloverAntiReplay(); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, []() { + PR_Sleep(PR_MillisecondsToInterval(10)); + return true; + }); + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// In this test, we falsely inflate the estimate of the RTT by delaying the +// ServerHello on the first handshake. This results in the server estimating a +// higher value of the ticket age than the client ultimately provides. Add a +// small tolerance for variation in ticket age and the ticket will appear to +// arrive prematurely, causing the server to reject early data. +TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + PR_Sleep(PR_MillisecondsToInterval(10)); + Handshake(); // Remainder of handshake + CheckConnected(); + SendReceive(); + CheckKeys(); + + Reset(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); + SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. + SSLInt_RolloverAntiReplay(); + ExpectResumption(RESUME_TICKET); + ExpectEarlyDataAccepted(false); + StartConnect(); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { EnableAlpn(); SetupForZeroRtt(); @@ -117,6 +257,14 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { CheckAlpn("a"); } +// NOTE: In this test and those below, the client always sends +// post-ServerHello alerts with the handshake keys, even if the server +// has accepted 0-RTT. In some cases, as with errors in +// EncryptedExtensions, the client can't know the server's behavior, +// and in others it's just simpler. What the server is expecting +// depends on whether it accepted 0-RTT or not. Eventually, we may +// make the server trial decrypt. +// // Have the server negotiate a different ALPN value, and therefore // reject 0-RTT. TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) { @@ -155,12 +303,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) { client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b))); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); - ExpectAlert(client_, kTlsAlertIllegalParameter); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); return true; }); - Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); } // Set up with no ALPN and then set the client so it thinks it has ALPN. @@ -175,12 +328,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) { PRUint8 b[] = {'b'}; EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1)); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); - ExpectAlert(client_, kTlsAlertIllegalParameter); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); return true; }); - Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); } // Remove the old ALPN value and so the client will not offer early data. @@ -218,9 +376,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); - client_->StartConnect(); - server_->StartConnect(); - + StartConnect(); // We will send the early data xtn without sending actual early data. Thus // a 1.2 server shouldn't fail until the client sends an alert because the // client sends end_of_early_data only after reading the server's flight. @@ -248,6 +404,9 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { // The client should abort the connection when sending a 0-rtt handshake but // the servers responds with a TLS 1.2 ServerHello. (with app data) TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->Set0RttEnabled(true); // set ticket_allow_early_data Connect(); @@ -261,33 +420,32 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); - client_->StartConnect(); - server_->StartConnect(); - + StartConnect(); // Send the early data xtn in the CH, followed by early app data. The server // will fail right after sending its flight, when receiving the early data. client_->Set0RttEnabled(true); - ZeroRttSendReceive(true, false, [this]() { - client_->ExpectSendAlert(kTlsAlertIllegalParameter); - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - } - return true; - }); - - client_->Handshake(); - server_->Handshake(); - ASSERT_TRUE_WAIT( - (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000); + client_->Handshake(); // Send ClientHello. + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); - // DTLS will timeout as we bump the epoch when installing the early app data - // cipher suite. Thus the encrypted alert will be ignored. if (variant_ == ssl_variant_stream) { - // The server sends an alert when receiving the early app data record. - ASSERT_TRUE_WAIT( - (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA), - 2000); + // When the server receives the early data, it will fail. + server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); + } else { + // If it's datagram, we just discard the early data. + server_->Handshake(); // Consume ClientHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); } + + // The client now reads the ServerHello and fails. + ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA); } static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent, @@ -300,17 +458,19 @@ static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent, } TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { + EnsureTlsSetup(); const char* big_message = "0123456789abcdef"; const size_t short_size = strlen(big_message) - 1; const PRInt32 short_length = static_cast<PRInt32>(short_size); - SSLInt_SetMaxEarlyDataSize(static_cast<PRUint32>(short_size)); + EXPECT_EQ(SECSuccess, + SSL_SetMaxEarlyDataSize(server_->ssl_fd(), + static_cast<PRUint32>(short_size))); SetupForZeroRtt(); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); - ExpectAlert(client_, kTlsAlertEndOfEarlyData); client_->Handshake(); CheckEarlyDataLimit(client_, short_size); @@ -356,18 +516,21 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { } TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { + EnsureTlsSetup(); + const size_t limit = 5; - SSLInt_SetMaxEarlyDataSize(limit); + EXPECT_EQ(SECSuccess, SSL_SetMaxEarlyDataSize(server_->ssl_fd(), limit)); SetupForZeroRtt(); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); - client_->ExpectSendAlert(kTlsAlertEndOfEarlyData); client_->Handshake(); // Send ClientHello CheckEarlyDataLimit(client_, limit); + server_->Handshake(); // Process ClientHello, send server flight. + // Lift the limit on the client. EXPECT_EQ(SECSuccess, SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000)); @@ -381,22 +544,114 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { // This error isn't fatal for DTLS. ExpectAlert(server_, kTlsAlertUnexpectedMessage); } - server_->Handshake(); // Process ClientHello, send server flight. - server_->Handshake(); // Just to make sure that we don't read ahead. + + server_->Handshake(); // This reads the early data and maybe throws an error. + if (variant_ == ssl_variant_stream) { + server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + } else { + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + } CheckEarlyDataLimit(server_, limit); - // Attempt to read early data. + // Attempt to read early data. This will get an error. std::vector<uint8_t> buf(strlen(message) + 1); EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity())); if (variant_ == ssl_variant_stream) { - server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA); + EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PORT_GetError()); + } else { + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); } - client_->Handshake(); // Process the handshake. - client_->Handshake(); // Process the alert. + client_->Handshake(); // Process the server's first flight. if (variant_ == ssl_variant_stream) { + client_->Handshake(); // Process the alert. client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + } else { + server_->Handshake(); // Finish connecting. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); } } +class PacketCoalesceFilter : public PacketFilter { + public: + PacketCoalesceFilter() : packet_data_() {} + + void SendCoalesced(std::shared_ptr<TlsAgent> agent) { + agent->SendDirect(packet_data_); + } + + protected: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + packet_data_.Write(packet_data_.len(), input); + return DROP; + } + + private: + DataBuffer packet_data_; +}; + +TEST_P(TlsConnectTls13, ZeroRttOrdering) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + // Send out the ClientHello. + client_->Handshake(); + + // Now, coalesce the next three things from the client: early data, second + // flight and 1-RTT data. + auto coalesce = std::make_shared<PacketCoalesceFilter>(); + client_->SetFilter(coalesce); + + // Send (and hold) early data. + static const std::vector<uint8_t> early_data = {3, 2, 1}; + EXPECT_EQ(static_cast<PRInt32>(early_data.size()), + PR_Write(client_->ssl_fd(), early_data.data(), early_data.size())); + + // Send (and hold) the second client handshake flight. + // The client sends EndOfEarlyData after seeing the server Finished. + server_->Handshake(); + client_->Handshake(); + + // Send (and hold) 1-RTT data. + static const std::vector<uint8_t> late_data = {7, 8, 9, 10}; + EXPECT_EQ(static_cast<PRInt32>(late_data.size()), + PR_Write(client_->ssl_fd(), late_data.data(), late_data.size())); + + // Now release them all at once. + coalesce->SendCoalesced(client_); + + // Now ensure that the three steps are exposed in the right order on the + // server: delivery of early data, handshake callback, delivery of 1-RTT. + size_t step = 0; + server_->SetHandshakeCallback([&step](TlsAgent*) { + EXPECT_EQ(1U, step); + ++step; + }); + + std::vector<uint8_t> buf(10); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read); + buf.resize(read); + EXPECT_EQ(early_data, buf); + EXPECT_EQ(0U, step); + ++step; + + // The third read should be after the handshake callback and should return the + // data that was sent after the handshake completed. + buf.resize(10); + read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read); + buf.resize(read); + EXPECT_EQ(late_data, buf); + EXPECT_EQ(2U, step); +} + +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest, + TlsConnectTestBase::kTlsVariantsAll); +#endif + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc index 5035a338d..f0c57e8b1 100644 --- a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -31,7 +31,7 @@ const static uint8_t kCannedTls13ClientHello[] = { 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, - 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00, + 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00, 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc, 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65, 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a, @@ -44,13 +44,14 @@ const static uint8_t kCannedTls13ClientHello[] = { 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02}; const static uint8_t kCannedTls13ServerHello[] = { - 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0, - 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5, - 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01, - 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, - 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, - 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, - 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19}; + 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, + 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, + 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, + 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, + 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, + 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, + 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, + 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13}; static const char *k0RttData = "ABCDEF"; TEST_P(TlsAgentTest, EarlyFinished) { @@ -159,9 +160,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) { SSL_LIBRARY_VERSION_TLS_1_3); agent_->StartConnect(); agent_->Set0RttEnabled(true); - auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeClientHello); - agent_->SetPacketFilter(filter); + auto filter = + MakeTlsFilter<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello); PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData)); EXPECT_EQ(-1, rv); int32_t err = PORT_GetError(); diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc index dbcbc9aa3..7f2b2840d 100644 --- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -29,7 +29,25 @@ TEST_P(TlsConnectGeneric, ServerAuthBigRsa) { } TEST_P(TlsConnectGeneric, ServerAuthRsaChain) { - Reset(TlsAgent::kServerRsaChain); + Reset("rsa_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) { + Reset("rsa_pss_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) { + Reset("rsa_ca_rsa_pss_chain"); Connect(); CheckKeys(); size_t chain_length; @@ -77,10 +95,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) { } // Offset is the position in the captured buffer where the signature sits. -static void CheckSigScheme( - std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset, - std::shared_ptr<TlsAgent>& peer, uint16_t expected_scheme, - size_t expected_size) { +static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture, + size_t offset, std::shared_ptr<TlsAgent>& peer, + uint16_t expected_scheme, size_t expected_size) { EXPECT_LT(offset + 2U, capture->buffer().len()); uint32_t scheme = 0; @@ -96,9 +113,8 @@ static void CheckSigScheme( // in the default certificate. TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { EnsureTlsSetup(); - auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(capture_ske); + auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); @@ -109,15 +125,14 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve"; EXPECT_EQ(ssl_grp_ec_curve25519, tmp); EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint"; - CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_sha256, 1024); + CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_rsae_sha256, + 1024); } TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { EnsureTlsSetup(); - auto capture_cert_verify = - std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); Connect(); @@ -128,26 +143,23 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); - auto capture_cert_verify = - std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); Connect(); CheckKeys(); - CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256, + 2048); } class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { public: + TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeCertificateRequest}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeCertificateRequest) { - return KEEP; - } - TlsParser parser(input); std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl; @@ -189,12 +201,9 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { // supported_signature_algorithms in the CertificateRequest message. TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) { EnsureTlsSetup(); - auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>(); - server_->SetPacketFilter(filter); - auto capture_cert_verify = - std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeCertificateVerify); - client_->SetPacketFilter(capture_cert_verify); + MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_); + auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>( + client_, kTlsHandshakeCertificateVerify); client_->SetupClientAuth(); server_->RequestClientAuth(true); @@ -342,8 +351,7 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { // The signature_algorithms extension is mandatory in TLS 1.3. TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { - client_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); @@ -352,8 +360,7 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { // TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and // only fails when the Finished is checked. TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { - client_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -371,11 +378,11 @@ class BeforeFinished : public TlsRecordFilter { enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE }; public: - BeforeFinished(std::shared_ptr<TlsAgent>& client, - std::shared_ptr<TlsAgent>& server, VoidFunction before_ccs, - VoidFunction before_finished) - : client_(client), - server_(server), + BeforeFinished(const std::shared_ptr<TlsAgent>& server, + const std::shared_ptr<TlsAgent>& client, + VoidFunction before_ccs, VoidFunction before_finished) + : TlsRecordFilter(server), + client_(client), before_ccs_(before_ccs), before_finished_(before_finished), state_(BEFORE_CCS) {} @@ -395,7 +402,7 @@ class BeforeFinished : public TlsRecordFilter { // but that means that they both get processed together. DataBuffer ccs; header.Write(&ccs, 0, body); - server_.lock()->SendDirect(ccs); + agent()->SendDirect(ccs); client_.lock()->Handshake(); state_ = AFTER_CCS; // Request that the original record be dropped by the filter. @@ -420,7 +427,6 @@ class BeforeFinished : public TlsRecordFilter { private: std::weak_ptr<TlsAgent> client_; - std::weak_ptr<TlsAgent> server_; VoidFunction before_ccs_; VoidFunction before_finished_; HandshakeState state_; @@ -445,11 +451,11 @@ class BeforeFinished13 : public PacketFilter { }; public: - BeforeFinished13(std::shared_ptr<TlsAgent>& client, - std::shared_ptr<TlsAgent>& server, + BeforeFinished13(const std::shared_ptr<TlsAgent>& server, + const std::shared_ptr<TlsAgent>& client, VoidFunction before_finished) - : client_(client), - server_(server), + : server_(server), + client_(client), before_finished_(before_finished), records_(0) {} @@ -481,8 +487,8 @@ class BeforeFinished13 : public PacketFilter { } private: - std::weak_ptr<TlsAgent> client_; std::weak_ptr<TlsAgent> server_; + std::weak_ptr<TlsAgent> client_; VoidFunction before_finished_; size_t records_; }; @@ -496,11 +502,9 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { // processed by the client, SSL_AuthCertificateComplete() is called. TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter( - std::make_shared<BeforeFinished13>(client_, server_, [this]() { - EXPECT_EQ(SECSuccess, - SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); - })); + MakeTlsFilter<BeforeFinished13>(server_, client_, [this]() { + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + }); Connect(); } @@ -528,13 +532,13 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) { TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { client_->EnableFalseStart(); - server_->SetPacketFilter(std::make_shared<BeforeFinished>( - client_, server_, + MakeTlsFilter<BeforeFinished>( + server_, client_, [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); }, [this]() { // Write something, which used to fail: bug 1235366. client_->SendData(10); - })); + }); Connect(); server_->SendData(10); @@ -544,8 +548,8 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { client_->EnableFalseStart(); client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->SetPacketFilter(std::make_shared<BeforeFinished>( - client_, server_, + MakeTlsFilter<BeforeFinished>( + server_, client_, []() { // Do nothing before CCS }, @@ -556,7 +560,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); EXPECT_TRUE(client_->can_falsestart_hook_called()); client_->SendData(10); - })); + }); Connect(); server_->SendData(10); @@ -581,8 +585,7 @@ class EnforceNoActivity : public PacketFilter { TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send ServerHello client_->Handshake(); // Send ClientKeyExchange and Finished @@ -591,7 +594,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); // The client should send nothing from here on. - client_->SetPacketFilter(std::make_shared<EnforceNoActivity>()); + client_->SetFilter(std::make_shared<EnforceNoActivity>()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); @@ -601,8 +604,33 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); - // Remove this before closing or the close_notify alert will trigger it. - client_->DeletePacketFilter(); + // Remove filter before closing or the close_notify alert will trigger it. + client_->ClearFilter(); +} + +TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + client_->Handshake(); // Send ClientKeyExchange and Finished + server_->Handshake(); // Send Finished + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // The client should send nothing from here on. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + client_->ExpectSendAlert(kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // Fail + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); } // TLS 1.3 handles a delayed AuthComplete callback differently since the @@ -610,20 +638,19 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { TEST_P(TlsConnectTls13, AuthCompleteDelayed) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send ServerHello EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); // The client will send nothing until AuthCertificateComplete is called. - client_->SetPacketFilter(std::make_shared<EnforceNoActivity>()); + client_->SetFilter(std::make_shared<EnforceNoActivity>()); client_->Handshake(); EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); // This should allow the handshake to complete now. - client_->DeletePacketFilter(); + client_->ClearFilter(); EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); client_->Handshake(); // Send Finished server_->Handshake(); // Transition to connected and send NewSessionTicket @@ -631,6 +658,44 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) { EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); } +TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + + // The client will send nothing until AuthCertificateComplete is called. + client_->SetFilter(std::make_shared<EnforceNoActivity>()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // Report failure. + client_->ClearFilter(); + ExpectAlert(client_, kTlsAlertBadCertificate); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), + SSL_ERROR_BAD_CERTIFICATE)); + client_->Handshake(); // This should now fail. + server_->Handshake(); // Get the error. + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +static SECStatus AuthCompleteFail(TlsAgent*, PRBool, PRBool) { + PORT_SetError(SSL_ERROR_BAD_CERTIFICATE); + return SECFailure; +} + +TEST_P(TlsConnectGeneric, AuthFailImmediate) { + client_->SetAuthCertificateCallback(AuthCompleteFail); + + StartConnect(); + ConnectExpectAlert(client_, kTlsAlertBadCertificate); + client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE); +} + static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = { ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr}; static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = { @@ -753,8 +818,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) { TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) { Reset(certificate_); auto capture = - std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn); - client_->SetPacketFilter(capture); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); TestSignatureSchemeConfig(client_); const DataBuffer& ext = capture->extension(); @@ -782,8 +846,8 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(TlsAgent::kServerRsaSign), ::testing::Values(ssl_auth_rsa_sign), ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, - ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_sha256, - ssl_sig_rsa_pss_sha384))); + ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384))); // PSS with SHA-512 needs a bigger key to work. INSTANTIATE_TEST_CASE_P( SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, @@ -791,7 +855,7 @@ INSTANTIATE_TEST_CASE_P( TlsConnectTestBase::kTlsV12Plus, ::testing::Values(TlsAgent::kRsa2048), ::testing::Values(ssl_auth_rsa_sign), - ::testing::Values(ssl_sig_rsa_pss_sha512))); + ::testing::Values(ssl_sig_rsa_pss_rsae_sha512))); INSTANTIATE_TEST_CASE_P( SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, @@ -828,4 +892,4 @@ INSTANTIATE_TEST_CASE_P( TlsAgent::kServerEcdsa384), ::testing::Values(ssl_auth_ecdsa), ::testing::Values(ssl_sig_ecdsa_sha1))); -} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc index 3463782e0..573c69c75 100644 --- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -82,9 +82,8 @@ TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) { ssl_kea_rsa)); EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), &kSctItem, ssl_kea_rsa)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -96,9 +95,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) { EnsureTlsSetup(); EXPECT_TRUE( server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -120,9 +117,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) { TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -173,23 +168,20 @@ TEST_P(TlsConnectGeneric, OcspNotRequested) { // Even if the client asks, the server has nothing unless it is configured. TEST_P(TlsConnectGeneric, OcspNotProvided) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); client_->SetAuthCertificateCallback(CheckNoOCSP); Connect(); } TEST_P(TlsConnectGenericPre13, OcspMangled) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); EXPECT_TRUE( server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); static const uint8_t val[] = {1}; - auto replacer = std::make_shared<TlsExtensionReplacer>( - ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); - server_->SetPacketFilter(replacer); + auto replacer = MakeTlsFilter<TlsExtensionReplacer>( + server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val))); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); @@ -197,11 +189,9 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) { TEST_P(TlsConnectGeneric, OcspSuccess) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); auto capture_ocsp = - std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn); - server_->SetPacketFilter(capture_ocsp); + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_cert_status_xtn); // The value should be available during the AuthCertificateCallback client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig, @@ -225,8 +215,7 @@ TEST_P(TlsConnectGeneric, OcspSuccess) { TEST_P(TlsConnectGeneric, OcspHugeSuccess) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); uint8_t hugeOcspValue[16385]; memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue)); @@ -254,4 +243,4 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) { Connect(); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 85c30b2bf..fa2238be7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -31,11 +31,11 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { public: TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version, uint16_t cipher_suite, SSLNamedGroup group, - SSLSignatureScheme signature_scheme) + SSLSignatureScheme sig_scheme) : TlsConnectTestBase(variant, version), cipher_suite_(cipher_suite), group_(group), - signature_scheme_(signature_scheme), + sig_scheme_(sig_scheme), csinfo_({0}) { SECStatus rv = SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_)); @@ -60,26 +60,26 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { server_->ConfigNamedGroups(groups); kea_type_ = SSLInt_GetKEAType(group_); - client_->SetSignatureSchemes(&signature_scheme_, 1); - server_->SetSignatureSchemes(&signature_scheme_, 1); + client_->SetSignatureSchemes(&sig_scheme_, 1); + server_->SetSignatureSchemes(&sig_scheme_, 1); } } virtual void SetupCertificate() { if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - switch (signature_scheme_) { + switch (sig_scheme_) { case ssl_sig_rsa_pkcs1_sha256: case ssl_sig_rsa_pkcs1_sha384: case ssl_sig_rsa_pkcs1_sha512: Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; - case ssl_sig_rsa_pss_sha256: - case ssl_sig_rsa_pss_sha384: + case ssl_sig_rsa_pss_rsae_sha256: + case ssl_sig_rsa_pss_rsae_sha384: Reset(TlsAgent::kServerRsaSign); auth_type_ = ssl_auth_rsa_sign; break; - case ssl_sig_rsa_pss_sha512: + case ssl_sig_rsa_pss_rsae_sha512: // You can't fit SHA-512 PSS in a 1024-bit key. Reset(TlsAgent::kRsa2048); auth_type_ = ssl_auth_rsa_sign; @@ -93,8 +93,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { auth_type_ = ssl_auth_ecdsa; break; default: - ASSERT_TRUE(false) << "Unsupported signature scheme: " - << signature_scheme_; + ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_; break; } } else { @@ -187,7 +186,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { SSLAuthType auth_type_; SSLKEAType kea_type_; SSLNamedGroup group_; - SSLSignatureScheme signature_scheme_; + SSLSignatureScheme sig_scheme_; SSLCipherSuiteInfo csinfo_; }; @@ -236,27 +235,29 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) { ConnectAndCheckCipherSuite(); } -// This only works for stream ciphers because we modify the sequence number - -// which is included explicitly in the DTLS record header - and that trips a -// different error code. Note that the message that the client sends would not -// decrypt (the nonce/IV wouldn't match), but the record limit is hit before -// attempting to decrypt a record. TEST_P(TlsCipherSuiteTest, ReadLimit) { SetupCertificate(); EnableSingleCipher(); ConnectAndCheckCipherSuite(); - EXPECT_EQ(SECSuccess, - SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write())); - EXPECT_EQ(SECSuccess, - SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write())); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + uint64_t last = last_safe_write(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); - client_->SendData(10, 10); - server_->ReadBytes(); // This should be OK. + client_->SendData(10, 10); + server_->ReadBytes(); // This should be OK. + } else { + // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean + // that the sequence numbers would reset and we wouldn't hit the limit. So + // we move the sequence number to one less than the limit directly and don't + // test sending and receiving just before the limit. + uint64_t last = record_limit() - 1; + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); + } - // The payload needs to be big enough to pass for encrypted. In the extreme - // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for - // authentication tag. - static const uint8_t payload[18] = {6}; + // The payload needs to be big enough to pass for encrypted. The code checks + // the limit before it tries to decrypt. + static const uint8_t payload[32] = {6}; DataBuffer record; uint64_t epoch; if (variant_ == ssl_variant_datagram) { @@ -271,13 +272,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_, payload, sizeof(payload), &record, (epoch << 48) | record_limit()); - server_->adapter()->PacketReceived(record); + client_->SendDirect(record); server_->ExpectReadWriteError(); server_->ReadBytes(); EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code()); } TEST_P(TlsCipherSuiteTest, WriteLimit) { + // This asserts in TLS 1.3 because we expect an automatic update. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + return; + } SetupCertificate(); EnableSingleCipher(); ConnectAndCheckCipherSuite(); @@ -308,8 +313,8 @@ static const auto kDummySignatureSchemesParams = static SSLSignatureScheme kSignatureSchemesParamsArr[] = { ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256, - ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_sha256, - ssl_sig_rsa_pss_sha384, ssl_sig_rsa_pss_sha512, + ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_rsae_sha256, + ssl_sig_rsa_pss_rsae_sha384, ssl_sig_rsa_pss_rsae_sha512, }; #endif @@ -461,4 +466,4 @@ static const SecStatusParams kSecStatusTestValuesArr[] = { INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest, ::testing::ValuesIn(kSecStatusTestValuesArr)); -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc new file mode 100644 index 000000000..c2f582a93 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc @@ -0,0 +1,498 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" +#include "sslexp.h" + +#include <memory> + +#include "tls_connect.h" + +namespace nss_test { + +static void IncrementCounterArg(void *arg) { + if (arg) { + auto *called = reinterpret_cast<size_t *>(arg); + ++*called; + } +} + +PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_FALSE; +} + +PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_TRUE; +} + +SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECSuccess; +} + +// All of the (current) set of supported extensions, plus a few extra. +static const uint16_t kManyExtensions[] = { + ssl_server_name_xtn, + ssl_cert_status_xtn, + ssl_supported_groups_xtn, + ssl_ec_point_formats_xtn, + ssl_signature_algorithms_xtn, + ssl_signature_algorithms_cert_xtn, + ssl_use_srtp_xtn, + ssl_app_layer_protocol_xtn, + ssl_signed_cert_timestamp_xtn, + ssl_padding_xtn, + ssl_extended_master_secret_xtn, + ssl_session_ticket_xtn, + ssl_tls13_key_share_xtn, + ssl_tls13_pre_shared_key_xtn, + ssl_tls13_early_data_xtn, + ssl_tls13_supported_versions_xtn, + ssl_tls13_cookie_xtn, + ssl_tls13_psk_key_exchange_modes_xtn, + ssl_tls13_ticket_early_data_info_xtn, + ssl_tls13_certificate_authorities_xtn, + ssl_next_proto_nego_xtn, + ssl_renegotiation_info_xtn, + ssl_tls13_short_header_xtn, + 1, + 0xffff}; +// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS), +// plus the deprecated values (see sslt.h), and two extra dummy values. +PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions)); + +void InstallManyWriters(std::shared_ptr<TlsAgent> agent, + SSLExtensionWriter writer, size_t *installed = nullptr, + size_t *called = nullptr) { + for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) { + SSLExtensionSupport support = ssl_ext_none; + SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support); + ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail"; + + rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer, + called, NoopExtensionHandler, nullptr); + if (support == ssl_ext_native_only) { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + } else { + if (installed) { + ++*installed; + } + EXPECT_EQ(SECSuccess, rv); + } + } +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(client_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + EXPECT_EQ(installed, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(server_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + // Extension writers are all called for each of ServerHello, + // EncryptedExtensions, and Certificate. + EXPECT_EQ(installed * 3, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) { + EnsureTlsSetup(); + InstallManyWriters(client_, EmptyExtensionWriter); + InstallManyWriters(server_, EmptyExtensionWriter); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) { + EnsureTlsSetup(); + InstallManyWriters(server_, EmptyExtensionWriter); + // Sending extensions that the client doesn't expect leads to extensions + // appearing even if the client didn't send one, or in the wrong messages. + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Install an writer to disable sending of a natively-supported extension. +TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that doesn't do anything. You have to specify + // something; passing all nullptr values removes an existing handler. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_signed_cert_timestamp_xtn); + + Connect(); + // So nothing will be sent. + EXPECT_FALSE(capture->captured()); +} + +// An extension that is unlikely to be parsed as valid. +static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19}; + +static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_NE(nullptr, agent); + EXPECT_NE(nullptr, data); + EXPECT_NE(nullptr, len); + EXPECT_EQ(0U, *len); + EXPECT_LT(0U, maxLen); + EXPECT_EQ(agent->ssl_fd(), fd); + + if (message != ssl_hs_client_hello && message != ssl_hs_server_hello && + message != ssl_hs_encrypted_extensions) { + return PR_FALSE; + } + + *len = static_cast<unsigned int>(sizeof(kNonsenseExtension)); + EXPECT_GE(maxLen, *len); + if (maxLen < *len) { + return PR_FALSE; + } + PORT_Memcpy(data, kNonsenseExtension, *len); + return PR_TRUE; +} + +// Override the extension handler for an natively-supported and produce +// nonsense, which results in a handshake failure. +TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that sends nonsense. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter, + client_.get(), NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_signed_cert_timestamp_xtn); + + ConnectExpectAlert(server_, kTlsAlertDecodeError); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static SECStatus NonsenseExtensionHandler(PRFileDesc *fd, + SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, + void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_EQ(agent->ssl_fd(), fd); + if (agent->role() == TlsAgent::SERVER) { + EXPECT_EQ(ssl_hs_client_hello, message); + } else { + EXPECT_TRUE(message == ssl_hs_server_hello || + message == ssl_hs_encrypted_extensions); + } + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + DataBuffer(data, len)); + EXPECT_NE(nullptr, alert); + return SECSuccess; +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe5; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, extension_code); + + // Handle it so that the handshake completes. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + NonsenseExtensionHandler, server_.get()); + EXPECT_EQ(SECSuccess, rv); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_server_hello) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in ServerHello. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterSH, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the ServerHello only and check it. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + capture->SetHandshakeTypes({kTlsHandshakeServerHello}); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_encrypted_extensions) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in EncryptedExtensions. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterEE, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the EncryptedExtensions only and check it. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions}); + capture->EnableDecryption(); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) { + EnsureTlsSetup(); + + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code); + + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe7; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff58; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kTlsAlertHandshakeFailure); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +static const uint8_t kCustomAlert = 0xf6; + +SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + *alert = kCustomAlert; + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffea; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kCustomAlert); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5a; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kCustomAlert); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Configure a custom extension hook badly. +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter, + nullptr, nullptr, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) { + EnsureTlsSetup(); + // This doesn't actually overrun the buffer, but it says that it does. + auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) -> PRBool { + *len = maxLen + 1; + return PR_TRUE; + }; + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + client_->StartConnect(); + client_->Handshake(); + client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc index 69fd00331..b8836d7fc 100644 --- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -29,8 +29,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); std::cerr << "Damaging HS secret" << std::endl; @@ -51,23 +50,19 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - client_->ExpectSendAlert(kTlsAlertDecryptError); - // The server can't read the client's alert, so it also sends an alert. - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - server_->SetPacketFilter(std::make_shared<AfterRecordN>( + MakeTlsFilter<AfterRecordN>( server_, client_, 0, // ServerHello. - [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); })); - ConnectExpectFail(); + [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }); + ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } TEST_P(TlsConnectGenericPre13, DamageServerSignature) { EnsureTlsSetup(); - auto filter = - std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange); - server_->SetTlsRecordFilter(filter); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + server_, kTlsHandshakeServerKeyExchange); + filter->EnableDecryption(); ExpectAlert(client_, kTlsAlertDecryptError); ConnectExpectFail(); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); @@ -76,19 +71,10 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) { TEST_P(TlsConnectTls13, DamageServerSignature) { EnsureTlsSetup(); - auto filter = - std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify); - server_->SetTlsRecordFilter(filter); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + server_, kTlsHandshakeCertificateVerify); filter->EnableDecryption(); - client_->ExpectSendAlert(kTlsAlertDecryptError); - // The server can't read the client's alert, so it also sends an alert. - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - ConnectExpectFail(); - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - } else { - ConnectExpectFailOneSide(TlsAgent::CLIENT); - } + ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } @@ -96,15 +82,13 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); - auto filter = - std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify); - client_->SetTlsRecordFilter(filter); - server_->ExpectSendAlert(kTlsAlertDecryptError); + auto filter = MakeTlsFilter<TlsLastByteDamager>( + client_, kTlsHandshakeCertificateVerify); filter->EnableDecryption(); + server_->ExpectSendAlert(kTlsAlertDecryptError); // Do these handshakes by hand to avoid race condition on // the client processing the server's alert. - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); client_->Handshake(); @@ -116,4 +100,4 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index 97943303a..cdafa7a84 100644 --- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -24,7 +24,7 @@ TEST_P(TlsConnectGeneric, ConnectDhe) { EnableOnlyDheCiphers(); Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { @@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { client_->ConfigNamedGroups(kAllDHEGroups); auto groups_capture = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); Connect(); @@ -59,15 +59,14 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { TEST_P(TlsConnectGeneric, ConnectFfdheClient) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); auto groups_capture = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); auto shares_capture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture, shares_capture}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); Connect(); @@ -90,8 +89,7 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) { // because the client automatically sends the supported groups extension. TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { Connect(); @@ -105,14 +103,11 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { public: - TlsDheServerKeyExchangeDamager() {} + TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Damage the first octet of dh_p. Anything other than the known prime will // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled. *output = input; @@ -126,9 +121,8 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { // the signature until everything else has been checked. TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); - server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>()); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); + MakeTlsFilter<TlsDheServerKeyExchangeDamager>(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -147,7 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { kYZeroPad }; - TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {} + TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, ChangeYTo change) + : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {} protected: void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, @@ -212,8 +208,11 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { public: - TlsDheSkeChangeYServer(ChangeYTo change, bool modify) - : TlsDheSkeChangeY(change), modify_(modify), p_() {} + TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& agent, + ChangeYTo change, bool modify) + : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change), + modify_(modify), + p_() {} const DataBuffer& prime() const { return p_; } @@ -221,10 +220,6 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - size_t offset = 2; // Read dh_p uint32_t dh_len = 0; @@ -252,18 +247,15 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { public: TlsDheSkeChangeYClient( - ChangeYTo change, + const std::shared_ptr<TlsAgent>& agent, ChangeYTo change, std::shared_ptr<const TlsDheSkeChangeYServer> server_filter) - : TlsDheSkeChangeY(change), server_filter_(server_filter) {} + : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change), + server_filter_(server_filter) {} protected: virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { - return KEEP; - } - ChangeY(input, output, 0, server_filter_->prime()); return CHANGE; } @@ -289,12 +281,10 @@ class TlsDamageDHYTest TEST_P(TlsDamageDHYTest, DamageServerY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - server_->SetPacketFilter( - std::make_shared<TlsDheSkeChangeYServer>(change, true)); + MakeTlsFilter<TlsDheSkeChangeYServer>(server_, change, true); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(client_, kTlsAlertDecryptError); @@ -320,18 +310,15 @@ TEST_P(TlsDamageDHYTest, DamageServerY) { TEST_P(TlsDamageDHYTest, DamageClientY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } // The filter on the server is required to capture the prime. - auto server_filter = - std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false); - server_->SetPacketFilter(server_filter); + auto server_filter = MakeTlsFilter<TlsDheSkeChangeYServer>( + server_, TlsDheSkeChangeY::kYZero, false); // The client filter does the damage. TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); - client_->SetPacketFilter( - std::make_shared<TlsDheSkeChangeYClient>(change, server_filter)); + MakeTlsFilter<TlsDheSkeChangeYClient>(client_, change, server_filter); if (change == TlsDheSkeChangeY::kYZeroPad) { ExpectAlert(server_, kTlsAlertDecryptError); @@ -370,13 +357,12 @@ INSTANTIATE_TEST_CASE_P( class TlsDheSkeMakePEven : public TlsHandshakeFilter { public: + TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Find the end of dh_p uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -394,7 +380,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter { // Even without requiring named groups, an even value for p is bad news. TEST_P(TlsConnectGenericPre13, MakeDhePEven) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>()); + MakeTlsFilter<TlsDheSkeMakePEven>(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -404,13 +390,12 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) { class TlsDheSkeZeroPadP : public TlsHandshakeFilter { public: + TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {} + virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - *output = input; uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -425,7 +410,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter { // Zero padding only causes signature failure. TEST_P(TlsConnectGenericPre13, PadDheP) { EnableOnlyDheCiphers(); - server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>()); + MakeTlsFilter<TlsDheSkeZeroPadP>(server_); ConnectExpectAlert(client_, kTlsAlertDecryptError); @@ -445,8 +430,7 @@ TEST_P(TlsConnectGenericPre13, PadDheP) { // Note: This test case can take ages to generate the weak DH key. TEST_P(TlsConnectGenericPre13, WeakDHGroup) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); @@ -474,7 +458,7 @@ TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) { Connect(); CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // Same test but for TLS 1.3. This has to fail. @@ -496,8 +480,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) { // custom group in contrast to the previous test. TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048}; @@ -519,14 +502,13 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { Connect(); client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); - client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); - server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, MismatchDHE) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group}; EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups, PR_ARRAY_SIZE(serverGroups))); @@ -544,37 +526,37 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); // Need to read so that we absorb the session ticket. - CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); EnableOnlyDheCiphers(); auto clientCapture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(clientCapture); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); auto serverCapture = - std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - server_->SetPacketFilter(serverCapture); + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_pre_shared_key_xtn); ExpectResumption(RESUME_TICKET); Connect(); - CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); ASSERT_LT(0UL, clientCapture->extension().len()); ASSERT_LT(0UL, serverCapture->extension().len()); } class TlsDheSkeChangeSignature : public TlsHandshakeFilter { public: - TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len) - : version_(version), data_(data), len_(len) {} + TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& agent, + uint16_t version, const uint8_t* data, size_t len) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}), + version_(version), + data_(data), + len_(len) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - TlsParser parser(input); EXPECT_TRUE(parser.SkipVariable(2)); // dh_p EXPECT_TRUE(parser.SkipVariable(2)); // dh_g @@ -615,8 +597,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048}; client_->ConfigNamedGroups(client_groups); - server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>( - version_, kBogusDheSignature, sizeof(kBogusDheSignature))); + MakeTlsFilter<TlsDheSkeChangeSignature>(server_, version_, kBogusDheSignature, + sizeof(kBogusDheSignature)); ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc index 3cc3b0e62..ee8906deb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -6,6 +6,7 @@ #include "secerr.h" #include "ssl.h" +#include "sslexp.h" extern "C" { // This is not something that should make you happy. @@ -20,14 +21,14 @@ extern "C" { namespace nss_test { -TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) { - client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1)); +TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1)); Connect(); SendReceive(); } -TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { - server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1)); +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1)); Connect(); SendReceive(); } @@ -35,36 +36,770 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { // This drops the first transmission from both the client and server of all // flights that they send. Note: In DTLS 1.3, the shorter handshake means that // this will also drop some application data, so we can't call SendReceive(). -TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) { - client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15)); - server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5)); +TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15)); + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5)); Connect(); } // This drops the server's first flight three times. -TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) { - server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7)); +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7)); Connect(); } // This drops the client's second flight once -TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) { - client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2)); +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2)); Connect(); } // This drops the client's second flight three times. -TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) { - client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe)); +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) { + client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe)); Connect(); } // This drops the server's second flight three times. -TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) { - server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe)); +TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe)); Connect(); } +class TlsDropDatagram13 : public TlsConnectDatagram13 { + public: + TlsDropDatagram13() + : client_filters_(), + server_filters_(), + expected_client_acks_(0), + expected_server_acks_(1) {} + + void SetUp() override { + TlsConnectDatagram13::SetUp(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + SetFilters(); + } + + void SetFilters() { + EnsureTlsSetup(); + client_filters_.Init(client_); + server_filters_.Init(server_); + } + + void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) { + agent->Handshake(); // Read flight. + ShiftDtlsTimers(); + agent->Handshake(); // Generate ACK. + } + + void ShrinkPostServerHelloMtu() { + // Abuse the custom extension mechanism to modify the MTU so that the + // Certificate message is split into two pieces. + ASSERT_EQ( + SECSuccess, + SSL_InstallExtensionHooks( + server_->ssl_fd(), 1, + [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data, + unsigned int* len, unsigned int maxLen, void* arg) -> PRBool { + SSLInt_SetMTU(fd, 500); // Splits the certificate. + return PR_FALSE; + }, + nullptr, + [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data, + unsigned int len, SSLAlertDescription* alert, + void* arg) -> SECStatus { return SECSuccess; }, + nullptr)); + } + + protected: + class DropAckChain { + public: + DropAckChain() + : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {} + + void Init(const std::shared_ptr<TlsAgent>& agent) { + records_ = std::make_shared<TlsRecordRecorder>(agent); + ack_ = std::make_shared<TlsRecordRecorder>(agent, content_ack); + ack_->EnableDecryption(); + drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false); + chain_ = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, ack_, drop_})); + agent->SetFilter(chain_); + } + + const TlsRecord& record(size_t i) const { return records_->record(i); } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsRecordRecorder> ack_; + std::shared_ptr<SelectiveRecordDropFilter> drop_; + std::shared_ptr<PacketFilter> chain_; + }; + + void CheckAcks(const DropAckChain& chain, size_t index, + std::vector<uint64_t> acks) { + const DataBuffer& buf = chain.ack_->record(index).buffer; + size_t offset = 0; + + EXPECT_EQ(acks.size() * 8, buf.len()); + if ((acks.size() * 8) != buf.len()) { + while (offset < buf.len()) { + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl; + } + return; + } + + for (size_t i = 0; i < acks.size(); ++i) { + uint64_t a = acks[i]; + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + if (a != ack) { + ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a + << " got=0x" << ack << std::dec; + } + } + } + + void CheckedHandshakeSendReceive() { + Handshake(); + CheckPostHandshake(); + } + + void CheckPostHandshake() { + CheckConnected(); + SendReceive(); + EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count()); + EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count()); + } + + protected: + DropAckChain client_filters_; + DropAckChain server_filters_; + size_t expected_client_acks_; + size_t expected_server_acks_; +}; + +// All of these tests produce a minimum one ACK, from the server +// to the client upon receiving the client Finished. +// Dropping complete first and second flights does not produce +// ACKs +TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) { + client_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) { + server_filters_.drop_->Reset(0xff); + StartConnect(); + client_->Handshake(); + // Send the first flight, all dropped. + server_->Handshake(); + server_filters_.drop_->Disable(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Dropping the server's first record also does not produce +// an ACK because the next record is ignored. +// TODO(ekr@rtfm.com): We should generate an empty ACK. +TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) { + server_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Dropping the second packet of the server's flight should +// produce an ACK. +TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + HandshakeAndAck(client_); + expected_client_acks_ = 1; + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, {0}); // ServerHello + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Drop the server ACK and verify that the client retransmits +// the ClientHello. +TEST_F(TlsDropDatagram13, DropServerAckOnce) { + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // At this point the server has sent it's first flight, + // so make it drop the ACK. + server_filters_.drop_->Reset({0}); + client_->Handshake(); // Send the client Finished. + server_->Handshake(); // Receive the Finished and send the ACK. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Wait for the DTLS timeout to make sure we retransmit the + // Finished. + ShiftDtlsTimers(); + client_->Handshake(); // Retransmit the Finished. + server_->Handshake(); // Read the Finished and send an ACK. + uint8_t buf[1]; + PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + expected_server_acks_ = 2; + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + CheckPostHandshake(); + // There should be two copies of the finished ACK + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); +} + +// Drop the client certificate verify. +TEST_F(TlsDropDatagram13, DropClientCertVerify) { + StartConnect(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->Handshake(); + server_->Handshake(); + // Have the client drop Cert Verify + client_filters_.drop_->Reset({1}); + expected_server_acks_ = 2; + CheckedHandshakeSendReceive(); + // Ack of the Cert. + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + // Ack of the whole client handshake. + CheckAcks( + server_filters_, 1, + {0x0002000000000000ULL, // CH (we drop everything after this on client) + 0x0002000000000003ULL, // CT (2) + 0x0002000000000004ULL}); // FIN (2) +} + +// Shrink the MTU down so that certs get split and drop the first piece. +TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { + server_filters_.drop_->Reset({2}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(2).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN + // Check that the first record is CT1 (which is identical to the same + // as the previous CT1). + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + {0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000002ULL}); // CT2 + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and drop the second piece. +TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { + server_filters_.drop_->Reset({3}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(3).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN + // Check that the first record is CT1 + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + { + 0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000001ULL, // CT1 + }); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// In this test, the Certificate message is sent four times, we drop all or part +// of the first three attempts: +// 1. Without fragmentation so that we can see how big it is - we drop that. +// 2. In two pieces - we drop half AND the resulting ACK. +// 3. In three pieces - we drop the middle piece. +// +// After that we let all the ACKs through and allow the handshake to complete +// without further interference. +// +// This allows us to test that ranges of handshake messages are sent correctly +// even when there are overlapping acknowledgments; that ACKs with duplicate or +// overlapping message ranges are handled properly; and that extra +// retransmissions are handled properly. +class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 { + public: + TlsFragmentationAndRecoveryTest() : cert_len_(0) {} + + protected: + void RunTest(size_t dropped_half) { + FirstFlightDropCertificate(); + + SecondAttemptDropHalf(dropped_half); + size_t dropped_half_size = server_record_len(dropped_half); + size_t second_flight_count = server_filters_.records_->count(); + + ThirdAttemptDropMiddle(); + size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2); + size_t third_flight_count = server_filters_.records_->count(); + + AckAndCompleteRetransmission(); + size_t final_server_flight_count = server_filters_.records_->count(); + EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + SendDelayedAck(); + // Same number of messages as the last flight. + EXPECT_EQ(final_server_flight_count, server_filters_.records_->count()); + // Double check that the Certificate size is still correct. + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + CompleteHandshake(final_server_flight_count); + + // This is the ACK for the first attempt to send a whole certificate. + std::vector<uint64_t> client_acks = { + 0, // SH + 0x0002000000000000ULL // EE + }; + CheckAcks(client_filters_, 0, client_acks); + // And from the second attempt for the half was kept (we delayed this ACK). + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + ~dropped_half % 2); + CheckAcks(client_filters_, 1, client_acks); + // And the third attempt where the first and last thirds got through. + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count - 1); + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count + 1); + CheckAcks(client_filters_, 2, client_acks); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + } + + private: + void FirstFlightDropCertificate() { + StartConnect(); + client_->Handshake(); + + // Note: 1 << N is the Nth packet, starting from zero. + server_filters_.drop_->Reset(1 << 2); // Drop Cert0. + server_->Handshake(); + EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin + cert_len_ = server_filters_.records_->record(2).buffer.len(); + + HandshakeAndAck(client_); + EXPECT_EQ(2U, client_filters_.records_->count()); + } + + // Lower the MTU so that the server has to split the certificate in two + // pieces. The server resends Certificate (in two), plus CV and Fin. + void SecondAttemptDropHalf(size_t dropped_half) { + ASSERT_LE(0U, dropped_half); + ASSERT_GT(2U, dropped_half); + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half] + SplitServerMtu(2); + server_->Handshake(); + EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin + + // Generate and capture the ACK from the client. + client_filters_.drop_->Reset({0}); + HandshakeAndAck(client_); + EXPECT_EQ(3U, client_filters_.records_->count()); + } + + // Lower the MTU again so that the server sends Certificate cut into three + // pieces. Drop the middle piece. + void ThirdAttemptDropMiddle() { + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3) + SplitServerMtu(3); + // Because we dropped the client ACK, the server retransmits on a timer. + ShiftDtlsTimers(); + server_->Handshake(); + EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin + } + + void AckAndCompleteRetransmission() { + // Generate ACKs. + HandshakeAndAck(client_); + // The server should send the final sixth of the certificate: the client has + // acknowledged the first half and the last third. Also send CV and Fin. + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) { + // Work out if the final sixth is the right size. We get the records with + // overheads added, which obscures the length of the payload. We want to + // ensure that the server only sent the missing sixth of the Certificate. + // + // We captured |size_of_half + overhead| and |size_of_third + overhead| and + // want to calculate |size_of_third - size_of_third + overhead|. We can't + // calculate |overhead|, but it is is (currently) always a handshake message + // header, a content type, and an authentication tag: + static const size_t record_overhead = 12 + 1 + 16; + EXPECT_EQ(size_of_half - size_of_third + record_overhead, + server_filters_.records_->record(0).buffer.len()); + } + + void SendDelayedAck() { + // Send the ACK we held back. The reordered ACK doesn't add new + // information, + // but triggers an extra retransmission of the missing records again (even + // though the client has all that it needs). + client_->SendRecordDirect(client_filters_.records_->record(2)); + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CompleteHandshake(size_t extra_retransmissions) { + // All this messing around shouldn't cause a failure... + Handshake(); + // ...but it leaves a mess. Add an extra few calls to Handshake() for the + // client so that it absorbs the extra retransmissions. + for (size_t i = 0; i < extra_retransmissions; ++i) { + client_->Handshake(); + } + CheckConnected(); + } + + // Split the server MTU so that the Certificate is split into |count| pieces. + // The calculation doesn't need to be perfect as long as the Certificate + // message is split into the right number of pieces. + void SplitServerMtu(size_t count) { + // Set the MTU based on the formula: + // bare_size = cert_len_ - actual_overhead + // MTU = ceil(bare_size / count) + pessimistic_overhead + // + // actual_overhead is the amount of actual overhead on the record we + // captured, which is (note that our length doesn't include the header): + static const size_t actual_overhead = 12 + // handshake message header + 1 + // content type + 16; // authentication tag + size_t bare_size = cert_len_ - actual_overhead; + + // pessimistic_overhead is the amount of expansion that NSS assumes will be + // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT: + static const size_t pessimistic_overhead = + 12 + // handshake message header + 1 + // content type + 13 + // record header length + 64; // maximum record expansion: IV, MAC and block cipher expansion + + size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead; + if (g_ssl_gtest_verbose) { + std::cerr << "server: set MTU to " << mtu << std::endl; + } + EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu)); + } + + size_t server_record_len(size_t index) const { + return server_filters_.records_->record(index).buffer.len(); + } + + size_t cert_len_; +}; + +TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); } + +TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); } + +TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + CheckAcks(server_filters_, 0, + {0x0001000000000001ULL, // EOED + 0x0002000000000000ULL}); // Finished +} + +TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + server_filters_.drop_->Reset({1}); + ZeroRttSendReceive(true, true); + HandshakeAndAck(client_); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + CheckAcks(client_filters_, 0, {0}); + CheckAcks(server_filters_, 0, + {0x0001000000000002ULL, // EOED + 0x0002000000000000ULL}); // Finished +} + +class TlsReorderDatagram13 : public TlsDropDatagram13 { + public: + TlsReorderDatagram13() {} + + // Send records from the records buffer in the given order. + void ReSend(TlsAgent::Role side, std::vector<size_t> indices) { + std::shared_ptr<TlsAgent> agent; + std::shared_ptr<TlsRecordRecorder> records; + + if (side == TlsAgent::CLIENT) { + agent = client_; + records = client_filters_.records_; + } else { + agent = server_; + records = server_filters_.records_; + } + + for (auto i : indices) { + agent->SendRecordDirect(records->record(i)); + } + } +}; + +// Reorder the server records so that EE comes at the end +// of the flight and will still produce an ACK. +TEST_F(TlsDropDatagram13, ReorderServerEE) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // We dropped EE, now reinject. + server_->SendRecordDirect(server_filters_.record(1)); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + { + 0, // SH + 0x0002000000000000, // EE + }); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// The client sends an out of order non-handshake message +// but with the handshake key. +class TlsSendCipherSpecCapturer { + public: + TlsSendCipherSpecCapturer(std::shared_ptr<TlsAgent>& agent) + : send_cipher_specs_() { + SSLInt_SetCipherSpecChangeFunc(agent->ssl_fd(), CipherSpecChanged, + (void*)this); + } + + std::shared_ptr<TlsCipherSpec> spec(size_t i) { + if (i >= send_cipher_specs_.size()) { + return nullptr; + } + return send_cipher_specs_[i]; + } + + private: + static void CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec) { + if (!sending) { + return; + } + + auto self = static_cast<TlsSendCipherSpecCapturer*>(arg); + + auto spec = std::make_shared<TlsCipherSpec>(); + bool ret = spec->Init(SSLInt_CipherSpecToEpoch(newSpec), + SSLInt_CipherSpecToAlgorithm(newSpec), + SSLInt_CipherSpecToKey(newSpec), + SSLInt_CipherSpecToIv(newSpec)); + EXPECT_EQ(true, ret); + self->send_cipher_specs_.push_back(spec); + } + + std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_; +}; + +TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // After the client sends Finished, inject an app data record + // with the handshake key. This should produce an alert. + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + ASSERT_TRUE(client_->SendEncryptedRecord( + spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, + kTlsApplicationDataType, DataBuffer(buf, sizeof(buf)))); + + // Now have the server consume the bogus message. + server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError()); +} + +TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Inject a new bogus handshake record, which the server responds + // to by just ACKing the original one (we ignore the contents). + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + ASSERT_TRUE(client_->SendEncryptedRecord( + spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, + kTlsHandshakeType, DataBuffer(buf, sizeof(buf)))); + server_->Handshake(); + EXPECT_EQ(2UL, server_filters_.ack_->count()); + // The server acknowledges client Finished twice. + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and then swap the first and +// second pieces of the server certificate. +TEST_F(TlsReorderDatagram13, ReorderServerCertificate) { + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + // Drop the entire handshake flight so we can reorder. + server_filters_.drop_->Reset(0xff); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN + // Now re-send things in a different order. + ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5}); + // Clear. + server_filters_.drop_->Disable(); + server_filters_.records_->Clear(); + // Wait for client to send ACK. + ShiftDtlsTimers(); + CheckedHandshakeSendReceive(); + EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, data, FIN + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2})); + server_->Handshake(); + CheckConnected(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, FIN, Data + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0})); + server_->Handshake(); + CheckConnected(); + EXPECT_EQ(0U, client_filters_.ack_->count()); + // Acknowledgements for EOED and Finished. + CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, uint64_t* limit = nullptr) { uint64_t l; @@ -111,7 +846,6 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindow) { GetCipherAndLimit(version_, &cipher); server_->EnableSingleCipher(cipher); Connect(); - EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0)); SendReceive(); } @@ -129,5 +863,7 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) { INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus, TlsConnectTestBase::kTlsV12Plus); +INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13, + TlsConnectTestBase::kTlsV11V12); } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc index 1e406b6c2..3c7cd2ecf 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -69,20 +69,19 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { server_->ConfigNamedGroups(groups); Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } // This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care. TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) { EnsureTlsSetup(); - auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(hrr_capture); + auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; server_->ConfigNamedGroups(groups); Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3, hrr_capture->buffer().len() != 0); } @@ -112,7 +111,7 @@ TEST_P(TlsKeyExchangeTest, P384Priority) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; CheckKEXDetails(groups, shares); @@ -129,7 +128,7 @@ TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1, @@ -147,7 +146,7 @@ TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; @@ -172,7 +171,7 @@ TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { @@ -188,12 +187,14 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { public: - TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {} + TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &agent) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}), + group_(ssl_grp_none) {} SSLNamedGroup group() const { return group_; } @@ -201,10 +202,6 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - uint32_t value = 0; EXPECT_TRUE(input.Read(0, 1, &value)); EXPECT_EQ(3U, value) << "curve type has to be 3"; @@ -223,10 +220,8 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { // P-256 is supported by the client (<= 1.2 only). TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn)); - auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>(); - server_->SetPacketFilter(group_capture); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn); + auto group_capture = MakeTlsFilter<TlsKeyExchangeGroupCapture>(server_); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); @@ -238,8 +233,7 @@ TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { // Supported groups is mandatory in TLS 1.3. TEST_P(TlsConnectTls13, DropSupportedGroupExtension) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); @@ -278,7 +272,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); CheckConnected(); // The renegotiation has to use the same preferences as the original session. @@ -286,7 +280,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { client_->StartRenegotiate(); Handshake(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } TEST_P(TlsKeyExchangeTest, Curve25519) { @@ -320,7 +314,7 @@ TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); } #ifndef NSS_DISABLE_TLS_1_3 @@ -339,7 +333,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1}; CheckKEXDetails(client_groups, shares); } @@ -359,7 +353,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares); } @@ -381,7 +375,7 @@ TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -403,7 +397,7 @@ TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) { Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -425,7 +419,7 @@ TEST_P(TlsKeyExchangeTest13, Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -447,7 +441,7 @@ TEST_P(TlsKeyExchangeTest13, Connect(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); } @@ -509,7 +503,7 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { // The server would accept 25519 but its preferred group (P256) has to win. CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, - ssl_sig_rsa_pss_sha256); + ssl_sig_rsa_pss_rsae_sha256); const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1}; CheckKEXDetails(client_groups, shares); @@ -518,16 +512,13 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { // Replace the point in the client key exchange message with an empty one class ECCClientKEXFilter : public TlsHandshakeFilter { public: - ECCClientKEXFilter() {} + ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { - return KEEP; - } - // Replace the client key exchange message with an empty point output->Allocate(1); output->Write(0, 0U, 1); // set point length 0 @@ -538,20 +529,17 @@ class ECCClientKEXFilter : public TlsHandshakeFilter { // Replace the point in the server key exchange message with an empty one class ECCServerKEXFilter : public TlsHandshakeFilter { public: - ECCServerKEXFilter() {} + ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server) + : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Replace the server key exchange message with an empty point output->Allocate(4); output->Write(0, 3U, 1); // named curve - uint32_t curve; + uint32_t curve = 0; EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id output->Write(1, curve, 2); // write curve id output->Write(3, 0U, 1); // point length 0 @@ -560,15 +548,13 @@ class ECCServerKEXFilter : public TlsHandshakeFilter { }; TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) { - // add packet filter - server_->SetPacketFilter(std::make_shared<ECCServerKEXFilter>()); + MakeTlsFilter<ECCServerKEXFilter>(server_); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH); } TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { - // add packet filter - client_->SetPacketFilter(std::make_shared<ECCClientKEXFilter>()); + MakeTlsFilter<ECCClientKEXFilter>(client_); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH); } diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc index be407b42e..c42883eb7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc @@ -118,7 +118,6 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, TEST_P(TlsConnectTls13, EarlyExporter) { SetupForZeroRtt(); - ExpectAlert(client_, kTlsAlertEndOfEarlyData); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index d15139419..0453dabdb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -19,8 +19,9 @@ namespace nss_test { class TlsExtensionTruncator : public TlsExtensionFilter { public: - TlsExtensionTruncator(uint16_t extension, size_t length) - : extension_(extension), length_(length) {} + TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t length) + : TlsExtensionFilter(agent), extension_(extension), length_(length) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter { class TlsExtensionDamager : public TlsExtensionFilter { public: - TlsExtensionDamager(uint16_t extension, size_t index) - : extension_(extension), index_(index) {} + TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, size_t index) + : TlsExtensionFilter(agent), extension_(extension), index_(index) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { @@ -61,60 +63,17 @@ class TlsExtensionDamager : public TlsExtensionFilter { size_t index_; }; -class TlsExtensionInjector : public TlsHandshakeFilter { - public: - TlsExtensionInjector(uint16_t ext, DataBuffer& data) - : extension_(ext), data_(data) {} - - virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindExtensions(&parser, header)) { - return KEEP; - } - size_t offset = parser.consumed(); - - *output = input; - - // Increase the size of the extensions. - uint16_t ext_len; - memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); - ext_len = htons(ntohs(ext_len) + data_.len() + 4); - memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); - - // Insert the extension type and length. - DataBuffer type_length; - type_length.Allocate(4); - type_length.Write(0, extension_, 2); - type_length.Write(2, data_.len(), 2); - output->Splice(type_length, offset + 2); - - // Insert the payload. - if (data_.len() > 0) { - output->Splice(data_, offset + 6); - } - - return CHANGE; - } - - private: - const uint16_t extension_; - const DataBuffer data_; -}; - class TlsExtensionAppender : public TlsHandshakeFilter { public: - TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : handshake_type_(handshake_type), extension_(ext), data_(data) {} + TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, uint16_t ext, DataBuffer& data) + : TlsHandshakeFilter(agent, {handshake_type}), + extension_(ext), + data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != handshake_type_) { - return KEEP; - } - TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; @@ -159,7 +118,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter { return true; } - const uint8_t handshake_type_; const uint16_t extension_; const DataBuffer data_; }; @@ -171,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase { void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - client_->SetPacketFilter(filter); + client_->SetFilter(filter); ConnectExpectAlert(server_, desc); } void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter, uint8_t desc = kTlsAlertDecodeError) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, desc); } @@ -200,11 +158,10 @@ class TlsExtensionTestBase : public TlsConnectTestBase { client_->ConfigNamedGroups(client_groups); server_->ConfigNamedGroups(server_groups); EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type)); + MakeTlsFilter<TlsExtensionDropper>(client_, type); Handshake(); client_->CheckErrorCode(client_error); server_->CheckErrorCode(server_error); @@ -245,8 +202,8 @@ class TlsExtensionTest13 void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { DataBuffer versions_buf(buf, len); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -257,8 +214,8 @@ class TlsExtensionTest13 size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_supported_versions_xtn, versions_buf)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_supported_versions_xtn, versions_buf); ConnectExpectFail(); } }; @@ -289,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase, TEST_P(TlsExtensionTestGeneric, DamageSniLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4)); + std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4)); } TEST_P(TlsExtensionTestGeneric, TruncateSni) { ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7)); } // A valid extension that appears twice will be reported as unsupported. TEST_P(TlsExtensionTestGeneric, RepeatSni) { DataBuffer extension; InitSimpleSni(&extension); - ClientHelloErrorTest( - std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); + ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>( + client_, ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); } // An SNI entry with zero length is considered invalid (strangely, not if it is @@ -320,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) { extension.Allocate(simple.len() + 3); extension.Write(0, static_cast<uint32_t>(0), 3); extension.Write(3, simple); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptySni) { DataBuffer extension; extension.Allocate(2); extension.Write(0, static_cast<uint32_t>(0), 2); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_server_name_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { EnableAlpn(); DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } @@ -347,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + client_, ssl_app_layer_protocol_xtn, extension), kTlsAlertNoApplicationProtocol); } TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { EnableAlpn(); - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 1)); } TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { EnableAlpn(); // This will leave the length of the second entry, but no value. - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5)); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_app_layer_protocol_xtn, 5)); } TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { @@ -369,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { const uint8_t val[] = {0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + client_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { @@ -388,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { @@ -396,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { @@ -404,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { @@ -412,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { @@ -420,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { @@ -428,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension)); + server_, ssl_app_layer_protocol_xtn, extension)); } TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { @@ -436,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) { const uint8_t val[] = {0x00, 0x02, 0x01, 0x67}; DataBuffer extension(val, sizeof(val)); ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_app_layer_protocol_xtn, extension), + server_, ssl_app_layer_protocol_xtn, extension), kTlsAlertIllegalParameter); } TEST_P(TlsExtensionTestDtls, SrtpShort) { EnableSrtp(); ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3)); + std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3)); } TEST_P(TlsExtensionTestDtls, SrtpOdd) { EnableSrtp(); const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest( - std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_use_srtp_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { const uint8_t val[] = {0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) { + const uint8_t val[] = {0x00, 0x02, 0xff, 0xff}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( + client_, ssl_signature_algorithms_xtn, extension), + kTlsAlertHandshakeFailure); } TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { const uint8_t val[] = {0x00, 0x01, 0x04}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_signature_algorithms_xtn, extension)); + client_, ssl_signature_algorithms_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { ClientHelloErrorTest( - std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn), + std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn), version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError : kTlsAlertMissingExtension); } @@ -493,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { const uint8_t val[] = {0x00, 0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_elliptic_curves_xtn, extension)); + client_, ssl_elliptic_curves_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { const uint8_t val[] = {0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { const uint8_t val[] = {0x99, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { const uint8_t val[] = {0x01, 0x00, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_ec_point_formats_xtn, extension)); + client_, ssl_ec_point_formats_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { const uint8_t val[] = {0x99}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { const uint8_t val[] = {0x01, 0x00}; DataBuffer extension(val, sizeof(val)); ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // The extension has to contain a length. TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { DataBuffer extension; ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>( - ssl_renegotiation_info_xtn, extension)); + client_, ssl_renegotiation_info_xtn, extension)); } // This only works on TLS 1.2, since it relies on static RSA; otherwise libssl // picks the wrong cipher suite. TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { - const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, - ssl_sig_rsa_pss_sha384}; + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512, + ssl_sig_rsa_pss_rsae_sha384}; auto capture = - std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn); client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); - client_->SetPacketFilter(capture); EnableOnlyStaticRsaCiphers(); Connect(); @@ -579,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { // Temporary test to verify that we choke on an empty ClientKeyShare. // This test will fail when we implement HelloRetryRequest. TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { - ClientHelloErrorTest( - std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2), - kTlsAlertHandshakeFailure); + ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>( + client_, ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); } // These tests only work in stream mode because the client sends a @@ -590,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { // packet gets dropped. TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn)); + MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -611,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertIllegalParameter); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -633,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { 0x02}; DataBuffer buf(key_share, sizeof(key_share)); EnsureTlsSetup(); - server_->SetPacketFilter( - std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf)); + MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf); client_->ExpectSendAlert(kTlsAlertMissingExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -645,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { SetupForResume(); DataBuffer empty; - server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>( - ssl_signature_algorithms_xtn, empty)); + MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn, + empty); client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -666,8 +628,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)> class TlsPreSharedKeyReplacer : public TlsExtensionFilter { public: - TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function) - : identities_(), binders_(), function_(function) {} + TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent, + TlsPreSharedKeyReplacerFunc function) + : TlsExtensionFilter(agent), + identities_(), + binders_(), + function_(function) {} static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, const std::unique_ptr<DataBuffer>& replace, @@ -781,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter { TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->identities_[0].identity.Truncate(0); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -792,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -805,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -818,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>( - [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -830,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); r->binders_.push_back(r->binders_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -845,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { SetupForResume(); - client_->SetPacketFilter( - std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) { + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { r->identities_.push_back(r->identities_[0]); - })); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -857,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { SetupForResume(); - client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([]( - TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_.push_back(r->binders_[0]); + }); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -870,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { const uint8_t empty_buf[] = {0}; DataBuffer empty(empty_buf, 0); // Inject an unused extension after the PSK extension. - client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>( - kTlsHandshakeClientHello, 0xffff, empty)); + MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff, + empty); ConnectExpectAlert(server_, kTlsAlertIllegalParameter); client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); @@ -881,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { SetupForResume(); DataBuffer empty; - client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>( - ssl_tls13_psk_key_exchange_modes_xtn)); + MakeTlsFilter<TlsExtensionDropper>(client_, + ssl_tls13_psk_key_exchange_modes_xtn); ConnectExpectAlert(server_, kTlsAlertMissingExtension); client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); @@ -897,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { kTls13PskKe}; DataBuffer modes(ke_modes, sizeof(ke_modes)); - client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>( - ssl_tls13_psk_key_exchange_modes_xtn, modes)); + MakeTlsFilter<TlsExtensionReplacer>( + client_, ssl_tls13_psk_key_exchange_modes_xtn, modes); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); @@ -908,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - auto capture = std::make_shared<TlsExtensionCapture>( - ssl_tls13_psk_key_exchange_modes_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + client_, ssl_tls13_psk_key_exchange_modes_xtn); Connect(); EXPECT_FALSE(capture->captured()); } @@ -1006,12 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, static uint8_t empty_buf[1] = {0}; DataBuffer empty(empty_buf, 0); auto filter = - std::make_shared<TlsExtensionAppender>(message, extension, empty); + MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - server_->SetTlsRecordFilter(filter); filter->EnableDecryption(); - } else { - server_->SetPacketFilter(filter); } } @@ -1032,17 +998,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { protected: void ConnectAndFail(uint8_t message) override { - if (message == kTlsHandshakeHelloRetryRequest) { + if (message != kTlsHandshakeServerHello) { ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); return; } - client_->StartConnect(); - server_->StartConnect(); + FailWithAlert(kTlsAlertUnsupportedExtension); + } + + void FailWithAlert(uint8_t alert) { + StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello - client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + client_->ExpectSendAlert(alert); client_->Handshake(); if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -1067,9 +1036,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { Run(kTlsHandshakeCertificate); } +// It's perfectly valid to set unknown extensions in CertificateRequest. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { server_->RequestClientAuth(false); - Run(kTlsHandshakeCertificateRequest); + AddFilter(kTlsHandshakeCertificateRequest, 0xff); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { @@ -1079,10 +1051,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { Run(kTlsHandshakeHelloRetryRequest); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) { - Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn); -} - TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) { Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); } @@ -1096,13 +1064,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) { Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) { - static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; - server_->ConfigNamedGroups(groups); - - Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn); -} - // NewSessionTicket allows unknown extensions AND it isn't protected by the // Finished. So adding an unknown extension doesn't cause an error. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { @@ -1132,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) { SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name); EXPECT_EQ(SECSuccess, rv); - auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn); client_->StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc index 44cacce46..f4940bf28 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -51,10 +51,16 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsHandshakeFilter::HandshakeHeader handshake_header; DataBuffer handshake_body; - if (!handshake_header.Parse(&parser, record_header, &handshake_body)) { + bool complete = false; + if (!handshake_header.Parse(&parser, record_header, DataBuffer(), + &handshake_body, &complete)) { ADD_FAILURE() << "couldn't parse handshake header"; return false; } + if (!complete) { + ADD_FAILURE() << "don't want to deal with fragmented messages"; + return false; + } DataBuffer record_fragment; // We can't fragment handshake records that are too small. @@ -82,7 +88,7 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(&parser, &record)) { + if (!header.Parse(0, &parser, &record)) { ADD_FAILURE() << "bad record header"; return false; } @@ -143,13 +149,13 @@ class RecordFragmenter : public PacketFilter { }; TEST_P(TlsConnectDatagram, FragmentClientPackets) { - client_->SetPacketFilter(std::make_shared<RecordFragmenter>()); + client_->SetFilter(std::make_shared<RecordFragmenter>()); Connect(); SendReceive(); } TEST_P(TlsConnectDatagram, FragmentServerPackets) { - server_->SetPacketFilter(std::make_shared<RecordFragmenter>()); + server_->SetFilter(std::make_shared<RecordFragmenter>()); Connect(); SendReceive(); } diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc index 1587b66de..99448321c 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -27,7 +27,8 @@ class TlsFuzzTest : public ::testing::Test {}; // Record the application data stream. class TlsApplicationDataRecorder : public TlsRecordFilter { public: - TlsApplicationDataRecorder() : buffer_() {} + TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), buffer_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -47,9 +48,9 @@ class TlsApplicationDataRecorder : public TlsRecordFilter { // Ensure that ssl_Time() returns a constant value. FUZZ_F(TlsFuzzTest, SSL_Time_Constant) { - PRUint32 now = ssl_Time(); + PRUint32 now = ssl_TimeSec(); PR_Sleep(PR_SecondsToInterval(2)); - EXPECT_EQ(ssl_Time(), now); + EXPECT_EQ(ssl_TimeSec(), now); } // Check that due to the deterministic PRNG we derive @@ -106,16 +107,16 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) { DisableECDHEServerKeyReuse(); DataBuffer buffer; - client_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer)); - server_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer)); + MakeTlsFilter<TlsConversationRecorder>(client_, buffer); + MakeTlsFilter<TlsConversationRecorder>(server_, buffer); // Reset the RNG state. EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0)); Connect(); // Ensure the filters go away before |buffer| does. - client_->DeletePacketFilter(); - server_->DeletePacketFilter(); + client_->ClearFilter(); + server_->ClearFilter(); if (last.len() > 0) { EXPECT_EQ(last, buffer); @@ -133,10 +134,8 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { EnsureTlsSetup(); // Set up app data filters. - auto client_recorder = std::make_shared<TlsApplicationDataRecorder>(); - client_->SetPacketFilter(client_recorder); - auto server_recorder = std::make_shared<TlsApplicationDataRecorder>(); - server_->SetPacketFilter(server_recorder); + auto client_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(client_); + auto server_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(server_); Connect(); @@ -161,10 +160,9 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) { FUZZ_P(TlsConnectGeneric, BogusClientFinished) { EnsureTlsSetup(); - auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>( - kTlsHandshakeFinished, + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeFinished, DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished))); - client_->SetPacketFilter(i1); Connect(); SendReceive(); } @@ -173,10 +171,9 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) { FUZZ_P(TlsConnectGeneric, BogusServerFinished) { EnsureTlsSetup(); - auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>( - kTlsHandshakeFinished, + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + server_, kTlsHandshakeFinished, DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished))); - server_->SetPacketFilter(i1); Connect(); SendReceive(); } @@ -187,7 +184,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) { uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsHandshakeCertificateVerify : kTlsHandshakeServerKeyExchange; - server_->SetPacketFilter(std::make_shared<TlsLastByteDamager>(msg_type)); + MakeTlsFilter<TlsLastByteDamager>(server_, msg_type); Connect(); SendReceive(); } @@ -197,8 +194,7 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) { EnsureTlsSetup(); client_->SetupClientAuth(); server_->RequestClientAuth(true); - client_->SetPacketFilter( - std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify)); + MakeTlsFilter<TlsLastByteDamager>(client_, kTlsHandshakeCertificateVerify); Connect(); } @@ -215,82 +211,32 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { SendReceive(); } -class TlsSessionTicketMacDamager : public TlsExtensionFilter { - public: - TlsSessionTicketMacDamager() {} - virtual PacketFilter::Action FilterExtension(uint16_t extension_type, - const DataBuffer& input, - DataBuffer* output) { - if (extension_type != ssl_session_ticket_xtn && - extension_type != ssl_tls13_pre_shared_key_xtn) { - return KEEP; - } - - *output = input; - - // Handle everything before TLS 1.3. - if (extension_type == ssl_session_ticket_xtn) { - // Modify the last byte of the MAC. - output->data()[output->len() - 1] ^= 0xff; - } - - // Handle TLS 1.3. - if (extension_type == ssl_tls13_pre_shared_key_xtn) { - TlsParser parser(input); - - uint32_t ids_len; - EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0); - - uint32_t ticket_len; - EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0); - - // Modify the last byte of the MAC. - output->data()[2 + 2 + ticket_len - 1] ^= 0xff; - } - - return CHANGE; - } -}; - -// Check that session ticket resumption works with a bad MAC. -FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) { - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - Connect(); - SendReceive(); - - Reset(); - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - ExpectResumption(RESUME_TICKET); - - client_->SetPacketFilter(std::make_shared<TlsSessionTicketMacDamager>()); - Connect(); - SendReceive(); -} - // Check that session tickets are not encrypted. FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); - auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeNewSessionTicket); - server_->SetPacketFilter(i1); + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeNewSessionTicket); Connect(); + std::cerr << "ticket" << filter->buffer() << std::endl; size_t offset = 4; /* lifetime */ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { - offset += 1 + 1 + /* ke_modes */ - 1 + 1; /* auth_modes */ + offset += 4; /* ticket_age_add */ + uint32_t nonce_len = 0; + EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len)); + offset += 1 + nonce_len; } offset += 2 + /* ticket length */ 2; /* TLS_EX_SESS_TICKET_VERSION */ // Check the protocol version number. uint32_t tls_version = 0; - EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version)); + EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version)); EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version)); // Check the cipher suite. uint32_t suite = 0; - EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite)); + EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite)); client_->CheckCipherSuite(static_cast<uint16_t>(suite)); } } diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc index cd10076b8..2fff9d7cb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_gtest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc @@ -6,6 +6,7 @@ #include <cstdlib> #include "test_io.h" +#include "databuffer.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -28,6 +29,7 @@ int main(int argc, char** argv) { ++i; } else if (!strcmp(argv[i], "-v")) { g_ssl_gtest_verbose = true; + nss_test::DataBuffer::SetLogLimit(16384); } } diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp index 8cd7d1009..e2a8d830a 100644 --- a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -11,6 +11,7 @@ 'target_name': 'ssl_gtest', 'type': 'executable', 'sources': [ + 'bloomfilter_unittest.cc', 'libssl_internals.c', 'selfencrypt_unittest.cc', 'ssl_0rtt_unittest.cc', @@ -18,6 +19,7 @@ 'ssl_auth_unittest.cc', 'ssl_cert_ext_unittest.cc', 'ssl_ciphersuite_unittest.cc', + 'ssl_custext_unittest.cc', 'ssl_damage_unittest.cc', 'ssl_dhe_unittest.cc', 'ssl_drop_unittest.cc', @@ -30,11 +32,16 @@ 'ssl_gather_unittest.cc', 'ssl_gtest.cc', 'ssl_hrr_unittest.cc', + 'ssl_keylog_unittest.cc', + 'ssl_keyupdate_unittest.cc', 'ssl_loopback_unittest.cc', + 'ssl_misc_unittest.cc', 'ssl_record_unittest.cc', 'ssl_resumption_unittest.cc', + 'ssl_renegotiation_unittest.cc', 'ssl_skip_unittest.cc', 'ssl_staticrsa_unittest.cc', + 'ssl_tls13compat_unittest.cc', 'ssl_v2_client_hello_unittest.cc', 'ssl_version_unittest.cc', 'ssl_versionpolicy_unittest.cc', diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc index 39055f641..05ae87034 100644 --- a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -35,17 +35,15 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // Send first ClientHello and send 0-RTT data auto capture_early_data = - std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn); - client_->SetPacketFilter(capture_early_data); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn); client_->Handshake(); EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen)); // 0-RTT write. EXPECT_TRUE(capture_early_data->captured()); // Send the HelloRetryRequest - auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(hrr_capture); + auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); server_->Handshake(); EXPECT_LT(0U, hrr_capture->buffer().len()); @@ -56,8 +54,7 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // Make a new capture for the early data. capture_early_data = - std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn); - client_->SetPacketFilter(capture_early_data); + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn); // Complete the handshake successfully Handshake(); @@ -71,6 +68,10 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { // packet. If the record is split into two packets, or there are multiple // handshake packets, this will break. class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter { + public: + CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} + protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, @@ -131,8 +132,7 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { // Correct the DTLS message sequence number after an HRR. if (variant_ == ssl_variant_datagram) { - client_->SetPacketFilter( - std::make_shared<CorrectMessageSeqAfterHrrFilter>()); + MakeTlsFilter<CorrectMessageSeqAfterHrrFilter>(client_); } server_->SetPeer(client_); @@ -151,7 +151,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) { class KeyShareReplayer : public TlsExtensionFilter { public: - KeyShareReplayer() {} + KeyShareReplayer(const std::shared_ptr<TlsAgent>& agent) + : TlsExtensionFilter(agent) {} virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, @@ -178,7 +179,22 @@ class KeyShareReplayer : public TlsExtensionFilter { // server should reject this. TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { EnsureTlsSetup(); - client_->SetPacketFilter(std::make_shared<KeyShareReplayer>()); + MakeTlsFilter<KeyShareReplayer>(client_); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); +} + +// Here we modify the second ClientHello so that the client retries with the +// same shares, even though the server wanted something else. +TEST_P(TlsConnectTls13, RetryWithTwoShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + MakeTlsFilter<KeyShareReplayer>(client_); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); @@ -187,13 +203,574 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); } +TEST_P(TlsConnectTls13, RetryCallbackAccept) { + EnsureTlsSetup(); + + auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + accept_hello, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) { + EnsureTlsSetup(); + + auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + auto capture = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn); + capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_run = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), accept_hello_twice, &cb_run)); + Connect(); + EXPECT_EQ(2U, cb_run); + EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest"; +} + +TEST_P(TlsConnectTls13, RetryCallbackFail) { + EnsureTlsSetup(); + + auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_fail; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + fail_hello, &cb_run)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT); + EXPECT_TRUE(cb_run); +} + +// Asking for retry twice isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + return ssl_hello_retry_request; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// Accepting the CH and modifying the token isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// As above, but with reject. +TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_fail; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// This is a (pretend) buffer overflow. +TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = appTokenMax + 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +SSLHelloRetryRequestAction RetryHello(PRBool firstHello, + const PRUint8* clientToken, + unsigned int clientTokenLen, + PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetry) { + EnsureTlsSetup(); + + auto capture_hrr = std::make_shared<TlsHandshakeRecorder>( + server_, ssl_hs_hello_retry_request); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr, + capture_key_share}; + server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain)); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected"; + EXPECT_FALSE(capture_key_share->captured()) + << "no key_share extension expected"; + + auto capture_cookie = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie"; +} + +static size_t CountShares(const DataBuffer& key_share) { + size_t count = 0; + uint32_t len = 0; + size_t offset = 2; + + EXPECT_TRUE(key_share.Read(0, 2, &len)); + EXPECT_EQ(key_share.len() - 2, len); + while (offset < key_share.len()) { + offset += 2; // Skip KeyShareEntry.group + EXPECT_TRUE(key_share.Read(offset, 2, &len)); + offset += 2 + len; // Skip KeyShareEntry.key_exchange + ++count; + } + return count; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_FALSE(capture_server->captured()) + << "no key_share extension expected from server"; + + auto capture_client_2nd = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share"; + EXPECT_EQ(2U, CountShares(capture_client_2nd->extension())) + << "client should still send two shares"; +} + +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants a P-384 key share and the client didn't offer one. +TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) { + EnsureTlsSetup(); + + auto capture_cookie = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn); + capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_cookie, capture_key_share})); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "cookie expected"; + EXPECT_TRUE(capture_key_share->captured()) << "key_share expected"; +} + +static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00}; + +SSLHelloRetryRequestAction RetryHelloWithToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + if (firstHello) { + memcpy(appToken, kApplicationToken, sizeof(kApplicationToken)); + *appTokenLen = sizeof(kApplicationToken); + return ssl_hello_retry_request; + } + + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) { + EnsureTlsSetup(); + + auto capture_key_share = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_FALSE(capture_key_share->captured()) << "no key share expected"; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) { + EnsureTlsSetup(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + auto capture_key_share = + MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_key_share->captured()) << "key share expected"; +} + +SSLHelloRetryRequestAction CheckTicketToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +// Stream because SSL_SendSessionTicket only supports that. +TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken, + sizeof(kApplicationToken))); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), CheckTicketToken, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server) { + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(), + RetryHello, &cb_called)); + + // Start the handshake. + client->StartConnect(); + server->StartConnect(); + client->Handshake(); + server->Handshake(); + EXPECT_EQ(1U, cb_called); +} + +TEST_P(TlsConnectTls13, RetryStateless) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + Handshake(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, RetryStatefulDropCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn); + + ExpectAlert(server_, kTlsAlertMissingExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION); +} + +// Stream only because DTLS drops bad packets. +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto damage_ch = + MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer()); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + auto damage_ch = + MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer()); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +// Read the cipher suite from the HRR and disable it on the identified agent. +static void DisableSuiteFromHrr( + std::shared_ptr<TlsAgent>& agent, + std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) { + uint32_t tmp; + size_t offset = 2 + 32; // skip version + server_random + ASSERT_TRUE( + capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length + EXPECT_EQ(0U, tmp); + offset += 1 + tmp; + ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite + EXPECT_EQ( + SECSuccess, + SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE)); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(client_, capture_hrr); + + // The client thinks that the HelloRetryRequest is bad, even though its + // because it changed its mind about the cipher suite. + ExpectAlert(client_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(server_, capture_hrr); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(groups); + + // We're into undefined behavior on the client side, but - at the point this + // test was written - the client here doesn't amend its key shares because the + // server doesn't ask it to. The server notices that the key share (x25519) + // doesn't match the negotiated group (P-384) and objects. + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessBadCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + + // Now replace the self-encrypt MAC key with a garbage key. + static const uint8_t bad_hmac_key[32] = {0}; + SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key), + sizeof(bad_hmac_key)}; + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey* hmac_key = + PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap, + CKA_SIGN, &key_item, nullptr); + ASSERT_NE(nullptr, hmac_key); + SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership. + + MakeNewServer(); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Stream because the server doesn't consume the alert and terminate. +TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) { + EnsureTlsSetup(); + // Force a HelloRetryRequest. + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + // Then switch out the default suite (TLS_AES_128_GCM_SHA256). + MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, + TLS_CHACHA20_POLY1305_SHA256); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + // This tests that the second attempt at sending a ClientHello (after receiving // a HelloRetryRequest) is correctly retransmitted. TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(groups); - server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2)); + server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2)); Connect(); } @@ -233,6 +810,54 @@ TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) { CheckKEXDetails(client_groups, client_groups); } +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants an X25519 key share and the client didn't offer one. +TEST_P(TlsKeyExchange13, + RetryCallbackRetryWithGroupMismatchAndAdditionalShares) { + EnsureKeyShareSetup(); + + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(server_groups); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_hrr_, capture_server})); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_TRUE(capture_server->captured()) << "key_share extension expected"; + + uint32_t server_group = 0; + EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group)); + EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group)); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares"; + + CheckKeys(); + static const std::vector<SSLNamedGroup> client_shares( + client_groups.begin(), client_groups.begin() + 2); + CheckKEXDetails(client_groups, client_shares, server_groups[0]); +} + TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { EnsureTlsSetup(); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -245,8 +870,7 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { static const std::vector<SSLNamedGroup> server_groups = { ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(server_groups); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); @@ -276,15 +900,30 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient { void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record, uint32_t seq_num = 0) const { DataBuffer hrr_data; - hrr_data.Allocate(len + 4); + const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + + hrr_data.Allocate(len + 6); size_t i = 0; + i = hrr_data.Write(i, 0x0303, 2); + i = hrr_data.Write(i, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)); + i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id + i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2); + i = hrr_data.Write(i, ssl_compression_null, 1); + // Add extensions. First a length, which includes the supported version. + i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2); + // Now the supported version. + i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2); + i = hrr_data.Write(i, 2, 2); i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2); - i = hrr_data.Write(i, static_cast<uint32_t>(len), 2); if (len) { hrr_data.Write(i, body, len); } DataBuffer hrr; - MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(), + MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(), hrr_data.len(), &hrr, seq_num); MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(), hrr.len(), hrr_record, seq_num); @@ -334,28 +973,6 @@ TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) { SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); } -TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { - const uint8_t canned_cookie_hrr[] = { - static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8), - static_cast<uint8_t>(ssl_tls13_cookie_xtn), - 0, - 5, // length of cookie extension - 0, - 3, // cookie value length - 0xc0, - 0x0c, - 0x13}; - DataBuffer hrr; - MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr); - auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn); - agent_->SetPacketFilter(capture); - ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); - const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len - DataBuffer cookie(canned_cookie_hrr + cookie_pos, - sizeof(canned_cookie_hrr) - cookie_pos); - EXPECT_EQ(cookie, capture->extension()); -} - INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc new file mode 100644 index 000000000..322b64837 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc @@ -0,0 +1,118 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifdef NSS_ALLOW_SSLKEYLOGFILE + +#include <cstdlib> +#include <fstream> +#include <sstream> + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +static const std::string keylog_file_path = "keylog.txt"; +static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path; + +class KeyLogFileTest : public TlsConnectGeneric { + public: + void SetUp() override { + TlsConnectGeneric::SetUp(); + // Remove previous results (if any). + (void)remove(keylog_file_path.c_str()); + PR_SetEnv(keylog_env.c_str()); + } + + void CheckKeyLog() { + std::ifstream f(keylog_file_path); + std::map<std::string, size_t> labels; + std::set<std::string> client_randoms; + for (std::string line; std::getline(f, line);) { + if (line[0] == '#') { + continue; + } + + std::istringstream iss(line); + std::string label, client_random, secret; + iss >> label >> client_random >> secret; + + ASSERT_EQ(64U, client_random.size()); + client_randoms.insert(client_random); + labels[label]++; + } + + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(1U, client_randoms.size()); + } else { + /* two handshakes for 0-RTT */ + ASSERT_EQ(2U, client_randoms.size()); + } + + // Every entry occurs twice (one log from server, one from client). + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(2U, labels["CLIENT_RANDOM"]); + } else { + ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]); + ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["EXPORTER_SECRET"]); + } + } + + void ConnectAndCheck() { + // This is a child process, ensure that error messages immediately + // propagate or else it will not be visible. + ::testing::GTEST_FLAG(throw_on_failure) = true; + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + } else { + Connect(); + } + CheckKeyLog(); + _exit(0); + } +}; + +// Tests are run in a separate process to ensure that NSS is not initialized yet +// and can process the SSLKEYLOGFILE environment variable. + +TEST_P(KeyLogFileTest, KeyLogFile) { + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + + ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), ""); +} + +INSTANTIATE_TEST_CASE_P( + KeyLogFileDTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS13, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test + +#endif // NSS_ALLOW_SSLKEYLOGFILE diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc new file mode 100644 index 000000000..d03775c25 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc @@ -0,0 +1,178 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// All stream only tests; DTLS isn't supported yet. + +TEST_F(TlsConnectTest, KeyUpdateClient) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 3); +} + +TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + // SendReceive() only gives each peer one chance to read. This isn't enough + // when the read on one side generates another handshake message. A second + // read gives each peer an extra chance to consume the KeyUpdate. + SendReceive(50); + SendReceive(60); // Cumulative count. + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServer) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(3, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // The server should have updated twice, but the client should have declined + // to respond to the second request from the server, since it doesn't send + // anything in between those two requests. + CheckEpochs(4, 5); +} + +// Check that a local update can be immediately followed by a remotely triggered +// update even if there is no use of the keys. +TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // This should trigger an update on the client. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + // The client should update for the first request. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + // ...but not the second. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // Both should have updated twice. + CheckEpochs(5, 5); +} + +TEST_F(TlsConnectTest, KeyUpdateMultiple) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 6); +} + +// Both ask the other for an update, and both should react. +TEST_F(TlsConnectTest, KeyUpdateBothRequest) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 5); +} + +// If the sequence number exceeds the number of writes before an automatic +// update (currently 3/4 of the max records for the cipher suite), then the +// stack should send an update automatically (but not request one). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Set this to one below the write threshold. + uint64_t threshold = (0x5aULL << 28) * 3 / 4; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should be OK. + client_->SendData(10); + server_->ReadBytes(); + + // This should cause the client to update. + client_->SendData(10); + server_->ReadBytes(); + + SendReceive(100); + CheckEpochs(4, 3); +} + +// If the sequence number exceeds a certain number of reads (currently 7/8 of +// the max records for the cipher suite), then the stack should send AND request +// an update automatically. However, the sender (client) will be above its +// automatic update threshold, so the KeyUpdate - that it sends with the old +// cipher spec - will exceed the receiver (server) automatic update threshold. +// The receiver gets a packet with a sequence number over its automatic read +// update threshold. Even though the sender has updated, the code that checks +// the sequence numbers at the receiver doesn't know this and it will request an +// update. This causes two updates: one from the sender (without requesting a +// response) and one from the receiver (which does request a response). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Move to right at the read threshold. Unlike the write test, we can't send + // packets because that would cause the client to update, which would spoil + // the test. + uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should cause the client to update, but not early enough to prevent the + // server from updating also. + client_->SendData(10); + server_->ReadBytes(); + + // Need two SendReceive() calls to ensure that the update that the server + // requested is properly generated and consumed. + SendReceive(70); + SendReceive(80); + CheckEpochs(5, 4); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 77703dd8e..f1b78f52f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -6,6 +6,7 @@ #include <functional> #include <memory> +#include <vector> #include "secerr.h" #include "ssl.h" #include "sslerr.h" @@ -55,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { class TlsAlertRecorder : public TlsRecordFilter { public: - TlsAlertRecorder() : level_(255), description_(255) {} + TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), level_(255), description_(255) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -84,13 +86,13 @@ class TlsAlertRecorder : public TlsRecordFilter { }; class HelloTruncator : public TlsHandshakeFilter { + public: + HelloTruncator(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter( + agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientHello && - header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } output->Assign(input.data(), input.len() - 1); return CHANGE; } @@ -98,19 +100,17 @@ class HelloTruncator : public TlsHandshakeFilter { // Verify that when NSS reports that an alert is sent, it is actually sent. TEST_P(TlsConnectGeneric, CaptureAlertServer) { - client_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - server_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(client_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_); - ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); ConnectExpectAlert(client_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -119,12 +119,10 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { // In TLS 1.3, the server can't read the client alert. TEST_P(TlsConnectTls13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); client_->ExpectSendAlert(kTlsAlertDecodeError); @@ -166,26 +164,111 @@ TEST_P(TlsConnectDatagram, ConnectSrtp) { SendReceive(); } -// 1.3 is disabled in the next few tests because we don't -// presently support resumption in 1.3. -TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) { +TEST_P(TlsConnectGeneric, ConnectSendReceive) { Connect(); - server_->PrepareForRenegotiate(); - client_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); } -TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) { +class SaveTlsRecord : public TlsRecordFilter { + public: + SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0), contents_() {} + + const DataBuffer& contents() const { return contents_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + contents_ = data; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; + DataBuffer contents_; +}; + +// Check that decrypting filters work and can read any record. +// This test (currently) only works in TLS 1.3 where we can decrypt. +TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3); + saved->EnableDecryption(); Connect(); - client_->PrepareForRenegotiate(); - server_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xdc}; + DataBuffer buf(data, sizeof(data)); + client_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); } -TEST_P(TlsConnectGeneric, ConnectSendReceive) { +TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { + EnsureTlsSetup(); + // Disable tickets so that we are sure to not get NewSessionTicket. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xd5}; + DataBuffer buf(data, sizeof(data)); + server_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +class DropTlsRecord : public TlsRecordFilter { + public: + DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + return DROP; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; +}; + +// Test that decrypting filters work correctly and are able to drop records. +TEST_F(TlsConnectStreamTls13, DropRecordServer) { + EnsureTlsSetup(); + // Disable session tickets so that the server doesn't send an extra record. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + // 0 = ServerHello, 1 = other handshake, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2); + filter->EnableDecryption(); + Connect(); + server_->SendData(23, 23); // This should be dropped, so it won't be counted. + server_->ResetSentBytes(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, DropRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2); + filter->EnableDecryption(); Connect(); + client_->SendData(26, 26); // This should be dropped, so it won't be counted. + client_->ResetSentBytes(); SendReceive(); } @@ -224,32 +307,74 @@ TEST_P(TlsConnectStream, ShortRead) { ASSERT_EQ(50U, client_->received_bytes()); } -TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { +// We enable compression via the API but it's disabled internally, +// so we should never get it. +TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { EnsureTlsSetup(); - client_->EnableCompression(); - server_->EnableCompression(); + client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); Connect(); - EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && - variant_ != ssl_variant_datagram, - client_->is_compressed()); + EXPECT_FALSE(client_->is_compressed()); SendReceive(); } -TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) { +class TlsHolddownTest : public TlsConnectDatagram { + protected: + // This causes all timers to run to completion. It advances the clock and + // handshakes on both peers until both peers have no more timers pending, + // which should happen at the end of a handshake. This is necessary to ensure + // that the relatively long holddown timer expires, but that any other timers + // also expire and run correctly. + void RunAllTimersDown() { + while (true) { + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv != SECSuccess) { + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv != SECSuccess) { + break; // Neither peer has an outstanding timer. + } + } + + if (g_ssl_gtest_verbose) { + std::cerr << "Shifting timers" << std::endl; + } + ShiftDtlsTimers(); + Handshake(); + } + } +}; + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { Connect(); - std::cerr << "Expiring holddown timer\n"; - SSLInt_ForceTimerExpiry(client_->ssl_fd()); - SSLInt_ForceTimerExpiry(server_->ssl_fd()); + std::cerr << "Expiring holddown timer" << std::endl; + RunAllTimersDown(); SendReceive(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { // One for send, one for receive. - EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd())); + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); } } +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + RunAllTimersDown(); + SendReceive(); + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); +} + class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: - TlsPreCCSHeaderInjector() {} + TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, size_t* offset, DataBuffer* output) override { @@ -266,16 +391,15 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { - client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); + MakeTlsFilter<TlsPreCCSHeaderInjector>(client_); ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { - server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); - client_->StartConnect(); - server_->StartConnect(); + MakeTlsFilter<TlsPreCCSHeaderInjector>(server_); + StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); @@ -306,21 +430,64 @@ TEST_P(TlsConnectTls13, AlertWrongLevel) { TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); // Send first flight. - client_->adapter()->CloseWrites(); + client_->adapter()->SetWriteError(PR_IO_ERROR); client_->Handshake(); // This will get an error, but shouldn't crash. client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); } -TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { - client_->SetShortHeadersEnabled(); - server_->SetShortHeadersEnabled(); - client_->ExpectShortHeaders(); - server_->ExpectShortHeaders(); +TEST_P(TlsConnectDatagram, BlockedWrite) { Connect(); + + // Mark the socket as blocked. + client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); + static const uint8_t data[] = {1, 2, 3}; + int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Remove the write error and though the previous write failed, future reads + // and writes should just work as if it never happened. + client_->adapter()->SetWriteError(0); + SendReceive(); +} + +TEST_F(TlsConnectTest, ConnectSSLv3) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { + // MAC-then-Encrypt expansion formula: + return ((in + hmac + (block - 1)) / block) * block; +} + +TEST_F(TlsConnectTest, OneNRecordSplitting) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); + EnsureTlsSetup(); + ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); + auto records = MakeTlsFilter<TlsRecordRecorder>(server_); + // This should be split into 1, 16384 and 20. + DataBuffer big_buffer; + big_buffer.Allocate(1 + 16384 + 20); + server_->SendBuffer(big_buffer); + ASSERT_EQ(3U, records->count()); + EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); } INSTANTIATE_TEST_CASE_P( @@ -336,6 +503,8 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, TlsConnectTestBase::kTlsVAll); INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_CASE_P(DatagramHolddown, TlsHolddownTest, + TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_CASE_P( Pre12Stream, TlsConnectPre12, @@ -368,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -} // namespace nspr_test +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc new file mode 100644 index 000000000..2b1b92dcd --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc @@ -0,0 +1,20 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "sslexp.h" + +#include "gtest_utils.h" + +namespace nss_test { + +class MiscTest : public ::testing::Test {}; + +TEST_F(MiscTest, NonExistentExperimentalAPI) { + EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah")); + EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc index ef81b222c..3b8727850 100644 --- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -10,6 +10,8 @@ #include "databuffer.h" #include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" namespace nss_test { @@ -51,8 +53,8 @@ class TlsPaddingTest << " total length=" << plaintext_.len() << std::endl; std::cerr << "Plaintext: " << plaintext_ << std::endl; sslBuffer s; - s.buf = const_cast<unsigned char *>( - static_cast<const unsigned char *>(plaintext_.data())); + s.buf = const_cast<unsigned char*>( + static_cast<const unsigned char*>(plaintext_.data())); s.len = plaintext_.len(); SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize); if (expect_success) { @@ -99,6 +101,73 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) { } } +class RecordReplacer : public TlsRecordFilter { + public: + RecordReplacer(const std::shared_ptr<TlsAgent>& agent, size_t size) + : TlsRecordFilter(agent), enabled_(false), size_(size) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (!enabled_) { + return KEEP; + } + + EXPECT_EQ(kTlsApplicationDataType, header.content_type()); + changed->Allocate(size_); + + for (size_t i = 0; i < size_; ++i) { + changed->data()[i] = i & 0xff; + } + + enabled_ = false; + return CHANGE; + } + + void Enable() { enabled_ = true; } + + private: + bool enabled_; + size_t size_; +}; + +TEST_F(TlsConnectStreamTls13, LargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + client_->SendData(10); + WAIT_(server_->received_bytes() == record_limit, 2000); + ASSERT_EQ(record_limit, server_->received_bytes()); +} + +TEST_F(TlsConnectStreamTls13, TooLargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit + 1); + replacer->EnableDecryption(); + Connect(); + + replacer->Enable(); + ExpectAlert(server_, kTlsAlertRecordOverflow); + client_->SendData(10); // This is expanded. + + uint8_t buf[record_limit + 2]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError()); + + // Read the server alert. + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError()); +} + const static size_t kContentSizesArr[] = { 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; @@ -108,4 +177,4 @@ auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest, ::testing::Combine(kContentSizes, kTrueFalse)); -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc new file mode 100644 index 000000000..a902a5f7f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc @@ -0,0 +1,212 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +// 1.3 is disabled in the next few tests because we don't +// presently support resumption in 1.3. +TEST_P(TlsConnectStreamPre13, RenegotiateClient) { + Connect(); + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectStreamPre13, RenegotiateServer) { + Connect(); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen. +TEST_F(TlsConnectTest, RenegotiationConfigTls13) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + Connect(); + SendReceive(); + CheckKeys(); +} + +TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + client_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + server_->StartRenegotiate(); + + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + server_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + client_->StartRenegotiate(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, rv); + return; + } + ASSERT_EQ(SECSuccess, rv); + + // Now, before handshaking, tweak the server configuration. + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + + // The server should catch the own error. + ExpectAlert(server_, kTlsAlertProtocolVersion); + + Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE)); +} + +TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the client set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + server_->ResetPreliminaryInfo(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // The client will refuse to renegotiate down. + EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE)); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc index ce0e3ca8d..eb78c0585 100644 --- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -9,6 +9,7 @@ #include "secerr.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" extern "C" { @@ -59,7 +60,7 @@ TEST_P(TlsConnectGenericPre13, ConnectResumed) { Connect(); } -TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectClientCacheDisabled) { ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID); Connect(); SendReceive(); @@ -70,7 +71,7 @@ TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectServerCacheDisabled) { ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE); Connect(); SendReceive(); @@ -81,7 +82,7 @@ TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { +TEST_P(TlsConnectGenericResumption, ConnectSessionCacheDisabled) { ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); SendReceive(); @@ -92,7 +93,7 @@ TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeSupportBoth) { // This prefers tickets. ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); Connect(); @@ -105,7 +106,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientTicketServerBoth) { // This causes no resumption because the client needs the // session cache to resume even with tickets. ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); @@ -119,7 +120,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothTicketServerTicket) { // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); @@ -132,7 +133,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientServerTicketOnly) { // This causes no resumption because the client needs the // session cache to resume even with tickets. ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); @@ -146,7 +147,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothServerNone) { ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); Connect(); SendReceive(); @@ -158,7 +159,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectResumeClientNoneServerBoth) { +TEST_P(TlsConnectGenericResumption, ConnectResumeClientNoneServerBoth) { ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); Connect(); SendReceive(); @@ -201,7 +202,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { SendReceive(); } -TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) { +TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) { SSLInt_SetTicketLifetime(1); // one second // This causes a ticket resumption. ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -218,8 +219,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) { SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? ssl_tls13_pre_shared_key_xtn : ssl_session_ticket_xtn; - auto capture = std::make_shared<TlsExtensionCapture>(xtn); - client_->SetPacketFilter(capture); + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn); Connect(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { @@ -244,10 +244,8 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) ? ssl_tls13_pre_shared_key_xtn : ssl_session_ticket_xtn; - auto capture = std::make_shared<TlsExtensionCapture>(xtn); - client_->SetPacketFilter(capture); - client_->StartConnect(); - server_->StartConnect(); + auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn); + StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); EXPECT_LT(0U, capture->extension().len()); @@ -327,25 +325,23 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { // Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { - auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); + EXPECT_TRUE(dhe1.Parse(filter->buffer())); // Restart Reset(); - auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); + auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); // Make sure they are the same. EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); @@ -355,32 +351,25 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { // This test parses the ServerKeyExchange, which isn't in 1.3 TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { - server_->EnsureTlsSetup(); - SECStatus rv = - SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); + server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + auto filter = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); + EXPECT_TRUE(dhe1.Parse(filter->buffer())); // Restart Reset(); - server_->EnsureTlsSetup(); - rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); + server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeServerKeyExchange); ConfigureSessionCache(RESUME_NONE, RESUME_NONE); Connect(); CheckKeys(); TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); + EXPECT_TRUE(dhe2.Parse(filter2->buffer())); // Make sure they are different. EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && @@ -401,7 +390,8 @@ TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) { client_->ConfigNamedGroups(kFFDHEGroups); server_->ConfigNamedGroups(kFFDHEGroups); Connect(); - CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); } // We need to enable different cipher suites at different times in the following @@ -421,7 +411,7 @@ static uint16_t ChooseAnotherCipher(uint16_t version) { } // Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { +TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); client_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -438,15 +428,15 @@ TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { } else { ticket_extension = ssl_session_ticket_xtn; } - auto ticket_capture = std::make_shared<TlsExtensionCapture>(ticket_extension); - client_->SetPacketFilter(ticket_capture); + auto ticket_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ticket_extension); Connect(); CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); EXPECT_EQ(0U, ticket_capture->extension().len()); } // Test that we don't resume when we can't negotiate the same cipher. -TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) { +TEST_P(TlsConnectGenericResumption, TestResumeServerDifferentCipher) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); server_->EnableSingleCipher(ChooseOneCipher(version_)); Connect(); @@ -461,36 +451,6 @@ TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) { CheckKeys(); } -class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { - public: - SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {} - - protected: - PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } - - *output = input; - uint32_t temp = 0; - EXPECT_TRUE(input.Read(0, 2, &temp)); - // Cipher suite is after version(2) and random(32). - size_t pos = 34; - if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { - // In old versions, we have to skip a session_id too. - EXPECT_TRUE(input.Read(pos, 1, &temp)); - pos += 1 + temp; - } - output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); - return CHANGE; - } - - private: - uint16_t cipher_suite_; -}; - // Test that the client doesn't tolerate the server picking a different cipher // suite for resumption. TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { @@ -502,8 +462,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>( - ChooseAnotherCipher(version_))); + MakeTlsFilter<SelectedCipherSuiteReplacer>(server_, + ChooseAnotherCipher(version_)); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { client_->ExpectSendAlert(kTlsAlertIllegalParameter); @@ -524,16 +484,15 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { class SelectedVersionReplacer : public TlsHandshakeFilter { public: - SelectedVersionReplacer(uint16_t version) : version_(version) {} + SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& agent, + uint16_t version) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + version_(version) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } - *output = input; output->Write(0, static_cast<uint32_t>(version_), 2); return CHANGE; @@ -580,8 +539,7 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { // Enable the lower version on the client. client_->SetVersionRange(version_ - 1, version_); server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); - server_->SetPacketFilter( - std::make_shared<SelectedVersionReplacer>(override_version)); + MakeTlsFilter<SelectedVersionReplacer>(server_, override_version); ConnectExpectAlert(client_, kTlsAlertHandshakeFailure); client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); @@ -604,12 +562,12 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); ExpectResumption(RESUME_TICKET); - auto c1 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(c1); + auto c1 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_none); + ssl_sig_rsa_pss_rsae_sha256); // The filter will go away when we reset, so save the captured extension. DataBuffer initialTicket(c1->extension()); ASSERT_LT(0U, initialTicket.len()); @@ -621,13 +579,13 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { ClearStats(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - auto c2 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); - client_->SetPacketFilter(c2); + auto c2 = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); ExpectResumption(RESUME_TICKET); Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_none); + ssl_sig_rsa_pss_rsae_sha256); ASSERT_LT(0U, c2->extension().len()); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); @@ -652,7 +610,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { // Clear the session ticket keys to invalidate the old ticket. SSLInt_ClearSelfEncryptKey(); - SSLInt_SendNewSessionTicket(server_->ssl_fd()); + SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0); SendReceive(); // Need to read so that we absorb the session tickets. CheckKeys(); @@ -666,6 +624,144 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { SendReceive(); } +// Check that the value captured in a NewSessionTicket message matches the value +// captured from a pre_shared_key extension. +void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) { + uint32_t len; + + size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add. + ASSERT_TRUE(nst.Read(offset, 1, &len)); + offset += 1 + len; // Skip ticket_nonce. + + ASSERT_TRUE(nst.Read(offset, 2, &len)); + offset += 2; // Skip the ticket length. + ASSERT_LE(offset + len, nst.len()); + DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len)); + + offset = 2; // Skip the identities length. + ASSERT_TRUE(psk.Read(offset, 2, &len)); + offset += 2; // Skip the identity length. + ASSERT_LE(offset + len, psk.len()); + DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len)); + + EXPECT_EQ(nst_ticket, psk_ticket); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + + // Clear the session ticket keys to invalidate the old ticket. + SSLInt_ClearSelfEncryptKey(); + nst_capture->Reset(); + uint8_t token[] = {0x20, 0x20, 0xff, 0x00}; + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token))); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + EXPECT_LT(0U, nst_capture->buffer().len()); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent +// by invoking SSL_SendSessionTicket directly (and that the ticket is usable). +TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + auto nst_capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket); + nst_capture->EnableDecryption(); + Connect(); + + EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet"; + + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now"; + + SendReceive(); // Ensure that the client reads the ticket. + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Test calling SSL_SendSessionTicket in inappropriate conditions. +TEST_F(TlsConnectTest, SendSessionTicketInappropriate) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0)) + << "clients can't send tickets"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + StartConnect(); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no ticket before the handshake has started"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + Handshake(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no special tickets in TLS 1.2"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // It should be safe to set length with a NULL token because the length should + // be checked before reading token. + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff)) + << "this is clearly too big"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + static const uint8_t big_token[0xffff] = {1}; + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token, + sizeof(big_token))) + << "this is too big, but that's not immediately obvious"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no extra tickets in DTLS until we have Ack support"; + EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError()); +} + TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); @@ -716,16 +812,220 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { // We will eventually fail the (sid.version == SH.version) check. std::vector<std::shared_ptr<PacketFilter>> filters; filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>( - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); + server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); + filters.push_back(std::make_shared<SelectedVersionReplacer>( + server_, SSL_LIBRARY_VERSION_TLS_1_2)); + + // Drop a bunch of extensions so that we get past the SH processing. The + // version extension says TLS 1.3, which is counter to our goal, the others + // are not permitted in TLS 1.2 handshakes. + filters.push_back(std::make_shared<TlsExtensionDropper>( + server_, ssl_tls13_supported_versions_xtn)); filters.push_back( - std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2)); - server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters)); - - client_->ExpectSendAlert(kTlsAlertDecodeError); + std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn)); + filters.push_back(std::make_shared<TlsExtensionDropper>( + server_, ssl_tls13_pre_shared_key_xtn)); + server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters)); + + // The client here generates an unexpected_message alert when it receives an + // encrypted handshake message from the server (EncryptedExtension). The + // client expects to receive an unencrypted TLS 1.2 Certificate message. + // The server can't decrypt the alert. + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read ConnectExpectFail(); - client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } +TEST_P(TlsConnectGenericResumption, ReConnectTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericPre13, ReConnectCache) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ExpectResumption(RESUME_SESSIONID); + Connect(); + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_rsae_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); + // Resume connection again + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET, 2); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256); +} + +void CheckGetInfoResult(uint32_t alpnSize, uint32_t earlyDataSize, + ScopedCERTCertificate& cert, + ScopedSSLResumptionTokenInfo& token) { + ASSERT_TRUE(cert); + ASSERT_TRUE(token->peerCert); + + // Check that the server cert is the correct one. + ASSERT_EQ(cert->derCert.len, token->peerCert->derCert.len); + EXPECT_EQ(0, memcmp(cert->derCert.data, token->peerCert->derCert.data, + cert->derCert.len)); + + ASSERT_EQ(alpnSize, token->alpnSelectionLen); + EXPECT_EQ(0, memcmp("a", token->alpnSelection, token->alpnSelectionLen)); + + ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize); +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(0, 0, cert, token); + + Handshake(); + CheckConnected(); + + SendReceive(); +} + +TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) { + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(1, 0, cert, token); + + Handshake(); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) { + EnableAlpn(); + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->Set0RttEnabled(true); + Connect(); + CheckAlpn("a"); + SendReceive(); + + Reset(); + EnableAlpn(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + + StartConnect(); + server_->Set0RttEnabled(true); + client_->Set0RttEnabled(true); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + // Get resumption token infos + SSLResumptionTokenInfo tokenInfo = {0}; + ScopedSSLResumptionTokenInfo token(&tokenInfo); + client_->GetTokenInfo(token); + ScopedCERTCertificate cert( + PK11_FindCertFromNickname(server_->name().c_str(), nullptr)); + + CheckGetInfoResult(1, 1024, cert, token); + + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + CheckAlpn("a"); + + SendReceive(); +} + +// Resumption on sessions with client authentication only works with internal +// caching. +TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + SendReceive(); + EXPECT_FALSE(client_->resumption_callback_called()); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + if (use_external_cache()) { + ExpectResumption(RESUME_NONE); + } else { + ExpectResumption(RESUME_TICKET); + } + Connect(); + SendReceive(); +} + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index a130ef77f..e4a9e5aed 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -22,8 +22,11 @@ namespace nss_test { class TlsHandshakeSkipFilter : public TlsRecordFilter { public: // A TLS record filter that skips handshake messages of the identified type. - TlsHandshakeSkipFilter(uint8_t handshake_type) - : handshake_type_(handshake_type), skipped_(false) {} + TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type) + : TlsRecordFilter(agent), + handshake_type_(handshake_type), + skipped_(false) {} protected: // Takes a record; if it is a handshake record, it removes the first handshake @@ -43,7 +46,14 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { size_t start = parser.consumed(); TlsHandshakeFilter::HandshakeHeader header; DataBuffer ignored; - if (!header.Parse(&parser, record_header, &ignored)) { + bool complete = false; + if (!header.Parse(&parser, record_header, DataBuffer(), &ignored, + &complete)) { + ADD_FAILURE() << "Error parsing handshake header"; + return KEEP; + } + if (!complete) { + ADD_FAILURE() << "Don't want to deal with fragmented input"; return KEEP; } @@ -85,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase, TlsSkipTest() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void SetUp() override { + TlsConnectTestBase::SetUp(); + EnsureTlsSetup(); + } + void ServerSkipTest(std::shared_ptr<PacketFilter> filter, uint8_t alert = kTlsAlertUnexpectedMessage) { - server_->SetPacketFilter(filter); + server_->SetFilter(filter); ConnectExpectAlert(client_, alert); } }; @@ -98,29 +113,23 @@ class Tls13SkipTest : public TlsConnectTestBase, Tls13SkipTest() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} - void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { + void SetUp() override { + TlsConnectTestBase::SetUp(); EnsureTlsSetup(); - server_->SetTlsRecordFilter(filter); + } + + void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { filter->EnableDecryption(); - client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - ConnectExpectFail(); - } else { - ConnectExpectFailOneSide(TlsAgent::CLIENT); - } + server_->SetFilter(filter); + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + ConnectExpectFail(); client_->CheckErrorCode(error); - if (variant_ == ssl_variant_stream) { - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - } else { - ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); - } + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); } void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { - EnsureTlsSetup(); - client_->SetTlsRecordFilter(filter); filter->EnableDecryption(); + client_->SetFilter(filter); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); ConnectExpectFailOneSide(TlsAgent::SERVER); @@ -133,49 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase, TEST_P(TlsSkipTest, SkipCertificateRsa) { EnableOnlyStaticRsaCiphers(); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipCertificateEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); } TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { Reset(TlsAgent::kServerEcdsa256); - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)}); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } @@ -183,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) { TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { Reset(TlsAgent::kServerEcdsa256); auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate)); + chain->Add(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeServerKeyExchange)); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } TEST_P(Tls13SkipTest, SkipEncryptedExtensions) { ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( - kTlsHandshakeEncryptedExtensions), + server_, kTlsHandshakeEncryptedExtensions), SSL_ERROR_RX_UNEXPECTED_CERTIFICATE); } TEST_P(Tls13SkipTest, SkipServerCertificate) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipServerCertificateVerify) { - ServerSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + server_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } TEST_P(Tls13SkipTest, SkipClientCertificate) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificate), + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); } TEST_P(Tls13SkipTest, SkipClientCertificateVerify) { client_->SetupClientAuth(); server_->RequestClientAuth(true); client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage); - ClientSkipTest( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify), - SSL_ERROR_RX_UNEXPECTED_FINISHED); + ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>( + client_, kTlsHandshakeCertificateVerify), + SSL_ERROR_RX_UNEXPECTED_FINISHED); } INSTANTIATE_TEST_CASE_P( diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc index 8db1f30e1..e5fccc12b 100644 --- a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -48,10 +48,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) { // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); - auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>( - kTlsHandshakeClientKeyExchange, + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeClientKeyExchange, DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(i1); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -59,8 +58,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { // This test is stream so we can catch the bad_record_mac alert. TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -69,9 +67,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { // ConnectStaticRSABogusPMSVersionDetect. TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); - server_->DisableRollbackDetection(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } @@ -79,10 +76,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>( - kTlsHandshakeClientKeyExchange, + MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>( + client_, kTlsHandshakeClientKeyExchange, DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(inspect); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -91,8 +87,7 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); ConnectExpectAlert(server_, kTlsAlertBadRecordMac); } @@ -100,10 +95,9 @@ TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); EnableExtendedMasterSecret(); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); - server_->DisableRollbackDetection(); + MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } -} // namespace nspr_test +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc new file mode 100644 index 000000000..f5ccf096b --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc @@ -0,0 +1,363 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> +#include <vector> +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class Tls13CompatTest : public TlsConnectStreamTls13 { + protected: + void EnableCompatMode() { + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + } + + void InstallFilters() { + EnsureTlsSetup(); + client_recorders_.Install(client_); + server_recorders_.Install(server_); + } + + void CheckRecordVersions() { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, + client_recorders_.records_->record(0).header.version()); + CheckRecordsAreTls12("client", client_recorders_.records_, 1); + CheckRecordsAreTls12("server", server_recorders_.records_, 0); + } + + void CheckHelloVersions() { + uint32_t ver; + ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + } + + void CheckForCCS(bool expected_client, bool expected_server) { + client_recorders_.CheckForCCS(expected_client); + server_recorders_.CheckForCCS(expected_server); + } + + void CheckForRegularHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(0U, client_recorders_.session_id_length()); + EXPECT_EQ(0U, server_recorders_.session_id_length()); + CheckForCCS(false, false); + } + + void CheckForCompatHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(32U, client_recorders_.session_id_length()); + EXPECT_EQ(32U, server_recorders_.session_id_length()); + CheckForCCS(true, true); + } + + private: + struct Recorders { + Recorders() : records_(nullptr), hello_(nullptr) {} + + uint8_t session_id_length() const { + // session_id is always after version (2) and random (32). + uint32_t len = 0; + EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len)); + return static_cast<uint8_t>(len); + } + + void CheckForCCS(bool expected) const { + EXPECT_LT(0U, records_->count()); + for (size_t i = 0; i < records_->count(); ++i) { + // Only the second record can be a CCS. + bool expected_match = expected && (i == 1); + EXPECT_EQ(expected_match, + kTlsChangeCipherSpecType == + records_->record(i).header.content_type()); + } + } + + void Install(std::shared_ptr<TlsAgent>& agent) { + if (records_ && records_->agent() == agent) { + // Avoid replacing the filters if they are already installed on this + // agent. This ensures that InstallFilters() can be used after + // MakeNewServer() without losing state on the client filters. + return; + } + records_.reset(new TlsRecordRecorder(agent)); + hello_.reset(new TlsHandshakeRecorder( + agent, std::set<uint8_t>( + {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))); + agent->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, hello_}))); + } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsHandshakeRecorder> hello_; + }; + + void CheckRecordsAreTls12(const std::string& agent, + const std::shared_ptr<TlsRecordRecorder>& records, + size_t start) { + EXPECT_LE(start, records->count()); + for (size_t i = start; i < records->count(); ++i) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, + records->record(i).header.version()) + << agent << ": record " << i << " has wrong version"; + } + } + + Recorders client_recorders_; + Recorders server_recorders_; +}; + +TEST_F(Tls13CompatTest, Disabled) { + InstallFilters(); + Connect(); + CheckForRegularHandshake(); +} + +TEST_F(Tls13CompatTest, Enabled) { + EnableCompatMode(); + InstallFilters(); + Connect(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + CheckForCCS(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest. The server sends CCS immediately. + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckForCCS(false, true); + + Handshake(); + CheckConnected(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledStatelessHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + + // The server should send CCS before HRR. + CheckForCCS(false, true); + + // A new server should complete the handshake, and not send CCS. + MakeNewServer(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + Handshake(); + CheckConnected(); + CheckRecordVersions(); + CheckHelloVersions(); + CheckForCCS(true, false); +} + +TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + // With 0-RTT, the client sends CCS immediately. With HRR, the server sends + // CCS immediately too. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + CheckForCCS(true, true); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + CheckForCompatHandshake(); +} + +static const uint8_t kCannedCcs[] = { + kTlsChangeCipherSpecType, + SSL_LIBRARY_VERSION_TLS_1_2 >> 8, + SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, + 0, + 1, // length + 1 // change_cipher_spec_choice +}; + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// If we negotiate 1.2, we abort. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + auto client_records = MakeTlsFilter<TlsRecordRecorder>(client_); + auto server_records = MakeTlsFilter<TlsRecordRecorder>(server_); + Connect(); + + ASSERT_EQ(2U, client_records->count()); // CH, Fin + EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type()); + EXPECT_EQ(kTlsApplicationDataType, + client_records->record(1).header.content_type()); + + ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack + EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kTlsApplicationDataType, + server_records->record(i).header.content_type()); + } +} + +class AddSessionIdFilter : public TlsHandshakeFilter { + public: + AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client) + : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + uint32_t session_id_len = 0; + EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); + uint8_t session_id[33] = {32}; // 32 for length, the rest zero. + *output = input; + output->Splice(session_id, sizeof(session_id), 34, 1); + return CHANGE; + } +}; + +// Adding a session ID to a DTLS ClientHello should not trigger compatibility +// mode. It should be ignored instead. +TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { + EnsureTlsSetup(); + auto client_records = std::make_shared<TlsRecordRecorder>(client_); + client_->SetFilter( + std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit( + {client_records, std::make_shared<AddSessionIdFilter>(client_)}))); + auto server_hello = + std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello); + auto server_records = std::make_shared<TlsRecordRecorder>(server_); + server_->SetFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({server_records, server_hello}))); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // The client will consume the ServerHello, but discard everything else + // because it doesn't decrypt. And don't wait around for the client to ACK. + client_->Handshake(); + + ASSERT_EQ(1U, client_records->count()); + EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type()); + + ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin + EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kTlsApplicationDataType, + server_records->record(i).header.content_type()); + } + + uint32_t session_id_len = 0; + EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); +} + +TEST_F(Tls13CompatTest, ConnectWith12ThenAttemptToResume13CompatMode) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); + + Reset(); + ExpectResumption(RESUME_NONE); + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + EnableCompatMode(); + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc index 110e3e0b6..100595732 100644 --- a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -23,7 +23,8 @@ namespace nss_test { // Replaces the client hello with an SSLv2 version once. class SSLv2ClientHelloFilter : public PacketFilter { public: - SSLv2ClientHelloFilter(std::shared_ptr<TlsAgent>& client, uint16_t version) + SSLv2ClientHelloFilter(const std::shared_ptr<TlsAgent>& client, + uint16_t version) : replaced_(false), client_(client), version_(version), @@ -147,17 +148,9 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version) : TlsConnectTestBase(variant, version), filter_(nullptr) {} - void SetUp() { + void SetUp() override { TlsConnectTestBase::SetUp(); - filter_ = std::make_shared<SSLv2ClientHelloFilter>(client_, version_); - client_->SetPacketFilter(filter_); - } - - void RequireSafeRenegotiation() { - server_->EnsureTlsSetup(); - SECStatus rv = - SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); - EXPECT_EQ(rv, SECSuccess); + filter_ = MakeTlsFilter<SSLv2ClientHelloFilter>(client_, version_); } void SetExpectedVersion(uint16_t version) { @@ -319,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) { // Connection must fail if we require safe renegotiation but the client doesn't // include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { - RequireSafeRenegotiation(); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code()); @@ -328,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { // Connection must succeed when requiring safe renegotiation and the client // includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) { - RequireSafeRenegotiation(); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, TLS_EMPTY_RENEGOTIATION_INFO_SCSV}; SetAvailableCipherSuites(cipher_suites); diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc index 379a67e35..4e9099561 100644 --- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -56,18 +56,15 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { // two validate that we can also detect fallback using the // SSL_SetDowngradeCheckVersion() API. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - SSL_LIBRARY_VERSION_TLS_1_1)); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, + SSL_LIBRARY_VERSION_TLS_1_1); ConnectExpectFail(); ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); } /* Attempt to negotiate the bogus DTLS 1.1 version. */ TEST_F(DtlsConnectTest, TestDtlsVersion11) { - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - ((~0x0101) & 0xffff))); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, ((~0x0101) & 0xffff)); ConnectExpectFail(); // It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is // what is returned here, but this is deliberate in ssl3_HandleAlert(). @@ -78,9 +75,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) { // Disabled as long as we have draft version. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { EnsureTlsSetup(); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - SSL_LIBRARY_VERSION_TLS_1_2)); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, + SSL_LIBRARY_VERSION_TLS_1_2); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -92,9 +88,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { // TLS 1.1 clients do not check the random values, so we should // instead get a handshake failure alert from the server. TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - SSL_LIBRARY_VERSION_TLS_1_0)); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, + SSL_LIBRARY_VERSION_TLS_1_0); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, SSL_LIBRARY_VERSION_TLS_1_1); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, @@ -128,12 +123,12 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) { #endif TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) { - client_->SetFallbackSCSVEnabled(true); + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); Connect(); } TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) { - client_->SetFallbackSCSVEnabled(true); + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); server_->SetVersionRange(version_, version_ + 1); ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT); @@ -155,107 +150,10 @@ TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) { EXPECT_EQ(SECFailure, rv); } -TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { - if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { - return; - } - // Set the client so it will accept any version from 1.0 - // to |version_|. - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, - SSL_LIBRARY_VERSION_TLS_1_0); - // Reset version so that the checks succeed. - uint16_t test_version = version_; - version_ = SSL_LIBRARY_VERSION_TLS_1_0; - Connect(); - - // Now renegotiate, with the server being set to do - // |version_|. - client_->PrepareForRenegotiate(); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); - // Reset version and cipher suite so that the preinfo callback - // doesn't fail. - server_->ResetPreliminaryInfo(); - server_->StartRenegotiate(); - - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - ExpectAlert(server_, kTlsAlertUnexpectedMessage); - } else { - ExpectAlert(client_, kTlsAlertIllegalParameter); - } - - Handshake(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - // In TLS 1.3, the server detects this problem. - client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); - server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); - } else { - client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); - } -} - -TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { - if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { - return; - } - // Set the client so it will accept any version from 1.0 - // to |version_|. - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, - SSL_LIBRARY_VERSION_TLS_1_0); - // Reset version so that the checks succeed. - uint16_t test_version = version_; - version_ = SSL_LIBRARY_VERSION_TLS_1_0; - Connect(); - - // Now renegotiate, with the server being set to do - // |version_|. - server_->PrepareForRenegotiate(); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); - // Reset version and cipher suite so that the preinfo callback - // doesn't fail. - server_->ResetPreliminaryInfo(); - client_->StartRenegotiate(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - ExpectAlert(server_, kTlsAlertUnexpectedMessage); - } else { - ExpectAlert(client_, kTlsAlertIllegalParameter); - } - Handshake(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - // In TLS 1.3, the server detects this problem. - client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); - server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); - } else { - client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); - } -} - -TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { - EnsureTlsSetup(); - ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - Connect(); - SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); - EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); -} - -TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { - EnsureTlsSetup(); - ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - Connect(); - SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); - EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); -} - TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { EnsureTlsSetup(); client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello. static const uint8_t kWarningAlert[] = {kTlsAlertWarning, kTlsAlertUnrecognizedName}; @@ -274,12 +172,10 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 { client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version); - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - overwritten_client_version)); - auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerHello); - server_->SetPacketFilter(capture); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, + overwritten_client_version); + auto capture = + MakeTlsFilter<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello); ConnectExpectAlert(server_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); @@ -311,23 +207,21 @@ TEST_F(Tls13NoSupportedVersions, // Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This // causes a bad MAC error when we read EncryptedExtensions. TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { - client_->SetPacketFilter( - std::make_shared<TlsInspectorClientHelloVersionSetter>( - SSL_LIBRARY_VERSION_TLS_1_3 + 1)); - auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerHello); - server_->SetPacketFilter(capture); + MakeTlsFilter<TlsClientHelloVersionSetter>(client_, + SSL_LIBRARY_VERSION_TLS_1_3 + 1); + auto capture = MakeTlsFilter<TlsExtensionCapture>( + server_, ssl_tls13_supported_versions_xtn); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - const DataBuffer& server_hello = capture->buffer(); - ASSERT_GT(server_hello.len(), 2U); - uint32_t ver; - ASSERT_TRUE(server_hello.Read(0, 2, &ver)); + + ASSERT_EQ(2U, capture->extension().len()); + uint32_t version = 0; + ASSERT_TRUE(capture->extension().Read(0, 2, &version)); // This way we don't need to change with new draft version. - ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver); + ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version); } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc index eda96831c..7f3c4a896 100644 --- a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc @@ -189,12 +189,12 @@ class TestPolicyVersionRange } } - void SetUp() { - SetPolicy(policy_.range()); + void SetUp() override { TlsConnectTestBase::SetUp(); + SetPolicy(policy_.range()); } - void TearDown() { + void TearDown() override { TlsConnectTestBase::TearDown(); saved_version_policy_.RestoreOriginalPolicy(); } diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc index b9f0c672e..728217851 100644 --- a/security/nss/gtests/ssl_gtest/test_io.cc +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -25,10 +25,6 @@ namespace nss_test { if (g_ssl_gtest_verbose) LOG(a); \ } while (false) -void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) { - filter_ = filter; -} - ScopedPRFileDesc DummyPrSocket::CreateFD() { static PRDescIdentity test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); @@ -98,8 +94,13 @@ int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, } int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + if (write_error_) { + PR_SetError(write_error_, 0); + return -1; + } + auto peer = peer_.lock(); - if (!peer || !writeable_) { + if (!peer) { PR_SetError(PR_IO_ERROR, 0); return -1; } @@ -109,7 +110,7 @@ int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { DataBuffer filtered; PacketFilter::Action action = PacketFilter::KEEP; if (filter_) { - action = filter_->Filter(packet, &filtered); + action = filter_->Process(packet, &filtered); } switch (action) { case PacketFilter::CHANGE: diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h index ac2497222..dbeb6b9d4 100644 --- a/security/nss/gtests/ssl_gtest/test_io.h +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -33,9 +33,18 @@ class PacketFilter { CHANGE, // change the packet to a different value DROP // drop the packet }; - + PacketFilter(bool enabled = true) : enabled_(enabled) {} virtual ~PacketFilter() {} + virtual Action Process(const DataBuffer& input, DataBuffer* output) { + if (!enabled_) { + return KEEP; + } + return Filter(input, output); + } + void Enable() { enabled_ = true; } + void Disable() { enabled_ = false; } + // The packet filter takes input and has the option of mutating it. // // A filter that modifies the data places the modified data in *output and @@ -43,6 +52,9 @@ class PacketFilter { // case the value in *output is ignored. A Filter can return DROP, in which // case the packet is dropped (and *output is ignored). virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; + + private: + bool enabled_; }; class DummyPrSocket : public DummyIOLayerMethods { @@ -53,7 +65,7 @@ class DummyPrSocket : public DummyIOLayerMethods { peer_(), input_(), filter_(nullptr), - writeable_(true) {} + write_error_(0) {} virtual ~DummyPrSocket() {} // Create a file descriptor that will reference this object. The fd must not @@ -62,7 +74,9 @@ class DummyPrSocket : public DummyIOLayerMethods { std::weak_ptr<DummyPrSocket>& peer() { return peer_; } void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; } - void SetPacketFilter(std::shared_ptr<PacketFilter> filter); + void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) { + filter_ = filter; + } // Drops peer, packet filter and any outstanding packets. void Reset(); @@ -71,7 +85,7 @@ class DummyPrSocket : public DummyIOLayerMethods { int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, PRIntervalTime to) override; int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; - void CloseWrites() { writeable_ = false; } + void SetWriteError(PRErrorCode code) { write_error_ = code; } SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } @@ -98,7 +112,7 @@ class DummyPrSocket : public DummyIOLayerMethods { std::weak_ptr<DummyPrSocket> peer_; std::queue<Packet> input_; std::shared_ptr<PacketFilter> filter_; - bool writeable_; + PRErrorCode write_error_; }; // Marker interface. @@ -164,6 +178,6 @@ class Poller { timers_; }; -} // end of namespace +} // namespace nss_test #endif diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index d6d91f7f7..2f71caedb 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -10,7 +10,9 @@ #include "pk11func.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" +#include "tls_filter.h" #include "tls_parser.h" extern "C" { @@ -35,7 +37,6 @@ const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt const std::string TlsAgent::kServerRsaSign = "rsa_sign"; const std::string TlsAgent::kServerRsaPss = "rsa_pss"; const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; -const std::string TlsAgent::kServerRsaChain = "rsa_chain"; const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; @@ -66,6 +67,7 @@ TlsAgent::TlsAgent(const std::string& name, Role role, expected_sent_alert_(kTlsAlertCloseNotify), expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), + resumption_callback_called_(false), error_code_(0), send_ctr_(0), recv_ctr_(0), @@ -73,8 +75,8 @@ TlsAgent::TlsAgent(const std::string& name, Role role, handshake_callback_(), auth_certificate_callback_(), sni_callback_(), - expect_short_headers_(false), - skip_version_checks_(false) { + skip_version_checks_(false), + resumption_token_() { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); @@ -93,11 +95,11 @@ TlsAgent::~TlsAgent() { // Add failures manually, if any, so we don't throw in a destructor. if (expected_received_alert_ != kTlsAlertCloseNotify || expected_received_alert_level_ != kTlsAlertWarning) { - ADD_FAILURE() << "Wrong expected_received_alert status"; + ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); } if (expected_sent_alert_ != kTlsAlertCloseNotify || expected_sent_alert_level_ != kTlsAlertWarning) { - ADD_FAILURE() << "Wrong expected_sent_alert status"; + ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); } } @@ -183,6 +185,10 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { ScopedCERTCertList anchors(CERT_NewCertList()); rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; + + rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; } else { rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); @@ -208,6 +214,29 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { return true; } +bool TlsAgent::MaybeSetResumptionToken() { + if (!resumption_token_.empty()) { + SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), + resumption_token_.size()); + + // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR + // if the resumption token was bad (expired/malformed/etc.). + if (expect_resumption_) { + // Only in case we expect resumption this has to be successful. We might + // not expect resumption due to some reason but the token is totally fine. + EXPECT_EQ(SECSuccess, rv); + } + if (rv != SECSuccess) { + EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError()); + resumption_token_.clear(); + EXPECT_FALSE(expect_resumption_); + if (expect_resumption_) return false; + } + } + + return true; +} + void TlsAgent::SetupClientAuth() { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); @@ -258,13 +287,10 @@ void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { } void TlsAgent::RequestClientAuth(bool requireAuth) { - EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE, - requireAuth ? PR_TRUE : PR_FALSE)); + SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); + SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( ssl_fd(), &TlsAgent::ClientAuthenticated, this)); @@ -376,42 +402,8 @@ void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { EXPECT_EQ(SECSuccess, rv); } -void TlsAgent::SetSessionTicketsEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, - en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetSessionCacheEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); -} - void TlsAgent::Set0RttEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = - SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetFallbackSCSVEnabled(bool en) { - EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV, - en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetShortHeadersEnabled() { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd()); - EXPECT_EQ(SECSuccess, rv); + SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); } void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { @@ -424,6 +416,27 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { } } +SECStatus ResumptionTokenCallback(PRFileDesc* fd, + const PRUint8* resumptionToken, + unsigned int len, void* ctx) { + EXPECT_NE(nullptr, resumptionToken); + if (!resumptionToken) { + return SECFailure; + } + + std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len); + reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token); + reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled(); + return SECSuccess; +} + +void TlsAgent::SetResumptionTokenCallback() { + EXPECT_TRUE(EnsureTlsSetup()); + SECStatus rv = + SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this); + EXPECT_EQ(SECSuccess, rv); +} + void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { *minver = vrange_.min; *maxver = vrange_.max; @@ -437,8 +450,6 @@ void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } -void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } - void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, @@ -517,6 +528,12 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, } } +void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { + if (kea_group != ssl_grp_ffdhe_custom) { + EXPECT_EQ(kea_group, info_.originalKeaGroup); + } +} + void TlsAgent::CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { EXPECT_EQ(STATE_CONNECTED, state_); @@ -569,8 +586,7 @@ void TlsAgent::EnableFalseStart() { falsestart_enabled_ = true; EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( ssl_fd(), CanFalseStartCallback, this)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE)); + SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } @@ -578,7 +594,7 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE)); + SetOption(SSL_ENABLE_ALPN, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } @@ -622,12 +638,8 @@ void TlsAgent::CheckErrorCode(int32_t expected) const { } static uint8_t GetExpectedAlertLevel(uint8_t alert) { - switch (alert) { - case kTlsAlertCloseNotify: - case kTlsAlertEndOfEarlyData: - return kTlsAlertWarning; - default: - break; + if (alert == kTlsAlertCloseNotify) { + return kTlsAlertWarning; } return kTlsAlertFatal; } @@ -730,6 +742,50 @@ void TlsAgent::ResetPreliminaryInfo() { expected_cipher_suite_ = 0; } +void TlsAgent::ValidateCipherSpecs() { + PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); + // We use one ciphersuite in each direction. + PRInt32 expected = 2; + if (variant_ == ssl_variant_datagram) { + // For DTLS 1.3, the client retains the cipher spec for early data and the + // handshake so that it can retransmit EndOfEarlyData and its final flight. + // It also retains the handshake read cipher spec so that it can read ACKs + // from the server. The server retains the handshake read cipher spec so it + // can read the client's retransmitted Finished. + if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + if (role_ == CLIENT) { + expected = info_.earlyDataAccepted ? 5 : 4; + } else { + expected = 3; + } + } else { + // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec + // until the holddown timer runs down. + if (expect_resumption_) { + if (role_ == CLIENT) { + expected = 3; + } + } else { + if (role_ == SERVER) { + expected = 3; + } + } + } + } + // This function will be run before the handshake completes if false start is + // enabled. In that case, the client will still be reading cleartext, but + // will have a spec prepared for reading ciphertext. With DTLS, the client + // will also have a spec retained for retransmission of handshake messages. + if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { + EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); + expected = (variant_ == ssl_variant_datagram) ? 4 : 3; + } + EXPECT_EQ(expected, cipherSpecs); + if (expected != cipherSpecs) { + SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); + } +} + void TlsAgent::Connected() { if (state_ == STATE_CONNECTED) { return; @@ -743,6 +799,8 @@ void TlsAgent::Connected() { EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); + EXPECT_EQ(expect_resumption_, info_.resumed == PR_TRUE); + // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. @@ -753,32 +811,13 @@ void TlsAgent::Connected() { EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo_), csinfo_.length); - if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); - // We use one ciphersuite in each direction, plus one that's kept around - // by DTLS for retransmission. - PRInt32 expected = - ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; - EXPECT_EQ(expected, cipherSuites); - if (expected != cipherSuites) { - SSLInt_PrintTls13CipherSpecs(ssl_fd()); - } - } + ValidateCipherSpecs(); - PRBool short_headers; - rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers); - EXPECT_EQ(SECSuccess, rv); - EXPECT_EQ((PRBool)expect_short_headers_, short_headers); SetState(STATE_CONNECTED); } void TlsAgent::EnableExtendedMasterSecret() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = - SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); - - ASSERT_EQ(SECSuccess, rv); + SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); } void TlsAgent::CheckExtendedMasterSecret(bool expected) { @@ -801,21 +840,6 @@ void TlsAgent::CheckSecretsDestroyed() { ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } -void TlsAgent::DisableRollbackDetection() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE); - - ASSERT_EQ(SECSuccess, rv); -} - -void TlsAgent::EnableCompression() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE); - ASSERT_EQ(SECSuccess, rv); -} - void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { ASSERT_TRUE(EnsureTlsSetup()); @@ -883,6 +907,14 @@ void TlsAgent::SendDirect(const DataBuffer& buf) { } } +void TlsAgent::SendRecordDirect(const TlsRecord& record) { + DataBuffer buf; + + auto rv = record.header.Write(&buf, 0, record.buffer); + EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); + SendDirect(buf); +} + static bool ErrorIsNonFatal(PRErrorCode code) { return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ; } @@ -918,6 +950,27 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { } } +bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint16_t wireVersion, uint64_t seq, + uint8_t ct, const DataBuffer& buf) { + LOGV("Writing " << buf.len() << " bytes"); + // Ensure we are a TLS 1.3 cipher agent. + EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); + TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq); + DataBuffer padded = buf; + padded.Write(padded.len(), ct, 1); + DataBuffer ciphertext; + if (!spec->Protect(header, padded, &ciphertext)) { + return false; + } + + DataBuffer record; + auto rv = header.Write(&record, 0, ciphertext); + EXPECT_EQ(header.header_length() + ciphertext.len(), rv); + SendDirect(record); + return true; +} + void TlsAgent::ReadBytes(size_t amount) { uint8_t block[16384]; @@ -951,23 +1004,20 @@ void TlsAgent::ReadBytes(size_t amount) { void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } -void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, - mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); +void TlsAgent::SetOption(int32_t option, int value) { + ASSERT_TRUE(EnsureTlsSetup()); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); +} - rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, - mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); +void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { + SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); + SetOption(SSL_ENABLE_SESSION_TICKETS, + mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); } void TlsAgent::DisableECDHEServerKeyReuse() { - ASSERT_TRUE(EnsureTlsSetup()); ASSERT_EQ(TlsAgent::SERVER, role_); - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); + SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); } static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index 4bccb9a84..6cd6d5073 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -14,7 +14,6 @@ #include <iostream> #include "test_io.h" -#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -37,7 +36,10 @@ enum SessionResumptionMode { RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET }; +class PacketFilter; class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; const extern std::vector<SSLNamedGroup> kAllDHEGroups; const extern std::vector<SSLNamedGroup> kECDHEGroups; @@ -66,7 +68,6 @@ class TlsAgent : public PollTarget { static const std::string kServerRsaSign; static const std::string kServerRsaPss; static const std::string kServerRsaDecrypt; - static const std::string kServerRsaChain; // A cert that requires a chain. static const std::string kServerEcdsa256; static const std::string kServerEcdsa384; static const std::string kServerEcdsa521; @@ -81,20 +82,15 @@ class TlsAgent : public PollTarget { adapter_->SetPeer(peer->adapter_); } - void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) { - filter->SetAgent(this); + void SetFilter(std::shared_ptr<PacketFilter> filter) { adapter_->SetPacketFilter(filter); } - - void SetPacketFilter(std::shared_ptr<PacketFilter> filter) { - adapter_->SetPacketFilter(filter); - } - - void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, size_t kea_size = 0) const; + void CheckOriginalKEA(SSLNamedGroup kea_group) const; void CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; @@ -121,12 +117,10 @@ class TlsAgent : public PollTarget { void SetupClientAuth(); void RequestClientAuth(bool requireAuth); + void SetOption(int32_t option, int value); void ConfigureSessionCache(SessionResumptionMode mode); - void SetSessionTicketsEnabled(bool en); - void SetSessionCacheEnabled(bool en); void Set0RttEnabled(bool en); void SetFallbackSCSVEnabled(bool en); - void SetShortHeadersEnabled(); void SetVersionRange(uint16_t minver, uint16_t maxver); void GetVersionRange(uint16_t* minver, uint16_t* maxver); void CheckPreliminaryInfo(); @@ -136,7 +130,6 @@ class TlsAgent : public PollTarget { void ExpectReadWriteError(); void EnableFalseStart(); void ExpectResumption(); - void ExpectShortHeaders(); void SkipVersionChecks(); void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); void EnableAlpn(const uint8_t* val, size_t len); @@ -149,27 +142,49 @@ class TlsAgent : public PollTarget { // Send data on the socket, encrypting it. void SendData(size_t bytes, size_t blocksize = 1024); void SendBuffer(const DataBuffer& buf); + bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint16_t wireVersion, uint64_t seq, uint8_t ct, + const DataBuffer& buf); // Send data directly to the underlying socket, skipping the TLS layer. void SendDirect(const DataBuffer& buf); + void SendRecordDirect(const TlsRecord& record); void ReadBytes(size_t max = 16384U); void ResetSentBytes(); // Hack to test drops. void EnableExtendedMasterSecret(); void CheckExtendedMasterSecret(bool expected); void CheckEarlyDataAccepted(bool expected); - void DisableRollbackDetection(); - void EnableCompression(); void SetDowngradeCheckVersion(uint16_t version); void CheckSecretsDestroyed(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector<uint8_t>& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector<uint8_t>& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } const std::string& name() const { return name_; } Role role() const { return role_; } std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + SSLProtocolVariant variant() const { return variant_; } + State state() const { return state_; } const CERTCertificate* peer_cert() const { @@ -253,6 +268,7 @@ class TlsAgent : public PollTarget { const static char* states[]; void SetState(State state); + void ValidateCipherSpecs(); // Dummy auth certificate hook. static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, @@ -378,6 +394,7 @@ class TlsAgent : public PollTarget { uint8_t expected_sent_alert_; uint8_t expected_sent_alert_level_; bool handshake_callback_called_; + bool resumption_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; SSLVersionRange vrange_; @@ -388,8 +405,8 @@ class TlsAgent : public PollTarget { HandshakeCallbackFunction handshake_callback_; AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; - bool expect_short_headers_; bool skip_version_checks_; + std::vector<uint8_t> resumption_token_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -440,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - std::unique_ptr<TlsAgent> agent_; + std::shared_ptr<TlsAgent> agent_; TlsAgent::Role role_; SSLProtocolVariant variant_; uint16_t version_; diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index c8de5a1fe..8567b392f 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -5,6 +5,7 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" +#include "sslexp.h" extern "C" { #include "libssl_internals.h" } @@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: @@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), @@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares( EXPECT_EQ(shares.len(), i); } +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; + EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; + EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; +} + void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); @@ -177,7 +197,7 @@ void TlsConnectTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); - SSLInt_SetMaxEarlyDataSize(1024); + SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); } @@ -209,7 +229,9 @@ void TlsConnectTestBase::Reset() { void TlsConnectTestBase::Reset(const std::string& server_name, const std::string& client_name) { + auto token = client_->GetResumptionToken(); client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_)); + client_->SetResumptionToken(token); server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_)); if (skip_version_checks_) { client_->SkipVersionChecks(); @@ -219,12 +241,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name, Init(); } -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { @@ -254,10 +291,16 @@ void TlsConnectTestBase::EnableExtendedMasterSecret() { void TlsConnectTestBase::Connect() { server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + client_->MaybeSetResumptionToken(); Handshake(); CheckConnected(); } +void TlsConnectTestBase::StartConnect() { + server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); + client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); +} + void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); @@ -274,6 +317,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected @@ -314,10 +370,12 @@ void TlsConnectTestBase::CheckConnected() { void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { - client_->CheckKEA(kea_type, kea_group); - server_->CheckKEA(kea_type, kea_group); - client_->CheckAuthType(auth_type, sig_scheme); + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, @@ -346,13 +404,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, break; case ssl_auth_rsa_sign: if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) { - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; } else { scheme = ssl_sig_rsa_pkcs1_sha256; } break; case ssl_auth_rsa_pss: - scheme = ssl_sig_rsa_pss_sha256; + scheme = ssl_sig_rsa_pss_rsae_sha256; break; case ssl_auth_ecdsa: scheme = ssl_sig_ecdsa_secp256r1_sha256; @@ -372,9 +430,19 @@ void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + void TlsConnectTestBase::ConnectExpectFail() { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); @@ -395,8 +463,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -455,29 +522,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() { } } +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey privKey; - ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, - &privKey)); - - ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); - ASSERT_TRUE(pubKey); - - EXPECT_EQ(SECSuccess, - SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); + ConfigureSelfEncrypt(); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); - int resume_count = expected ? 1 : 0; - int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); @@ -490,7 +561,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check that the last two session ids match. - ASSERT_EQ(2U, session_ids_.size()); + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { @@ -540,31 +611,28 @@ void TlsConnectTestBase::CheckSrtp() const { server_->CheckSrtp(); } -void TlsConnectTestBase::SendReceive() { - client_->SendData(50); - server_->SendData(50); - Receive(50); +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { + // If we don't do this, then all 0-RTT attempts will be rejected. + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); } // Do a first connection so we can do resumption @@ -584,10 +652,6 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); - if (expect_writable && expect_readable) { - ExpectAlert(client_, kTlsAlertEndOfEarlyData); - } - client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,7 +663,7 @@ void TlsConnectTestBase::ZeroRttSendReceive( } else { EXPECT_EQ(SECFailure, rv); } - server_->Handshake(); // Consume ClientHello, EE, Finished. + server_->Handshake(); // Consume ClientHello std::vector<uint8_t> buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read @@ -608,7 +672,8 @@ void TlsConnectTestBase::ZeroRttSendReceive( EXPECT_EQ(k0RttDataLen, rv); } else { EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()) + << "Unexpected error: " << PORT_ErrorToName(PORT_GetError()); } // Do a second read. this should fail. @@ -653,6 +718,30 @@ void TlsConnectTestBase::SkipVersionChecks() { server_->SkipVersionChecks(); } +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift == PR_INTERVAL_NO_TIMEOUT) { + return; + } + + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -668,20 +757,29 @@ TlsConnectTls12Plus::TlsConnectTls12Plus() TlsConnectTls13::TlsConnectTls13() : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} +TlsConnectGenericResumption::TlsConnectGenericResumption() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + external_cache_(std::get<2>(GetParam())) {} + +TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + void TlsKeyExchangeTest::EnsureKeyShareSetup() { EnsureTlsSetup(); groups_capture_ = - std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); + std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn); shares_capture_ = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); - shares_capture2_ = - std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true); + std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn); + shares_capture2_ = std::make_shared<TlsExtensionCapture>( + client_, ssl_tls13_key_share_xtn, true); std::vector<std::shared_ptr<PacketFilter>> captures = { groups_capture_, shares_capture_, shares_capture2_}; - client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures)); - capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeHelloRetryRequest); - server_->SetPacketFilter(capture_hrr_); + client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures)); + capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>( + server_, kTlsHandshakeHelloRetryRequest); } void TlsKeyExchangeTest::ConfigNamedGroups( @@ -691,11 +789,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); + std::vector<SSLNamedGroup> groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); @@ -705,10 +807,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + std::vector<SSLNamedGroup> shares; size_t i = 2; while (i < ext.len()) { @@ -724,17 +830,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( void TlsKeyExchangeTest::CheckKEXDetails( const std::vector<SSLNamedGroup>& expected_groups, const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { - std::vector<SSLNamedGroup> groups = - GetGroupDetails(groups_capture_->extension()); + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture_->extension()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); EXPECT_EQ(expected_shares, shares); } else { - EXPECT_EQ(0U, shares_capture_->extension().len()); + EXPECT_FALSE(shares_capture_->captured()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); @@ -756,8 +860,6 @@ void TlsKeyExchangeTest::CheckKEXDetails( EXPECT_NE(expected_share2, it); } std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture2_->extension()); - EXPECT_EQ(expected_shares2, shares); + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index 73e8dc81a..7dffe7f8a 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test { TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); - void SetUp(); - void TearDown(); + virtual void SetUp(); + virtual void TearDown(); // Initialize client and server. void Init(); @@ -55,13 +55,17 @@ class TlsConnectTestBase : public ::testing::Test { // Clear the server session cache. void ClearServerCache(); // Make sure TLS is configured for a connection. - void EnsureTlsSetup(); + virtual void EnsureTlsSetup(); // Reset and keep the same certificate names void Reset(); // Reset, and update the certificate names on both peers void Reset(const std::string& server_name, const std::string& client_name = "client"); + // Replace the server. + void MakeNewServer(); + // Set up + void StartConnect(); // Run the handshake. void Handshake(); // Connect and check that it works. @@ -81,20 +85,28 @@ class TlsConnectTestBase : public ::testing::Test { void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; // This version assumes defaults. void CheckKeys() const; + // Check that keys on resumed sessions. + void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme); void CheckGroups(const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group); void CheckShares(const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group); + void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; void ConfigureVersion(uint16_t version); void SetExpectedVersion(uint16_t version); // Expect resumption of a particular type. - void ExpectResumption(SessionResumptionMode expected); + void ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumed = 1); void DisableAllCiphers(); void EnableOnlyStaticRsaCiphers(); void EnableOnlyDheCiphers(); void EnableSomeEcdhCiphers(); void EnableExtendedMasterSecret(); + void ConfigureSelfEncrypt(); void ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server); void EnableAlpn(); @@ -103,7 +115,7 @@ class TlsConnectTestBase : public ::testing::Test { void CheckAlpn(const std::string& val); void EnableSrtp(); void CheckSrtp() const; - void SendReceive(); + void SendReceive(size_t total = 50); void SetupForZeroRtt(); void SetupForResume(); void ZeroRttSendReceive( @@ -115,6 +127,9 @@ class TlsConnectTestBase : public ::testing::Test { void DisableECDHEServerKeyReuse(); void SkipVersionChecks(); + // Move the DTLS timers for both endpoints to pop the next timer. + void ShiftDtlsTimers(); + protected: SSLProtocolVariant variant_; std::shared_ptr<TlsAgent> client_; @@ -123,6 +138,7 @@ class TlsConnectTestBase : public ::testing::Test { std::unique_ptr<TlsAgent> server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; + uint8_t expected_resumptions_; std::vector<std::vector<uint8_t>> session_ids_; // A simple value of "a", "b". Note that the preferred value of "a" is placed @@ -192,6 +208,52 @@ class TlsConnectGeneric : public TlsConnectTestBase, TlsConnectGeneric(); }; +class TlsConnectGenericResumption + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, bool>> { + private: + bool external_cache_; + + public: + TlsConnectGenericResumption(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + // Enable external resumption token cache. + if (external_cache_) { + client_->SetResumptionTokenCallback(); + } + } + + bool use_external_cache() const { return external_cache_; } +}; + +class TlsConnectTls13ResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13ResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +class TlsConnectGenericResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGenericResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + // A Pre TLS 1.2 generic test. class TlsConnectPre12 : public TlsConnectTestBase, public ::testing::WithParamInterface< @@ -244,6 +306,11 @@ class TlsConnectDatagram13 : public TlsConnectTestBase { : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; +class TlsConnectDatagramPre13 : public TlsConnectDatagram { + public: + TlsConnectDatagramPre13() {} +}; + // A variant that is used only with Pre13. class TlsConnectGenericPre13 : public TlsConnectGeneric {}; @@ -252,12 +319,14 @@ class TlsKeyExchangeTest : public TlsConnectGeneric { std::shared_ptr<TlsExtensionCapture> groups_capture_; std::shared_ptr<TlsExtensionCapture> shares_capture_; std::shared_ptr<TlsExtensionCapture> shares_capture2_; - std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_; + std::shared_ptr<TlsHandshakeRecorder> capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); - std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext); - std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext); + std::vector<SSLNamedGroup> GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + std::vector<SSLNamedGroup> GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, const std::vector<SSLNamedGroup>& expectedShares); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 76d9aaaff..d34b13bcb 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -12,6 +12,7 @@ extern "C" { #include "libssl_internals.h" } +#include <cassert> #include <iostream> #include "gtest_utils.h" #include "tls_agent.h" @@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, PRBool isServer = self->agent()->role() == TlsAgent::SERVER; if (g_ssl_gtest_verbose) { - std::cerr << "Cipher spec changed. Role=" - << (isServer ? "server" : "client") - << " direction=" << (sending ? "send" : "receive") << std::endl; + std::cerr << (isServer ? "server" : "client") << ": " + << (sending ? "send" : "receive") + << " cipher spec changed: " << newSpec->epoch << " (" + << newSpec->phase << ")" << std::endl; + } + if (!sending) { + return; } - if (!sending) return; + self->in_sequence_number_ = 0; + self->out_sequence_number_ = 0; + self->dropped_record_ = false; self->cipher_spec_.reset(new TlsCipherSpec()); - bool ret = - self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec), - SSLInt_CipherSpecToKey(isServer, newSpec), - SSLInt_CipherSpecToIv(isServer, newSpec)); + bool ret = self->cipher_spec_->Init( + SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec), + SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec)); EXPECT_EQ(true, ret); } @@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, TlsRecordHeader header; DataBuffer record; - if (!header.Parse(&parser, &record)) { + if (!header.Parse(in_sequence_number_, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } + // Track the sequence number, which is necessary for stream mode (the + // sequence number is in the header for datagram). + // + // This isn't perfectly robust. If there is a change from an active cipher + // spec to another active cipher spec (KeyUpdate for instance) AND writes + // are consolidated across that change AND packets were dropped from the + // older epoch, we will not correctly re-encrypt records in the old epoch to + // update their sequence numbers. + if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) { + ++in_sequence_number_; + } + if (FilterRecord(header, record, &offset, output) != KEEP) { changed = true; } else { @@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( header.sequence_number()}; PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); + // In stream mode, even if something doesn't change we need to re-encrypt if + // previous packets were dropped. if (action == KEEP) { - return KEEP; + if (header.is_dtls() || !dropped_record_) { + return KEEP; + } + filtered = plaintext; } if (action == DROP) { - std::cerr << "record drop: " << record << std::endl; + std::cerr << "record drop: " << header << ":" << record << std::endl; + dropped_record_ = true; return DROP; } EXPECT_GT(0x10000U, filtered.len()); - std::cerr << "record old: " << plaintext << std::endl; - std::cerr << "record new: " << filtered << std::endl; + if (action != KEEP) { + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; + } + + uint64_t seq_num; + if (header.is_dtls() || !cipher_spec_ || + header.content_type() != kTlsApplicationDataType) { + seq_num = header.sequence_number(); + } else { + seq_num = out_sequence_number_++; + } + TlsRecordHeader out_header = {header.version(), header.content_type(), + seq_num}; DataBuffer ciphertext; - bool rv = Protect(header, inner_content_type, filtered, &ciphertext); + bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); EXPECT_TRUE(rv); if (!rv) { return KEEP; } - *offset = header.Write(output, *offset, ciphertext); + *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } -bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { +bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser, + DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } @@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { } version_ = version; - sequence_number_ = 0; + // If this is DTLS, overwrite the sequence number. if (IsDtls(version)) { uint32_t tmp; if (!parser->Read(&tmp, 4)) { @@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { return false; } sequence_number_ |= static_cast<uint64_t>(tmp); + } else { + sequence_number_ = sequence_number; } return parser->ReadVariable(body, 2); } @@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, return true; } - if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false; + if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) { + return false; + } size_t len = plaintext->len(); while (len > 0 && !plaintext->data()[len - 1]) { @@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, *inner_content_type = plaintext->data()[len - 1]; plaintext->Truncate(len - 1); + if (g_ssl_gtest_verbose) { + std::cerr << "unprotect: " << std::hex << header.sequence_number() + << std::dec << " type=" << static_cast<int>(*inner_content_type) + << " " << *plaintext << std::endl; + } return true; } @@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header, *ciphertext = plaintext; return true; } + if (g_ssl_gtest_verbose) { + std::cerr << "protect: " << header.sequence_number() << std::endl; + } DataBuffer padded = plaintext; padded.Write(padded.len(), inner_content_type, 1); return cipher_spec_->Protect(header, padded, ciphertext); } +bool IsHelloRetry(const DataBuffer& body) { + static const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + return memcmp(body.data() + 2, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)) == 0; +} + +bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, + const DataBuffer& body) { + if (handshake_types_.empty()) { + return true; + } + + uint8_t type = header.handshake_type(); + if (type == kTlsHandshakeServerHello) { + if (IsHelloRetry(body)) { + type = kTlsHandshakeHelloRetryRequest; + } + } + return handshake_types_.count(type) > 0U; +} + PacketFilter::Action TlsHandshakeFilter::FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. - if (record_header.content_type() != kTlsHandshakeType) { + if ((record_header.content_type() != kTlsHandshakeType) && + (record_header.content_type() != kTlsAltHandshakeType)) { return KEEP; } @@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( while (parser.remaining()) { HandshakeHeader header; DataBuffer handshake; - if (!header.Parse(&parser, record_header, &handshake)) { + bool complete = false; + if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, + &complete)) { return KEEP; } + if (!complete) { + EXPECT_TRUE(record_header.is_dtls()); + // Save the fragment and drop it from this record. Fragments are + // coalesced with the last fragment of the handshake message. + changed = true; + preceding_fragment_.Assign(handshake); + continue; + } + preceding_fragment_.Truncate(0); + DataBuffer filtered; - PacketFilter::Action action = FilterHandshake(header, handshake, &filtered); + PacketFilter::Action action; + if (!IsFilteredType(header, handshake)) { + action = KEEP; + } else { + action = FilterHandshake(header, handshake, &filtered); + } if (action == DROP) { changed = true; std::cerr << "handshake drop: " << handshake << std::endl; @@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( std::cerr << "handshake old: " << handshake << std::endl; std::cerr << "handshake new: " << filtered << std::endl; source = &filtered; + } else if (preceding_fragment_.len()) { + changed = true; } offset = header.Write(output, offset, *source); @@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( } bool TlsHandshakeFilter::HandshakeHeader::ReadLength( - TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) { - if (!parser->Read(length, 3)) { + TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, + uint32_t* length, bool* last_fragment) { + uint32_t message_length; + if (!parser->Read(&message_length, 3)) { return false; // malformed } if (!header.is_dtls()) { + *last_fragment = true; + *length = message_length; return true; // nothing left to do } @@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength( } message_seq_ = message_seq_tmp; - uint32_t fragment_offset; - if (!parser->Read(&fragment_offset, 3)) { + uint32_t offset = 0; + if (!parser->Read(&offset, 3)) { + return false; + } + // We only parse if the fragments are all complete and in order. + if (offset != expected_offset) { + EXPECT_NE(0U, header.epoch()) + << "Received out of order handshake fragment for epoch 0"; return false; } - uint32_t fragment_length; - if (!parser->Read(&fragment_length, 3)) { + // For DTLS, we return the length of just this fragment. + if (!parser->Read(length, 3)) { return false; } - // All current tests where we are using this code don't fragment. - return (fragment_offset == 0 && fragment_length == *length); + // It's a fragment if the entire message is longer than what we have. + *last_fragment = message_length == (*length + offset); + return true; } bool TlsHandshakeFilter::HandshakeHeader::Parse( - TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) { + TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { + *complete = false; + version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed } + uint32_t length; - if (!ReadLength(parser, record_header, &length)) { + if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, + complete)) { return false; } - return parser->Read(body, length); + if (!parser->Read(body, length)) { + return false; + } + if (preceding_fragment.len()) { + body->Splice(preceding_fragment, 0); + } + return true; } size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( @@ -337,7 +452,7 @@ size_t TlsHandshakeFilter::HandshakeHeader::Write( return offset; } -PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( +PacketFilter::Action TlsHandshakeRecorder::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { // Only do this once. @@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( return KEEP; } - if (header.handshake_type() == handshake_type_) { - buffer_ = input; - } + buffer_ = input; return KEEP; } PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == handshake_type_) { - *output = buffer_; - return CHANGE; - } + *output = buffer_; + return CHANGE; +} +PacketFilter::Action TlsRecordRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (!filter_ || (header.content_type() == ct_)) { + records_.push_back({header, input}); + } return KEEP; } @@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord( return KEEP; } +PacketFilter::Action TlsHeaderRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(header); + return KEEP; +} + +const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { + if (index > headers_.size() + 1) { + return nullptr; + } + return &headers_[index]; +} + PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); bool changed = false; for (auto it = filters_.begin(); it != filters_.end(); ++it) { - PacketFilter::Action action = (*it)->Filter(in, output); + PacketFilter::Action action = (*it)->Process(in, output); if (action == DROP) { return DROP; } + if (action == CHANGE) { in = *output; changed = true; @@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } -static bool FindHelloRetryExtensions(TlsParser* parser, - const TlsVersioned& header) { - // TODO for -19 add cipher suite - if (!parser->Skip(2)) { // version - return false; - } - return true; -} - bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } @@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser, if (!parser->SkipVariable(1)) { // request context return false; } - // TODO remove the next two for -19 - if (!parser->SkipVariable(2)) { // signature_algorithms - return false; - } - if (!parser->SkipVariable(2)) { // certificate_authorities - return false; - } return true; } @@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, if (!parser->Skip(8)) { // lifetime, age add return false; } + if (!parser->SkipVariable(1)) { // ticket_nonce + return false; + } if (!parser->SkipVariable(2)) { // ticket return false; } @@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = { {kTlsHandshakeClientHello, FindClientHelloExtensions}, {kTlsHandshakeServerHello, FindServerHelloExtensions}, - {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions}, {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, {kTlsHandshakeCertificate, FindCertificateExtensions}, @@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser, PacketFilter::Action TlsExtensionFilter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (handshake_types_.count(header.handshake_type()) == 0) { - return KEEP; - } - TlsParser parser(input); if (!FindExtensions(&parser, header)) { return KEEP; @@ -610,13 +725,45 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension( return KEEP; } +PacketFilter::Action TlsExtensionInjector::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; +} + PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { if (counter_++ == record_) { DataBuffer buf; header.Write(&buf, 0, body); - src_.lock()->SendDirect(buf); + agent()->SendDirect(buf); dest_.lock()->Handshake(); func_(); return DROP; @@ -625,13 +772,11 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, return KEEP; } -PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( +PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { - EXPECT_EQ(SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); - } + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); return KEEP; } @@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, return ((1 << counter_++) & pattern_) ? DROP : KEEP; } -PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake( +PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& data, + DataBuffer* changed) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +/* static */ uint32_t SelectiveRecordDropFilter::ToPattern( + std::initializer_list<size_t> records) { + uint32_t pattern = 0; + for (auto it = records.begin(); it != records.end(); ++it) { + EXPECT_GT(32U, *it); + assert(*it < 32U); + pattern |= 1 << *it; + } + return pattern; +} + +PacketFilter::Action TlsClientHelloVersionSetter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientHello) { - *output = input; - output->Write(0, version_, 2); - return CHANGE; - } - return KEEP; + *output = input; + output->Write(0, version_, 2); + return CHANGE; +} + +PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + // Cipher suite is after version(2) and random(32). + size_t pos = 34; + if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { + // In old versions, we have to skip a session_id too. + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + } + output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); + return CHANGE; } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index e4030e23f..1bbe190ab 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -13,6 +13,7 @@ #include <vector> #include "test_io.h" +#include "tls_agent.h" #include "tls_parser.h" #include "tls_protect.h" @@ -23,7 +24,6 @@ extern "C" { namespace nss_test { class TlsCipherSpec; -class TlsAgent; class TlsVersioned { public: @@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned { uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } - size_t header_length() const { return is_dtls() ? 11 : 3; } + uint16_t epoch() const { + return static_cast<uint16_t>(sequence_number_ >> 48); + } + size_t header_length() const { return is_dtls() ? 13 : 5; } // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(TlsParser* parser, DataBuffer* body); + bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; @@ -63,13 +66,32 @@ class TlsRecordHeader : public TlsVersioned { uint64_t sequence_number_; }; +struct TlsRecord { + const TlsRecordHeader header; + const DataBuffer buffer; +}; + +// Make a filter and install it on a TlsAgent. +template <class T, typename... Args> +inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent, + Args&&... args) { + auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...); + agent->SetFilter(filter); + return filter; +} + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} + TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent) + : agent_(agent), + count_(0), + cipher_spec_(), + dropped_record_(false), + in_sequence_number_(0), + out_sequence_number_(0) {} - void SetAgent(const TlsAgent* agent) { agent_ = agent; } - const TlsAgent* agent() const { return agent_; } + std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); @@ -112,17 +134,24 @@ class TlsRecordFilter : public PacketFilter { static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); - const TlsAgent* agent_; + std::weak_ptr<TlsAgent> agent_; size_t count_; std::unique_ptr<TlsCipherSpec> cipher_spec_; + // Whether we dropped a record since the cipher spec changed. + bool dropped_record_; + // The sequence number we use for reading records as they are written. + uint64_t in_sequence_number_; + // The sequence number we use for writing modified records. + uint64_t out_sequence_number_; }; -inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { +inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { v.WriteStream(stream); return stream; } -inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { +inline std::ostream& operator<<(std::ostream& stream, + const TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { @@ -133,13 +162,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { stream << "Alert"; break; case kTlsHandshakeType: + case kTlsAltHandshakeType: stream << "Handshake"; break; case kTlsApplicationDataType: stream << "Data"; break; + case kTlsAckType: + stream << "ACK"; + break; default: - stream << '<' << hdr.content_type() << '>'; + stream << '<' << static_cast<int>(hdr.content_type()) << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; @@ -150,7 +183,20 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& types) + : TlsRecordFilter(agent), + handshake_types_(types), + preceding_fragment_() {} + + // This filter can be set to be selective based on handshake message type. If + // this function isn't used (or the set is empty), then all handshake messages + // will be filtered. + void SetHandshakeTypes(const std::set<uint8_t>& types) { + handshake_types_ = types; + } class HandshakeHeader : public TlsVersioned { public: @@ -158,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, - DataBuffer* body); + const DataBuffer& preceding_fragment, DataBuffer* body, + bool* complete); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, @@ -169,7 +216,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, - uint32_t* length); + uint32_t expected_offset, uint32_t* length, + bool* last_fragment); uint8_t handshake_type_; uint16_t message_seq_; @@ -185,60 +233,115 @@ class TlsHandshakeFilter : public TlsRecordFilter { DataBuffer* output) = 0; private: + bool IsFilteredType(const HandshakeHeader& header, + const DataBuffer& handshake); + + std::set<uint8_t> handshake_types_; + DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. -class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { +class TlsHandshakeRecorder : public TlsHandshakeFilter { public: - TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : handshake_type_(handshake_type), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type) + : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& handshake_types) + : TlsHandshakeFilter(agent, handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); + void Reset() { buffer_.Truncate(0); } + const DataBuffer& buffer() const { return buffer_; } private: - uint8_t handshake_type_; DataBuffer buffer_; }; // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: - TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, + TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, const DataBuffer& replacement) - : handshake_type_(handshake_type), buffer_(replacement) {} + : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: - uint8_t handshake_type_; DataBuffer buffer_; }; +// Make a copy of each record of a given type. +class TlsRecordRecorder : public TlsRecordFilter { + public: + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct) + : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {} + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), + filter_(false), + ct_(content_handshake), // dummy (<optional> is C++14) + records_() {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + size_t count() const { return records_.size(); } + void Clear() { records_.clear(); } + + const TlsRecord& record(size_t i) const { return records_[i]; } + + private: + bool filter_; + uint8_t ct_; + std::vector<TlsRecord> records_; +}; + // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: - TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} + TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent, + DataBuffer& buffer) + : TlsRecordFilter(agent), buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); private: - DataBuffer& buffer_; + DataBuffer buffer_; }; +// Make a copy of the records +class TlsHeaderRecorder : public TlsRecordFilter { + public: + TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + const TlsRecordHeader* header(size_t index); + + private: + std::vector<TlsRecordHeader> headers_; +}; + +typedef std::initializer_list<std::shared_ptr<PacketFilter>> + ChainedPacketFilterInit; + // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) : filters_(filters.begin(), filters.end()) {} + ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, @@ -256,13 +359,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() : handshake_types_() { - handshake_types_.insert(kTlsHandshakeClientHello); - handshake_types_.insert(kTlsHandshakeServerHello); - } + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, + {kTlsHandshakeClientHello, kTlsHandshakeServerHello, + kTlsHandshakeHelloRetryRequest, + kTlsHandshakeEncryptedExtensions}) {} - TlsExtensionFilter(const std::set<uint8_t>& types) - : handshake_types_(types) {} + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& types) + : TlsHandshakeFilter(agent, types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -279,14 +384,17 @@ class TlsExtensionFilter : public TlsHandshakeFilter { PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); - - std::set<uint8_t> handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { public: - TlsExtensionCapture(uint16_t ext, bool last = false) - : extension_(ext), captured_(false), last_(last), data_() {} + TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + bool last = false) + : TlsExtensionFilter(agent), + extension_(ext), + captured_(false), + last_(last), + data_() {} const DataBuffer& extension() const { return data_; } bool captured() const { return captured_; } @@ -305,8 +413,9 @@ class TlsExtensionCapture : public TlsExtensionFilter { class TlsExtensionReplacer : public TlsExtensionFilter { public: - TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) - : extension_(extension), data_(data) {} + TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, const DataBuffer& data) + : TlsExtensionFilter(agent), extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; @@ -318,7 +427,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter { class TlsExtensionDropper : public TlsExtensionFilter { public: - TlsExtensionDropper(uint16_t extension) : extension_(extension) {} + TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension) + : TlsExtensionFilter(agent), extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; @@ -326,21 +437,41 @@ class TlsExtensionDropper : public TlsExtensionFilter { uint16_t extension_; }; +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + const DataBuffer& data) + : TlsHandshakeFilter(agent), extension_(ext), data_(data) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + class TlsAgent; typedef std::function<void(void)> VoidFunction; class AfterRecordN : public TlsRecordFilter { public: - AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest, - unsigned int record, VoidFunction func) - : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} + AfterRecordN(const std::shared_ptr<TlsAgent>& src, + const std::shared_ptr<TlsAgent>& dest, unsigned int record, + VoidFunction func) + : TlsRecordFilter(src), + dest_(dest), + record_(record), + func_(func), + counter_(0) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: - std::weak_ptr<TlsAgent> src_; std::weak_ptr<TlsAgent> dest_; unsigned int record_; VoidFunction func_; @@ -349,10 +480,12 @@ class AfterRecordN : public TlsRecordFilter { // When we see the ClientKeyExchange from |client|, increment the // ClientHelloVersion on |server|. -class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { +class TlsClientHelloVersionChanger : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) - : server_(server) {} + TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client, + const std::shared_ptr<TlsAgent>& server) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), + server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -377,10 +510,51 @@ class SelectiveDropFilter : public PacketFilter { uint8_t counter_; }; +// This class selectively drops complete records. The difference from +// SelectiveDropFilter is that if multiple DTLS records are in the same +// datagram, we just drop one. +class SelectiveRecordDropFilter : public TlsRecordFilter { + public: + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + uint32_t pattern, bool enabled = true) + : TlsRecordFilter(agent), pattern_(pattern), counter_(0) { + if (!enabled) { + Disable(); + } + } + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + std::initializer_list<size_t> records) + : SelectiveRecordDropFilter(agent, ToPattern(records), true) {} + + void Reset(uint32_t pattern) { + counter_ = 0; + PacketFilter::Enable(); + pattern_ = pattern; + } + + void Reset(std::initializer_list<size_t> records) { + Reset(ToPattern(records)); + } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override; + + private: + static uint32_t ToPattern(std::initializer_list<size_t> records); + + uint32_t pattern_; + uint8_t counter_; +}; + // Set the version number in the ClientHello. -class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { +class TlsClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {} + TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent, + uint16_t version) + : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}), + version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -393,7 +567,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: - TlsLastByteDamager(uint8_t type) : type_(type) {} + TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type) + : TlsHandshakeFilter(agent), type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -411,6 +586,22 @@ class TlsLastByteDamager : public TlsHandshakeFilter { uint8_t type_; }; +class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { + public: + SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent, + uint16_t suite) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + cipher_suite_(suite) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + uint16_t cipher_suite_; +}; + } // namespace nss_test #endif diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc index 51ff938b1..45f6cf2bd 100644 --- a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc +++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -241,13 +241,13 @@ TEST_P(TlsHkdfTest, HkdfExpandLabel) { {/* ssl_hash_md5 */}, {/* ssl_hash_sha1 */}, {/* ssl_hash_sha224 */}, - {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b, - 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13, - 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c}, - {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01, - 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21, - 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9, - 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}}; + {0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59, + 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc, + 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1}, + {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8, + 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54, + 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7, + 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}}; const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_], diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc index efcd89e14..6c945f66e 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.cc +++ b/security/nss/gtests/ssl_gtest/tls_protect.cc @@ -32,7 +32,6 @@ void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) { } DataBuffer d(nonce, 12); - std::cerr << "Nonce " << d << std::endl; } bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length, @@ -92,8 +91,9 @@ bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq, in, inlen, out, outlen, maxlen); } -bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key, - const uint8_t *iv) { +bool TlsCipherSpec::Init(uint16_t epoch, SSLCipherAlgorithm cipher, + PK11SymKey *key, const uint8_t *iv) { + epoch_ = epoch; switch (cipher) { case ssl_calg_aes_gcm: aead_.reset(new AeadCipherAesGcm()); diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h index 4efbd6e6b..93ffd6322 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.h +++ b/security/nss/gtests/ssl_gtest/tls_protect.h @@ -20,7 +20,7 @@ class TlsRecordHeader; class AeadCipher { public: AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {} - ~AeadCipher(); + virtual ~AeadCipher(); bool Init(PK11SymKey *key, const uint8_t *iv); virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen, @@ -58,16 +58,19 @@ class AeadCipherAesGcm : public AeadCipher { // Our analog of ssl3CipherSpec class TlsCipherSpec { public: - TlsCipherSpec() : aead_() {} + TlsCipherSpec() : epoch_(0), aead_() {} - bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv); + bool Init(uint16_t epoch, SSLCipherAlgorithm cipher, PK11SymKey *key, + const uint8_t *iv); bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext, DataBuffer *ciphertext); bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext, DataBuffer *plaintext); + uint16_t epoch() const { return epoch_; } private: + uint16_t epoch_; std::unique_ptr<AeadCipher> aead_; }; |