diff options
author | wolfbeast <mcwerewolf@gmail.com> | 2018-02-23 11:04:39 +0100 |
---|---|---|
committer | wolfbeast <mcwerewolf@gmail.com> | 2018-06-05 22:24:08 +0200 |
commit | e10349ab8dda8a3f11be6aa19f2b6e29fe814044 (patch) | |
tree | 1a9b078b06a76af06839d407b7267880890afccc /security/nss/gtests/ssl_gtest | |
parent | 75b3dd4cbffb6e4534128278300ed6c8a3ab7506 (diff) | |
download | UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.gz UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.lz UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.xz UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.zip |
Update NSS to 3.35-RTM
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
45 files changed, 4600 insertions, 928 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile index a9a9290e0..95c111aeb 100644 --- a/security/nss/gtests/ssl_gtest/Makefile +++ b/security/nss/gtests/ssl_gtest/Makefile @@ -29,10 +29,6 @@ include ../common/gtest.mk CFLAGS += -I$(CORE_DEPTH)/lib/ssl -ifdef NSS_SSL_ENABLE_ZLIB -include $(CORE_DEPTH)/coreconf/zlib.mk -endif - ifdef NSS_DISABLE_TLS_1_3 NSS_DISABLE_TLS_1_3=1 # Run parameterized tests only, for which we can easily exclude TLS 1.3 diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc new file mode 100644 index 000000000..110cfa13a --- /dev/null +++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc @@ -0,0 +1,108 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +extern "C" { +#include "sslbloom.h" +} + +#include "gtest_utils.h" + +namespace nss_test { + +// Some random-ish inputs to test with. These don't result in collisions in any +// of the configurations that are tested below. +static const uint8_t kHashes1[] = { + 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8, + 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34, + 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48}; +static const uint8_t kHashes2[] = { + 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59, + 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc, + 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1}; + +typedef struct { + unsigned int k; + unsigned int bits; +} BloomFilterConfig; + +class BloomFilterTest + : public ::testing::Test, + public ::testing::WithParamInterface<BloomFilterConfig> { + public: + BloomFilterTest() : filter_() {} + + void SetUp() { Init(); } + + void TearDown() { sslBloom_Destroy(&filter_); } + + protected: + void Init() { + if (filter_.filter) { + sslBloom_Destroy(&filter_); + } + ASSERT_EQ(SECSuccess, + sslBloom_Init(&filter_, GetParam().k, GetParam().bits)); + } + + bool Check(const uint8_t* hashes) { + return sslBloom_Check(&filter_, hashes) ? true : false; + } + + void Add(const uint8_t* hashes, bool expect_collision = false) { + EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false); + EXPECT_TRUE(Check(hashes)); + } + + sslBloomFilter filter_; +}; + +TEST_P(BloomFilterTest, InitOnly) {} + +TEST_P(BloomFilterTest, AddToEmpty) { + EXPECT_FALSE(Check(kHashes1)); + Add(kHashes1); +} + +TEST_P(BloomFilterTest, AddTwo) { + Add(kHashes1); + Add(kHashes2); +} + +TEST_P(BloomFilterTest, AddOneTwice) { + Add(kHashes1); + Add(kHashes1, true); +} + +TEST_P(BloomFilterTest, Zero) { + Add(kHashes1); + sslBloom_Zero(&filter_); + EXPECT_FALSE(Check(kHashes1)); + EXPECT_FALSE(Check(kHashes2)); +} + +TEST_P(BloomFilterTest, Fill) { + sslBloom_Fill(&filter_); + EXPECT_TRUE(Check(kHashes1)); + EXPECT_TRUE(Check(kHashes2)); +} + +static const BloomFilterConfig kBloomFilterConfigurations[] = { + {1, 1}, // 1 hash, 1 bit input - high chance of collision. + {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size. + {1, 3}, // 1 hash, 3 bits - same as basic unit size. + {1, 4}, // 1 hash, 4 bits - 2 octets each. + {3, 10}, // 3 hashes over a reasonable number of bits. + {3, 3}, // Test that we can read multiple bits. + {4, 15}, // A credible filter. + {2, 18}, // A moderately large allocation. + {16, 16}, // Insane, use all of the bits from the hashes. + {16, 9}, // This also uses all of the bits from the hashes. +}; + +INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest, + ::testing::ValuesIn(kBloomFilterConfigurations)); + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c index 97b8354ae..887d85278 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.c +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -34,18 +34,17 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, return SECFailure; } - ssl3_InitState(ss); ssl3_RestartHandshakeHashes(ss); // Ensure we don't overrun hs.client_random. rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); - // Zero the client_random struct. - PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + // Zero the client_random. + PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); // Copy over the challenge bytes. size_t offset = SSL3_RANDOM_LENGTH - rnd_len; - PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len); + PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len); // Rehash the SSLv2 client hello message. return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); @@ -73,10 +72,11 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { return SECFailure; } ss->ssl3.mtu = mtu; + ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */ return SECSuccess; } -PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) { PRCList *cur_p; PRInt32 ct = 0; @@ -92,7 +92,7 @@ PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { return ct; } -void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) { PRCList *cur_p; sslSocket *ss = ssl_FindSocket(fd); @@ -100,27 +100,31 @@ void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { return; } - fprintf(stderr, "Cipher specs\n"); + fprintf(stderr, "Cipher specs for %s\n", label); for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; - fprintf(stderr, " %s\n", spec->phase); + fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec), + spec->epoch, spec->phase, spec->refCt); } } -/* Force a timer expiry by backdating when the timer was started. - * We could set the remaining time to 0 but then backoff would not - * work properly if we decide to test it. */ -void SSLInt_ForceTimerExpiry(PRFileDesc *fd) { +/* Force a timer expiry by backdating when all active timers were started. We + * could set the remaining time to 0 but then backoff would not work properly if + * we decide to test it. */ +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) { + size_t i; sslSocket *ss = ssl_FindSocket(fd); if (!ss) { - return; + return SECFailure; } - if (!ss->ssl3.hs.rtTimerCb) return; - - ss->ssl3.hs.rtTimerStarted = - PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1); + for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) { + if (ss->ssl3.hs.timers[i].cb) { + ss->ssl3.hs.timers[i].started -= shift; + } + } + return SECSuccess; } #define CHECK_SECRET(secret) \ @@ -136,7 +140,6 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { } CHECK_SECRET(currentSecret); - CHECK_SECRET(resumptionMasterSecret); CHECK_SECRET(dheSecret); CHECK_SECRET(clientEarlyTrafficSecret); CHECK_SECRET(clientHsTrafficSecret); @@ -226,28 +229,7 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { return PR_TRUE; } -PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) { - sslSocket *ss = ssl_FindSocket(fd); - if (!ss) { - return PR_FALSE; - } - - ssl_GetSSL3HandshakeLock(ss); - ssl_GetXmitBufLock(ss); - - SECStatus rv = tls13_SendNewSessionTicket(ss); - if (rv == SECSuccess) { - rv = ssl3_FlushHandshake(ss, 0); - } - - ssl_ReleaseXmitBufLock(ss); - ssl_ReleaseSSL3HandshakeLock(ss); - - return rv == SECSuccess; -} - SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ssl3CipherSpec *spec; @@ -255,43 +237,40 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); spec = ss->ssl3.crSpec; - epoch = spec->read_seq_num >> 48; - spec->read_seq_num = (epoch << 48) | to; + spec->seqNum = to; /* For DTLS, we need to fix the record sequence number. For this, we can just * scrub the entire structure on the assumption that the new sequence number * is far enough past the last received sequence number. */ - if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } - dtls_RecordSetRecvd(&spec->recvdRecords, to); + dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum); ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { - PRUint64 epoch; sslSocket *ss; ss = ssl_FindSocket(fd); if (!ss) { return SECFailure; } - if (to >= (1ULL << 48)) { + if (to >= RECORD_SEQ_MAX) { PORT_SetError(SEC_ERROR_INVALID_ARGS); return SECFailure; } ssl_GetSpecWriteLock(ss); - epoch = ss->ssl3.cwSpec->write_seq_num >> 48; - ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to; + ss->ssl3.cwSpec->seqNum = to; ssl_ReleaseSpecWriteLock(ss); return SECSuccess; } @@ -305,9 +284,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { return SECFailure; } ssl_GetSpecReadLock(ss); - to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra; + to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra; ssl_ReleaseSpecReadLock(ss); - return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX); + return SSLInt_AdvanceWriteSeqNum(fd, to); } SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { @@ -333,46 +312,20 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, return SECSuccess; } -static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer, - ssl3CipherSpec *spec) { - return isServer ? &spec->server : &spec->client; +PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) { + return spec->keyMaterial.key; } -PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_key; +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) { + return spec->cipherDef->calg; } -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, - ssl3CipherSpec *spec) { - return spec->cipher_def->calg; +const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) { + return spec->keyMaterial.iv; } -unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) { - return GetKeyingMaterial(isServer, spec)->write_iv; -} - -SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - ss->opt.enableShortHeaders = PR_TRUE; - return SECSuccess; -} - -SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { - sslSocket *ss; - - ss = ssl_FindSocket(fd); - if (!ss) { - return SECFailure; - } - - *result = ss->ssl3.hs.shortHeaders; - return SECSuccess; +PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) { + return spec->epoch; } void SSLInt_SetTicketLifetime(uint32_t lifetime) { @@ -405,3 +358,21 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) { return SECSuccess; } + +void SSLInt_RolloverAntiReplay(void) { + tls13_AntiReplayRollover(ssl_TimeUsec()); +} + +SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, + PRUint16 *writeEpoch) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss || !readEpoch || !writeEpoch) { + return SECFailure; + } + + ssl_GetSpecReadLock(ss); + *readEpoch = ss->ssl3.crSpec->epoch; + *writeEpoch = ss->ssl3.cwSpec->epoch; + ssl_ReleaseSpecReadLock(ss); + return SECSuccess; +} diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h index 33709c4b4..95d4afdaf 100644 --- a/security/nss/gtests/ssl_gtest/libssl_internals.h +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -24,9 +24,9 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext); void SSLInt_ClearSelfEncryptKey(); void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key); -PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd); -void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd); -void SSLInt_ForceTimerExpiry(PRFileDesc *fd); +PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd); +void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd); +SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift); SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu); PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd); PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd); @@ -35,23 +35,23 @@ PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd); SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len); PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType); PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type); -PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd); SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); +SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch, + PRUint16 *writeEpoch); SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd, sslCipherSpecChangedFunc func, void *arg); -PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec); -SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer, - ssl3CipherSpec *spec); -unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec); -SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd); -SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result); +PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec); +PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec); +SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec); +const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec); void SSLInt_SetTicketLifetime(uint32_t lifetime); void SSLInt_SetMaxEarlyDataSize(uint32_t size); SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size); +void SSLInt_RolloverAntiReplay(void); #endif // ndef libssl_internals_h_ diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn index cc729c0f1..5d893bab3 100644 --- a/security/nss/gtests/ssl_gtest/manifest.mn +++ b/security/nss/gtests/ssl_gtest/manifest.mn @@ -12,11 +12,13 @@ CSRCS = \ $(NULL) CPPSRCS = \ + bloomfilter_unittest.cc \ ssl_0rtt_unittest.cc \ ssl_agent_unittest.cc \ ssl_auth_unittest.cc \ ssl_cert_ext_unittest.cc \ ssl_ciphersuite_unittest.cc \ + ssl_custext_unittest.cc \ ssl_damage_unittest.cc \ ssl_dhe_unittest.cc \ ssl_drop_unittest.cc \ @@ -29,11 +31,16 @@ CPPSRCS = \ ssl_gather_unittest.cc \ ssl_gtest.cc \ ssl_hrr_unittest.cc \ + ssl_keylog_unittest.cc \ + ssl_keyupdate_unittest.cc \ ssl_loopback_unittest.cc \ + ssl_misc_unittest.cc \ ssl_record_unittest.cc \ ssl_resumption_unittest.cc \ + ssl_renegotiation_unittest.cc \ ssl_skip_unittest.cc \ ssl_staticrsa_unittest.cc \ + ssl_tls13compat_unittest.cc \ ssl_v2_client_hello_unittest.cc \ ssl_version_unittest.cc \ ssl_versionpolicy_unittest.cc \ diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc index 85b7011a1..a60295490 100644 --- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -7,6 +7,7 @@ #include "secerr.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" extern "C" { @@ -44,6 +45,93 @@ TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) { SendReceive(); } +TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) { + // The test fixtures call SSL_SetupAntiReplay() in SetUp(). This results in + // 0-RTT being rejected until at least one window passes. SetupFor0Rtt() + // forces a rollover of the anti-replay filters, which clears this state. + // Here, we do the setup manually here without that forced rollover. + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + StartConnect(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +class TlsZeroRttReplayTest : public TlsConnectTls13 { + private: + class SaveFirstPacket : public PacketFilter { + public: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + if (!packet_.len() && input.len()) { + packet_ = input; + } + return KEEP; + } + + const DataBuffer& packet() const { return packet_; } + + private: + DataBuffer packet_; + }; + + protected: + void RunTest(bool rollover) { + // Run the initial handshake + SetupForZeroRtt(); + + // Now run a true 0-RTT handshake, but capture the first packet. + auto first_packet = std::make_shared<SaveFirstPacket>(); + client_->SetPacketFilter(first_packet); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + EXPECT_LT(0U, first_packet->packet().len()); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + + if (rollover) { + SSLInt_RolloverAntiReplay(); + } + + // Now replay that packet against the server. + Reset(); + server_->StartConnect(); + server_->Set0RttEnabled(true); + + // Capture the early_data extension, which should not appear. + auto early_data_ext = + std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn); + server_->SetPacketFilter(early_data_ext); + + // Finally, replay the ClientHello and force the server to consume it. Stop + // after the server sends its first flight; the client will not be able to + // complete this handshake. + server_->adapter()->PacketReceived(first_packet->packet()); + server_->Handshake(); + EXPECT_FALSE(early_data_ext->captured()); + } +}; + +TEST_P(TlsZeroRttReplayTest, ZeroRttReplay) { RunTest(false); } + +TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { RunTest(true); } + // Test that we don't try to send 0-RTT data when the server sent // us a ticket without the 0-RTT flags. TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { @@ -52,8 +140,7 @@ TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); Reset(); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); // Now turn on 0-RTT but too late for the ticket. client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); @@ -80,8 +167,7 @@ TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) { TEST_P(TlsConnectTls13, ZeroRttServerOnly) { ExpectResumption(RESUME_NONE); server_->Set0RttEnabled(true); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); // Client sends ordinary ClientHello. client_->Handshake(); @@ -99,6 +185,61 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnly) { CheckKeys(); } +// A small sleep after sending the ClientHello means that the ticket age that +// arrives at the server is too low. With a small tolerance for variation in +// ticket age (which is determined by the |window| parameter that is passed to +// SSL_SetupAntiReplay()), the server then rejects early data. +TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); + SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. + SSLInt_RolloverAntiReplay(); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, []() { + PR_Sleep(PR_MillisecondsToInterval(10)); + return true; + }); + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + SendReceive(); +} + +// In this test, we falsely inflate the estimate of the RTT by delaying the +// ServerHello on the first handshake. This results in the server estimating a +// higher value of the ticket age than the client ultimately provides. Add a +// small tolerance for variation in ticket age and the ticket will appear to +// arrive prematurely, causing the server to reject early data. +TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); + StartConnect(); + client_->Handshake(); // ClientHello + server_->Handshake(); // ServerHello + PR_Sleep(PR_MillisecondsToInterval(10)); + Handshake(); // Remainder of handshake + CheckConnected(); + SendReceive(); + CheckKeys(); + + Reset(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3)); + SSLInt_RolloverAntiReplay(); // Make sure to flush replay state. + SSLInt_RolloverAntiReplay(); + ExpectResumption(RESUME_TICKET); + ExpectEarlyDataAccepted(false); + StartConnect(); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { EnableAlpn(); SetupForZeroRtt(); @@ -117,6 +258,14 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { CheckAlpn("a"); } +// NOTE: In this test and those below, the client always sends +// post-ServerHello alerts with the handshake keys, even if the server +// has accepted 0-RTT. In some cases, as with errors in +// EncryptedExtensions, the client can't know the server's behavior, +// and in others it's just simpler. What the server is expecting +// depends on whether it accepted 0-RTT or not. Eventually, we may +// make the server trial decrypt. +// // Have the server negotiate a different ALPN value, and therefore // reject 0-RTT. TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) { @@ -155,12 +304,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) { client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b))); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); - ExpectAlert(client_, kTlsAlertIllegalParameter); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); return true; }); - Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); } // Set up with no ALPN and then set the client so it thinks it has ALPN. @@ -175,12 +329,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) { PRUint8 b[] = {'b'}; EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1)); client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); - ExpectAlert(client_, kTlsAlertIllegalParameter); + client_->ExpectSendAlert(kTlsAlertIllegalParameter); return true; }); - Handshake(); + if (variant_ == ssl_variant_stream) { + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + client_->Handshake(); + } client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); } // Remove the old ALPN value and so the client will not offer early data. @@ -218,9 +377,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); - client_->StartConnect(); - server_->StartConnect(); - + StartConnect(); // We will send the early data xtn without sending actual early data. Thus // a 1.2 server shouldn't fail until the client sends an alert because the // client sends end_of_early_data only after reading the server's flight. @@ -261,9 +418,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_2); - client_->StartConnect(); - server_->StartConnect(); - + StartConnect(); // Send the early data xtn in the CH, followed by early app data. The server // will fail right after sending its flight, when receiving the early data. client_->Set0RttEnabled(true); @@ -310,7 +465,6 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) { server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); - ExpectAlert(client_, kTlsAlertEndOfEarlyData); client_->Handshake(); CheckEarlyDataLimit(client_, short_size); @@ -364,7 +518,6 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); - client_->ExpectSendAlert(kTlsAlertEndOfEarlyData); client_->Handshake(); // Send ClientHello CheckEarlyDataLimit(client_, limit); @@ -399,4 +552,86 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) { } } +class PacketCoalesceFilter : public PacketFilter { + public: + PacketCoalesceFilter() : packet_data_() {} + + void SendCoalesced(std::shared_ptr<TlsAgent> agent) { + agent->SendDirect(packet_data_); + } + + protected: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + packet_data_.Write(packet_data_.len(), input); + return DROP; + } + + private: + DataBuffer packet_data_; +}; + +TEST_P(TlsConnectTls13, ZeroRttOrdering) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + // Send out the ClientHello. + client_->Handshake(); + + // Now, coalesce the next three things from the client: early data, second + // flight and 1-RTT data. + auto coalesce = std::make_shared<PacketCoalesceFilter>(); + client_->SetPacketFilter(coalesce); + + // Send (and hold) early data. + static const std::vector<uint8_t> early_data = {3, 2, 1}; + EXPECT_EQ(static_cast<PRInt32>(early_data.size()), + PR_Write(client_->ssl_fd(), early_data.data(), early_data.size())); + + // Send (and hold) the second client handshake flight. + // The client sends EndOfEarlyData after seeing the server Finished. + server_->Handshake(); + client_->Handshake(); + + // Send (and hold) 1-RTT data. + static const std::vector<uint8_t> late_data = {7, 8, 9, 10}; + EXPECT_EQ(static_cast<PRInt32>(late_data.size()), + PR_Write(client_->ssl_fd(), late_data.data(), late_data.size())); + + // Now release them all at once. + coalesce->SendCoalesced(client_); + + // Now ensure that the three steps are exposed in the right order on the + // server: delivery of early data, handshake callback, delivery of 1-RTT. + size_t step = 0; + server_->SetHandshakeCallback([&step](TlsAgent*) { + EXPECT_EQ(1U, step); + ++step; + }); + + std::vector<uint8_t> buf(10); + PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read); + buf.resize(read); + EXPECT_EQ(early_data, buf); + EXPECT_EQ(0U, step); + ++step; + + // The third read should be after the handshake callback and should return the + // data that was sent after the handshake completed. + buf.resize(10); + read = PR_Read(server_->ssl_fd(), buf.data(), buf.size()); + ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read); + buf.resize(read); + EXPECT_EQ(late_data, buf); + EXPECT_EQ(2U, step); +} + +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest, + TlsConnectTestBase::kTlsVariantsAll); +#endif + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc index 5035a338d..0aa9a4c78 100644 --- a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -31,7 +31,7 @@ const static uint8_t kCannedTls13ClientHello[] = { 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, - 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00, + 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00, 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc, 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65, 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a, @@ -44,13 +44,14 @@ const static uint8_t kCannedTls13ClientHello[] = { 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02}; const static uint8_t kCannedTls13ServerHello[] = { - 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0, - 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5, - 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01, - 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, - 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, - 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, - 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19}; + 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, + 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, + 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, + 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, + 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, + 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, + 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, + 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13}; static const char *k0RttData = "ABCDEF"; TEST_P(TlsAgentTest, EarlyFinished) { diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc index dbcbc9aa3..dbcdd92ea 100644 --- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -29,7 +29,25 @@ TEST_P(TlsConnectGeneric, ServerAuthBigRsa) { } TEST_P(TlsConnectGeneric, ServerAuthRsaChain) { - Reset(TlsAgent::kServerRsaChain); + Reset("rsa_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) { + Reset("rsa_pss_chain"); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) { + Reset("rsa_ca_rsa_pss_chain"); Connect(); CheckKeys(); size_t chain_length; @@ -141,13 +159,11 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter { public: + TlsZeroCertificateRequestSigAlgsFilter() + : TlsHandshakeFilter({kTlsHandshakeCertificateRequest}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeCertificateRequest) { - return KEEP; - } - TlsParser parser(input); std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl; @@ -581,8 +597,7 @@ class EnforceNoActivity : public PacketFilter { TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send ServerHello client_->Handshake(); // Send ClientKeyExchange and Finished @@ -610,8 +625,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { TEST_P(TlsConnectTls13, AuthCompleteDelayed) { client_->SetAuthCertificateCallback(AuthCompleteBlock); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send ServerHello EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc index 3463782e0..36ee104af 100644 --- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -82,9 +82,8 @@ TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) { ssl_kea_rsa)); EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), &kSctItem, ssl_kea_rsa)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -96,9 +95,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) { EnsureTlsSetup(); EXPECT_TRUE( server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -120,9 +117,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) { TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, - PR_TRUE)); + client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); SignedCertificateTimestampsExtractor timestamps_extractor(client_); Connect(); @@ -173,16 +168,14 @@ TEST_P(TlsConnectGeneric, OcspNotRequested) { // Even if the client asks, the server has nothing unless it is configured. TEST_P(TlsConnectGeneric, OcspNotProvided) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); client_->SetAuthCertificateCallback(CheckNoOCSP); Connect(); } TEST_P(TlsConnectGenericPre13, OcspMangled) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); EXPECT_TRUE( server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); @@ -197,8 +190,7 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) { TEST_P(TlsConnectGeneric, OcspSuccess) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); auto capture_ocsp = std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn); server_->SetPacketFilter(capture_ocsp); @@ -225,8 +217,7 @@ TEST_P(TlsConnectGeneric, OcspSuccess) { TEST_P(TlsConnectGeneric, OcspHugeSuccess) { EnsureTlsSetup(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE); uint8_t hugeOcspValue[16385]; memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue)); diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc index 85c30b2bf..810656868 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -31,11 +31,11 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { public: TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version, uint16_t cipher_suite, SSLNamedGroup group, - SSLSignatureScheme signature_scheme) + SSLSignatureScheme sig_scheme) : TlsConnectTestBase(variant, version), cipher_suite_(cipher_suite), group_(group), - signature_scheme_(signature_scheme), + sig_scheme_(sig_scheme), csinfo_({0}) { SECStatus rv = SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_)); @@ -60,14 +60,14 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { server_->ConfigNamedGroups(groups); kea_type_ = SSLInt_GetKEAType(group_); - client_->SetSignatureSchemes(&signature_scheme_, 1); - server_->SetSignatureSchemes(&signature_scheme_, 1); + client_->SetSignatureSchemes(&sig_scheme_, 1); + server_->SetSignatureSchemes(&sig_scheme_, 1); } } virtual void SetupCertificate() { if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - switch (signature_scheme_) { + switch (sig_scheme_) { case ssl_sig_rsa_pkcs1_sha256: case ssl_sig_rsa_pkcs1_sha384: case ssl_sig_rsa_pkcs1_sha512: @@ -93,8 +93,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { auth_type_ = ssl_auth_ecdsa; break; default: - ASSERT_TRUE(false) << "Unsupported signature scheme: " - << signature_scheme_; + ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_; break; } } else { @@ -187,7 +186,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase { SSLAuthType auth_type_; SSLKEAType kea_type_; SSLNamedGroup group_; - SSLSignatureScheme signature_scheme_; + SSLSignatureScheme sig_scheme_; SSLCipherSuiteInfo csinfo_; }; @@ -236,27 +235,29 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) { ConnectAndCheckCipherSuite(); } -// This only works for stream ciphers because we modify the sequence number - -// which is included explicitly in the DTLS record header - and that trips a -// different error code. Note that the message that the client sends would not -// decrypt (the nonce/IV wouldn't match), but the record limit is hit before -// attempting to decrypt a record. TEST_P(TlsCipherSuiteTest, ReadLimit) { SetupCertificate(); EnableSingleCipher(); ConnectAndCheckCipherSuite(); - EXPECT_EQ(SECSuccess, - SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write())); - EXPECT_EQ(SECSuccess, - SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write())); + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + uint64_t last = last_safe_write(); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); - client_->SendData(10, 10); - server_->ReadBytes(); // This should be OK. + client_->SendData(10, 10); + server_->ReadBytes(); // This should be OK. + } else { + // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean + // that the sequence numbers would reset and we wouldn't hit the limit. So + // we move the sequence number to one less than the limit directly and don't + // test sending and receiving just before the limit. + uint64_t last = record_limit() - 1; + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last)); + } - // The payload needs to be big enough to pass for encrypted. In the extreme - // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for - // authentication tag. - static const uint8_t payload[18] = {6}; + // The payload needs to be big enough to pass for encrypted. The code checks + // the limit before it tries to decrypt. + static const uint8_t payload[32] = {6}; DataBuffer record; uint64_t epoch; if (variant_ == ssl_variant_datagram) { @@ -271,13 +272,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) { TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_, payload, sizeof(payload), &record, (epoch << 48) | record_limit()); - server_->adapter()->PacketReceived(record); + client_->SendDirect(record); server_->ExpectReadWriteError(); server_->ReadBytes(); EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code()); } TEST_P(TlsCipherSuiteTest, WriteLimit) { + // This asserts in TLS 1.3 because we expect an automatic update. + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + return; + } SetupCertificate(); EnableSingleCipher(); ConnectAndCheckCipherSuite(); diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc new file mode 100644 index 000000000..dad944a1f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc @@ -0,0 +1,503 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" +#include "sslexp.h" + +#include <memory> + +#include "tls_connect.h" + +namespace nss_test { + +static void IncrementCounterArg(void *arg) { + if (arg) { + auto *called = reinterpret_cast<size_t *>(arg); + ++*called; + } +} + +PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_FALSE; +} + +PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + IncrementCounterArg(arg); + return PR_TRUE; +} + +SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECSuccess; +} + +// All of the (current) set of supported extensions, plus a few extra. +static const uint16_t kManyExtensions[] = { + ssl_server_name_xtn, + ssl_cert_status_xtn, + ssl_supported_groups_xtn, + ssl_ec_point_formats_xtn, + ssl_signature_algorithms_xtn, + ssl_signature_algorithms_cert_xtn, + ssl_use_srtp_xtn, + ssl_app_layer_protocol_xtn, + ssl_signed_cert_timestamp_xtn, + ssl_padding_xtn, + ssl_extended_master_secret_xtn, + ssl_session_ticket_xtn, + ssl_tls13_key_share_xtn, + ssl_tls13_pre_shared_key_xtn, + ssl_tls13_early_data_xtn, + ssl_tls13_supported_versions_xtn, + ssl_tls13_cookie_xtn, + ssl_tls13_psk_key_exchange_modes_xtn, + ssl_tls13_ticket_early_data_info_xtn, + ssl_tls13_certificate_authorities_xtn, + ssl_next_proto_nego_xtn, + ssl_renegotiation_info_xtn, + ssl_tls13_short_header_xtn, + 1, + 0xffff}; +// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS), +// plus the deprecated values (see sslt.h), and two extra dummy values. +PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions)); + +void InstallManyWriters(std::shared_ptr<TlsAgent> agent, + SSLExtensionWriter writer, size_t *installed = nullptr, + size_t *called = nullptr) { + for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) { + SSLExtensionSupport support = ssl_ext_none; + SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support); + ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail"; + + rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer, + called, NoopExtensionHandler, nullptr); + if (support == ssl_ext_native_only) { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + } else { + if (installed) { + ++*installed; + } + EXPECT_EQ(SECSuccess, rv); + } + } +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(client_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + EXPECT_EQ(installed, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) { + EnsureTlsSetup(); + size_t installed = 0; + size_t called = 0; + InstallManyWriters(server_, NoopExtensionWriter, &installed, &called); + EXPECT_LT(0U, installed); + Connect(); + // Extension writers are all called for each of ServerHello, + // EncryptedExtensions, and Certificate. + EXPECT_EQ(installed * 3, called); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) { + EnsureTlsSetup(); + InstallManyWriters(client_, EmptyExtensionWriter); + InstallManyWriters(server_, EmptyExtensionWriter); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) { + EnsureTlsSetup(); + InstallManyWriters(server_, EmptyExtensionWriter); + // Sending extensions that the client doesn't expect leads to extensions + // appearing even if the client didn't send one, or in the wrong messages. + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Install an writer to disable sending of a natively-supported extension. +TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that doesn't do anything. You have to specify + // something; passing all nullptr values removes an existing handler. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + auto capture = + std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn); + client_->SetPacketFilter(capture); + + Connect(); + // So nothing will be sent. + EXPECT_FALSE(capture->captured()); +} + +// An extension that is unlikely to be parsed as valid. +static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19}; + +static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_NE(nullptr, agent); + EXPECT_NE(nullptr, data); + EXPECT_NE(nullptr, len); + EXPECT_EQ(0U, *len); + EXPECT_LT(0U, maxLen); + EXPECT_EQ(agent->ssl_fd(), fd); + + if (message != ssl_hs_client_hello && message != ssl_hs_server_hello && + message != ssl_hs_encrypted_extensions) { + return PR_FALSE; + } + + *len = static_cast<unsigned int>(sizeof(kNonsenseExtension)); + EXPECT_GE(maxLen, *len); + if (maxLen < *len) { + return PR_FALSE; + } + PORT_Memcpy(data, kNonsenseExtension, *len); + return PR_TRUE; +} + +// Override the extension handler for an natively-supported and produce +// nonsense, which results in a handshake failure. +TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) { + EnsureTlsSetup(); + + // This option enables sending the extension via the native support. + SECStatus rv = SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + // This installs an override that sends nonsense. + rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter, + client_.get(), NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = + std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn); + client_->SetPacketFilter(capture); + + ConnectExpectAlert(server_, kTlsAlertDecodeError); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static SECStatus NonsenseExtensionHandler(PRFileDesc *fd, + SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, + void *arg) { + TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg); + EXPECT_EQ(agent->ssl_fd(), fd); + if (agent->role() == TlsAgent::SERVER) { + EXPECT_EQ(ssl_hs_client_hello, message); + } else { + EXPECT_TRUE(message == ssl_hs_server_hello || + message == ssl_hs_encrypted_extensions); + } + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + DataBuffer(data, len)); + EXPECT_NE(nullptr, alert); + return SECSuccess; +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe5; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = std::make_shared<TlsExtensionCapture>(extension_code); + client_->SetPacketFilter(capture); + + // Handle it so that the handshake completes. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + NonsenseExtensionHandler, server_.get()); + EXPECT_EQ(SECSuccess, rv); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_server_hello) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in ServerHello. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterSH, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the ServerHello only and check it. + auto capture = std::make_shared<TlsExtensionCapture>(extension_code); + capture->SetHandshakeTypes({kTlsHandshakeServerHello}); + server_->SetPacketFilter(capture); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd, + SSLHandshakeType message, PRUint8 *data, + unsigned int *len, unsigned int maxLen, + void *arg) { + if (message == ssl_hs_encrypted_extensions) { + return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg); + } + return PR_FALSE; +} + +// Send nonsense in an extension from server to client, in EncryptedExtensions. +TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr, + NonsenseExtensionHandler, client_.get()); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NonsenseExtensionWriterEE, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture the extension from the EncryptedExtensions only and check it. + auto capture = std::make_shared<TlsExtensionCapture>(extension_code); + capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions}); + server_->SetTlsRecordFilter(capture); + + Connect(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) { + EnsureTlsSetup(); + + const uint16_t extension_code = 0xff5e; + SECStatus rv = SSL_InstallExtensionHooks( + server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(), + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Capture it to see what we got. + auto capture = std::make_shared<TlsExtensionCapture>(extension_code); + server_->SetPacketFilter(capture); + + client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + + EXPECT_TRUE(capture->captured()); + EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)), + capture->extension()); +} + +SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffe7; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff58; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + RejectExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kTlsAlertHandshakeFailure); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +static const uint8_t kCustomAlert = 0xf6; + +SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message, + const PRUint8 *data, unsigned int len, + SSLAlertDescription *alert, void *arg) { + *alert = kCustomAlert; + return SECFailure; +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nonsense. + const uint16_t extension_code = 0xffea; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Reject the extension for no good reason. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + NoopExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + ConnectExpectAlert(server_, kCustomAlert); +} + +// Send nonsense in an extension from client to server. +TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + const uint16_t extension_code = 0xff5a; + SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + AlertExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + // Have the server send nonsense. + rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code, + EmptyExtensionWriter, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + + client_->ExpectSendAlert(kCustomAlert); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); +} + +// Configure a custom extension hook badly. +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter, + nullptr, nullptr, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) { + EnsureTlsSetup(); + + // This installs an override that sends nothing but expects nonsense. + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr, + NoopExtensionHandler, nullptr); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) { + EnsureTlsSetup(); + // This doesn't actually overrun the buffer, but it says that it does. + auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message, + PRUint8 *data, unsigned int *len, + unsigned int maxLen, void *arg) -> PRBool { + *len = maxLen + 1; + return PR_TRUE; + }; + SECStatus rv = + SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer, + nullptr, NoopExtensionHandler, nullptr); + EXPECT_EQ(SECSuccess, rv); + client_->StartConnect(); + client_->Handshake(); + client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR); +} + +} // namespace "nss_test" diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc index 69fd00331..d1668b823 100644 --- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -29,8 +29,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); std::cerr << "Damaging HS secret" << std::endl; @@ -51,16 +50,12 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { SSL_LIBRARY_VERSION_TLS_1_3); server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_3); - client_->ExpectSendAlert(kTlsAlertDecryptError); - // The server can't read the client's alert, so it also sends an alert. - server_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->SetPacketFilter(std::make_shared<AfterRecordN>( server_, client_, 0, // ServerHello. [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); })); - ConnectExpectFail(); + ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } TEST_P(TlsConnectGenericPre13, DamageServerSignature) { @@ -79,16 +74,7 @@ TEST_P(TlsConnectTls13, DamageServerSignature) { auto filter = std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify); server_->SetTlsRecordFilter(filter); - filter->EnableDecryption(); - client_->ExpectSendAlert(kTlsAlertDecryptError); - // The server can't read the client's alert, so it also sends an alert. - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - ConnectExpectFail(); - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - } else { - ConnectExpectFailOneSide(TlsAgent::CLIENT); - } + ConnectExpectAlert(client_, kTlsAlertDecryptError); client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } @@ -100,11 +86,9 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) { std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify); client_->SetTlsRecordFilter(filter); server_->ExpectSendAlert(kTlsAlertDecryptError); - filter->EnableDecryption(); // Do these handshakes by hand to avoid race condition on // the client processing the server's alert. - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); client_->Handshake(); diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc index 97943303a..4aa3bb639 100644 --- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -59,8 +59,7 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { TEST_P(TlsConnectGeneric, ConnectFfdheClient) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); auto groups_capture = std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn); auto shares_capture = @@ -90,8 +89,7 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) { // because the client automatically sends the supported groups extension. TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { Connect(); @@ -105,14 +103,11 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { public: - TlsDheServerKeyExchangeDamager() {} + TlsDheServerKeyExchangeDamager() + : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Damage the first octet of dh_p. Anything other than the known prime will // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled. *output = input; @@ -126,8 +121,7 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { // the signature until everything else has been checked. TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>()); ConnectExpectAlert(client_, kTlsAlertIllegalParameter); @@ -147,7 +141,8 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { kYZeroPad }; - TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {} + TlsDheSkeChangeY(uint8_t handshake_type, ChangeYTo change) + : TlsHandshakeFilter({handshake_type}), change_Y_(change) {} protected: void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, @@ -213,7 +208,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter { class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { public: TlsDheSkeChangeYServer(ChangeYTo change, bool modify) - : TlsDheSkeChangeY(change), modify_(modify), p_() {} + : TlsDheSkeChangeY(kTlsHandshakeServerKeyExchange, change), + modify_(modify), + p_() {} const DataBuffer& prime() const { return p_; } @@ -221,10 +218,6 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - size_t offset = 2; // Read dh_p uint32_t dh_len = 0; @@ -254,16 +247,13 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { TlsDheSkeChangeYClient( ChangeYTo change, std::shared_ptr<const TlsDheSkeChangeYServer> server_filter) - : TlsDheSkeChangeY(change), server_filter_(server_filter) {} + : TlsDheSkeChangeY(kTlsHandshakeClientKeyExchange, change), + server_filter_(server_filter) {} protected: virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { - return KEEP; - } - ChangeY(input, output, 0, server_filter_->prime()); return CHANGE; } @@ -289,8 +279,7 @@ class TlsDamageDHYTest TEST_P(TlsDamageDHYTest, DamageServerY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); server_->SetPacketFilter( @@ -320,8 +309,7 @@ TEST_P(TlsDamageDHYTest, DamageServerY) { TEST_P(TlsDamageDHYTest, DamageClientY) { EnableOnlyDheCiphers(); if (std::get<3>(GetParam())) { - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); } // The filter on the server is required to capture the prime. auto server_filter = @@ -370,13 +358,10 @@ INSTANTIATE_TEST_CASE_P( class TlsDheSkeMakePEven : public TlsHandshakeFilter { public: + TlsDheSkeMakePEven() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Find the end of dh_p uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -404,13 +389,10 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) { class TlsDheSkeZeroPadP : public TlsHandshakeFilter { public: + TlsDheSkeZeroPadP() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} virtual PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - *output = input; uint32_t dh_len = 0; EXPECT_TRUE(input.Read(0, 2, &dh_len)); @@ -445,8 +427,7 @@ TEST_P(TlsConnectGenericPre13, PadDheP) { // Note: This test case can take ages to generate the weak DH key. TEST_P(TlsConnectGenericPre13, WeakDHGroup) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); @@ -496,8 +477,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) { // custom group in contrast to the previous test. TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048}; @@ -525,8 +505,7 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { TEST_P(TlsConnectGenericPre13, MismatchDHE) { EnableOnlyDheCiphers(); - EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), - SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE); static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group}; EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups, PR_ARRAY_SIZE(serverGroups))); @@ -544,7 +523,8 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); Connect(); SendReceive(); // Need to read so that we absorb the session ticket. - CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); @@ -557,7 +537,8 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { server_->SetPacketFilter(serverCapture); ExpectResumption(RESUME_TICKET); Connect(); - CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); ASSERT_LT(0UL, clientCapture->extension().len()); ASSERT_LT(0UL, serverCapture->extension().len()); } @@ -565,16 +546,15 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) { class TlsDheSkeChangeSignature : public TlsHandshakeFilter { public: TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len) - : version_(version), data_(data), len_(len) {} + : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}), + version_(version), + data_(data), + len_(len) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - TlsParser parser(input); EXPECT_TRUE(parser.SkipVariable(2)); // dh_p EXPECT_TRUE(parser.SkipVariable(2)); // dh_g diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc index 3cc3b0e62..c059e9938 100644 --- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -6,6 +6,7 @@ #include "secerr.h" #include "ssl.h" +#include "sslexp.h" extern "C" { // This is not something that should make you happy. @@ -20,13 +21,13 @@ extern "C" { namespace nss_test { -TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) { +TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) { client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1)); Connect(); SendReceive(); } -TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) { server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1)); Connect(); SendReceive(); @@ -35,36 +36,760 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { // This drops the first transmission from both the client and server of all // flights that they send. Note: In DTLS 1.3, the shorter handshake means that // this will also drop some application data, so we can't call SendReceive(). -TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) { +TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) { client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15)); server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5)); Connect(); } // This drops the server's first flight three times. -TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) { +TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) { server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7)); Connect(); } // This drops the client's second flight once -TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) { +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) { client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2)); Connect(); } // This drops the client's second flight three times. -TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) { +TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) { client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe)); Connect(); } // This drops the server's second flight three times. -TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) { +TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) { server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe)); Connect(); } +class TlsDropDatagram13 : public TlsConnectDatagram13 { + public: + TlsDropDatagram13() + : client_filters_(), + server_filters_(), + expected_client_acks_(0), + expected_server_acks_(1) {} + + void SetUp() { + TlsConnectDatagram13::SetUp(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + SetFilters(); + } + + void SetFilters() { + EnsureTlsSetup(); + client_->SetPacketFilter(client_filters_.chain_); + client_filters_.ack_->SetAgent(client_.get()); + client_filters_.ack_->EnableDecryption(); + server_->SetPacketFilter(server_filters_.chain_); + server_filters_.ack_->SetAgent(server_.get()); + server_filters_.ack_->EnableDecryption(); + } + + void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) { + agent->Handshake(); // Read flight. + ShiftDtlsTimers(); + agent->Handshake(); // Generate ACK. + } + + void ShrinkPostServerHelloMtu() { + // Abuse the custom extension mechanism to modify the MTU so that the + // Certificate message is split into two pieces. + ASSERT_EQ( + SECSuccess, + SSL_InstallExtensionHooks( + server_->ssl_fd(), 1, + [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data, + unsigned int* len, unsigned int maxLen, void* arg) -> PRBool { + SSLInt_SetMTU(fd, 500); // Splits the certificate. + return PR_FALSE; + }, + nullptr, + [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data, + unsigned int len, SSLAlertDescription* alert, + void* arg) -> SECStatus { return SECSuccess; }, + nullptr)); + } + + protected: + class DropAckChain { + public: + DropAckChain() + : records_(std::make_shared<TlsRecordRecorder>()), + ack_(std::make_shared<TlsRecordRecorder>(content_ack)), + drop_(std::make_shared<SelectiveRecordDropFilter>(0, false)), + chain_(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, ack_, drop_}))) {} + + const TlsRecord& record(size_t i) const { return records_->record(i); } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsRecordRecorder> ack_; + std::shared_ptr<SelectiveRecordDropFilter> drop_; + std::shared_ptr<PacketFilter> chain_; + }; + + void CheckAcks(const DropAckChain& chain, size_t index, + std::vector<uint64_t> acks) { + const DataBuffer& buf = chain.ack_->record(index).buffer; + size_t offset = 0; + + EXPECT_EQ(acks.size() * 8, buf.len()); + if ((acks.size() * 8) != buf.len()) { + while (offset < buf.len()) { + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl; + } + return; + } + + for (size_t i = 0; i < acks.size(); ++i) { + uint64_t a = acks[i]; + uint64_t ack; + ASSERT_TRUE(buf.Read(offset, 8, &ack)); + offset += 8; + if (a != ack) { + ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a + << " got=0x" << ack << std::dec; + } + } + } + + void CheckedHandshakeSendReceive() { + Handshake(); + CheckPostHandshake(); + } + + void CheckPostHandshake() { + CheckConnected(); + SendReceive(); + EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count()); + EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count()); + } + + protected: + DropAckChain client_filters_; + DropAckChain server_filters_; + size_t expected_client_acks_; + size_t expected_server_acks_; +}; + +// All of these tests produce a minimum one ACK, from the server +// to the client upon receiving the client Finished. +// Dropping complete first and second flights does not produce +// ACKs +TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) { + client_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) { + server_filters_.drop_->Reset(0xff); + StartConnect(); + client_->Handshake(); + // Send the first flight, all dropped. + server_->Handshake(); + server_filters_.drop_->Disable(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Dropping the server's first record also does not produce +// an ACK because the next record is ignored. +// TODO(ekr@rtfm.com): We should generate an empty ACK. +TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) { + server_filters_.drop_->Reset({0}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + Handshake(); + CheckedHandshakeSendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Dropping the second packet of the server's flight should +// produce an ACK. +TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + HandshakeAndAck(client_); + expected_client_acks_ = 1; + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, {0}); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Drop the server ACK and verify that the client retransmits +// the ClientHello. +TEST_F(TlsDropDatagram13, DropServerAckOnce) { + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // At this point the server has sent it's first flight, + // so make it drop the ACK. + server_filters_.drop_->Reset({0}); + client_->Handshake(); // Send the client Finished. + server_->Handshake(); // Receive the Finished and send the ACK. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Wait for the DTLS timeout to make sure we retransmit the + // Finished. + ShiftDtlsTimers(); + client_->Handshake(); // Retransmit the Finished. + server_->Handshake(); // Read the Finished and send an ACK. + uint8_t buf[1]; + PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + expected_server_acks_ = 2; + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + CheckPostHandshake(); + // There should be two copies of the finished ACK + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Drop the client certificate verify. +TEST_F(TlsDropDatagram13, DropClientCertVerify) { + StartConnect(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->Handshake(); + server_->Handshake(); + // Have the client drop Cert Verify + client_filters_.drop_->Reset({1}); + expected_server_acks_ = 2; + CheckedHandshakeSendReceive(); + // Ack of the Cert. + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + // Ack of the whole client handshake. + CheckAcks( + server_filters_, 1, + {0x0002000000000000ULL, // CH (we drop everything after this on client) + 0x0002000000000003ULL, // CT (2) + 0x0002000000000004ULL} // FIN (2) + ); +} + +// Shrink the MTU down so that certs get split and drop the first piece. +TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) { + server_filters_.drop_->Reset({2}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(2).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN + // Check that the first record is CT1 (which is identical to the same + // as the previous CT1). + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + {0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000002ULL} // CT2 + ); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and drop the second piece. +TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) { + server_filters_.drop_->Reset({3}); + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN + size_t ct1_size = server_filters_.record(3).buffer.len(); + server_filters_.records_->Clear(); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + server_->Handshake(); // Retransmit + EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN + // Check that the first record is CT1 + EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len()); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + { + 0, // SH + 0x0002000000000000ULL, // EE + 0x0002000000000001ULL, // CT1 + }); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// In this test, the Certificate message is sent four times, we drop all or part +// of the first three attempts: +// 1. Without fragmentation so that we can see how big it is - we drop that. +// 2. In two pieces - we drop half AND the resulting ACK. +// 3. In three pieces - we drop the middle piece. +// +// After that we let all the ACKs through and allow the handshake to complete +// without further interference. +// +// This allows us to test that ranges of handshake messages are sent correctly +// even when there are overlapping acknowledgments; that ACKs with duplicate or +// overlapping message ranges are handled properly; and that extra +// retransmissions are handled properly. +class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 { + public: + TlsFragmentationAndRecoveryTest() : cert_len_(0) {} + + protected: + void RunTest(size_t dropped_half) { + FirstFlightDropCertificate(); + + SecondAttemptDropHalf(dropped_half); + size_t dropped_half_size = server_record_len(dropped_half); + size_t second_flight_count = server_filters_.records_->count(); + + ThirdAttemptDropMiddle(); + size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2); + size_t third_flight_count = server_filters_.records_->count(); + + AckAndCompleteRetransmission(); + size_t final_server_flight_count = server_filters_.records_->count(); + EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + SendDelayedAck(); + // Same number of messages as the last flight. + EXPECT_EQ(final_server_flight_count, server_filters_.records_->count()); + // Double check that the Certificate size is still correct. + CheckSizeOfSixth(dropped_half_size, repaired_third_size); + + CompleteHandshake(final_server_flight_count); + + // This is the ACK for the first attempt to send a whole certificate. + std::vector<uint64_t> client_acks = { + 0, // SH + 0x0002000000000000ULL // EE + }; + CheckAcks(client_filters_, 0, client_acks); + // And from the second attempt for the half was kept (we delayed this ACK). + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + ~dropped_half % 2); + CheckAcks(client_filters_, 1, client_acks); + // And the third attempt where the first and last thirds got through. + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count - 1); + client_acks.push_back(0x0002000000000000ULL + second_flight_count + + third_flight_count + 1); + CheckAcks(client_filters_, 2, client_acks); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + } + + private: + void FirstFlightDropCertificate() { + StartConnect(); + client_->Handshake(); + + // Note: 1 << N is the Nth packet, starting from zero. + server_filters_.drop_->Reset(1 << 2); // Drop Cert0. + server_->Handshake(); + EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin + cert_len_ = server_filters_.records_->record(2).buffer.len(); + + HandshakeAndAck(client_); + EXPECT_EQ(2U, client_filters_.records_->count()); + } + + // Lower the MTU so that the server has to split the certificate in two + // pieces. The server resends Certificate (in two), plus CV and Fin. + void SecondAttemptDropHalf(size_t dropped_half) { + ASSERT_LE(0U, dropped_half); + ASSERT_GT(2U, dropped_half); + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half] + SplitServerMtu(2); + server_->Handshake(); + EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin + + // Generate and capture the ACK from the client. + client_filters_.drop_->Reset({0}); + HandshakeAndAck(client_); + EXPECT_EQ(3U, client_filters_.records_->count()); + } + + // Lower the MTU again so that the server sends Certificate cut into three + // pieces. Drop the middle piece. + void ThirdAttemptDropMiddle() { + server_filters_.records_->Clear(); + server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3) + SplitServerMtu(3); + // Because we dropped the client ACK, the server retransmits on a timer. + ShiftDtlsTimers(); + server_->Handshake(); + EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin + } + + void AckAndCompleteRetransmission() { + // Generate ACKs. + HandshakeAndAck(client_); + // The server should send the final sixth of the certificate: the client has + // acknowledged the first half and the last third. Also send CV and Fin. + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) { + // Work out if the final sixth is the right size. We get the records with + // overheads added, which obscures the length of the payload. We want to + // ensure that the server only sent the missing sixth of the Certificate. + // + // We captured |size_of_half + overhead| and |size_of_third + overhead| and + // want to calculate |size_of_third - size_of_third + overhead|. We can't + // calculate |overhead|, but it is is (currently) always a handshake message + // header, a content type, and an authentication tag: + static const size_t record_overhead = 12 + 1 + 16; + EXPECT_EQ(size_of_half - size_of_third + record_overhead, + server_filters_.records_->record(0).buffer.len()); + } + + void SendDelayedAck() { + // Send the ACK we held back. The reordered ACK doesn't add new + // information, + // but triggers an extra retransmission of the missing records again (even + // though the client has all that it needs). + client_->SendRecordDirect(client_filters_.records_->record(2)); + server_filters_.records_->Clear(); + server_->Handshake(); + } + + void CompleteHandshake(size_t extra_retransmissions) { + // All this messing around shouldn't cause a failure... + Handshake(); + // ...but it leaves a mess. Add an extra few calls to Handshake() for the + // client so that it absorbs the extra retransmissions. + for (size_t i = 0; i < extra_retransmissions; ++i) { + client_->Handshake(); + } + CheckConnected(); + } + + // Split the server MTU so that the Certificate is split into |count| pieces. + // The calculation doesn't need to be perfect as long as the Certificate + // message is split into the right number of pieces. + void SplitServerMtu(size_t count) { + // Set the MTU based on the formula: + // bare_size = cert_len_ - actual_overhead + // MTU = ceil(bare_size / count) + pessimistic_overhead + // + // actual_overhead is the amount of actual overhead on the record we + // captured, which is (note that our length doesn't include the header): + static const size_t actual_overhead = 12 + // handshake message header + 1 + // content type + 16; // authentication tag + size_t bare_size = cert_len_ - actual_overhead; + + // pessimistic_overhead is the amount of expansion that NSS assumes will be + // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT: + static const size_t pessimistic_overhead = + 12 + // handshake message header + 1 + // content type + 13 + // record header length + 64; // maximum record expansion: IV, MAC and block cipher expansion + + size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead; + if (g_ssl_gtest_verbose) { + std::cerr << "server: set MTU to " << mtu << std::endl; + } + EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu)); + } + + size_t server_record_len(size_t index) const { + return server_filters_.records_->record(index).buffer.len(); + } + + size_t cert_len_; +}; + +TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); } + +TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); } + +TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + server_filters_.drop_->Reset({1}); + ZeroRttSendReceive(true, true); + HandshakeAndAck(client_); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + CheckAcks(client_filters_, 0, {0}); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +class TlsReorderDatagram13 : public TlsDropDatagram13 { + public: + TlsReorderDatagram13() {} + + // Send records from the records buffer in the given order. + void ReSend(TlsAgent::Role side, std::vector<size_t> indices) { + std::shared_ptr<TlsAgent> agent; + std::shared_ptr<TlsRecordRecorder> records; + + if (side == TlsAgent::CLIENT) { + agent = client_; + records = client_filters_.records_; + } else { + agent = server_; + records = server_filters_.records_; + } + + for (auto i : indices) { + agent->SendRecordDirect(records->record(i)); + } + } +}; + +// Reorder the server records so that EE comes at the end +// of the flight and will still produce an ACK. +TEST_F(TlsDropDatagram13, ReorderServerEE) { + server_filters_.drop_->Reset({1}); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // We dropped EE, now reinject. + server_->SendRecordDirect(server_filters_.record(1)); + expected_client_acks_ = 1; + HandshakeAndAck(client_); + CheckedHandshakeSendReceive(); + CheckAcks(client_filters_, 0, + { + 0, // SH + 0x0002000000000000, // EE + }); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +// The client sends an out of order non-handshake message +// but with the handshake key. +class TlsSendCipherSpecCapturer { + public: + TlsSendCipherSpecCapturer(std::shared_ptr<TlsAgent>& agent) + : send_cipher_specs_() { + SSLInt_SetCipherSpecChangeFunc(agent->ssl_fd(), CipherSpecChanged, + (void*)this); + } + + std::shared_ptr<TlsCipherSpec> spec(size_t i) { + if (i >= send_cipher_specs_.size()) { + return nullptr; + } + return send_cipher_specs_[i]; + } + + private: + static void CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec) { + if (!sending) { + return; + } + + auto self = static_cast<TlsSendCipherSpecCapturer*>(arg); + + auto spec = std::make_shared<TlsCipherSpec>(); + bool ret = spec->Init(SSLInt_CipherSpecToEpoch(newSpec), + SSLInt_CipherSpecToAlgorithm(newSpec), + SSLInt_CipherSpecToKey(newSpec), + SSLInt_CipherSpecToIv(newSpec)); + EXPECT_EQ(true, ret); + self->send_cipher_specs_.push_back(spec); + } + + std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_; +}; + +TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // After the client sends Finished, inject an app data record + // with the handshake key. This should produce an alert. + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + ASSERT_TRUE(client_->SendEncryptedRecord( + spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, + kTlsApplicationDataType, DataBuffer(buf, sizeof(buf)))); + + // Now have the server consume the bogus message. + server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state()); + EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError()); +} + +TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) { + StartConnect(); + TlsSendCipherSpecCapturer capturer(client_); + client_->Handshake(); + server_->Handshake(); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + // Inject a new bogus handshake record, which the server responds + // to by just ACKing the original one (we ignore the contents). + uint8_t buf[] = {'a', 'b', 'c'}; + auto spec = capturer.spec(0); + ASSERT_NE(nullptr, spec.get()); + ASSERT_EQ(2, spec->epoch()); + ASSERT_TRUE(client_->SendEncryptedRecord( + spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002, + kTlsHandshakeType, DataBuffer(buf, sizeof(buf)))); + server_->Handshake(); + EXPECT_EQ(2UL, server_filters_.ack_->count()); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + CheckAcks(server_filters_, 1, {0x0002000000000000ULL}); +} + +// Shrink the MTU down so that certs get split and then swap the first and +// second pieces of the server certificate. +TEST_F(TlsReorderDatagram13, ReorderServerCertificate) { + StartConnect(); + ShrinkPostServerHelloMtu(); + client_->Handshake(); + // Drop the entire handshake flight so we can reorder. + server_filters_.drop_->Reset(0xff); + server_->Handshake(); + // Check that things got split. + EXPECT_EQ(6UL, + server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN + // Now re-send things in a different order. + ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5}); + // Clear. + server_filters_.drop_->Disable(); + server_filters_.records_->Clear(); + // Wait for client to send ACK. + ShiftDtlsTimers(); + CheckedHandshakeSendReceive(); + EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); +} + +TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, data, FIN + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2})); + server_->Handshake(); + CheckConnected(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) { + SetupForZeroRtt(); + SetFilters(); + std::cerr << "Starting second handshake" << std::endl; + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + // Send the client's first flight of zero RTT data. + ZeroRttSendReceive(true, true); + // Now send another client application data record but + // capture it. + client_filters_.records_->Clear(); + client_filters_.drop_->Reset(0xff); + const char* k0RttData = "123456"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + EXPECT_EQ(k0RttDataLen, rv); + EXPECT_EQ(1UL, client_filters_.records_->count()); // data + server_->Handshake(); + client_->Handshake(); + ExpectEarlyDataAccepted(true); + // The server still hasn't received anything at this point. + EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + // Now re-send the client's messages: EOED, FIN, Data + ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0})); + server_->Handshake(); + CheckConnected(); + CheckAcks(server_filters_, 0, {0x0002000000000000ULL}); + uint8_t buf[8]; + rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(-1, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, uint64_t* limit = nullptr) { uint64_t l; @@ -111,7 +836,6 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindow) { GetCipherAndLimit(version_, &cipher); server_->EnableSingleCipher(cipher); Connect(); - EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0)); SendReceive(); } @@ -129,5 +853,7 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) { INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus, TlsConnectTestBase::kTlsV12Plus); +INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13, + TlsConnectTestBase::kTlsV11V12); } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc index 1e406b6c2..e0f8b1f55 100644 --- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -193,7 +193,9 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { public: - TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {} + TlsKeyExchangeGroupCapture() + : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}), + group_(ssl_grp_none) {} SSLNamedGroup group() const { return group_; } @@ -201,10 +203,6 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - uint32_t value = 0; EXPECT_TRUE(input.Read(0, 1, &value)); EXPECT_EQ(3U, value) << "curve type has to be 3"; @@ -518,16 +516,12 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { // Replace the point in the client key exchange message with an empty one class ECCClientKEXFilter : public TlsHandshakeFilter { public: - ECCClientKEXFilter() {} + ECCClientKEXFilter() : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { - return KEEP; - } - // Replace the client key exchange message with an empty point output->Allocate(1); output->Write(0, 0U, 1); // set point length 0 @@ -538,20 +532,16 @@ class ECCClientKEXFilter : public TlsHandshakeFilter { // Replace the point in the server key exchange message with an empty one class ECCServerKEXFilter : public TlsHandshakeFilter { public: - ECCServerKEXFilter() {} + ECCServerKEXFilter() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {} protected: virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, const DataBuffer &input, DataBuffer *output) { - if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { - return KEEP; - } - // Replace the server key exchange message with an empty point output->Allocate(4); output->Write(0, 3U, 1); // named curve - uint32_t curve; + uint32_t curve = 0; EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id output->Write(1, curve, 2); // write curve id output->Write(3, 0U, 1); // point length 0 diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc index be407b42e..c42883eb7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc @@ -118,7 +118,6 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, TEST_P(TlsConnectTls13, EarlyExporter) { SetupForZeroRtt(); - ExpectAlert(client_, kTlsAlertEndOfEarlyData); client_->Set0RttEnabled(true); server_->Set0RttEnabled(true); ExpectResumption(RESUME_TICKET); diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index d15139419..4142ab07a 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -61,60 +61,14 @@ class TlsExtensionDamager : public TlsExtensionFilter { size_t index_; }; -class TlsExtensionInjector : public TlsHandshakeFilter { - public: - TlsExtensionInjector(uint16_t ext, DataBuffer& data) - : extension_(ext), data_(data) {} - - virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindExtensions(&parser, header)) { - return KEEP; - } - size_t offset = parser.consumed(); - - *output = input; - - // Increase the size of the extensions. - uint16_t ext_len; - memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); - ext_len = htons(ntohs(ext_len) + data_.len() + 4); - memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); - - // Insert the extension type and length. - DataBuffer type_length; - type_length.Allocate(4); - type_length.Write(0, extension_, 2); - type_length.Write(2, data_.len(), 2); - output->Splice(type_length, offset + 2); - - // Insert the payload. - if (data_.len() > 0) { - output->Splice(data_, offset + 6); - } - - return CHANGE; - } - - private: - const uint16_t extension_; - const DataBuffer data_; -}; - class TlsExtensionAppender : public TlsHandshakeFilter { public: TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data) - : handshake_type_(handshake_type), extension_(ext), data_(data) {} + : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() != handshake_type_) { - return KEEP; - } - TlsParser parser(input); if (!TlsExtensionFilter::FindExtensions(&parser, header)) { return KEEP; @@ -159,7 +113,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter { return true; } - const uint8_t handshake_type_; const uint16_t extension_; const DataBuffer data_; }; @@ -200,8 +153,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase { client_->ConfigNamedGroups(client_groups); server_->ConfigNamedGroups(server_groups); EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello server_->Handshake(); // Send HRR. client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type)); @@ -1009,7 +961,6 @@ class TlsBogusExtensionTest : public TlsConnectTestBase, std::make_shared<TlsExtensionAppender>(message, extension, empty); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { server_->SetTlsRecordFilter(filter); - filter->EnableDecryption(); } else { server_->SetPacketFilter(filter); } @@ -1032,17 +983,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest { class TlsBogusExtensionTest13 : public TlsBogusExtensionTest { protected: void ConnectAndFail(uint8_t message) override { - if (message == kTlsHandshakeHelloRetryRequest) { + if (message != kTlsHandshakeServerHello) { ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension); return; } - client_->StartConnect(); - server_->StartConnect(); + FailWithAlert(kTlsAlertUnsupportedExtension); + } + + void FailWithAlert(uint8_t alert) { + StartConnect(); client_->Handshake(); // ClientHello server_->Handshake(); // ServerHello - client_->ExpectSendAlert(kTlsAlertUnsupportedExtension); + client_->ExpectSendAlert(alert); client_->Handshake(); if (variant_ == ssl_variant_stream) { server_->ExpectSendAlert(kTlsAlertBadRecordMac); @@ -1067,9 +1021,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) { Run(kTlsHandshakeCertificate); } +// It's perfectly valid to set unknown extensions in CertificateRequest. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) { server_->RequestClientAuth(false); - Run(kTlsHandshakeCertificateRequest); + AddFilter(kTlsHandshakeCertificateRequest, 0xff); + ConnectExpectAlert(client_, kTlsAlertDecryptError); + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); } TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { @@ -1079,10 +1036,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) { Run(kTlsHandshakeHelloRetryRequest); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) { - Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn); -} - TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) { Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn); } @@ -1096,13 +1049,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) { Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn); } -TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) { - static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; - server_->ConfigNamedGroups(groups); - - Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn); -} - // NewSessionTicket allows unknown extensions AND it isn't protected by the // Finished. So adding an unknown extension doesn't cause an error. TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) { diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc index 44cacce46..64b824786 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -51,10 +51,16 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsHandshakeFilter::HandshakeHeader handshake_header; DataBuffer handshake_body; - if (!handshake_header.Parse(&parser, record_header, &handshake_body)) { + bool complete = false; + if (!handshake_header.Parse(&parser, record_header, DataBuffer(), + &handshake_body, &complete)) { ADD_FAILURE() << "couldn't parse handshake header"; return false; } + if (!complete) { + ADD_FAILURE() << "don't want to deal with fragmented messages"; + return false; + } DataBuffer record_fragment; // We can't fragment handshake records that are too small. @@ -82,7 +88,7 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(&parser, &record)) { + if (!header.Parse(0, &parser, &record)) { ADD_FAILURE() << "bad record header"; return false; } diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc index 1587b66de..ab4c0eab7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -47,9 +47,9 @@ class TlsApplicationDataRecorder : public TlsRecordFilter { // Ensure that ssl_Time() returns a constant value. FUZZ_F(TlsFuzzTest, SSL_Time_Constant) { - PRUint32 now = ssl_Time(); + PRUint32 now = ssl_TimeSec(); PR_Sleep(PR_SecondsToInterval(2)); - EXPECT_EQ(ssl_Time(), now); + EXPECT_EQ(ssl_TimeSec(), now); } // Check that due to the deterministic PRNG we derive @@ -215,58 +215,6 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) { SendReceive(); } -class TlsSessionTicketMacDamager : public TlsExtensionFilter { - public: - TlsSessionTicketMacDamager() {} - virtual PacketFilter::Action FilterExtension(uint16_t extension_type, - const DataBuffer& input, - DataBuffer* output) { - if (extension_type != ssl_session_ticket_xtn && - extension_type != ssl_tls13_pre_shared_key_xtn) { - return KEEP; - } - - *output = input; - - // Handle everything before TLS 1.3. - if (extension_type == ssl_session_ticket_xtn) { - // Modify the last byte of the MAC. - output->data()[output->len() - 1] ^= 0xff; - } - - // Handle TLS 1.3. - if (extension_type == ssl_tls13_pre_shared_key_xtn) { - TlsParser parser(input); - - uint32_t ids_len; - EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0); - - uint32_t ticket_len; - EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0); - - // Modify the last byte of the MAC. - output->data()[2 + 2 + ticket_len - 1] ^= 0xff; - } - - return CHANGE; - } -}; - -// Check that session ticket resumption works with a bad MAC. -FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) { - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - Connect(); - SendReceive(); - - Reset(); - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - ExpectResumption(RESUME_TICKET); - - client_->SetPacketFilter(std::make_shared<TlsSessionTicketMacDamager>()); - Connect(); - SendReceive(); -} - // Check that session tickets are not encrypted. FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); @@ -276,10 +224,13 @@ FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) { server_->SetPacketFilter(i1); Connect(); + std::cerr << "ticket" << i1->buffer() << std::endl; size_t offset = 4; /* lifetime */ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { - offset += 1 + 1 + /* ke_modes */ - 1 + 1; /* auth_modes */ + offset += 4; /* ticket_age_add */ + uint32_t nonce_len = 0; + EXPECT_TRUE(i1->buffer().Read(offset, 1, &nonce_len)); + offset += 1 + nonce_len; } offset += 2 + /* ticket length */ 2; /* TLS_EX_SESS_TICKET_VERSION */ diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc index cd10076b8..2fff9d7cb 100644 --- a/security/nss/gtests/ssl_gtest/ssl_gtest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc @@ -6,6 +6,7 @@ #include <cstdlib> #include "test_io.h" +#include "databuffer.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -28,6 +29,7 @@ int main(int argc, char** argv) { ++i; } else if (!strcmp(argv[i], "-v")) { g_ssl_gtest_verbose = true; + nss_test::DataBuffer::SetLogLimit(16384); } } diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp index 8cd7d1009..e2a8d830a 100644 --- a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -11,6 +11,7 @@ 'target_name': 'ssl_gtest', 'type': 'executable', 'sources': [ + 'bloomfilter_unittest.cc', 'libssl_internals.c', 'selfencrypt_unittest.cc', 'ssl_0rtt_unittest.cc', @@ -18,6 +19,7 @@ 'ssl_auth_unittest.cc', 'ssl_cert_ext_unittest.cc', 'ssl_ciphersuite_unittest.cc', + 'ssl_custext_unittest.cc', 'ssl_damage_unittest.cc', 'ssl_dhe_unittest.cc', 'ssl_drop_unittest.cc', @@ -30,11 +32,16 @@ 'ssl_gather_unittest.cc', 'ssl_gtest.cc', 'ssl_hrr_unittest.cc', + 'ssl_keylog_unittest.cc', + 'ssl_keyupdate_unittest.cc', 'ssl_loopback_unittest.cc', + 'ssl_misc_unittest.cc', 'ssl_record_unittest.cc', 'ssl_resumption_unittest.cc', + 'ssl_renegotiation_unittest.cc', 'ssl_skip_unittest.cc', 'ssl_staticrsa_unittest.cc', + 'ssl_tls13compat_unittest.cc', 'ssl_v2_client_hello_unittest.cc', 'ssl_version_unittest.cc', 'ssl_versionpolicy_unittest.cc', diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc index 39055f641..93e19a720 100644 --- a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -187,6 +187,590 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); } +// Here we modify the second ClientHello so that the client retries with the +// same shares, even though the server wanted something else. +TEST_P(TlsConnectTls13, RetryWithTwoShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + client_->SetPacketFilter(std::make_shared<KeyShareReplayer>()); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); +} + +TEST_P(TlsConnectTls13, RetryCallbackAccept) { + EnsureTlsSetup(); + + auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + accept_hello, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) { + EnsureTlsSetup(); + + auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_accept; + }; + + auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn); + capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(capture); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_run = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), accept_hello_twice, &cb_run)); + Connect(); + EXPECT_EQ(2U, cb_run); + EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest"; +} + +TEST_P(TlsConnectTls13, RetryCallbackFail) { + EnsureTlsSetup(); + + auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(0U, clientTokenLen); + return ssl_hello_retry_fail; + }; + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + fail_hello, &cb_run)); + ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); + server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT); + EXPECT_TRUE(cb_run); +} + +// Asking for retry twice isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + return ssl_hello_retry_request; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// Accepting the CH and modifying the token isn't allowed. +TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// As above, but with reject. +TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = 1; + return ssl_hello_retry_fail; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +// This is a (pretend) buffer overflow. +TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) { + EnsureTlsSetup(); + + auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken, + unsigned int clientTokenLen, PRUint8* appToken, + unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) -> SSLHelloRetryRequestAction { + *appTokenLen = appTokenMax + 1; + return ssl_hello_retry_accept; + }; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + bad_callback, NULL)); + ConnectExpectAlert(server_, kTlsAlertInternalError); + server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR); +} + +SSLHelloRetryRequestAction RetryHello(PRBool firstHello, + const PRUint8* clientToken, + unsigned int clientTokenLen, + PRUint8* appToken, + unsigned int* appTokenLen, + unsigned int appTokenMax, void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + EXPECT_EQ(0U, clientTokenLen); + return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetry) { + EnsureTlsSetup(); + + auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>( + ssl_hs_hello_retry_request); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr, + capture_key_share}; + server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(chain)); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected"; + EXPECT_FALSE(capture_key_share->captured()) + << "no key_share extension expected"; + + auto capture_cookie = + std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn); + client_->SetPacketFilter(capture_cookie); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie"; +} + +static size_t CountShares(const DataBuffer& key_share) { + size_t count = 0; + uint32_t len = 0; + size_t offset = 2; + + EXPECT_TRUE(key_share.Read(0, 2, &len)); + EXPECT_EQ(key_share.len() - 2, len); + while (offset < key_share.len()) { + offset += 2; // Skip KeyShareEntry.group + EXPECT_TRUE(key_share.Read(offset, 2, &len)); + offset += 2 + len; // Skip KeyShareEntry.key_exchange + ++count; + } + return count; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(capture_server); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_FALSE(capture_server->captured()) + << "no key_share extension expected from server"; + + auto capture_client_2nd = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + client_->SetPacketFilter(capture_client_2nd); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share"; + EXPECT_EQ(2U, CountShares(capture_client_2nd->extension())) + << "client should still send two shares"; +} + +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants a P-384 key share and the client didn't offer one. +TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) { + EnsureTlsSetup(); + + auto capture_cookie = + std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn); + capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_cookie, capture_key_share})); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_cookie->captured()) << "cookie expected"; + EXPECT_TRUE(capture_key_share->captured()) << "key_share expected"; +} + +static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00}; + +SSLHelloRetryRequestAction RetryHelloWithToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<size_t*>(arg); + ++*called; + + if (firstHello) { + memcpy(appToken, kApplicationToken, sizeof(kApplicationToken)); + *appTokenLen = sizeof(kApplicationToken); + return ssl_hello_retry_request; + } + + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) { + EnsureTlsSetup(); + + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(capture_key_share); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_FALSE(capture_key_share->captured()) << "no key share expected"; +} + +TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) { + EnsureTlsSetup(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + auto capture_key_share = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(capture_key_share); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, + SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHelloWithToken, &cb_called)); + Connect(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(capture_key_share->captured()) << "key share expected"; +} + +SSLHelloRetryRequestAction CheckTicketToken( + PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen, + PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax, + void* arg) { + auto* called = reinterpret_cast<bool*>(arg); + *called = true; + + EXPECT_TRUE(firstHello); + EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)), + DataBuffer(clientToken, static_cast<size_t>(clientTokenLen))); + return ssl_hello_retry_accept; +} + +// Stream because SSL_SendSessionTicket only supports that. +TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken, + sizeof(kApplicationToken))); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + + bool cb_run = false; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback( + server_->ssl_fd(), CheckTicketToken, &cb_run)); + Connect(); + EXPECT_TRUE(cb_run); +} + +void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client, + std::shared_ptr<TlsAgent>& server) { + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(), + RetryHello, &cb_called)); + + // Start the handshake. + client->StartConnect(); + server->StartConnect(); + client->Handshake(); + server->Handshake(); + EXPECT_EQ(1U, cb_called); +} + +TEST_P(TlsConnectTls13, RetryStateless) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + Handshake(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, RetryStatefulDropCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + client_->SetPacketFilter( + std::make_shared<TlsExtensionDropper>(ssl_tls13_cookie_xtn)); + + ExpectAlert(server_, kTlsAlertMissingExtension); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION); +} + +// Stream only because DTLS drops bad packets. +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer()); + client_->SetPacketFilter(damage_ch); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer()); + client_->SetPacketFilter(damage_ch); + + // Key exchange fails when the handshake continues because client and server + // disagree about the transcript. + client_->ExpectSendAlert(kTlsAlertBadRecordMac); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +// Read the cipher suite from the HRR and disable it on the identified agent. +static void DisableSuiteFromHrr( + std::shared_ptr<TlsAgent>& agent, + std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture_hrr) { + uint32_t tmp; + size_t offset = 2 + 32; // skip version + server_random + ASSERT_TRUE( + capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length + EXPECT_EQ(0U, tmp); + offset += 1 + tmp; + ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite + EXPECT_EQ( + SECSuccess, + SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE)); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>( + ssl_hs_hello_retry_request); + server_->SetPacketFilter(capture_hrr); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(client_, capture_hrr); + + // The client thinks that the HelloRetryRequest is bad, even though its + // because it changed its mind about the cipher suite. + ExpectAlert(client_, kTlsAlertIllegalParameter); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>( + ssl_hs_hello_retry_request); + server_->SetPacketFilter(capture_hrr); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + DisableSuiteFromHrr(server_, capture_hrr); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(groups); + + // We're into undefined behavior on the client side, but - at the point this + // test was written - the client here doesn't amend its key shares because the + // server doesn't ask it to. The server notices that the key share (x25519) + // doesn't match the negotiated group (P-384) and objects. + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + MakeNewServer(); + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectTls13, RetryStatelessBadCookie) { + ConfigureSelfEncrypt(); + EnsureTlsSetup(); + + TriggerHelloRetryRequest(client_, server_); + + // Now replace the self-encrypt MAC key with a garbage key. + static const uint8_t bad_hmac_key[32] = {0}; + SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key), + sizeof(bad_hmac_key)}; + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + PK11SymKey* hmac_key = + PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap, + CKA_SIGN, &key_item, nullptr); + ASSERT_NE(nullptr, hmac_key); + SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership. + + MakeNewServer(); + + ExpectAlert(server_, kTlsAlertIllegalParameter); + Handshake(); + server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Stream because the server doesn't consume the alert and terminate. +TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) { + EnsureTlsSetup(); + // Force a HelloRetryRequest. + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + // Then switch out the default suite (TLS_AES_128_GCM_SHA256). + server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>( + TLS_CHACHA20_POLY1305_SHA256)); + + client_->ExpectSendAlert(kTlsAlertIllegalParameter); + server_->ExpectSendAlert(kTlsAlertBadRecordMac); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + // This tests that the second attempt at sending a ClientHello (after receiving // a HelloRetryRequest) is correctly retransmitted. TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { @@ -233,6 +817,54 @@ TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) { CheckKEXDetails(client_groups, client_groups); } +// The callback should be run even if we have another reason to send +// HelloRetryRequest. In this case, the server sends HRR because the server +// wants an X25519 key share and the client didn't offer one. +TEST_P(TlsKeyExchange13, + RetryCallbackRetryWithGroupMismatchAndAdditionalShares) { + EnsureKeyShareSetup(); + + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519}; + server_->ConfigNamedGroups(server_groups); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + auto capture_server = + std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn); + capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest}); + server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit{capture_hrr_, capture_server})); + + size_t cb_called = 0; + EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(), + RetryHello, &cb_called)); + + // Do the first message exchange. + StartConnect(); + client_->Handshake(); + server_->Handshake(); + + EXPECT_EQ(1U, cb_called) << "callback should be called once here"; + EXPECT_TRUE(capture_server->captured()) << "key_share extension expected"; + + uint32_t server_group = 0; + EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group)); + EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group)); + + Handshake(); + CheckConnected(); + EXPECT_EQ(2U, cb_called); + EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares"; + + CheckKeys(); + static const std::vector<SSLNamedGroup> client_shares( + client_groups.begin(), client_groups.begin() + 2); + CheckKEXDetails(client_groups, client_shares, server_groups[0]); +} + TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { EnsureTlsSetup(); client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, @@ -245,8 +877,7 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { static const std::vector<SSLNamedGroup> server_groups = { ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; server_->ConfigNamedGroups(server_groups); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); @@ -276,15 +907,30 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient { void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record, uint32_t seq_num = 0) const { DataBuffer hrr_data; - hrr_data.Allocate(len + 4); + const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + + hrr_data.Allocate(len + 6); size_t i = 0; + i = hrr_data.Write(i, 0x0303, 2); + i = hrr_data.Write(i, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)); + i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id + i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2); + i = hrr_data.Write(i, ssl_compression_null, 1); + // Add extensions. First a length, which includes the supported version. + i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2); + // Now the supported version. + i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2); + i = hrr_data.Write(i, 2, 2); i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2); - i = hrr_data.Write(i, static_cast<uint32_t>(len), 2); if (len) { hrr_data.Write(i, body, len); } DataBuffer hrr; - MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(), + MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(), hrr_data.len(), &hrr, seq_num); MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(), hrr.len(), hrr_record, seq_num); @@ -334,28 +980,6 @@ TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) { SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); } -TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { - const uint8_t canned_cookie_hrr[] = { - static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8), - static_cast<uint8_t>(ssl_tls13_cookie_xtn), - 0, - 5, // length of cookie extension - 0, - 3, // cookie value length - 0xc0, - 0x0c, - 0x13}; - DataBuffer hrr; - MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr); - auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn); - agent_->SetPacketFilter(capture); - ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); - const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len - DataBuffer cookie(canned_cookie_hrr + cookie_pos, - sizeof(canned_cookie_hrr) - cookie_pos); - EXPECT_EQ(cookie, capture->extension()); -} - INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV13)); diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc new file mode 100644 index 000000000..8ed342305 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc @@ -0,0 +1,118 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifdef NSS_ALLOW_SSLKEYLOGFILE + +#include <cstdlib> +#include <fstream> +#include <sstream> + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +static const std::string keylog_file_path = "keylog.txt"; +static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path; + +class KeyLogFileTest : public TlsConnectGeneric { + public: + void SetUp() { + TlsConnectTestBase::SetUp(); + // Remove previous results (if any). + (void)remove(keylog_file_path.c_str()); + PR_SetEnv(keylog_env.c_str()); + } + + void CheckKeyLog() { + std::ifstream f(keylog_file_path); + std::map<std::string, size_t> labels; + std::set<std::string> client_randoms; + for (std::string line; std::getline(f, line);) { + if (line[0] == '#') { + continue; + } + + std::istringstream iss(line); + std::string label, client_random, secret; + iss >> label >> client_random >> secret; + + ASSERT_EQ(64U, client_random.size()); + client_randoms.insert(client_random); + labels[label]++; + } + + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(1U, client_randoms.size()); + } else { + /* two handshakes for 0-RTT */ + ASSERT_EQ(2U, client_randoms.size()); + } + + // Every entry occurs twice (one log from server, one from client). + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_EQ(2U, labels["CLIENT_RANDOM"]); + } else { + ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]); + ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]); + ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]); + ASSERT_EQ(4U, labels["EXPORTER_SECRET"]); + } + } + + void ConnectAndCheck() { + // This is a child process, ensure that error messages immediately + // propagate or else it will not be visible. + ::testing::GTEST_FLAG(throw_on_failure) = true; + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); + } else { + Connect(); + } + CheckKeyLog(); + _exit(0); + } +}; + +// Tests are run in a separate process to ensure that NSS is not initialized yet +// and can process the SSLKEYLOGFILE environment variable. + +TEST_P(KeyLogFileTest, KeyLogFile) { + testing::GTEST_FLAG(death_test_style) = "threadsafe"; + + ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), ""); +} + +INSTANTIATE_TEST_CASE_P( + KeyLogFileDTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS12, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV10ToV12)); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P( + KeyLogFileTLS13, KeyLogFileTest, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test + +#endif // NSS_ALLOW_SSLKEYLOGFILE diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc new file mode 100644 index 000000000..d03775c25 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc @@ -0,0 +1,178 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// All stream only tests; DTLS isn't supported yet. + +TEST_F(TlsConnectTest, KeyUpdateClient) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 3); +} + +TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + // SendReceive() only gives each peer one chance to read. This isn't enough + // when the read on one side generates another handshake message. A second + // read gives each peer an extra chance to consume the KeyUpdate. + SendReceive(50); + SendReceive(60); // Cumulative count. + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServer) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(3, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(4, 4); +} + +TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // The server should have updated twice, but the client should have declined + // to respond to the second request from the server, since it doesn't send + // anything in between those two requests. + CheckEpochs(4, 5); +} + +// Check that a local update can be immediately followed by a remotely triggered +// update even if there is no use of the keys. +TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // This should trigger an update on the client. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + // The client should update for the first request. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + // ...but not the second. + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + // Both should have updated twice. + CheckEpochs(5, 5); +} + +TEST_F(TlsConnectTest, KeyUpdateMultiple) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 6); +} + +// Both ask the other for an update, and both should react. +TEST_F(TlsConnectTest, KeyUpdateBothRequest) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE)); + SendReceive(50); + SendReceive(60); + CheckEpochs(5, 5); +} + +// If the sequence number exceeds the number of writes before an automatic +// update (currently 3/4 of the max records for the cipher suite), then the +// stack should send an update automatically (but not request one). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Set this to one below the write threshold. + uint64_t threshold = (0x5aULL << 28) * 3 / 4; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should be OK. + client_->SendData(10); + server_->ReadBytes(); + + // This should cause the client to update. + client_->SendData(10); + server_->ReadBytes(); + + SendReceive(100); + CheckEpochs(4, 3); +} + +// If the sequence number exceeds a certain number of reads (currently 7/8 of +// the max records for the cipher suite), then the stack should send AND request +// an update automatically. However, the sender (client) will be above its +// automatic update threshold, so the KeyUpdate - that it sends with the old +// cipher spec - will exceed the receiver (server) automatic update threshold. +// The receiver gets a packet with a sequence number over its automatic read +// update threshold. Even though the sender has updated, the code that checks +// the sequence numbers at the receiver doesn't know this and it will request an +// update. This causes two updates: one from the sender (without requesting a +// response) and one from the receiver (which does request a response). +TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256); + + // Move to right at the read threshold. Unlike the write test, we can't send + // packets because that would cause the client to update, which would spoil + // the test. + uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1; + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold)); + EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold)); + + // This should cause the client to update, but not early enough to prevent the + // server from updating also. + client_->SendData(10); + server_->ReadBytes(); + + // Need two SendReceive() calls to ensure that the update that the server + // requested is properly generated and consumed. + SendReceive(70); + SendReceive(80); + CheckEpochs(5, 4); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 77703dd8e..4bc6e60ab 100644 --- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -6,6 +6,7 @@ #include <functional> #include <memory> +#include <vector> #include "secerr.h" #include "ssl.h" #include "sslerr.h" @@ -84,13 +85,13 @@ class TlsAlertRecorder : public TlsRecordFilter { }; class HelloTruncator : public TlsHandshakeFilter { + public: + HelloTruncator() + : TlsHandshakeFilter( + {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientHello && - header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } output->Assign(input.data(), input.len() - 1); return CHANGE; } @@ -102,9 +103,9 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) { auto alert_recorder = std::make_shared<TlsAlertRecorder>(); server_->SetPacketFilter(alert_recorder); - ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { @@ -123,8 +124,7 @@ TEST_P(TlsConnectTls13, CaptureAlertClient) { auto alert_recorder = std::make_shared<TlsAlertRecorder>(); client_->SetPacketFilter(alert_recorder); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); client_->ExpectSendAlert(kTlsAlertDecodeError); @@ -166,26 +166,107 @@ TEST_P(TlsConnectDatagram, ConnectSrtp) { SendReceive(); } -// 1.3 is disabled in the next few tests because we don't -// presently support resumption in 1.3. -TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) { +TEST_P(TlsConnectGeneric, ConnectSendReceive) { Connect(); - server_->PrepareForRenegotiate(); - client_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); +} + +class SaveTlsRecord : public TlsRecordFilter { + public: + SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {} + + const DataBuffer& contents() const { return contents_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + contents_ = data; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; + DataBuffer contents_; +}; + +// Check that decrypting filters work and can read any record. +// This test (currently) only works in TLS 1.3 where we can decrypt. +TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer + auto saved = std::make_shared<SaveTlsRecord>(3); + client_->SetTlsRecordFilter(saved); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xdc}; + DataBuffer buf(data, sizeof(data)); + client_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); } -TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) { +TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { + EnsureTlsSetup(); + // Disable tickets so that we are sure to not get NewSessionTicket. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer + auto saved = std::make_shared<SaveTlsRecord>(3); + server_->SetTlsRecordFilter(saved); Connect(); - client_->PrepareForRenegotiate(); - server_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xd5}; + DataBuffer buf(data, sizeof(data)); + server_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); } -TEST_P(TlsConnectGeneric, ConnectSendReceive) { +class DropTlsRecord : public TlsRecordFilter { + public: + DropTlsRecord(size_t index) : index_(index), count_(0) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + return DROP; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; +}; + +// Test that decrypting filters work correctly and are able to drop records. +TEST_F(TlsConnectStreamTls13, DropRecordServer) { + EnsureTlsSetup(); + // Disable session tickets so that the server doesn't send an extra record. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + // 0 = ServerHello, 1 = other handshake, 2 = first write + server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2)); + Connect(); + server_->SendData(23, 23); // This should be dropped, so it won't be counted. + server_->ResetSentBytes(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, DropRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = first write + client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2)); Connect(); + client_->SendData(26, 26); // This should be dropped, so it won't be counted. + client_->ResetSentBytes(); SendReceive(); } @@ -224,29 +305,70 @@ TEST_P(TlsConnectStream, ShortRead) { ASSERT_EQ(50U, client_->received_bytes()); } -TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { +// We enable compression via the API but it's disabled internally, +// so we should never get it. +TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { EnsureTlsSetup(); - client_->EnableCompression(); - server_->EnableCompression(); + client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); Connect(); - EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && - variant_ != ssl_variant_datagram, - client_->is_compressed()); + EXPECT_FALSE(client_->is_compressed()); SendReceive(); } -TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) { +class TlsHolddownTest : public TlsConnectDatagram { + protected: + // This causes all timers to run to completion. It advances the clock and + // handshakes on both peers until both peers have no more timers pending, + // which should happen at the end of a handshake. This is necessary to ensure + // that the relatively long holddown timer expires, but that any other timers + // also expire and run correctly. + void RunAllTimersDown() { + while (true) { + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv != SECSuccess) { + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv != SECSuccess) { + break; // Neither peer has an outstanding timer. + } + } + + if (g_ssl_gtest_verbose) { + std::cerr << "Shifting timers" << std::endl; + } + ShiftDtlsTimers(); + Handshake(); + } + } +}; + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { Connect(); - std::cerr << "Expiring holddown timer\n"; - SSLInt_ForceTimerExpiry(client_->ssl_fd()); - SSLInt_ForceTimerExpiry(server_->ssl_fd()); + std::cerr << "Expiring holddown timer" << std::endl; + RunAllTimersDown(); SendReceive(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { // One for send, one for receive. - EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd())); + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); } } +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + RunAllTimersDown(); + SendReceive(); + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); +} + class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: TlsPreCCSHeaderInjector() {} @@ -274,8 +396,7 @@ TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); @@ -306,21 +427,65 @@ TEST_P(TlsConnectTls13, AlertWrongLevel) { TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); // Send first flight. - client_->adapter()->CloseWrites(); + client_->adapter()->SetWriteError(PR_IO_ERROR); client_->Handshake(); // This will get an error, but shouldn't crash. client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); } -TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { - client_->SetShortHeadersEnabled(); - server_->SetShortHeadersEnabled(); - client_->ExpectShortHeaders(); - server_->ExpectShortHeaders(); +TEST_P(TlsConnectDatagram, BlockedWrite) { + Connect(); + + // Mark the socket as blocked. + client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); + static const uint8_t data[] = {1, 2, 3}; + int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Remove the write error and though the previous write failed, future reads + // and writes should just work as if it never happened. + client_->adapter()->SetWriteError(0); + SendReceive(); +} + +TEST_F(TlsConnectTest, ConnectSSLv3) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { + // MAC-then-Encrypt expansion formula: + return ((in + hmac + (block - 1)) / block) * block; +} + +TEST_F(TlsConnectTest, OneNRecordSplitting) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); + EnsureTlsSetup(); + ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); + auto records = std::make_shared<TlsRecordRecorder>(); + server_->SetPacketFilter(records); + // This should be split into 1, 16384 and 20. + DataBuffer big_buffer; + big_buffer.Allocate(1 + 16384 + 20); + server_->SendBuffer(big_buffer); + ASSERT_EQ(3U, records->count()); + EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); } INSTANTIATE_TEST_CASE_P( @@ -336,6 +501,8 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, TlsConnectTestBase::kTlsVAll); INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_CASE_P(DatagramHolddown, TlsHolddownTest, + TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_CASE_P( Pre12Stream, TlsConnectPre12, diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc new file mode 100644 index 000000000..2b1b92dcd --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc @@ -0,0 +1,20 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "sslexp.h" + +#include "gtest_utils.h" + +namespace nss_test { + +class MiscTest : public ::testing::Test {}; + +TEST_F(MiscTest, NonExistentExperimentalAPI) { + EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah")); + EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc index ef81b222c..d1d496f49 100644 --- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -10,6 +10,8 @@ #include "databuffer.h" #include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" namespace nss_test { @@ -51,8 +53,8 @@ class TlsPaddingTest << " total length=" << plaintext_.len() << std::endl; std::cerr << "Plaintext: " << plaintext_ << std::endl; sslBuffer s; - s.buf = const_cast<unsigned char *>( - static_cast<const unsigned char *>(plaintext_.data())); + s.buf = const_cast<unsigned char*>( + static_cast<const unsigned char*>(plaintext_.data())); s.len = plaintext_.len(); SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize); if (expect_success) { @@ -99,6 +101,73 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) { } } +class RecordReplacer : public TlsRecordFilter { + public: + RecordReplacer(size_t size) + : TlsRecordFilter(), enabled_(false), size_(size) {} + + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (!enabled_) { + return KEEP; + } + + EXPECT_EQ(kTlsApplicationDataType, header.content_type()); + changed->Allocate(size_); + + for (size_t i = 0; i < size_; ++i) { + changed->data()[i] = i & 0xff; + } + + enabled_ = false; + return CHANGE; + } + + void Enable() { enabled_ = true; } + + private: + bool enabled_; + size_t size_; +}; + +TEST_F(TlsConnectStreamTls13, LargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = std::make_shared<RecordReplacer>(record_limit); + client_->SetTlsRecordFilter(replacer); + Connect(); + + replacer->Enable(); + client_->SendData(10); + WAIT_(server_->received_bytes() == record_limit, 2000); + ASSERT_EQ(record_limit, server_->received_bytes()); +} + +TEST_F(TlsConnectStreamTls13, TooLargeRecord) { + EnsureTlsSetup(); + + const size_t record_limit = 16384; + auto replacer = std::make_shared<RecordReplacer>(record_limit + 1); + client_->SetTlsRecordFilter(replacer); + Connect(); + + replacer->Enable(); + ExpectAlert(server_, kTlsAlertRecordOverflow); + client_->SendData(10); // This is expanded. + + uint8_t buf[record_limit + 2]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError()); + + // Read the server alert. + rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf)); + EXPECT_GT(0, rv); + EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError()); +} + const static size_t kContentSizesArr[] = { 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc new file mode 100644 index 000000000..a902a5f7f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc @@ -0,0 +1,212 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +// 1.3 is disabled in the next few tests because we don't +// presently support resumption in 1.3. +TEST_P(TlsConnectStreamPre13, RenegotiateClient) { + Connect(); + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectStreamPre13, RenegotiateServer) { + Connect(); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen. +TEST_F(TlsConnectTest, RenegotiationConfigTls13) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + Connect(); + SendReceive(); + CheckKeys(); +} + +TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + client_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + server_->StartRenegotiate(); + + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + server_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + client_->StartRenegotiate(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + ExpectAlert(server_, kTlsAlertUnexpectedMessage); + } else { + ExpectAlert(server_, kTlsAlertProtocolVersion); + } + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + } +} + +TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + EXPECT_EQ(SECFailure, rv); + return; + } + ASSERT_EQ(SECSuccess, rv); + + // Now, before handshaking, tweak the server configuration. + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + + // The server should catch the own error. + ExpectAlert(server_, kTlsAlertProtocolVersion); + + Handshake(); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the server set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + + EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE)); +} + +TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + Connect(); + + // Now renegotiate with the client set to TLS 1.0. + client_->PrepareForRenegotiate(); + server_->PrepareForRenegotiate(); + server_->ResetPreliminaryInfo(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // The client will refuse to renegotiate down. + EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE)); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc index ce0e3ca8d..a413caf2c 100644 --- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -9,6 +9,7 @@ #include "secerr.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" extern "C" { @@ -246,8 +247,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) { : ssl_session_ticket_xtn; auto capture = std::make_shared<TlsExtensionCapture>(xtn); client_->SetPacketFilter(capture); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); EXPECT_TRUE(capture->captured()); EXPECT_LT(0U, capture->extension().len()); @@ -355,10 +355,7 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { // This test parses the ServerKeyExchange, which isn't in 1.3 TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { - server_->EnsureTlsSetup(); - SECStatus rv = - SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); + server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>( kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i1); @@ -369,9 +366,7 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { // Restart Reset(); - server_->EnsureTlsSetup(); - rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); + server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>( kTlsHandshakeServerKeyExchange); server_->SetPacketFilter(i2); @@ -401,7 +396,8 @@ TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) { client_->ConfigNamedGroups(kFFDHEGroups); server_->ConfigNamedGroups(kFFDHEGroups); Connect(); - CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); } // We need to enable different cipher suites at different times in the following @@ -461,36 +457,6 @@ TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) { CheckKeys(); } -class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { - public: - SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {} - - protected: - PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } - - *output = input; - uint32_t temp = 0; - EXPECT_TRUE(input.Read(0, 2, &temp)); - // Cipher suite is after version(2) and random(32). - size_t pos = 34; - if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { - // In old versions, we have to skip a session_id too. - EXPECT_TRUE(input.Read(pos, 1, &temp)); - pos += 1 + temp; - } - output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); - return CHANGE; - } - - private: - uint16_t cipher_suite_; -}; - // Test that the client doesn't tolerate the server picking a different cipher // suite for resumption. TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { @@ -524,16 +490,13 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { class SelectedVersionReplacer : public TlsHandshakeFilter { public: - SelectedVersionReplacer(uint16_t version) : version_(version) {} + SelectedVersionReplacer(uint16_t version) + : TlsHandshakeFilter({kTlsHandshakeServerHello}), version_(version) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } - *output = input; output->Write(0, static_cast<uint32_t>(version_), 2); return CHANGE; @@ -609,7 +572,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_none); + ssl_sig_rsa_pss_sha256); // The filter will go away when we reset, so save the captured extension. DataBuffer initialTicket(c1->extension()); ASSERT_LT(0U, initialTicket.len()); @@ -627,7 +590,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { Connect(); SendReceive(); CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, - ssl_sig_none); + ssl_sig_rsa_pss_sha256); ASSERT_LT(0U, c2->extension().len()); ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); @@ -652,18 +615,158 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { // Clear the session ticket keys to invalidate the old ticket. SSLInt_ClearSelfEncryptKey(); - SSLInt_SendNewSessionTicket(server_->ssl_fd()); + SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +// Check that the value captured in a NewSessionTicket message matches the value +// captured from a pre_shared_key extension. +void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) { + uint32_t len; + + size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add. + ASSERT_TRUE(nst.Read(offset, 1, &len)); + offset += 1 + len; // Skip ticket_nonce. + + ASSERT_TRUE(nst.Read(offset, 2, &len)); + offset += 2; // Skip the ticket length. + ASSERT_LE(offset + len, nst.len()); + DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len)); + + offset = 2; // Skip the identities length. + ASSERT_TRUE(psk.Read(offset, 2, &len)); + offset += 2; // Skip the identity length. + ASSERT_LE(offset + len, psk.len()); + DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len)); + + EXPECT_EQ(nst_ticket, psk_ticket); +} + +TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( + ssl_hs_new_session_ticket); + server_->SetTlsRecordFilter(nst_capture); + Connect(); + + // Clear the session ticket keys to invalidate the old ticket. + SSLInt_ClearSelfEncryptKey(); + nst_capture->Reset(); + uint8_t token[] = {0x20, 0x20, 0xff, 0x00}; + EXPECT_EQ(SECSuccess, + SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token))); SendReceive(); // Need to read so that we absorb the session tickets. CheckKeys(); + EXPECT_LT(0U, nst_capture->buffer().len()); // Resume the connection. Reset(); ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); ExpectResumption(RESUME_TICKET); + + auto psk_capture = + std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); + client_->SetPacketFilter(psk_capture); Connect(); SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent +// by invoking SSL_SendSessionTicket directly (and that the ticket is usable). +TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( + ssl_hs_new_session_ticket); + server_->SetTlsRecordFilter(nst_capture); + Connect(); + + EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet"; + + EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)); + EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now"; + + SendReceive(); // Ensure that the client reads the ticket. + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + + auto psk_capture = + std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn); + client_->SetPacketFilter(psk_capture); + Connect(); + SendReceive(); + + NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension()); +} + +// Test calling SSL_SendSessionTicket in inappropriate conditions. +TEST_F(TlsConnectTest, SendSessionTicketInappropriate) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0)) + << "clients can't send tickets"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + StartConnect(); + + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no ticket before the handshake has started"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + Handshake(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no special tickets in TLS 1.2"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + // It should be safe to set length with a NULL token because the length should + // be checked before reading token. + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff)) + << "this is clearly too big"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); + + static const uint8_t big_token[0xffff] = {1}; + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token, + sizeof(big_token))) + << "this is too big, but that's not immediately obvious"; + EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError()); +} + +TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0)) + << "no extra tickets in DTLS until we have Ack support"; + EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError()); } TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) { @@ -719,13 +822,84 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) { TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)); filters.push_back( std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2)); + + // Drop a bunch of extensions so that we get past the SH processing. The + // version extension says TLS 1.3, which is counter to our goal, the others + // are not permitted in TLS 1.2 handshakes. + filters.push_back( + std::make_shared<TlsExtensionDropper>(ssl_tls13_supported_versions_xtn)); + filters.push_back( + std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn)); + filters.push_back( + std::make_shared<TlsExtensionDropper>(ssl_tls13_pre_shared_key_xtn)); server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters)); - client_->ExpectSendAlert(kTlsAlertDecodeError); + // The client here generates an unexpected_message alert when it receives an + // encrypted handshake message from the server (EncryptedExtension). The + // client expects to receive an unencrypted TLS 1.2 Certificate message. + // The server can't decrypt the alert. + client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read ConnectExpectFail(); - client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); } +TEST_P(TlsConnectGeneric, ReConnectTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsConnectGenericPre13, ReConnectCache) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + // Resume + Reset(); + ExpectResumption(RESUME_SESSIONID); + Connect(); + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsConnectGeneric, ReConnectAgainTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + // Resume + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + // Resume connection again + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET, 2); + Connect(); + // Only the client knows this. + CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519, + ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); +} + } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc index a130ef77f..335bfecfa 100644 --- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -43,7 +43,14 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter { size_t start = parser.consumed(); TlsHandshakeFilter::HandshakeHeader header; DataBuffer ignored; - if (!header.Parse(&parser, record_header, &ignored)) { + bool complete = false; + if (!header.Parse(&parser, record_header, DataBuffer(), &ignored, + &complete)) { + ADD_FAILURE() << "Error parsing handshake header"; + return KEEP; + } + if (!complete) { + ADD_FAILURE() << "Don't want to deal with fragmented input"; return KEEP; } @@ -101,26 +108,15 @@ class Tls13SkipTest : public TlsConnectTestBase, void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { EnsureTlsSetup(); server_->SetTlsRecordFilter(filter); - filter->EnableDecryption(); - client_->ExpectSendAlert(kTlsAlertUnexpectedMessage); - if (variant_ == ssl_variant_stream) { - server_->ExpectSendAlert(kTlsAlertBadRecordMac); - ConnectExpectFail(); - } else { - ConnectExpectFailOneSide(TlsAgent::CLIENT); - } + ExpectAlert(client_, kTlsAlertUnexpectedMessage); + ConnectExpectFail(); client_->CheckErrorCode(error); - if (variant_ == ssl_variant_stream) { - server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - } else { - ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); - } + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); } void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) { EnsureTlsSetup(); client_->SetTlsRecordFilter(filter); - filter->EnableDecryption(); server_->ExpectSendAlert(kTlsAlertUnexpectedMessage); ConnectExpectFailOneSide(TlsAgent::SERVER); @@ -171,11 +167,10 @@ TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { } TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = std::make_shared<ChainedPacketFilter>(); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate)); - chain->Add( - std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange)); + auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{ + std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate), + std::make_shared<TlsHandshakeSkipFilter>( + kTlsHandshakeServerKeyExchange)}); ServerSkipTest(chain); client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); } diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc index 8db1f30e1..e7fe44d92 100644 --- a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -71,7 +71,7 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { EnableOnlyStaticRsaCiphers(); client_->SetPacketFilter( std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); - server_->DisableRollbackDetection(); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } @@ -102,7 +102,7 @@ TEST_P(TlsConnectStreamPre13, EnableExtendedMasterSecret(); client_->SetPacketFilter( std::make_shared<TlsInspectorClientHelloVersionChanger>(server_)); - server_->DisableRollbackDetection(); + server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE); Connect(); } diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc new file mode 100644 index 000000000..75cee52fc --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc @@ -0,0 +1,337 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> +#include <vector> +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class Tls13CompatTest : public TlsConnectStreamTls13 { + protected: + void EnableCompatMode() { + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + } + + void InstallFilters() { + EnsureTlsSetup(); + client_recorders_.Install(client_); + server_recorders_.Install(server_); + } + + void CheckRecordVersions() { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0, + client_recorders_.records_->record(0).header.version()); + CheckRecordsAreTls12("client", client_recorders_.records_, 1); + CheckRecordsAreTls12("server", server_recorders_.records_, 0); + } + + void CheckHelloVersions() { + uint32_t ver; + ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver)); + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver)); + } + + void CheckForCCS(bool expected_client, bool expected_server) { + client_recorders_.CheckForCCS(expected_client); + server_recorders_.CheckForCCS(expected_server); + } + + void CheckForRegularHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(0U, client_recorders_.session_id_length()); + EXPECT_EQ(0U, server_recorders_.session_id_length()); + CheckForCCS(false, false); + } + + void CheckForCompatHandshake() { + CheckRecordVersions(); + CheckHelloVersions(); + EXPECT_EQ(32U, client_recorders_.session_id_length()); + EXPECT_EQ(32U, server_recorders_.session_id_length()); + CheckForCCS(true, true); + } + + private: + struct Recorders { + Recorders() + : records_(new TlsRecordRecorder()), + hello_(new TlsInspectorRecordHandshakeMessage(std::set<uint8_t>( + {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))) {} + + uint8_t session_id_length() const { + // session_id is always after version (2) and random (32). + uint32_t len = 0; + EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len)); + return static_cast<uint8_t>(len); + } + + void CheckForCCS(bool expected) const { + EXPECT_LT(0U, records_->count()); + for (size_t i = 0; i < records_->count(); ++i) { + // Only the second record can be a CCS. + bool expected_match = expected && (i == 1); + EXPECT_EQ(expected_match, + kTlsChangeCipherSpecType == + records_->record(i).header.content_type()); + } + } + + void Install(std::shared_ptr<TlsAgent>& agent) { + agent->SetPacketFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({records_, hello_}))); + } + + std::shared_ptr<TlsRecordRecorder> records_; + std::shared_ptr<TlsInspectorRecordHandshakeMessage> hello_; + }; + + void CheckRecordsAreTls12(const std::string& agent, + const std::shared_ptr<TlsRecordRecorder>& records, + size_t start) { + EXPECT_LE(start, records->count()); + for (size_t i = start; i < records->count(); ++i) { + EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, + records->record(i).header.version()) + << agent << ": record " << i << " has wrong version"; + } + } + + Recorders client_recorders_; + Recorders server_recorders_; +}; + +TEST_F(Tls13CompatTest, Disabled) { + InstallFilters(); + Connect(); + CheckForRegularHandshake(); +} + +TEST_F(Tls13CompatTest, Enabled) { + EnableCompatMode(); + InstallFilters(); + Connect(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + CheckForCCS(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest. The server sends CCS immediately. + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckForCCS(false, true); + + Handshake(); + CheckConnected(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledStatelessHrr) { + EnableCompatMode(); + InstallFilters(); + + // Force a HelloRetryRequest + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + CheckForCCS(false, true); + + // A new server should just work, but not send another CCS. + MakeNewServer(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + Handshake(); + CheckConnected(); + CheckForCompatHandshake(); +} + +TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) { + SetupForZeroRtt(); + EnableCompatMode(); + InstallFilters(); + server_->ConfigNamedGroups({ssl_grp_ec_secp384r1}); + + // With 0-RTT, the client sends CCS immediately. With HRR, the server sends + // CCS immediately too. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + CheckForCCS(true, true); + + Handshake(); + ExpectEarlyDataAccepted(false); + CheckConnected(); + CheckForCompatHandshake(); +} + +static const uint8_t kCannedCcs[] = { + kTlsChangeCipherSpecType, + SSL_LIBRARY_VERSION_TLS_1_2 >> 8, + SSL_LIBRARY_VERSION_TLS_1_2 & 0xff, + 0, + 1, // length + 1 // change_cipher_spec_choice +}; + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// A ChangeCipherSpec is ignored by a server because we have to tolerate it for +// compatibility mode. That doesn't mean that we have to tolerate it +// unconditionally. If we negotiate 1.3, we expect to see a cookie extension. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +// If we negotiate 1.2, we abort. +TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) { + EnsureTlsSetup(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + // Client sends CCS before starting the handshake. + client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs))); + ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); +} + +TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) { + EnsureTlsSetup(); + client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE); + auto client_records = std::make_shared<TlsRecordRecorder>(); + client_->SetPacketFilter(client_records); + auto server_records = std::make_shared<TlsRecordRecorder>(); + server_->SetPacketFilter(server_records); + Connect(); + + ASSERT_EQ(2U, client_records->count()); // CH, Fin + EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type()); + EXPECT_EQ(kTlsApplicationDataType, + client_records->record(1).header.content_type()); + + ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack + EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kTlsApplicationDataType, + server_records->record(i).header.content_type()); + } +} + +class AddSessionIdFilter : public TlsHandshakeFilter { + public: + AddSessionIdFilter() : TlsHandshakeFilter({ssl_hs_client_hello}) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + uint32_t session_id_len = 0; + EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); + uint8_t session_id[33] = {32}; // 32 for length, the rest zero. + *output = input; + output->Splice(session_id, sizeof(session_id), 34, 1); + return CHANGE; + } +}; + +// Adding a session ID to a DTLS ClientHello should not trigger compatibility +// mode. It should be ignored instead. +TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) { + EnsureTlsSetup(); + auto client_records = std::make_shared<TlsRecordRecorder>(); + client_->SetPacketFilter( + std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit( + {client_records, std::make_shared<AddSessionIdFilter>()}))); + auto server_hello = std::make_shared<TlsInspectorRecordHandshakeMessage>( + kTlsHandshakeServerHello); + auto server_records = std::make_shared<TlsRecordRecorder>(); + server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>( + ChainedPacketFilterInit({server_records, server_hello}))); + StartConnect(); + client_->Handshake(); + server_->Handshake(); + // The client will consume the ServerHello, but discard everything else + // because it doesn't decrypt. And don't wait around for the client to ACK. + client_->Handshake(); + + ASSERT_EQ(1U, client_records->count()); + EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type()); + + ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin + EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type()); + for (size_t i = 1; i < server_records->count(); ++i) { + EXPECT_EQ(kTlsApplicationDataType, + server_records->record(i).header.content_type()); + } + + uint32_t session_id_len = 0; + EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len)); + EXPECT_EQ(0U, session_id_len); +} + +} // nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc index 110e3e0b6..2f8ddd6fe 100644 --- a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -153,13 +153,6 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase { client_->SetPacketFilter(filter_); } - void RequireSafeRenegotiation() { - server_->EnsureTlsSetup(); - SECStatus rv = - SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); - EXPECT_EQ(rv, SECSuccess); - } - void SetExpectedVersion(uint16_t version) { TlsConnectTestBase::SetExpectedVersion(version); filter_->SetVersion(version); @@ -319,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) { // Connection must fail if we require safe renegotiation but the client doesn't // include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { - RequireSafeRenegotiation(); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); ConnectExpectAlert(server_, kTlsAlertHandshakeFailure); EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code()); @@ -328,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { // Connection must succeed when requiring safe renegotiation and the client // includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) { - RequireSafeRenegotiation(); + server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, TLS_EMPTY_RENEGOTIATION_INFO_SCSV}; SetAvailableCipherSuites(cipher_suites); diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc index 379a67e35..9db293b07 100644 --- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -128,12 +128,12 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) { #endif TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) { - client_->SetFallbackSCSVEnabled(true); + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); Connect(); } TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) { - client_->SetFallbackSCSVEnabled(true); + client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE); server_->SetVersionRange(version_, version_ + 1); ConnectExpectAlert(server_, kTlsAlertInappropriateFallback); client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT); @@ -155,107 +155,10 @@ TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) { EXPECT_EQ(SECFailure, rv); } -TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { - if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { - return; - } - // Set the client so it will accept any version from 1.0 - // to |version_|. - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, - SSL_LIBRARY_VERSION_TLS_1_0); - // Reset version so that the checks succeed. - uint16_t test_version = version_; - version_ = SSL_LIBRARY_VERSION_TLS_1_0; - Connect(); - - // Now renegotiate, with the server being set to do - // |version_|. - client_->PrepareForRenegotiate(); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); - // Reset version and cipher suite so that the preinfo callback - // doesn't fail. - server_->ResetPreliminaryInfo(); - server_->StartRenegotiate(); - - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - ExpectAlert(server_, kTlsAlertUnexpectedMessage); - } else { - ExpectAlert(client_, kTlsAlertIllegalParameter); - } - - Handshake(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - // In TLS 1.3, the server detects this problem. - client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); - server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); - } else { - client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); - } -} - -TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { - if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { - return; - } - // Set the client so it will accept any version from 1.0 - // to |version_|. - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, - SSL_LIBRARY_VERSION_TLS_1_0); - // Reset version so that the checks succeed. - uint16_t test_version = version_; - version_ = SSL_LIBRARY_VERSION_TLS_1_0; - Connect(); - - // Now renegotiate, with the server being set to do - // |version_|. - server_->PrepareForRenegotiate(); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); - // Reset version and cipher suite so that the preinfo callback - // doesn't fail. - server_->ResetPreliminaryInfo(); - client_->StartRenegotiate(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - ExpectAlert(server_, kTlsAlertUnexpectedMessage); - } else { - ExpectAlert(client_, kTlsAlertIllegalParameter); - } - Handshake(); - if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { - // In TLS 1.3, the server detects this problem. - client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); - server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); - } else { - client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); - server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); - } -} - -TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { - EnsureTlsSetup(); - ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - Connect(); - SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); - EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); -} - -TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { - EnsureTlsSetup(); - ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); - Connect(); - SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); - EXPECT_EQ(SECFailure, rv); - EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); -} - TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { EnsureTlsSetup(); client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); // Send ClientHello. static const uint8_t kWarningAlert[] = {kTlsAlertWarning, kTlsAlertUnrecognizedName}; @@ -314,20 +217,20 @@ TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { client_->SetPacketFilter( std::make_shared<TlsInspectorClientHelloVersionSetter>( SSL_LIBRARY_VERSION_TLS_1_3 + 1)); - auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>( - kTlsHandshakeServerHello); + auto capture = + std::make_shared<TlsExtensionCapture>(ssl_tls13_supported_versions_xtn); server_->SetPacketFilter(capture); client_->ExpectSendAlert(kTlsAlertBadRecordMac); server_->ExpectSendAlert(kTlsAlertBadRecordMac); ConnectExpectFail(); client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); - const DataBuffer& server_hello = capture->buffer(); - ASSERT_GT(server_hello.len(), 2U); - uint32_t ver; - ASSERT_TRUE(server_hello.Read(0, 2, &ver)); + + ASSERT_EQ(2U, capture->extension().len()); + uint32_t version = 0; + ASSERT_TRUE(capture->extension().Read(0, 2, &version)); // This way we don't need to change with new draft version. - ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver); + ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version); } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc index b9f0c672e..adcdbfbaf 100644 --- a/security/nss/gtests/ssl_gtest/test_io.cc +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -98,8 +98,13 @@ int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, } int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + if (write_error_) { + PR_SetError(write_error_, 0); + return -1; + } + auto peer = peer_.lock(); - if (!peer || !writeable_) { + if (!peer) { PR_SetError(PR_IO_ERROR, 0); return -1; } @@ -109,7 +114,7 @@ int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { DataBuffer filtered; PacketFilter::Action action = PacketFilter::KEEP; if (filter_) { - action = filter_->Filter(packet, &filtered); + action = filter_->Process(packet, &filtered); } switch (action) { case PacketFilter::CHANGE: diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h index ac2497222..469d90a7c 100644 --- a/security/nss/gtests/ssl_gtest/test_io.h +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -33,9 +33,18 @@ class PacketFilter { CHANGE, // change the packet to a different value DROP // drop the packet }; - + PacketFilter(bool enabled = true) : enabled_(enabled) {} virtual ~PacketFilter() {} + virtual Action Process(const DataBuffer& input, DataBuffer* output) { + if (!enabled_) { + return KEEP; + } + return Filter(input, output); + } + void Enable() { enabled_ = true; } + void Disable() { enabled_ = false; } + // The packet filter takes input and has the option of mutating it. // // A filter that modifies the data places the modified data in *output and @@ -43,6 +52,9 @@ class PacketFilter { // case the value in *output is ignored. A Filter can return DROP, in which // case the packet is dropped (and *output is ignored). virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; + + private: + bool enabled_; }; class DummyPrSocket : public DummyIOLayerMethods { @@ -53,7 +65,7 @@ class DummyPrSocket : public DummyIOLayerMethods { peer_(), input_(), filter_(nullptr), - writeable_(true) {} + write_error_(0) {} virtual ~DummyPrSocket() {} // Create a file descriptor that will reference this object. The fd must not @@ -71,7 +83,7 @@ class DummyPrSocket : public DummyIOLayerMethods { int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, PRIntervalTime to) override; int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; - void CloseWrites() { writeable_ = false; } + void SetWriteError(PRErrorCode code) { write_error_ = code; } SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } @@ -98,7 +110,7 @@ class DummyPrSocket : public DummyIOLayerMethods { std::weak_ptr<DummyPrSocket> peer_; std::queue<Packet> input_; std::shared_ptr<PacketFilter> filter_; - bool writeable_; + PRErrorCode write_error_; }; // Marker interface. diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index d6d91f7f7..3b939bba8 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -10,6 +10,7 @@ #include "pk11func.h" #include "ssl.h" #include "sslerr.h" +#include "sslexp.h" #include "sslproto.h" #include "tls_parser.h" @@ -35,7 +36,6 @@ const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt const std::string TlsAgent::kServerRsaSign = "rsa_sign"; const std::string TlsAgent::kServerRsaPss = "rsa_pss"; const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; -const std::string TlsAgent::kServerRsaChain = "rsa_chain"; const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; @@ -73,7 +73,6 @@ TlsAgent::TlsAgent(const std::string& name, Role role, handshake_callback_(), auth_certificate_callback_(), sni_callback_(), - expect_short_headers_(false), skip_version_checks_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); @@ -93,11 +92,11 @@ TlsAgent::~TlsAgent() { // Add failures manually, if any, so we don't throw in a destructor. if (expected_received_alert_ != kTlsAlertCloseNotify || expected_received_alert_level_ != kTlsAlertWarning) { - ADD_FAILURE() << "Wrong expected_received_alert status"; + ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str(); } if (expected_sent_alert_ != kTlsAlertCloseNotify || expected_sent_alert_level_ != kTlsAlertWarning) { - ADD_FAILURE() << "Wrong expected_sent_alert status"; + ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str(); } } @@ -258,13 +257,10 @@ void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { } void TlsAgent::RequestClientAuth(bool requireAuth) { - EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE, - requireAuth ? PR_TRUE : PR_FALSE)); + SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE); + SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( ssl_fd(), &TlsAgent::ClientAuthenticated, this)); @@ -376,42 +372,8 @@ void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { EXPECT_EQ(SECSuccess, rv); } -void TlsAgent::SetSessionTicketsEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, - en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetSessionCacheEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); -} - void TlsAgent::Set0RttEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = - SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetFallbackSCSVEnabled(bool en) { - EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV, - en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetShortHeadersEnabled() { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd()); - EXPECT_EQ(SECSuccess, rv); + SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); } void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { @@ -437,8 +399,6 @@ void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } -void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } - void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, @@ -517,6 +477,12 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, } } +void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { + if (kea_group != ssl_grp_ffdhe_custom) { + EXPECT_EQ(kea_group, info_.originalKeaGroup); + } +} + void TlsAgent::CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { EXPECT_EQ(STATE_CONNECTED, state_); @@ -569,8 +535,7 @@ void TlsAgent::EnableFalseStart() { falsestart_enabled_ = true; EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( ssl_fd(), CanFalseStartCallback, this)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE)); + SetOption(SSL_ENABLE_FALSE_START, PR_TRUE); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } @@ -578,7 +543,7 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE)); + SetOption(SSL_ENABLE_ALPN, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } @@ -622,12 +587,8 @@ void TlsAgent::CheckErrorCode(int32_t expected) const { } static uint8_t GetExpectedAlertLevel(uint8_t alert) { - switch (alert) { - case kTlsAlertCloseNotify: - case kTlsAlertEndOfEarlyData: - return kTlsAlertWarning; - default: - break; + if (alert == kTlsAlertCloseNotify) { + return kTlsAlertWarning; } return kTlsAlertFatal; } @@ -730,6 +691,50 @@ void TlsAgent::ResetPreliminaryInfo() { expected_cipher_suite_ = 0; } +void TlsAgent::ValidateCipherSpecs() { + PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd()); + // We use one ciphersuite in each direction. + PRInt32 expected = 2; + if (variant_ == ssl_variant_datagram) { + // For DTLS 1.3, the client retains the cipher spec for early data and the + // handshake so that it can retransmit EndOfEarlyData and its final flight. + // It also retains the handshake read cipher spec so that it can read ACKs + // from the server. The server retains the handshake read cipher spec so it + // can read the client's retransmitted Finished. + if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + if (role_ == CLIENT) { + expected = info_.earlyDataAccepted ? 5 : 4; + } else { + expected = 3; + } + } else { + // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec + // until the holddown timer runs down. + if (expect_resumption_) { + if (role_ == CLIENT) { + expected = 3; + } + } else { + if (role_ == SERVER) { + expected = 3; + } + } + } + } + // This function will be run before the handshake completes if false start is + // enabled. In that case, the client will still be reading cleartext, but + // will have a spec prepared for reading ciphertext. With DTLS, the client + // will also have a spec retained for retransmission of handshake messages. + if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) { + EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_); + expected = (variant_ == ssl_variant_datagram) ? 4 : 3; + } + EXPECT_EQ(expected, cipherSpecs); + if (expected != cipherSpecs) { + SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd()); + } +} + void TlsAgent::Connected() { if (state_ == STATE_CONNECTED) { return; @@ -743,6 +748,8 @@ void TlsAgent::Connected() { EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); + EXPECT_EQ(expect_resumption_, info_.resumed == PR_TRUE); + // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. @@ -753,32 +760,13 @@ void TlsAgent::Connected() { EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo_), csinfo_.length); - if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); - // We use one ciphersuite in each direction, plus one that's kept around - // by DTLS for retransmission. - PRInt32 expected = - ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; - EXPECT_EQ(expected, cipherSuites); - if (expected != cipherSuites) { - SSLInt_PrintTls13CipherSpecs(ssl_fd()); - } - } + ValidateCipherSpecs(); - PRBool short_headers; - rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers); - EXPECT_EQ(SECSuccess, rv); - EXPECT_EQ((PRBool)expect_short_headers_, short_headers); SetState(STATE_CONNECTED); } void TlsAgent::EnableExtendedMasterSecret() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = - SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); - - ASSERT_EQ(SECSuccess, rv); + SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); } void TlsAgent::CheckExtendedMasterSecret(bool expected) { @@ -801,21 +789,6 @@ void TlsAgent::CheckSecretsDestroyed() { ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } -void TlsAgent::DisableRollbackDetection() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE); - - ASSERT_EQ(SECSuccess, rv); -} - -void TlsAgent::EnableCompression() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE); - ASSERT_EQ(SECSuccess, rv); -} - void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { ASSERT_TRUE(EnsureTlsSetup()); @@ -883,6 +856,14 @@ void TlsAgent::SendDirect(const DataBuffer& buf) { } } +void TlsAgent::SendRecordDirect(const TlsRecord& record) { + DataBuffer buf; + + auto rv = record.header.Write(&buf, 0, record.buffer); + EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv); + SendDirect(buf); +} + static bool ErrorIsNonFatal(PRErrorCode code) { return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ; } @@ -918,6 +899,27 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { } } +bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint16_t wireVersion, uint64_t seq, + uint8_t ct, const DataBuffer& buf) { + LOGV("Writing " << buf.len() << " bytes"); + // Ensure we are a TLS 1.3 cipher agent. + EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); + TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq); + DataBuffer padded = buf; + padded.Write(padded.len(), ct, 1); + DataBuffer ciphertext; + if (!spec->Protect(header, padded, &ciphertext)) { + return false; + } + + DataBuffer record; + auto rv = header.Write(&record, 0, ciphertext); + EXPECT_EQ(header.header_length() + ciphertext.len(), rv); + SendDirect(record); + return true; +} + void TlsAgent::ReadBytes(size_t amount) { uint8_t block[16384]; @@ -951,23 +953,20 @@ void TlsAgent::ReadBytes(size_t amount) { void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } -void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, - mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); +void TlsAgent::SetOption(int32_t option, int value) { + ASSERT_TRUE(EnsureTlsSetup()); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value)); +} - rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, - mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); +void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { + SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); + SetOption(SSL_ENABLE_SESSION_TICKETS, + mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); } void TlsAgent::DisableECDHEServerKeyReuse() { - ASSERT_TRUE(EnsureTlsSetup()); ASSERT_EQ(TlsAgent::SERVER, role_); - SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); + SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); } static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index 4bccb9a84..b3fd892ae 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -66,7 +66,6 @@ class TlsAgent : public PollTarget { static const std::string kServerRsaSign; static const std::string kServerRsaPss; static const std::string kServerRsaDecrypt; - static const std::string kServerRsaChain; // A cert that requires a chain. static const std::string kServerEcdsa256; static const std::string kServerEcdsa384; static const std::string kServerEcdsa521; @@ -81,9 +80,11 @@ class TlsAgent : public PollTarget { adapter_->SetPeer(peer->adapter_); } + // Set a filter that can access plaintext (TLS 1.3 only). void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) { filter->SetAgent(this); adapter_->SetPacketFilter(filter); + filter->EnableDecryption(); } void SetPacketFilter(std::shared_ptr<PacketFilter> filter) { @@ -95,6 +96,7 @@ class TlsAgent : public PollTarget { void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, size_t kea_size = 0) const; + void CheckOriginalKEA(SSLNamedGroup kea_group) const; void CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; @@ -121,12 +123,10 @@ class TlsAgent : public PollTarget { void SetupClientAuth(); void RequestClientAuth(bool requireAuth); + void SetOption(int32_t option, int value); void ConfigureSessionCache(SessionResumptionMode mode); - void SetSessionTicketsEnabled(bool en); - void SetSessionCacheEnabled(bool en); void Set0RttEnabled(bool en); void SetFallbackSCSVEnabled(bool en); - void SetShortHeadersEnabled(); void SetVersionRange(uint16_t minver, uint16_t maxver); void GetVersionRange(uint16_t* minver, uint16_t* maxver); void CheckPreliminaryInfo(); @@ -136,7 +136,6 @@ class TlsAgent : public PollTarget { void ExpectReadWriteError(); void EnableFalseStart(); void ExpectResumption(); - void ExpectShortHeaders(); void SkipVersionChecks(); void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); void EnableAlpn(const uint8_t* val, size_t len); @@ -149,15 +148,17 @@ class TlsAgent : public PollTarget { // Send data on the socket, encrypting it. void SendData(size_t bytes, size_t blocksize = 1024); void SendBuffer(const DataBuffer& buf); + bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, + uint16_t wireVersion, uint64_t seq, uint8_t ct, + const DataBuffer& buf); // Send data directly to the underlying socket, skipping the TLS layer. void SendDirect(const DataBuffer& buf); + void SendRecordDirect(const TlsRecord& record); void ReadBytes(size_t max = 16384U); void ResetSentBytes(); // Hack to test drops. void EnableExtendedMasterSecret(); void CheckExtendedMasterSecret(bool expected); void CheckEarlyDataAccepted(bool expected); - void DisableRollbackDetection(); - void EnableCompression(); void SetDowngradeCheckVersion(uint16_t version); void CheckSecretsDestroyed(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); @@ -170,6 +171,8 @@ class TlsAgent : public PollTarget { Role role() const { return role_; } std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + SSLProtocolVariant variant() const { return variant_; } + State state() const { return state_; } const CERTCertificate* peer_cert() const { @@ -253,6 +256,7 @@ class TlsAgent : public PollTarget { const static char* states[]; void SetState(State state); + void ValidateCipherSpecs(); // Dummy auth certificate hook. static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, @@ -388,7 +392,6 @@ class TlsAgent : public PollTarget { HandshakeCallbackFunction handshake_callback_; AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; - bool expect_short_headers_; bool skip_version_checks_; }; diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index c8de5a1fe..0af5123e9 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -5,6 +5,7 @@ * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" +#include "sslexp.h" extern "C" { #include "libssl_internals.h" } @@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) { switch (version) { case 0: return "(no version)"; + case SSL_LIBRARY_VERSION_3_0: + return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: @@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant, server_model_(nullptr), version_(version), expected_resumption_mode_(RESUME_NONE), + expected_resumptions_(0), session_ids_(), expect_extended_master_secret_(false), expect_early_data_accepted_(false), @@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares( EXPECT_EQ(shares.len(), i); } +void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch, + uint16_t server_epoch) const { + uint16_t read_epoch = 0; + uint16_t write_epoch = 0; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(server_epoch, read_epoch) << "client read epoch"; + EXPECT_EQ(client_epoch, write_epoch) << "client write epoch"; + + EXPECT_EQ(SECSuccess, + SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch)); + EXPECT_EQ(client_epoch, read_epoch) << "server read epoch"; + EXPECT_EQ(server_epoch, write_epoch) << "server write epoch"; +} + void TlsConnectTestBase::ClearStats() { // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); @@ -178,6 +198,7 @@ void TlsConnectTestBase::SetUp() { SSLInt_ClearSelfEncryptKey(); SSLInt_SetTicketLifetime(30); SSLInt_SetMaxEarlyDataSize(1024); + SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3); ClearStats(); Init(); } @@ -219,12 +240,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name, Init(); } -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { +void TlsConnectTestBase::MakeNewServer() { + auto replacement = std::make_shared<TlsAgent>( + server_->name(), TlsAgent::SERVER, server_->variant()); + server_ = replacement; + if (version_) { + server_->SetVersionRange(version_, version_); + } + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->StartConnect(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumptions) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); + expected_resumptions_ = num_resumptions; } + EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE); } void TlsConnectTestBase::EnsureTlsSetup() { @@ -258,6 +294,11 @@ void TlsConnectTestBase::Connect() { CheckConnected(); } +void TlsConnectTestBase::StartConnect() { + server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); + client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); +} + void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { EnsureTlsSetup(); client_->EnableSingleCipher(cipher_suite); @@ -274,6 +315,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { } void TlsConnectTestBase::CheckConnected() { + // Have the client read handshake twice to make sure we get the + // NST and the ACK. + if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 && + variant_ == ssl_variant_datagram) { + client_->Handshake(); + client_->Handshake(); + auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd()); + // Verify that we dropped the client's retransmission cipher suites. + EXPECT_EQ(2, suites) << "Client has the wrong number of suites"; + if (suites != 2) { + SSLInt_PrintCipherSpecs("client", client_->ssl_fd()); + } + } EXPECT_EQ(client_->version(), server_->version()); if (!skip_version_checks_) { // Check the version is as expected @@ -314,10 +368,12 @@ void TlsConnectTestBase::CheckConnected() { void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { - client_->CheckKEA(kea_type, kea_group); - server_->CheckKEA(kea_type, kea_group); - client_->CheckAuthType(auth_type, sig_scheme); + if (kea_group != ssl_grp_none) { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + } server_->CheckAuthType(auth_type, sig_scheme); + client_->CheckAuthType(auth_type, sig_scheme); } void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, @@ -372,9 +428,19 @@ void TlsConnectTestBase::CheckKeys() const { CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); } +void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type, + SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) { + CheckKeys(kea_type, kea_group, auth_type, sig_scheme); + EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE); + client_->CheckOriginalKEA(original_kea_group); + server_->CheckOriginalKEA(original_kea_group); +} + void TlsConnectTestBase::ConnectExpectFail() { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); @@ -395,8 +461,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, } void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) { - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->SetServerKeyBits(server_->server_key_bits()); client_->Handshake(); server_->Handshake(); @@ -455,29 +520,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() { } } +void TlsConnectTestBase::ConfigureSelfEncrypt() { + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey privKey; + ASSERT_TRUE( + TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey)); + + ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); + ASSERT_TRUE(pubKey); + + EXPECT_EQ(SECSuccess, + SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); +} + void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); if ((server & RESUME_TICKET) != 0) { - ScopedCERTCertificate cert; - ScopedSECKEYPrivateKey privKey; - ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, - &privKey)); - - ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get())); - ASSERT_TRUE(pubKey); - - EXPECT_EQ(SECSuccess, - SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get())); + ConfigureSelfEncrypt(); } } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); - int resume_count = expected ? 1 : 0; - int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; + int resume_count = expected ? expected_resumptions_ : 0; + int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0; // Note: hch == server counter; hsh == client counter. SSL3Statistics* stats = SSL_GetStatistics(); @@ -490,7 +559,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { if (expected != RESUME_NONE) { if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { // Check that the last two session ids match. - ASSERT_EQ(2U, session_ids_.size()); + ASSERT_EQ(1U + expected_resumptions_, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size() - 1], session_ids_[session_ids_.size() - 2]); } else { @@ -540,31 +609,28 @@ void TlsConnectTestBase::CheckSrtp() const { server_->CheckSrtp(); } -void TlsConnectTestBase::SendReceive() { - client_->SendData(50); - server_->SendData(50); - Receive(50); +void TlsConnectTestBase::SendReceive(size_t total) { + ASSERT_GT(total, client_->received_bytes()); + ASSERT_GT(total, server_->received_bytes()); + client_->SendData(total - server_->received_bytes()); + server_->SendData(total - client_->received_bytes()); + Receive(total); // Receive() is cumulative } // Do a first connection so we can do 0-RTT on the second one. void TlsConnectTestBase::SetupForZeroRtt() { + // If we don't do this, then all 0-RTT attempts will be rejected. + SSLInt_RolloverAntiReplay(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. Connect(); SendReceive(); // Need to read so that we absorb the session ticket. CheckKeys(); Reset(); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_3); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); } // Do a first connection so we can do resumption @@ -584,10 +650,6 @@ void TlsConnectTestBase::ZeroRttSendReceive( const char* k0RttData = "ABCDEF"; const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); - if (expect_writable && expect_readable) { - ExpectAlert(client_, kTlsAlertEndOfEarlyData); - } - client_->Handshake(); // Send ClientHello. if (post_clienthello_check) { if (!post_clienthello_check()) return; @@ -599,7 +661,7 @@ void TlsConnectTestBase::ZeroRttSendReceive( } else { EXPECT_EQ(SECFailure, rv); } - server_->Handshake(); // Consume ClientHello, EE, Finished. + server_->Handshake(); // Consume ClientHello std::vector<uint8_t> buf(k0RttDataLen); rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read @@ -653,6 +715,30 @@ void TlsConnectTestBase::SkipVersionChecks() { server_->SkipVersionChecks(); } +// Shift the DTLS timers, to the minimum time necessary to let the next timer +// run on either client or server. This allows tests to skip waiting without +// having timers run out of order. +void TlsConnectTestBase::ShiftDtlsTimers() { + PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT; + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv == SECSuccess) { + time_shift = time; + } + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv == SECSuccess && + (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) { + time_shift = time; + } + + if (time_shift == PR_INTERVAL_NO_TIMEOUT) { + return; + } + + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift)); + EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift)); +} + TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} @@ -691,11 +777,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); EXPECT_TRUE(ext.len() % 2 == 0); + std::vector<SSLNamedGroup> groups; for (size_t i = 1; i < ext.len() / 2; i += 1) { EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); @@ -705,10 +795,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( } std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( - const DataBuffer& ext) { + const std::shared_ptr<TlsExtensionCapture>& capture) { + EXPECT_TRUE(capture->captured()); + const DataBuffer& ext = capture->extension(); + uint32_t tmp = 0; EXPECT_TRUE(ext.Read(0, 2, &tmp)); EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + std::vector<SSLNamedGroup> shares; size_t i = 2; while (i < ext.len()) { @@ -724,17 +818,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( void TlsKeyExchangeTest::CheckKEXDetails( const std::vector<SSLNamedGroup>& expected_groups, const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { - std::vector<SSLNamedGroup> groups = - GetGroupDetails(groups_capture_->extension()); + std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_); EXPECT_EQ(expected_groups, groups); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { ASSERT_LT(0U, expected_shares.size()); - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture_->extension()); + std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_); EXPECT_EQ(expected_shares, shares); } else { - EXPECT_EQ(0U, shares_capture_->extension().len()); + EXPECT_FALSE(shares_capture_->captured()); } EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); @@ -756,8 +848,6 @@ void TlsKeyExchangeTest::CheckKEXDetails( EXPECT_NE(expected_share2, it); } std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; - std::vector<SSLNamedGroup> shares = - GetShareDetails(shares_capture2_->extension()); - EXPECT_EQ(expected_shares2, shares); + EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_)); } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index 73e8dc81a..c650dda1d 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -61,7 +61,11 @@ class TlsConnectTestBase : public ::testing::Test { // Reset, and update the certificate names on both peers void Reset(const std::string& server_name, const std::string& client_name = "client"); + // Replace the server. + void MakeNewServer(); + // Set up + void StartConnect(); // Run the handshake. void Handshake(); // Connect and check that it works. @@ -81,20 +85,28 @@ class TlsConnectTestBase : public ::testing::Test { void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; // This version assumes defaults. void CheckKeys() const; + // Check that keys on resumed sessions. + void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme); void CheckGroups(const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group); void CheckShares(const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group); + void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; void ConfigureVersion(uint16_t version); void SetExpectedVersion(uint16_t version); // Expect resumption of a particular type. - void ExpectResumption(SessionResumptionMode expected); + void ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumed = 1); void DisableAllCiphers(); void EnableOnlyStaticRsaCiphers(); void EnableOnlyDheCiphers(); void EnableSomeEcdhCiphers(); void EnableExtendedMasterSecret(); + void ConfigureSelfEncrypt(); void ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server); void EnableAlpn(); @@ -103,7 +115,7 @@ class TlsConnectTestBase : public ::testing::Test { void CheckAlpn(const std::string& val); void EnableSrtp(); void CheckSrtp() const; - void SendReceive(); + void SendReceive(size_t total = 50); void SetupForZeroRtt(); void SetupForResume(); void ZeroRttSendReceive( @@ -115,6 +127,9 @@ class TlsConnectTestBase : public ::testing::Test { void DisableECDHEServerKeyReuse(); void SkipVersionChecks(); + // Move the DTLS timers for both endpoints to pop the next timer. + void ShiftDtlsTimers(); + protected: SSLProtocolVariant variant_; std::shared_ptr<TlsAgent> client_; @@ -123,6 +138,7 @@ class TlsConnectTestBase : public ::testing::Test { std::unique_ptr<TlsAgent> server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; + uint8_t expected_resumptions_; std::vector<std::vector<uint8_t>> session_ids_; // A simple value of "a", "b". Note that the preferred value of "a" is placed @@ -244,6 +260,11 @@ class TlsConnectDatagram13 : public TlsConnectTestBase { : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; +class TlsConnectDatagramPre13 : public TlsConnectDatagram { + public: + TlsConnectDatagramPre13() {} +}; + // A variant that is used only with Pre13. class TlsConnectGenericPre13 : public TlsConnectGeneric {}; @@ -256,8 +277,10 @@ class TlsKeyExchangeTest : public TlsConnectGeneric { void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); - std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext); - std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext); + std::vector<SSLNamedGroup> GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + std::vector<SSLNamedGroup> GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, const std::vector<SSLNamedGroup>& expectedShares); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 76d9aaaff..89f201295 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -12,6 +12,7 @@ extern "C" { #include "libssl_internals.h" } +#include <cassert> #include <iostream> #include "gtest_utils.h" #include "tls_agent.h" @@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, PRBool isServer = self->agent()->role() == TlsAgent::SERVER; if (g_ssl_gtest_verbose) { - std::cerr << "Cipher spec changed. Role=" - << (isServer ? "server" : "client") - << " direction=" << (sending ? "send" : "receive") << std::endl; + std::cerr << (isServer ? "server" : "client") << ": " + << (sending ? "send" : "receive") + << " cipher spec changed: " << newSpec->epoch << " (" + << newSpec->phase << ")" << std::endl; + } + if (!sending) { + return; } - if (!sending) return; + self->in_sequence_number_ = 0; + self->out_sequence_number_ = 0; + self->dropped_record_ = false; self->cipher_spec_.reset(new TlsCipherSpec()); - bool ret = - self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec), - SSLInt_CipherSpecToKey(isServer, newSpec), - SSLInt_CipherSpecToIv(isServer, newSpec)); + bool ret = self->cipher_spec_->Init( + SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec), + SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec)); EXPECT_EQ(true, ret); } @@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, TlsRecordHeader header; DataBuffer record; - if (!header.Parse(&parser, &record)) { + if (!header.Parse(in_sequence_number_, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } + // Track the sequence number, which is necessary for stream mode (the + // sequence number is in the header for datagram). + // + // This isn't perfectly robust. If there is a change from an active cipher + // spec to another active cipher spec (KeyUpdate for instance) AND writes + // are consolidated across that change AND packets were dropped from the + // older epoch, we will not correctly re-encrypt records in the old epoch to + // update their sequence numbers. + if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) { + ++in_sequence_number_; + } + if (FilterRecord(header, record, &offset, output) != KEEP) { changed = true; } else { @@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( header.sequence_number()}; PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); + // In stream mode, even if something doesn't change we need to re-encrypt if + // previous packets were dropped. if (action == KEEP) { - return KEEP; + if (header.is_dtls() || !dropped_record_) { + return KEEP; + } + filtered = plaintext; } if (action == DROP) { - std::cerr << "record drop: " << record << std::endl; + std::cerr << "record drop: " << header << ":" << record << std::endl; + dropped_record_ = true; return DROP; } EXPECT_GT(0x10000U, filtered.len()); - std::cerr << "record old: " << plaintext << std::endl; - std::cerr << "record new: " << filtered << std::endl; + if (action != KEEP) { + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; + } + + uint64_t seq_num; + if (header.is_dtls() || !cipher_spec_ || + header.content_type() != kTlsApplicationDataType) { + seq_num = header.sequence_number(); + } else { + seq_num = out_sequence_number_++; + } + TlsRecordHeader out_header = {header.version(), header.content_type(), + seq_num}; DataBuffer ciphertext; - bool rv = Protect(header, inner_content_type, filtered, &ciphertext); + bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); EXPECT_TRUE(rv); if (!rv) { return KEEP; } - *offset = header.Write(output, *offset, ciphertext); + *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } -bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { +bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser, + DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } @@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { } version_ = version; - sequence_number_ = 0; + // If this is DTLS, overwrite the sequence number. if (IsDtls(version)) { uint32_t tmp; if (!parser->Read(&tmp, 4)) { @@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { return false; } sequence_number_ |= static_cast<uint64_t>(tmp); + } else { + sequence_number_ = sequence_number; } return parser->ReadVariable(body, 2); } @@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, return true; } - if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false; + if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) { + return false; + } size_t len = plaintext->len(); while (len > 0 && !plaintext->data()[len - 1]) { @@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, *inner_content_type = plaintext->data()[len - 1]; plaintext->Truncate(len - 1); + if (g_ssl_gtest_verbose) { + std::cerr << "unprotect: " << std::hex << header.sequence_number() + << std::dec << " type=" << static_cast<int>(*inner_content_type) + << " " << *plaintext << std::endl; + } return true; } @@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header, *ciphertext = plaintext; return true; } + if (g_ssl_gtest_verbose) { + std::cerr << "protect: " << header.sequence_number() << std::endl; + } DataBuffer padded = plaintext; padded.Write(padded.len(), inner_content_type, 1); return cipher_spec_->Protect(header, padded, ciphertext); } +bool IsHelloRetry(const DataBuffer& body) { + static const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + return memcmp(body.data() + 2, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)) == 0; +} + +bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, + const DataBuffer& body) { + if (handshake_types_.empty()) { + return true; + } + + uint8_t type = header.handshake_type(); + if (type == kTlsHandshakeServerHello) { + if (IsHelloRetry(body)) { + type = kTlsHandshakeHelloRetryRequest; + } + } + return handshake_types_.count(type) > 0U; +} + PacketFilter::Action TlsHandshakeFilter::FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. - if (record_header.content_type() != kTlsHandshakeType) { + if ((record_header.content_type() != kTlsHandshakeType) && + (record_header.content_type() != kTlsAltHandshakeType)) { return KEEP; } @@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( while (parser.remaining()) { HandshakeHeader header; DataBuffer handshake; - if (!header.Parse(&parser, record_header, &handshake)) { + bool complete = false; + if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, + &complete)) { return KEEP; } + if (!complete) { + EXPECT_TRUE(record_header.is_dtls()); + // Save the fragment and drop it from this record. Fragments are + // coalesced with the last fragment of the handshake message. + changed = true; + preceding_fragment_.Assign(handshake); + continue; + } + preceding_fragment_.Truncate(0); + DataBuffer filtered; - PacketFilter::Action action = FilterHandshake(header, handshake, &filtered); + PacketFilter::Action action; + if (!IsFilteredType(header, handshake)) { + action = KEEP; + } else { + action = FilterHandshake(header, handshake, &filtered); + } if (action == DROP) { changed = true; std::cerr << "handshake drop: " << handshake << std::endl; @@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( std::cerr << "handshake old: " << handshake << std::endl; std::cerr << "handshake new: " << filtered << std::endl; source = &filtered; + } else if (preceding_fragment_.len()) { + changed = true; } offset = header.Write(output, offset, *source); @@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( } bool TlsHandshakeFilter::HandshakeHeader::ReadLength( - TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) { - if (!parser->Read(length, 3)) { + TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, + uint32_t* length, bool* last_fragment) { + uint32_t message_length; + if (!parser->Read(&message_length, 3)) { return false; // malformed } if (!header.is_dtls()) { + *last_fragment = true; + *length = message_length; return true; // nothing left to do } @@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength( } message_seq_ = message_seq_tmp; - uint32_t fragment_offset; - if (!parser->Read(&fragment_offset, 3)) { + uint32_t offset = 0; + if (!parser->Read(&offset, 3)) { + return false; + } + // We only parse if the fragments are all complete and in order. + if (offset != expected_offset) { + EXPECT_NE(0U, header.epoch()) + << "Received out of order handshake fragment for epoch 0"; return false; } - uint32_t fragment_length; - if (!parser->Read(&fragment_length, 3)) { + // For DTLS, we return the length of just this fragment. + if (!parser->Read(length, 3)) { return false; } - // All current tests where we are using this code don't fragment. - return (fragment_offset == 0 && fragment_length == *length); + // It's a fragment if the entire message is longer than what we have. + *last_fragment = message_length == (*length + offset); + return true; } bool TlsHandshakeFilter::HandshakeHeader::Parse( - TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) { + TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { + *complete = false; + version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed } + uint32_t length; - if (!ReadLength(parser, record_header, &length)) { + if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, + complete)) { return false; } - return parser->Read(body, length); + if (!parser->Read(body, length)) { + return false; + } + if (preceding_fragment.len()) { + body->Splice(preceding_fragment, 0); + } + return true; } size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( @@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( return KEEP; } - if (header.handshake_type() == handshake_type_) { - buffer_ = input; - } + buffer_ = input; return KEEP; } PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == handshake_type_) { - *output = buffer_; - return CHANGE; - } + *output = buffer_; + return CHANGE; +} +PacketFilter::Action TlsRecordRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (!filter_ || (header.content_type() == ct_)) { + records_.push_back({header, input}); + } return KEEP; } @@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord( return KEEP; } +PacketFilter::Action TlsHeaderRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(header); + return KEEP; +} + +const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { + if (index > headers_.size() + 1) { + return nullptr; + } + return &headers_[index]; +} + PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); bool changed = false; for (auto it = filters_.begin(); it != filters_.end(); ++it) { - PacketFilter::Action action = (*it)->Filter(in, output); + PacketFilter::Action action = (*it)->Process(in, output); if (action == DROP) { return DROP; } + if (action == CHANGE) { in = *output; changed = true; @@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } -static bool FindHelloRetryExtensions(TlsParser* parser, - const TlsVersioned& header) { - // TODO for -19 add cipher suite - if (!parser->Skip(2)) { // version - return false; - } - return true; -} - bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } @@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser, if (!parser->SkipVariable(1)) { // request context return false; } - // TODO remove the next two for -19 - if (!parser->SkipVariable(2)) { // signature_algorithms - return false; - } - if (!parser->SkipVariable(2)) { // certificate_authorities - return false; - } return true; } @@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, if (!parser->Skip(8)) { // lifetime, age add return false; } + if (!parser->SkipVariable(1)) { // ticket_nonce + return false; + } if (!parser->SkipVariable(2)) { // ticket return false; } @@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = { {kTlsHandshakeClientHello, FindClientHelloExtensions}, {kTlsHandshakeServerHello, FindServerHelloExtensions}, - {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions}, {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, {kTlsHandshakeCertificate, FindCertificateExtensions}, @@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser, PacketFilter::Action TlsExtensionFilter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (handshake_types_.count(header.handshake_type()) == 0) { - return KEEP; - } - TlsParser parser(input); if (!FindExtensions(&parser, header)) { return KEEP; @@ -610,6 +725,38 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension( return KEEP; } +PacketFilter::Action TlsExtensionInjector::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; +} + PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { @@ -628,10 +775,8 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { - EXPECT_EQ(SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); - } + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); return KEEP; } @@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, return ((1 << counter_++) & pattern_) ? DROP : KEEP; } +PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& data, + DataBuffer* changed) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +/* static */ uint32_t SelectiveRecordDropFilter::ToPattern( + std::initializer_list<size_t> records) { + uint32_t pattern = 0; + for (auto it = records.begin(); it != records.end(); ++it) { + EXPECT_GT(32U, *it); + assert(*it < 32U); + pattern |= 1 << *it; + } + return pattern; +} + PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientHello) { - *output = input; - output->Write(0, version_, 2); - return CHANGE; - } - return KEEP; + *output = input; + output->Write(0, version_, 2); + return CHANGE; +} + +PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + // Cipher suite is after version(2) and random(32). + size_t pos = 34; + if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { + // In old versions, we have to skip a session_id too. + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + } + output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); + return CHANGE; } } // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index e4030e23f..1db3b90f6 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned { uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } - size_t header_length() const { return is_dtls() ? 11 : 3; } + uint16_t epoch() const { + return static_cast<uint16_t>(sequence_number_ >> 48); + } + size_t header_length() const { return is_dtls() ? 13 : 5; } // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(TlsParser* parser, DataBuffer* body); + bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; @@ -63,10 +66,21 @@ class TlsRecordHeader : public TlsVersioned { uint64_t sequence_number_; }; +struct TlsRecord { + const TlsRecordHeader header; + const DataBuffer buffer; +}; + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} + TlsRecordFilter() + : agent_(nullptr), + count_(0), + cipher_spec_(), + dropped_record_(false), + in_sequence_number_(0), + out_sequence_number_(0) {} void SetAgent(const TlsAgent* agent) { agent_ = agent; } const TlsAgent* agent() const { return agent_; } @@ -115,14 +129,21 @@ class TlsRecordFilter : public PacketFilter { const TlsAgent* agent_; size_t count_; std::unique_ptr<TlsCipherSpec> cipher_spec_; + // Whether we dropped a record since the cipher spec changed. + bool dropped_record_; + // The sequence number we use for reading records as they are written. + uint64_t in_sequence_number_; + // The sequence number we use for writing modified records. + uint64_t out_sequence_number_; }; -inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { +inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { v.WriteStream(stream); return stream; } -inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { +inline std::ostream& operator<<(std::ostream& stream, + const TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { @@ -133,13 +154,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { stream << "Alert"; break; case kTlsHandshakeType: + case kTlsAltHandshakeType: stream << "Handshake"; break; case kTlsApplicationDataType: stream << "Data"; break; + case kTlsAckType: + stream << "ACK"; + break; default: - stream << '<' << hdr.content_type() << '>'; + stream << '<' << static_cast<int>(hdr.content_type()) << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; @@ -150,7 +175,16 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() {} + TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::set<uint8_t>& types) + : handshake_types_(types), preceding_fragment_() {} + + // This filter can be set to be selective based on handshake message type. If + // this function isn't used (or the set is empty), then all handshake messages + // will be filtered. + void SetHandshakeTypes(const std::set<uint8_t>& types) { + handshake_types_ = types; + } class HandshakeHeader : public TlsVersioned { public: @@ -158,7 +192,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, - DataBuffer* body); + const DataBuffer& preceding_fragment, DataBuffer* body, + bool* complete); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, @@ -169,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, - uint32_t* length); + uint32_t expected_offset, uint32_t* length, + bool* last_fragment); uint8_t handshake_type_; uint16_t message_seq_; @@ -185,22 +221,30 @@ class TlsHandshakeFilter : public TlsRecordFilter { DataBuffer* output) = 0; private: + bool IsFilteredType(const HandshakeHeader& header, + const DataBuffer& handshake); + + std::set<uint8_t> handshake_types_; + DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : handshake_type_(handshake_type), buffer_() {} + : TlsHandshakeFilter({handshake_type}), buffer_() {} + TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types) + : TlsHandshakeFilter(handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); + void Reset() { buffer_.Truncate(0); } + const DataBuffer& buffer() const { return buffer_; } private: - uint8_t handshake_type_; DataBuffer buffer_; }; @@ -209,17 +253,39 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, const DataBuffer& replacement) - : handshake_type_(handshake_type), buffer_(replacement) {} + : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: - uint8_t handshake_type_; DataBuffer buffer_; }; +// Make a copy of each record of a given type. +class TlsRecordRecorder : public TlsRecordFilter { + public: + TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} + TlsRecordRecorder() + : filter_(false), + ct_(content_handshake), // dummy (<optional> is C++14) + records_() {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + size_t count() const { return records_.size(); } + void Clear() { records_.clear(); } + + const TlsRecord& record(size_t i) const { return records_[i]; } + + private: + bool filter_; + uint8_t ct_; + std::vector<TlsRecord> records_; +}; + // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: @@ -230,15 +296,31 @@ class TlsConversationRecorder : public TlsRecordFilter { DataBuffer* output); private: - DataBuffer& buffer_; + DataBuffer buffer_; }; +// Make a copy of the records +class TlsHeaderRecorder : public TlsRecordFilter { + public: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + const TlsRecordHeader* header(size_t index); + + private: + std::vector<TlsRecordHeader> headers_; +}; + +typedef std::initializer_list<std::shared_ptr<PacketFilter>> + ChainedPacketFilterInit; + // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) : filters_(filters.begin(), filters.end()) {} + ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, @@ -256,13 +338,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() : handshake_types_() { - handshake_types_.insert(kTlsHandshakeClientHello); - handshake_types_.insert(kTlsHandshakeServerHello); - } + TlsExtensionFilter() + : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, + kTlsHandshakeHelloRetryRequest, + kTlsHandshakeEncryptedExtensions}) {} TlsExtensionFilter(const std::set<uint8_t>& types) - : handshake_types_(types) {} + : TlsHandshakeFilter(types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -279,8 +361,6 @@ class TlsExtensionFilter : public TlsHandshakeFilter { PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); - - std::set<uint8_t> handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { @@ -326,6 +406,21 @@ class TlsExtensionDropper : public TlsExtensionFilter { uint16_t extension_; }; +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(uint16_t ext, const DataBuffer& data) + : extension_(ext), data_(data) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + class TlsAgent; typedef std::function<void(void)> VoidFunction; @@ -352,7 +447,7 @@ class AfterRecordN : public TlsRecordFilter { class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) - : server_(server) {} + : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -377,10 +472,47 @@ class SelectiveDropFilter : public PacketFilter { uint8_t counter_; }; +// This class selectively drops complete records. The difference from +// SelectiveDropFilter is that if multiple DTLS records are in the same +// datagram, we just drop one. +class SelectiveRecordDropFilter : public TlsRecordFilter { + public: + SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) + : pattern_(pattern), counter_(0) { + if (!enabled) { + Disable(); + } + } + SelectiveRecordDropFilter(std::initializer_list<size_t> records) + : SelectiveRecordDropFilter(ToPattern(records), true) {} + + void Reset(uint32_t pattern) { + counter_ = 0; + PacketFilter::Enable(); + pattern_ = pattern; + } + + void Reset(std::initializer_list<size_t> records) { + Reset(ToPattern(records)); + } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override; + + private: + static uint32_t ToPattern(std::initializer_list<size_t> records); + + uint32_t pattern_; + uint8_t counter_; +}; + // Set the version number in the ClientHello. class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {} + TlsInspectorClientHelloVersionSetter(uint16_t version) + : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -411,6 +543,20 @@ class TlsLastByteDamager : public TlsHandshakeFilter { uint8_t type_; }; +class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { + public: + SelectedCipherSuiteReplacer(uint16_t suite) + : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + uint16_t cipher_suite_; +}; + } // namespace nss_test #endif diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc index 51ff938b1..45f6cf2bd 100644 --- a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc +++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -241,13 +241,13 @@ TEST_P(TlsHkdfTest, HkdfExpandLabel) { {/* ssl_hash_md5 */}, {/* ssl_hash_sha1 */}, {/* ssl_hash_sha224 */}, - {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b, - 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13, - 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c}, - {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01, - 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21, - 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9, - 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}}; + {0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59, + 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc, + 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1}, + {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8, + 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54, + 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7, + 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}}; const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_], diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc index efcd89e14..6c945f66e 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.cc +++ b/security/nss/gtests/ssl_gtest/tls_protect.cc @@ -32,7 +32,6 @@ void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) { } DataBuffer d(nonce, 12); - std::cerr << "Nonce " << d << std::endl; } bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length, @@ -92,8 +91,9 @@ bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq, in, inlen, out, outlen, maxlen); } -bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key, - const uint8_t *iv) { +bool TlsCipherSpec::Init(uint16_t epoch, SSLCipherAlgorithm cipher, + PK11SymKey *key, const uint8_t *iv) { + epoch_ = epoch; switch (cipher) { case ssl_calg_aes_gcm: aead_.reset(new AeadCipherAesGcm()); diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h index 4efbd6e6b..93ffd6322 100644 --- a/security/nss/gtests/ssl_gtest/tls_protect.h +++ b/security/nss/gtests/ssl_gtest/tls_protect.h @@ -20,7 +20,7 @@ class TlsRecordHeader; class AeadCipher { public: AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {} - ~AeadCipher(); + virtual ~AeadCipher(); bool Init(PK11SymKey *key, const uint8_t *iv); virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen, @@ -58,16 +58,19 @@ class AeadCipherAesGcm : public AeadCipher { // Our analog of ssl3CipherSpec class TlsCipherSpec { public: - TlsCipherSpec() : aead_() {} + TlsCipherSpec() : epoch_(0), aead_() {} - bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv); + bool Init(uint16_t epoch, SSLCipherAlgorithm cipher, PK11SymKey *key, + const uint8_t *iv); bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext, DataBuffer *ciphertext); bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext, DataBuffer *plaintext); + uint16_t epoch() const { return epoch_; } private: + uint16_t epoch_; std::unique_ptr<AeadCipher> aead_; }; |