summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest
diff options
context:
space:
mode:
authorwolfbeast <mcwerewolf@gmail.com>2018-06-06 21:27:04 +0200
committerwolfbeast <mcwerewolf@gmail.com>2018-06-06 21:27:04 +0200
commit4a71b30364a4b6d1eaf16fcfdc8e873e6697f293 (patch)
treea47014077c14579249859ad34afcc5a8f2f0730a /security/nss/gtests/ssl_gtest
parentd7da72799521386c110dbba73b1e483b00a0a56a (diff)
parent2dad0ec41d0b69c0a815012e6ea4bdde81b2875b (diff)
downloadUXP-4a71b30364a4b6d1eaf16fcfdc8e873e6697f293.tar
UXP-4a71b30364a4b6d1eaf16fcfdc8e873e6697f293.tar.gz
UXP-4a71b30364a4b6d1eaf16fcfdc8e873e6697f293.tar.lz
UXP-4a71b30364a4b6d1eaf16fcfdc8e873e6697f293.tar.xz
UXP-4a71b30364a4b6d1eaf16fcfdc8e873e6697f293.zip
Merge branch 'NSS-335'
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
-rw-r--r--security/nss/gtests/ssl_gtest/Makefile4
-rw-r--r--security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc108
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.c139
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.h20
-rw-r--r--security/nss/gtests/ssl_gtest/manifest.mn7
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc271
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc17
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc32
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc25
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc53
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc503
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc24
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc74
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc742
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc22
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc1
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc80
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc10
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc63
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.cc2
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.gyp7
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc678
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc118
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc178
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc251
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc20
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_record_unittest.cc73
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc212
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc274
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc35
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc4
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc337
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc11
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_version_unittest.cc117
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.cc9
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.h20
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc205
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.h19
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc190
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.h31
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc321
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h192
-rw-r--r--security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc14
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.cc6
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.h9
45 files changed, 4600 insertions, 928 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile
index a9a9290e0..95c111aeb 100644
--- a/security/nss/gtests/ssl_gtest/Makefile
+++ b/security/nss/gtests/ssl_gtest/Makefile
@@ -29,10 +29,6 @@ include ../common/gtest.mk
CFLAGS += -I$(CORE_DEPTH)/lib/ssl
-ifdef NSS_SSL_ENABLE_ZLIB
-include $(CORE_DEPTH)/coreconf/zlib.mk
-endif
-
ifdef NSS_DISABLE_TLS_1_3
NSS_DISABLE_TLS_1_3=1
# Run parameterized tests only, for which we can easily exclude TLS 1.3
diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
new file mode 100644
index 000000000..110cfa13a
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
@@ -0,0 +1,108 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+extern "C" {
+#include "sslbloom.h"
+}
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+// Some random-ish inputs to test with. These don't result in collisions in any
+// of the configurations that are tested below.
+static const uint8_t kHashes1[] = {
+ 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8,
+ 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34,
+ 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48};
+static const uint8_t kHashes2[] = {
+ 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59,
+ 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc,
+ 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1};
+
+typedef struct {
+ unsigned int k;
+ unsigned int bits;
+} BloomFilterConfig;
+
+class BloomFilterTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<BloomFilterConfig> {
+ public:
+ BloomFilterTest() : filter_() {}
+
+ void SetUp() { Init(); }
+
+ void TearDown() { sslBloom_Destroy(&filter_); }
+
+ protected:
+ void Init() {
+ if (filter_.filter) {
+ sslBloom_Destroy(&filter_);
+ }
+ ASSERT_EQ(SECSuccess,
+ sslBloom_Init(&filter_, GetParam().k, GetParam().bits));
+ }
+
+ bool Check(const uint8_t* hashes) {
+ return sslBloom_Check(&filter_, hashes) ? true : false;
+ }
+
+ void Add(const uint8_t* hashes, bool expect_collision = false) {
+ EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false);
+ EXPECT_TRUE(Check(hashes));
+ }
+
+ sslBloomFilter filter_;
+};
+
+TEST_P(BloomFilterTest, InitOnly) {}
+
+TEST_P(BloomFilterTest, AddToEmpty) {
+ EXPECT_FALSE(Check(kHashes1));
+ Add(kHashes1);
+}
+
+TEST_P(BloomFilterTest, AddTwo) {
+ Add(kHashes1);
+ Add(kHashes2);
+}
+
+TEST_P(BloomFilterTest, AddOneTwice) {
+ Add(kHashes1);
+ Add(kHashes1, true);
+}
+
+TEST_P(BloomFilterTest, Zero) {
+ Add(kHashes1);
+ sslBloom_Zero(&filter_);
+ EXPECT_FALSE(Check(kHashes1));
+ EXPECT_FALSE(Check(kHashes2));
+}
+
+TEST_P(BloomFilterTest, Fill) {
+ sslBloom_Fill(&filter_);
+ EXPECT_TRUE(Check(kHashes1));
+ EXPECT_TRUE(Check(kHashes2));
+}
+
+static const BloomFilterConfig kBloomFilterConfigurations[] = {
+ {1, 1}, // 1 hash, 1 bit input - high chance of collision.
+ {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size.
+ {1, 3}, // 1 hash, 3 bits - same as basic unit size.
+ {1, 4}, // 1 hash, 4 bits - 2 octets each.
+ {3, 10}, // 3 hashes over a reasonable number of bits.
+ {3, 3}, // Test that we can read multiple bits.
+ {4, 15}, // A credible filter.
+ {2, 18}, // A moderately large allocation.
+ {16, 16}, // Insane, use all of the bits from the hashes.
+ {16, 9}, // This also uses all of the bits from the hashes.
+};
+
+INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest,
+ ::testing::ValuesIn(kBloomFilterConfigurations));
+
+} // namespace nspr_test
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c
index 97b8354ae..887d85278 100644
--- a/security/nss/gtests/ssl_gtest/libssl_internals.c
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.c
@@ -34,18 +34,17 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
return SECFailure;
}
- ssl3_InitState(ss);
ssl3_RestartHandshakeHashes(ss);
// Ensure we don't overrun hs.client_random.
rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
- // Zero the client_random struct.
- PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
+ // Zero the client_random.
+ PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
// Copy over the challenge bytes.
size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
- PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len);
+ PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len);
// Rehash the SSLv2 client hello message.
return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
@@ -73,10 +72,11 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
return SECFailure;
}
ss->ssl3.mtu = mtu;
+ ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */
return SECSuccess;
}
-PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) {
PRCList *cur_p;
PRInt32 ct = 0;
@@ -92,7 +92,7 @@ PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
return ct;
}
-void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) {
PRCList *cur_p;
sslSocket *ss = ssl_FindSocket(fd);
@@ -100,27 +100,31 @@ void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
return;
}
- fprintf(stderr, "Cipher specs\n");
+ fprintf(stderr, "Cipher specs for %s\n", label);
for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
- fprintf(stderr, " %s\n", spec->phase);
+ fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec),
+ spec->epoch, spec->phase, spec->refCt);
}
}
-/* Force a timer expiry by backdating when the timer was started.
- * We could set the remaining time to 0 but then backoff would not
- * work properly if we decide to test it. */
-void SSLInt_ForceTimerExpiry(PRFileDesc *fd) {
+/* Force a timer expiry by backdating when all active timers were started. We
+ * could set the remaining time to 0 but then backoff would not work properly if
+ * we decide to test it. */
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) {
+ size_t i;
sslSocket *ss = ssl_FindSocket(fd);
if (!ss) {
- return;
+ return SECFailure;
}
- if (!ss->ssl3.hs.rtTimerCb) return;
-
- ss->ssl3.hs.rtTimerStarted =
- PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1);
+ for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
+ if (ss->ssl3.hs.timers[i].cb) {
+ ss->ssl3.hs.timers[i].started -= shift;
+ }
+ }
+ return SECSuccess;
}
#define CHECK_SECRET(secret) \
@@ -136,7 +140,6 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
}
CHECK_SECRET(currentSecret);
- CHECK_SECRET(resumptionMasterSecret);
CHECK_SECRET(dheSecret);
CHECK_SECRET(clientEarlyTrafficSecret);
CHECK_SECRET(clientHsTrafficSecret);
@@ -226,28 +229,7 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
return PR_TRUE;
}
-PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) {
- sslSocket *ss = ssl_FindSocket(fd);
- if (!ss) {
- return PR_FALSE;
- }
-
- ssl_GetSSL3HandshakeLock(ss);
- ssl_GetXmitBufLock(ss);
-
- SECStatus rv = tls13_SendNewSessionTicket(ss);
- if (rv == SECSuccess) {
- rv = ssl3_FlushHandshake(ss, 0);
- }
-
- ssl_ReleaseXmitBufLock(ss);
- ssl_ReleaseSSL3HandshakeLock(ss);
-
- return rv == SECSuccess;
-}
-
SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
- PRUint64 epoch;
sslSocket *ss;
ssl3CipherSpec *spec;
@@ -255,43 +237,40 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
if (!ss) {
return SECFailure;
}
- if (to >= (1ULL << 48)) {
+ if (to >= RECORD_SEQ_MAX) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
spec = ss->ssl3.crSpec;
- epoch = spec->read_seq_num >> 48;
- spec->read_seq_num = (epoch << 48) | to;
+ spec->seqNum = to;
/* For DTLS, we need to fix the record sequence number. For this, we can just
* scrub the entire structure on the assumption that the new sequence number
* is far enough past the last received sequence number. */
- if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
- dtls_RecordSetRecvd(&spec->recvdRecords, to);
+ dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum);
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
- PRUint64 epoch;
sslSocket *ss;
ss = ssl_FindSocket(fd);
if (!ss) {
return SECFailure;
}
- if (to >= (1ULL << 48)) {
+ if (to >= RECORD_SEQ_MAX) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
- epoch = ss->ssl3.cwSpec->write_seq_num >> 48;
- ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to;
+ ss->ssl3.cwSpec->seqNum = to;
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
@@ -305,9 +284,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
return SECFailure;
}
ssl_GetSpecReadLock(ss);
- to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra;
+ to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
ssl_ReleaseSpecReadLock(ss);
- return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX);
+ return SSLInt_AdvanceWriteSeqNum(fd, to);
}
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
@@ -333,46 +312,20 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
return SECSuccess;
}
-static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer,
- ssl3CipherSpec *spec) {
- return isServer ? &spec->server : &spec->client;
+PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) {
+ return spec->keyMaterial.key;
}
-PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) {
- return GetKeyingMaterial(isServer, spec)->write_key;
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) {
+ return spec->cipherDef->calg;
}
-SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
- ssl3CipherSpec *spec) {
- return spec->cipher_def->calg;
+const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) {
+ return spec->keyMaterial.iv;
}
-unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) {
- return GetKeyingMaterial(isServer, spec)->write_iv;
-}
-
-SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) {
- sslSocket *ss;
-
- ss = ssl_FindSocket(fd);
- if (!ss) {
- return SECFailure;
- }
-
- ss->opt.enableShortHeaders = PR_TRUE;
- return SECSuccess;
-}
-
-SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) {
- sslSocket *ss;
-
- ss = ssl_FindSocket(fd);
- if (!ss) {
- return SECFailure;
- }
-
- *result = ss->ssl3.hs.shortHeaders;
- return SECSuccess;
+PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) {
+ return spec->epoch;
}
void SSLInt_SetTicketLifetime(uint32_t lifetime) {
@@ -405,3 +358,21 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
return SECSuccess;
}
+
+void SSLInt_RolloverAntiReplay(void) {
+ tls13_AntiReplayRollover(ssl_TimeUsec());
+}
+
+SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch,
+ PRUint16 *writeEpoch) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss || !readEpoch || !writeEpoch) {
+ return SECFailure;
+ }
+
+ ssl_GetSpecReadLock(ss);
+ *readEpoch = ss->ssl3.crSpec->epoch;
+ *writeEpoch = ss->ssl3.cwSpec->epoch;
+ ssl_ReleaseSpecReadLock(ss);
+ return SECSuccess;
+}
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h
index 33709c4b4..95d4afdaf 100644
--- a/security/nss/gtests/ssl_gtest/libssl_internals.h
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.h
@@ -24,9 +24,9 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext);
void SSLInt_ClearSelfEncryptKey();
void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key);
-PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd);
-void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd);
-void SSLInt_ForceTimerExpiry(PRFileDesc *fd);
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd);
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd);
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift);
SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu);
PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd);
PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd);
@@ -35,23 +35,23 @@ PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd);
SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len);
PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType);
PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type);
-PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd);
SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);
+SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch,
+ PRUint16 *writeEpoch);
SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
sslCipherSpecChangedFunc func,
void *arg);
-PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec);
-SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
- ssl3CipherSpec *spec);
-unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec);
-SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd);
-SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result);
+PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec);
+PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec);
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec);
+const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec);
void SSLInt_SetTicketLifetime(uint32_t lifetime);
void SSLInt_SetMaxEarlyDataSize(uint32_t size);
SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size);
+void SSLInt_RolloverAntiReplay(void);
#endif // ndef libssl_internals_h_
diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn
index cc729c0f1..5d893bab3 100644
--- a/security/nss/gtests/ssl_gtest/manifest.mn
+++ b/security/nss/gtests/ssl_gtest/manifest.mn
@@ -12,11 +12,13 @@ CSRCS = \
$(NULL)
CPPSRCS = \
+ bloomfilter_unittest.cc \
ssl_0rtt_unittest.cc \
ssl_agent_unittest.cc \
ssl_auth_unittest.cc \
ssl_cert_ext_unittest.cc \
ssl_ciphersuite_unittest.cc \
+ ssl_custext_unittest.cc \
ssl_damage_unittest.cc \
ssl_dhe_unittest.cc \
ssl_drop_unittest.cc \
@@ -29,11 +31,16 @@ CPPSRCS = \
ssl_gather_unittest.cc \
ssl_gtest.cc \
ssl_hrr_unittest.cc \
+ ssl_keylog_unittest.cc \
+ ssl_keyupdate_unittest.cc \
ssl_loopback_unittest.cc \
+ ssl_misc_unittest.cc \
ssl_record_unittest.cc \
ssl_resumption_unittest.cc \
+ ssl_renegotiation_unittest.cc \
ssl_skip_unittest.cc \
ssl_staticrsa_unittest.cc \
+ ssl_tls13compat_unittest.cc \
ssl_v2_client_hello_unittest.cc \
ssl_version_unittest.cc \
ssl_versionpolicy_unittest.cc \
diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index 85b7011a1..a60295490 100644
--- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -7,6 +7,7 @@
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
extern "C" {
@@ -44,6 +45,93 @@ TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) {
SendReceive();
}
+TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) {
+ // The test fixtures call SSL_SetupAntiReplay() in SetUp(). This results in
+ // 0-RTT being rejected until at least one window passes. SetupFor0Rtt()
+ // forces a rollover of the anti-replay filters, which clears this state.
+ // Here, we do the setup manually here without that forced rollover.
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+class TlsZeroRttReplayTest : public TlsConnectTls13 {
+ private:
+ class SaveFirstPacket : public PacketFilter {
+ public:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!packet_.len() && input.len()) {
+ packet_ = input;
+ }
+ return KEEP;
+ }
+
+ const DataBuffer& packet() const { return packet_; }
+
+ private:
+ DataBuffer packet_;
+ };
+
+ protected:
+ void RunTest(bool rollover) {
+ // Run the initial handshake
+ SetupForZeroRtt();
+
+ // Now run a true 0-RTT handshake, but capture the first packet.
+ auto first_packet = std::make_shared<SaveFirstPacket>();
+ client_->SetPacketFilter(first_packet);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ EXPECT_LT(0U, first_packet->packet().len());
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+
+ if (rollover) {
+ SSLInt_RolloverAntiReplay();
+ }
+
+ // Now replay that packet against the server.
+ Reset();
+ server_->StartConnect();
+ server_->Set0RttEnabled(true);
+
+ // Capture the early_data extension, which should not appear.
+ auto early_data_ext =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
+ server_->SetPacketFilter(early_data_ext);
+
+ // Finally, replay the ClientHello and force the server to consume it. Stop
+ // after the server sends its first flight; the client will not be able to
+ // complete this handshake.
+ server_->adapter()->PacketReceived(first_packet->packet());
+ server_->Handshake();
+ EXPECT_FALSE(early_data_ext->captured());
+ }
+};
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplay) { RunTest(false); }
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { RunTest(true); }
+
// Test that we don't try to send 0-RTT data when the server sent
// us a ticket without the 0-RTT flags.
TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
@@ -52,8 +140,7 @@ TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
Reset();
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
// Now turn on 0-RTT but too late for the ticket.
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
@@ -80,8 +167,7 @@ TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) {
TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
ExpectResumption(RESUME_NONE);
server_->Set0RttEnabled(true);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
// Client sends ordinary ClientHello.
client_->Handshake();
@@ -99,6 +185,61 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
CheckKeys();
}
+// A small sleep after sending the ClientHello means that the ticket age that
+// arrives at the server is too low. With a small tolerance for variation in
+// ticket age (which is determined by the |window| parameter that is passed to
+// SSL_SetupAntiReplay()), the server then rejects early data.
+TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3));
+ SSLInt_RolloverAntiReplay(); // Make sure to flush replay state.
+ SSLInt_RolloverAntiReplay();
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, []() {
+ PR_Sleep(PR_MillisecondsToInterval(10));
+ return true;
+ });
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ SendReceive();
+}
+
+// In this test, we falsely inflate the estimate of the RTT by delaying the
+// ServerHello on the first handshake. This results in the server estimating a
+// higher value of the ticket age than the client ultimately provides. Add a
+// small tolerance for variation in ticket age and the ticket will appear to
+// arrive prematurely, causing the server to reject early data.
+TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+ PR_Sleep(PR_MillisecondsToInterval(10));
+ Handshake(); // Remainder of handshake
+ CheckConnected();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3));
+ SSLInt_RolloverAntiReplay(); // Make sure to flush replay state.
+ SSLInt_RolloverAntiReplay();
+ ExpectResumption(RESUME_TICKET);
+ ExpectEarlyDataAccepted(false);
+ StartConnect();
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
EnableAlpn();
SetupForZeroRtt();
@@ -117,6 +258,14 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
CheckAlpn("a");
}
+// NOTE: In this test and those below, the client always sends
+// post-ServerHello alerts with the handshake keys, even if the server
+// has accepted 0-RTT. In some cases, as with errors in
+// EncryptedExtensions, the client can't know the server's behavior,
+// and in others it's just simpler. What the server is expecting
+// depends on whether it accepted 0-RTT or not. Eventually, we may
+// make the server trial decrypt.
+//
// Have the server negotiate a different ALPN value, and therefore
// reject 0-RTT.
TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) {
@@ -155,12 +304,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) {
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
- ExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
return true;
});
- Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
}
// Set up with no ALPN and then set the client so it thinks it has ALPN.
@@ -175,12 +329,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
PRUint8 b[] = {'b'};
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
- ExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
return true;
});
- Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
}
// Remove the old ALPN value and so the client will not offer early data.
@@ -218,9 +377,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
- client_->StartConnect();
- server_->StartConnect();
-
+ StartConnect();
// We will send the early data xtn without sending actual early data. Thus
// a 1.2 server shouldn't fail until the client sends an alert because the
// client sends end_of_early_data only after reading the server's flight.
@@ -261,9 +418,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
- client_->StartConnect();
- server_->StartConnect();
-
+ StartConnect();
// Send the early data xtn in the CH, followed by early app data. The server
// will fail right after sending its flight, when receiving the early data.
client_->Set0RttEnabled(true);
@@ -310,7 +465,6 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
client_->Handshake();
CheckEarlyDataLimit(client_, short_size);
@@ -364,7 +518,6 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
- client_->ExpectSendAlert(kTlsAlertEndOfEarlyData);
client_->Handshake(); // Send ClientHello
CheckEarlyDataLimit(client_, limit);
@@ -399,4 +552,86 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
}
}
+class PacketCoalesceFilter : public PacketFilter {
+ public:
+ PacketCoalesceFilter() : packet_data_() {}
+
+ void SendCoalesced(std::shared_ptr<TlsAgent> agent) {
+ agent->SendDirect(packet_data_);
+ }
+
+ protected:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ packet_data_.Write(packet_data_.len(), input);
+ return DROP;
+ }
+
+ private:
+ DataBuffer packet_data_;
+};
+
+TEST_P(TlsConnectTls13, ZeroRttOrdering) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ // Send out the ClientHello.
+ client_->Handshake();
+
+ // Now, coalesce the next three things from the client: early data, second
+ // flight and 1-RTT data.
+ auto coalesce = std::make_shared<PacketCoalesceFilter>();
+ client_->SetPacketFilter(coalesce);
+
+ // Send (and hold) early data.
+ static const std::vector<uint8_t> early_data = {3, 2, 1};
+ EXPECT_EQ(static_cast<PRInt32>(early_data.size()),
+ PR_Write(client_->ssl_fd(), early_data.data(), early_data.size()));
+
+ // Send (and hold) the second client handshake flight.
+ // The client sends EndOfEarlyData after seeing the server Finished.
+ server_->Handshake();
+ client_->Handshake();
+
+ // Send (and hold) 1-RTT data.
+ static const std::vector<uint8_t> late_data = {7, 8, 9, 10};
+ EXPECT_EQ(static_cast<PRInt32>(late_data.size()),
+ PR_Write(client_->ssl_fd(), late_data.data(), late_data.size()));
+
+ // Now release them all at once.
+ coalesce->SendCoalesced(client_);
+
+ // Now ensure that the three steps are exposed in the right order on the
+ // server: delivery of early data, handshake callback, delivery of 1-RTT.
+ size_t step = 0;
+ server_->SetHandshakeCallback([&step](TlsAgent*) {
+ EXPECT_EQ(1U, step);
+ ++step;
+ });
+
+ std::vector<uint8_t> buf(10);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(early_data, buf);
+ EXPECT_EQ(0U, step);
+ ++step;
+
+ // The third read should be after the handshake callback and should return the
+ // data that was sent after the handshake completed.
+ buf.resize(10);
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(late_data, buf);
+ EXPECT_EQ(2U, step);
+}
+
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+#endif
+
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
index 5035a338d..0aa9a4c78 100644
--- a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -31,7 +31,7 @@ const static uint8_t kCannedTls13ClientHello[] = {
0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00,
0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01,
- 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00,
+ 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00,
0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc,
0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65,
0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a,
@@ -44,13 +44,14 @@ const static uint8_t kCannedTls13ClientHello[] = {
0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02};
const static uint8_t kCannedTls13ServerHello[] = {
- 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0,
- 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5,
- 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01,
- 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf,
- 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65,
- 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a,
- 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19};
+ 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
+ 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
+ 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
+ 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
+ 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
+ 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
+ 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13};
static const char *k0RttData = "ABCDEF";
TEST_P(TlsAgentTest, EarlyFinished) {
diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
index dbcbc9aa3..dbcdd92ea 100644
--- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -29,7 +29,25 @@ TEST_P(TlsConnectGeneric, ServerAuthBigRsa) {
}
TEST_P(TlsConnectGeneric, ServerAuthRsaChain) {
- Reset(TlsAgent::kServerRsaChain);
+ Reset("rsa_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) {
+ Reset("rsa_pss_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) {
+ Reset("rsa_ca_rsa_pss_chain");
Connect();
CheckKeys();
size_t chain_length;
@@ -141,13 +159,11 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
public:
+ TlsZeroCertificateRequestSigAlgsFilter()
+ : TlsHandshakeFilter({kTlsHandshakeCertificateRequest}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeCertificateRequest) {
- return KEEP;
- }
-
TlsParser parser(input);
std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl;
@@ -581,8 +597,7 @@ class EnforceNoActivity : public PacketFilter {
TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send ServerHello
client_->Handshake(); // Send ClientKeyExchange and Finished
@@ -610,8 +625,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send ServerHello
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
index 3463782e0..36ee104af 100644
--- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -82,9 +82,8 @@ TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) {
ssl_kea_rsa));
EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
&kSctItem, ssl_kea_rsa));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -96,9 +95,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) {
EnsureTlsSetup();
EXPECT_TRUE(
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -120,9 +117,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) {
TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -173,16 +168,14 @@ TEST_P(TlsConnectGeneric, OcspNotRequested) {
// Even if the client asks, the server has nothing unless it is configured.
TEST_P(TlsConnectGeneric, OcspNotProvided) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
client_->SetAuthCertificateCallback(CheckNoOCSP);
Connect();
}
TEST_P(TlsConnectGenericPre13, OcspMangled) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
EXPECT_TRUE(
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
@@ -197,8 +190,7 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
auto capture_ocsp =
std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
server_->SetPacketFilter(capture_ocsp);
@@ -225,8 +217,7 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
uint8_t hugeOcspValue[16385];
memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue));
diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index 85c30b2bf..810656868 100644
--- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -31,11 +31,11 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
public:
TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version,
uint16_t cipher_suite, SSLNamedGroup group,
- SSLSignatureScheme signature_scheme)
+ SSLSignatureScheme sig_scheme)
: TlsConnectTestBase(variant, version),
cipher_suite_(cipher_suite),
group_(group),
- signature_scheme_(signature_scheme),
+ sig_scheme_(sig_scheme),
csinfo_({0}) {
SECStatus rv =
SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_));
@@ -60,14 +60,14 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
server_->ConfigNamedGroups(groups);
kea_type_ = SSLInt_GetKEAType(group_);
- client_->SetSignatureSchemes(&signature_scheme_, 1);
- server_->SetSignatureSchemes(&signature_scheme_, 1);
+ client_->SetSignatureSchemes(&sig_scheme_, 1);
+ server_->SetSignatureSchemes(&sig_scheme_, 1);
}
}
virtual void SetupCertificate() {
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- switch (signature_scheme_) {
+ switch (sig_scheme_) {
case ssl_sig_rsa_pkcs1_sha256:
case ssl_sig_rsa_pkcs1_sha384:
case ssl_sig_rsa_pkcs1_sha512:
@@ -93,8 +93,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
auth_type_ = ssl_auth_ecdsa;
break;
default:
- ASSERT_TRUE(false) << "Unsupported signature scheme: "
- << signature_scheme_;
+ ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_;
break;
}
} else {
@@ -187,7 +186,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
SSLAuthType auth_type_;
SSLKEAType kea_type_;
SSLNamedGroup group_;
- SSLSignatureScheme signature_scheme_;
+ SSLSignatureScheme sig_scheme_;
SSLCipherSuiteInfo csinfo_;
};
@@ -236,27 +235,29 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
ConnectAndCheckCipherSuite();
}
-// This only works for stream ciphers because we modify the sequence number -
-// which is included explicitly in the DTLS record header - and that trips a
-// different error code. Note that the message that the client sends would not
-// decrypt (the nonce/IV wouldn't match), but the record limit is hit before
-// attempting to decrypt a record.
TEST_P(TlsCipherSuiteTest, ReadLimit) {
SetupCertificate();
EnableSingleCipher();
ConnectAndCheckCipherSuite();
- EXPECT_EQ(SECSuccess,
- SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write()));
- EXPECT_EQ(SECSuccess,
- SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write()));
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ uint64_t last = last_safe_write();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
- client_->SendData(10, 10);
- server_->ReadBytes(); // This should be OK.
+ client_->SendData(10, 10);
+ server_->ReadBytes(); // This should be OK.
+ } else {
+ // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean
+ // that the sequence numbers would reset and we wouldn't hit the limit. So
+ // we move the sequence number to one less than the limit directly and don't
+ // test sending and receiving just before the limit.
+ uint64_t last = record_limit() - 1;
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
+ }
- // The payload needs to be big enough to pass for encrypted. In the extreme
- // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for
- // authentication tag.
- static const uint8_t payload[18] = {6};
+ // The payload needs to be big enough to pass for encrypted. The code checks
+ // the limit before it tries to decrypt.
+ static const uint8_t payload[32] = {6};
DataBuffer record;
uint64_t epoch;
if (variant_ == ssl_variant_datagram) {
@@ -271,13 +272,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_,
payload, sizeof(payload), &record,
(epoch << 48) | record_limit());
- server_->adapter()->PacketReceived(record);
+ client_->SendDirect(record);
server_->ExpectReadWriteError();
server_->ReadBytes();
EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code());
}
TEST_P(TlsCipherSuiteTest, WriteLimit) {
+ // This asserts in TLS 1.3 because we expect an automatic update.
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
SetupCertificate();
EnableSingleCipher();
ConnectAndCheckCipherSuite();
diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
new file mode 100644
index 000000000..dad944a1f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
@@ -0,0 +1,503 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "sslexp.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static void IncrementCounterArg(void *arg) {
+ if (arg) {
+ auto *called = reinterpret_cast<size_t *>(arg);
+ ++*called;
+ }
+}
+
+PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_FALSE;
+}
+
+PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_TRUE;
+}
+
+SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECSuccess;
+}
+
+// All of the (current) set of supported extensions, plus a few extra.
+static const uint16_t kManyExtensions[] = {
+ ssl_server_name_xtn,
+ ssl_cert_status_xtn,
+ ssl_supported_groups_xtn,
+ ssl_ec_point_formats_xtn,
+ ssl_signature_algorithms_xtn,
+ ssl_signature_algorithms_cert_xtn,
+ ssl_use_srtp_xtn,
+ ssl_app_layer_protocol_xtn,
+ ssl_signed_cert_timestamp_xtn,
+ ssl_padding_xtn,
+ ssl_extended_master_secret_xtn,
+ ssl_session_ticket_xtn,
+ ssl_tls13_key_share_xtn,
+ ssl_tls13_pre_shared_key_xtn,
+ ssl_tls13_early_data_xtn,
+ ssl_tls13_supported_versions_xtn,
+ ssl_tls13_cookie_xtn,
+ ssl_tls13_psk_key_exchange_modes_xtn,
+ ssl_tls13_ticket_early_data_info_xtn,
+ ssl_tls13_certificate_authorities_xtn,
+ ssl_next_proto_nego_xtn,
+ ssl_renegotiation_info_xtn,
+ ssl_tls13_short_header_xtn,
+ 1,
+ 0xffff};
+// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS),
+// plus the deprecated values (see sslt.h), and two extra dummy values.
+PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions));
+
+void InstallManyWriters(std::shared_ptr<TlsAgent> agent,
+ SSLExtensionWriter writer, size_t *installed = nullptr,
+ size_t *called = nullptr) {
+ for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) {
+ SSLExtensionSupport support = ssl_ext_none;
+ SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support);
+ ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail";
+
+ rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer,
+ called, NoopExtensionHandler, nullptr);
+ if (support == ssl_ext_native_only) {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ } else {
+ if (installed) {
+ ++*installed;
+ }
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(client_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ EXPECT_EQ(installed, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(server_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ // Extension writers are all called for each of ServerHello,
+ // EncryptedExtensions, and Certificate.
+ EXPECT_EQ(installed * 3, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) {
+ EnsureTlsSetup();
+ InstallManyWriters(client_, EmptyExtensionWriter);
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) {
+ EnsureTlsSetup();
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ // Sending extensions that the client doesn't expect leads to extensions
+ // appearing even if the client didn't send one, or in the wrong messages.
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Install an writer to disable sending of a natively-supported extension.
+TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that doesn't do anything. You have to specify
+ // something; passing all nullptr values removes an existing handler.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
+ client_->SetPacketFilter(capture);
+
+ Connect();
+ // So nothing will be sent.
+ EXPECT_FALSE(capture->captured());
+}
+
+// An extension that is unlikely to be parsed as valid.
+static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19};
+
+static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_NE(nullptr, agent);
+ EXPECT_NE(nullptr, data);
+ EXPECT_NE(nullptr, len);
+ EXPECT_EQ(0U, *len);
+ EXPECT_LT(0U, maxLen);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+
+ if (message != ssl_hs_client_hello && message != ssl_hs_server_hello &&
+ message != ssl_hs_encrypted_extensions) {
+ return PR_FALSE;
+ }
+
+ *len = static_cast<unsigned int>(sizeof(kNonsenseExtension));
+ EXPECT_GE(maxLen, *len);
+ if (maxLen < *len) {
+ return PR_FALSE;
+ }
+ PORT_Memcpy(data, kNonsenseExtension, *len);
+ return PR_TRUE;
+}
+
+// Override the extension handler for an natively-supported and produce
+// nonsense, which results in a handshake failure.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that sends nonsense.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter,
+ client_.get(), NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(ssl_signed_cert_timestamp_xtn);
+ client_->SetPacketFilter(capture);
+
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static SECStatus NonsenseExtensionHandler(PRFileDesc *fd,
+ SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert,
+ void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+ if (agent->role() == TlsAgent::SERVER) {
+ EXPECT_EQ(ssl_hs_client_hello, message);
+ } else {
+ EXPECT_TRUE(message == ssl_hs_server_hello ||
+ message == ssl_hs_encrypted_extensions);
+ }
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ DataBuffer(data, len));
+ EXPECT_NE(nullptr, alert);
+ return SECSuccess;
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe5;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ client_->SetPacketFilter(capture);
+
+ // Handle it so that the handshake completes.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ NonsenseExtensionHandler, server_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_server_hello) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in ServerHello.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterSH, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the ServerHello only and check it.
+ auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeServerHello});
+ server_->SetPacketFilter(capture);
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_encrypted_extensions) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in EncryptedExtensions.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterEE, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the EncryptedExtensions only and check it.
+ auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
+ server_->SetTlsRecordFilter(capture);
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) {
+ EnsureTlsSetup();
+
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = std::make_shared<TlsExtensionCapture>(extension_code);
+ server_->SetPacketFilter(capture);
+
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe7;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff58;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+static const uint8_t kCustomAlert = 0xf6;
+
+SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ *alert = kCustomAlert;
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffea;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kCustomAlert);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5a;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kCustomAlert);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Configure a custom extension hook badly.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter,
+ nullptr, nullptr, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) {
+ EnsureTlsSetup();
+ // This doesn't actually overrun the buffer, but it says that it does.
+ auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) -> PRBool {
+ *len = maxLen + 1;
+ return PR_TRUE;
+ };
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ client_->StartConnect();
+ client_->Handshake();
+ client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR);
+}
+
+} // namespace "nss_test"
diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
index 69fd00331..d1668b823 100644
--- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -29,8 +29,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
std::cerr << "Damaging HS secret" << std::endl;
@@ -51,16 +50,12 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- client_->ExpectSendAlert(kTlsAlertDecryptError);
- // The server can't read the client's alert, so it also sends an alert.
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->SetPacketFilter(std::make_shared<AfterRecordN>(
server_, client_,
0, // ServerHello.
[this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
}
TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
@@ -79,16 +74,7 @@ TEST_P(TlsConnectTls13, DamageServerSignature) {
auto filter =
std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
server_->SetTlsRecordFilter(filter);
- filter->EnableDecryption();
- client_->ExpectSendAlert(kTlsAlertDecryptError);
- // The server can't read the client's alert, so it also sends an alert.
- if (variant_ == ssl_variant_stream) {
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
- ConnectExpectFail();
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- } else {
- ConnectExpectFailOneSide(TlsAgent::CLIENT);
- }
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
@@ -100,11 +86,9 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
client_->SetTlsRecordFilter(filter);
server_->ExpectSendAlert(kTlsAlertDecryptError);
- filter->EnableDecryption();
// Do these handshakes by hand to avoid race condition on
// the client processing the server's alert.
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
client_->Handshake();
diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
index 97943303a..4aa3bb639 100644
--- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -59,8 +59,7 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
auto groups_capture =
std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
auto shares_capture =
@@ -90,8 +89,7 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
// because the client automatically sends the supported groups extension.
TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
Connect();
@@ -105,14 +103,11 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
public:
- TlsDheServerKeyExchangeDamager() {}
+ TlsDheServerKeyExchangeDamager()
+ : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Damage the first octet of dh_p. Anything other than the known prime will
// be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled.
*output = input;
@@ -126,8 +121,7 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
// the signature until everything else has been checked.
TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -147,7 +141,8 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
kYZeroPad
};
- TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {}
+ TlsDheSkeChangeY(uint8_t handshake_type, ChangeYTo change)
+ : TlsHandshakeFilter({handshake_type}), change_Y_(change) {}
protected:
void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
@@ -213,7 +208,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
public:
TlsDheSkeChangeYServer(ChangeYTo change, bool modify)
- : TlsDheSkeChangeY(change), modify_(modify), p_() {}
+ : TlsDheSkeChangeY(kTlsHandshakeServerKeyExchange, change),
+ modify_(modify),
+ p_() {}
const DataBuffer& prime() const { return p_; }
@@ -221,10 +218,6 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
size_t offset = 2;
// Read dh_p
uint32_t dh_len = 0;
@@ -254,16 +247,13 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
TlsDheSkeChangeYClient(
ChangeYTo change,
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
- : TlsDheSkeChangeY(change), server_filter_(server_filter) {}
+ : TlsDheSkeChangeY(kTlsHandshakeClientKeyExchange, change),
+ server_filter_(server_filter) {}
protected:
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
- return KEEP;
- }
-
ChangeY(input, output, 0, server_filter_->prime());
return CHANGE;
}
@@ -289,8 +279,7 @@ class TlsDamageDHYTest
TEST_P(TlsDamageDHYTest, DamageServerY) {
EnableOnlyDheCiphers();
if (std::get<3>(GetParam())) {
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
server_->SetPacketFilter(
@@ -320,8 +309,7 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
TEST_P(TlsDamageDHYTest, DamageClientY) {
EnableOnlyDheCiphers();
if (std::get<3>(GetParam())) {
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
// The filter on the server is required to capture the prime.
auto server_filter =
@@ -370,13 +358,10 @@ INSTANTIATE_TEST_CASE_P(
class TlsDheSkeMakePEven : public TlsHandshakeFilter {
public:
+ TlsDheSkeMakePEven() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Find the end of dh_p
uint32_t dh_len = 0;
EXPECT_TRUE(input.Read(0, 2, &dh_len));
@@ -404,13 +389,10 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
public:
+ TlsDheSkeZeroPadP() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
*output = input;
uint32_t dh_len = 0;
EXPECT_TRUE(input.Read(0, 2, &dh_len));
@@ -445,8 +427,7 @@ TEST_P(TlsConnectGenericPre13, PadDheP) {
// Note: This test case can take ages to generate the weak DH key.
TEST_P(TlsConnectGenericPre13, WeakDHGroup) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
EXPECT_EQ(SECSuccess,
SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
@@ -496,8 +477,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) {
// custom group in contrast to the previous test.
TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
ssl_grp_ffdhe_2048};
@@ -525,8 +505,7 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) {
TEST_P(TlsConnectGenericPre13, MismatchDHE) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group};
EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups,
PR_ARRAY_SIZE(serverGroups)));
@@ -544,7 +523,8 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
- CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
@@ -557,7 +537,8 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
server_->SetPacketFilter(serverCapture);
ExpectResumption(RESUME_TICKET);
Connect();
- CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
ASSERT_LT(0UL, clientCapture->extension().len());
ASSERT_LT(0UL, serverCapture->extension().len());
}
@@ -565,16 +546,15 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
public:
TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len)
- : version_(version), data_(data), len_(len) {}
+ : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ version_(version),
+ data_(data),
+ len_(len) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
TlsParser parser(input);
EXPECT_TRUE(parser.SkipVariable(2)); // dh_p
EXPECT_TRUE(parser.SkipVariable(2)); // dh_g
diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
index 3cc3b0e62..c059e9938 100644
--- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -6,6 +6,7 @@
#include "secerr.h"
#include "ssl.h"
+#include "sslexp.h"
extern "C" {
// This is not something that should make you happy.
@@ -20,13 +21,13 @@ extern "C" {
namespace nss_test {
-TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) {
+TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) {
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
-TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
@@ -35,36 +36,760 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
// This drops the first transmission from both the client and server of all
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
-TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) {
+TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) {
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}
// This drops the server's first flight three times.
-TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) {
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) {
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}
// This drops the client's second flight once
-TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) {
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) {
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
// This drops the client's second flight three times.
-TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) {
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) {
client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
// This drops the server's second flight three times.
-TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) {
+TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
+class TlsDropDatagram13 : public TlsConnectDatagram13 {
+ public:
+ TlsDropDatagram13()
+ : client_filters_(),
+ server_filters_(),
+ expected_client_acks_(0),
+ expected_server_acks_(1) {}
+
+ void SetUp() {
+ TlsConnectDatagram13::SetUp();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ SetFilters();
+ }
+
+ void SetFilters() {
+ EnsureTlsSetup();
+ client_->SetPacketFilter(client_filters_.chain_);
+ client_filters_.ack_->SetAgent(client_.get());
+ client_filters_.ack_->EnableDecryption();
+ server_->SetPacketFilter(server_filters_.chain_);
+ server_filters_.ack_->SetAgent(server_.get());
+ server_filters_.ack_->EnableDecryption();
+ }
+
+ void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) {
+ agent->Handshake(); // Read flight.
+ ShiftDtlsTimers();
+ agent->Handshake(); // Generate ACK.
+ }
+
+ void ShrinkPostServerHelloMtu() {
+ // Abuse the custom extension mechanism to modify the MTU so that the
+ // Certificate message is split into two pieces.
+ ASSERT_EQ(
+ SECSuccess,
+ SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 1,
+ [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data,
+ unsigned int* len, unsigned int maxLen, void* arg) -> PRBool {
+ SSLInt_SetMTU(fd, 500); // Splits the certificate.
+ return PR_FALSE;
+ },
+ nullptr,
+ [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data,
+ unsigned int len, SSLAlertDescription* alert,
+ void* arg) -> SECStatus { return SECSuccess; },
+ nullptr));
+ }
+
+ protected:
+ class DropAckChain {
+ public:
+ DropAckChain()
+ : records_(std::make_shared<TlsRecordRecorder>()),
+ ack_(std::make_shared<TlsRecordRecorder>(content_ack)),
+ drop_(std::make_shared<SelectiveRecordDropFilter>(0, false)),
+ chain_(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, ack_, drop_}))) {}
+
+ const TlsRecord& record(size_t i) const { return records_->record(i); }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsRecordRecorder> ack_;
+ std::shared_ptr<SelectiveRecordDropFilter> drop_;
+ std::shared_ptr<PacketFilter> chain_;
+ };
+
+ void CheckAcks(const DropAckChain& chain, size_t index,
+ std::vector<uint64_t> acks) {
+ const DataBuffer& buf = chain.ack_->record(index).buffer;
+ size_t offset = 0;
+
+ EXPECT_EQ(acks.size() * 8, buf.len());
+ if ((acks.size() * 8) != buf.len()) {
+ while (offset < buf.len()) {
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl;
+ }
+ return;
+ }
+
+ for (size_t i = 0; i < acks.size(); ++i) {
+ uint64_t a = acks[i];
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ if (a != ack) {
+ ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a
+ << " got=0x" << ack << std::dec;
+ }
+ }
+ }
+
+ void CheckedHandshakeSendReceive() {
+ Handshake();
+ CheckPostHandshake();
+ }
+
+ void CheckPostHandshake() {
+ CheckConnected();
+ SendReceive();
+ EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count());
+ EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count());
+ }
+
+ protected:
+ DropAckChain client_filters_;
+ DropAckChain server_filters_;
+ size_t expected_client_acks_;
+ size_t expected_server_acks_;
+};
+
+// All of these tests produce a minimum one ACK, from the server
+// to the client upon receiving the client Finished.
+// Dropping complete first and second flights does not produce
+// ACKs
+TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) {
+ client_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) {
+ server_filters_.drop_->Reset(0xff);
+ StartConnect();
+ client_->Handshake();
+ // Send the first flight, all dropped.
+ server_->Handshake();
+ server_filters_.drop_->Disable();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the server's first record also does not produce
+// an ACK because the next record is ignored.
+// TODO(ekr@rtfm.com): We should generate an empty ACK.
+TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) {
+ server_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the second packet of the server's flight should
+// produce an ACK.
+TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ HandshakeAndAck(client_);
+ expected_client_acks_ = 1;
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0, {0});
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Drop the server ACK and verify that the client retransmits
+// the ClientHello.
+TEST_F(TlsDropDatagram13, DropServerAckOnce) {
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // At this point the server has sent it's first flight,
+ // so make it drop the ACK.
+ server_filters_.drop_->Reset({0});
+ client_->Handshake(); // Send the client Finished.
+ server_->Handshake(); // Receive the Finished and send the ACK.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Wait for the DTLS timeout to make sure we retransmit the
+ // Finished.
+ ShiftDtlsTimers();
+ client_->Handshake(); // Retransmit the Finished.
+ server_->Handshake(); // Read the Finished and send an ACK.
+ uint8_t buf[1];
+ PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ expected_server_acks_ = 2;
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ CheckPostHandshake();
+ // There should be two copies of the finished ACK
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Drop the client certificate verify.
+TEST_F(TlsDropDatagram13, DropClientCertVerify) {
+ StartConnect();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->Handshake();
+ server_->Handshake();
+ // Have the client drop Cert Verify
+ client_filters_.drop_->Reset({1});
+ expected_server_acks_ = 2;
+ CheckedHandshakeSendReceive();
+ // Ack of the Cert.
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ // Ack of the whole client handshake.
+ CheckAcks(
+ server_filters_, 1,
+ {0x0002000000000000ULL, // CH (we drop everything after this on client)
+ 0x0002000000000003ULL, // CT (2)
+ 0x0002000000000004ULL} // FIN (2)
+ );
+}
+
+// Shrink the MTU down so that certs get split and drop the first piece.
+TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({2});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(2).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN
+ // Check that the first record is CT1 (which is identical to the same
+ // as the previous CT1).
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000002ULL} // CT2
+ );
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and drop the second piece.
+TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({3});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(3).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN
+ // Check that the first record is CT1
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000001ULL, // CT1
+ });
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// In this test, the Certificate message is sent four times, we drop all or part
+// of the first three attempts:
+// 1. Without fragmentation so that we can see how big it is - we drop that.
+// 2. In two pieces - we drop half AND the resulting ACK.
+// 3. In three pieces - we drop the middle piece.
+//
+// After that we let all the ACKs through and allow the handshake to complete
+// without further interference.
+//
+// This allows us to test that ranges of handshake messages are sent correctly
+// even when there are overlapping acknowledgments; that ACKs with duplicate or
+// overlapping message ranges are handled properly; and that extra
+// retransmissions are handled properly.
+class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 {
+ public:
+ TlsFragmentationAndRecoveryTest() : cert_len_(0) {}
+
+ protected:
+ void RunTest(size_t dropped_half) {
+ FirstFlightDropCertificate();
+
+ SecondAttemptDropHalf(dropped_half);
+ size_t dropped_half_size = server_record_len(dropped_half);
+ size_t second_flight_count = server_filters_.records_->count();
+
+ ThirdAttemptDropMiddle();
+ size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2);
+ size_t third_flight_count = server_filters_.records_->count();
+
+ AckAndCompleteRetransmission();
+ size_t final_server_flight_count = server_filters_.records_->count();
+ EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ SendDelayedAck();
+ // Same number of messages as the last flight.
+ EXPECT_EQ(final_server_flight_count, server_filters_.records_->count());
+ // Double check that the Certificate size is still correct.
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ CompleteHandshake(final_server_flight_count);
+
+ // This is the ACK for the first attempt to send a whole certificate.
+ std::vector<uint64_t> client_acks = {
+ 0, // SH
+ 0x0002000000000000ULL // EE
+ };
+ CheckAcks(client_filters_, 0, client_acks);
+ // And from the second attempt for the half was kept (we delayed this ACK).
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ ~dropped_half % 2);
+ CheckAcks(client_filters_, 1, client_acks);
+ // And the third attempt where the first and last thirds got through.
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count - 1);
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count + 1);
+ CheckAcks(client_filters_, 2, client_acks);
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ }
+
+ private:
+ void FirstFlightDropCertificate() {
+ StartConnect();
+ client_->Handshake();
+
+ // Note: 1 << N is the Nth packet, starting from zero.
+ server_filters_.drop_->Reset(1 << 2); // Drop Cert0.
+ server_->Handshake();
+ EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin
+ cert_len_ = server_filters_.records_->record(2).buffer.len();
+
+ HandshakeAndAck(client_);
+ EXPECT_EQ(2U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU so that the server has to split the certificate in two
+ // pieces. The server resends Certificate (in two), plus CV and Fin.
+ void SecondAttemptDropHalf(size_t dropped_half) {
+ ASSERT_LE(0U, dropped_half);
+ ASSERT_GT(2U, dropped_half);
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half]
+ SplitServerMtu(2);
+ server_->Handshake();
+ EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin
+
+ // Generate and capture the ACK from the client.
+ client_filters_.drop_->Reset({0});
+ HandshakeAndAck(client_);
+ EXPECT_EQ(3U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU again so that the server sends Certificate cut into three
+ // pieces. Drop the middle piece.
+ void ThirdAttemptDropMiddle() {
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3)
+ SplitServerMtu(3);
+ // Because we dropped the client ACK, the server retransmits on a timer.
+ ShiftDtlsTimers();
+ server_->Handshake();
+ EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin
+ }
+
+ void AckAndCompleteRetransmission() {
+ // Generate ACKs.
+ HandshakeAndAck(client_);
+ // The server should send the final sixth of the certificate: the client has
+ // acknowledged the first half and the last third. Also send CV and Fin.
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) {
+ // Work out if the final sixth is the right size. We get the records with
+ // overheads added, which obscures the length of the payload. We want to
+ // ensure that the server only sent the missing sixth of the Certificate.
+ //
+ // We captured |size_of_half + overhead| and |size_of_third + overhead| and
+ // want to calculate |size_of_third - size_of_third + overhead|. We can't
+ // calculate |overhead|, but it is is (currently) always a handshake message
+ // header, a content type, and an authentication tag:
+ static const size_t record_overhead = 12 + 1 + 16;
+ EXPECT_EQ(size_of_half - size_of_third + record_overhead,
+ server_filters_.records_->record(0).buffer.len());
+ }
+
+ void SendDelayedAck() {
+ // Send the ACK we held back. The reordered ACK doesn't add new
+ // information,
+ // but triggers an extra retransmission of the missing records again (even
+ // though the client has all that it needs).
+ client_->SendRecordDirect(client_filters_.records_->record(2));
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CompleteHandshake(size_t extra_retransmissions) {
+ // All this messing around shouldn't cause a failure...
+ Handshake();
+ // ...but it leaves a mess. Add an extra few calls to Handshake() for the
+ // client so that it absorbs the extra retransmissions.
+ for (size_t i = 0; i < extra_retransmissions; ++i) {
+ client_->Handshake();
+ }
+ CheckConnected();
+ }
+
+ // Split the server MTU so that the Certificate is split into |count| pieces.
+ // The calculation doesn't need to be perfect as long as the Certificate
+ // message is split into the right number of pieces.
+ void SplitServerMtu(size_t count) {
+ // Set the MTU based on the formula:
+ // bare_size = cert_len_ - actual_overhead
+ // MTU = ceil(bare_size / count) + pessimistic_overhead
+ //
+ // actual_overhead is the amount of actual overhead on the record we
+ // captured, which is (note that our length doesn't include the header):
+ static const size_t actual_overhead = 12 + // handshake message header
+ 1 + // content type
+ 16; // authentication tag
+ size_t bare_size = cert_len_ - actual_overhead;
+
+ // pessimistic_overhead is the amount of expansion that NSS assumes will be
+ // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT:
+ static const size_t pessimistic_overhead =
+ 12 + // handshake message header
+ 1 + // content type
+ 13 + // record header length
+ 64; // maximum record expansion: IV, MAC and block cipher expansion
+
+ size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead;
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "server: set MTU to " << mtu << std::endl;
+ }
+ EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu));
+ }
+
+ size_t server_record_len(size_t index) const {
+ return server_filters_.records_->record(index).buffer.len();
+ }
+
+ size_t cert_len_;
+};
+
+TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); }
+
+TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); }
+
+TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ server_filters_.drop_->Reset({1});
+ ZeroRttSendReceive(true, true);
+ HandshakeAndAck(client_);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckAcks(client_filters_, 0, {0});
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+class TlsReorderDatagram13 : public TlsDropDatagram13 {
+ public:
+ TlsReorderDatagram13() {}
+
+ // Send records from the records buffer in the given order.
+ void ReSend(TlsAgent::Role side, std::vector<size_t> indices) {
+ std::shared_ptr<TlsAgent> agent;
+ std::shared_ptr<TlsRecordRecorder> records;
+
+ if (side == TlsAgent::CLIENT) {
+ agent = client_;
+ records = client_filters_.records_;
+ } else {
+ agent = server_;
+ records = server_filters_.records_;
+ }
+
+ for (auto i : indices) {
+ agent->SendRecordDirect(records->record(i));
+ }
+ }
+};
+
+// Reorder the server records so that EE comes at the end
+// of the flight and will still produce an ACK.
+TEST_F(TlsDropDatagram13, ReorderServerEE) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // We dropped EE, now reinject.
+ server_->SendRecordDirect(server_filters_.record(1));
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000, // EE
+ });
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// The client sends an out of order non-handshake message
+// but with the handshake key.
+class TlsSendCipherSpecCapturer {
+ public:
+ TlsSendCipherSpecCapturer(std::shared_ptr<TlsAgent>& agent)
+ : send_cipher_specs_() {
+ SSLInt_SetCipherSpecChangeFunc(agent->ssl_fd(), CipherSpecChanged,
+ (void*)this);
+ }
+
+ std::shared_ptr<TlsCipherSpec> spec(size_t i) {
+ if (i >= send_cipher_specs_.size()) {
+ return nullptr;
+ }
+ return send_cipher_specs_[i];
+ }
+
+ private:
+ static void CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec) {
+ if (!sending) {
+ return;
+ }
+
+ auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
+
+ auto spec = std::make_shared<TlsCipherSpec>();
+ bool ret = spec->Init(SSLInt_CipherSpecToEpoch(newSpec),
+ SSLInt_CipherSpecToAlgorithm(newSpec),
+ SSLInt_CipherSpecToKey(newSpec),
+ SSLInt_CipherSpecToIv(newSpec));
+ EXPECT_EQ(true, ret);
+ self->send_cipher_specs_.push_back(spec);
+ }
+
+ std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
+};
+
+TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // After the client sends Finished, inject an app data record
+ // with the handshake key. This should produce an alert.
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+ ASSERT_TRUE(client_->SendEncryptedRecord(
+ spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
+ kTlsApplicationDataType, DataBuffer(buf, sizeof(buf))));
+
+ // Now have the server consume the bogus message.
+ server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal);
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError());
+}
+
+TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Inject a new bogus handshake record, which the server responds
+ // to by just ACKing the original one (we ignore the contents).
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+ ASSERT_TRUE(client_->SendEncryptedRecord(
+ spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
+ kTlsHandshakeType, DataBuffer(buf, sizeof(buf))));
+ server_->Handshake();
+ EXPECT_EQ(2UL, server_filters_.ack_->count());
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and then swap the first and
+// second pieces of the server certificate.
+TEST_F(TlsReorderDatagram13, ReorderServerCertificate) {
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ // Drop the entire handshake flight so we can reorder.
+ server_filters_.drop_->Reset(0xff);
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN
+ // Now re-send things in a different order.
+ ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5});
+ // Clear.
+ server_filters_.drop_->Disable();
+ server_filters_.records_->Clear();
+ // Wait for client to send ACK.
+ ShiftDtlsTimers();
+ CheckedHandshakeSendReceive();
+ EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, data, FIN
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2}));
+ server_->Handshake();
+ CheckConnected();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, FIN, Data
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0}));
+ server_->Handshake();
+ CheckConnected();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
uint64_t* limit = nullptr) {
uint64_t l;
@@ -111,7 +836,6 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindow) {
GetCipherAndLimit(version_, &cipher);
server_->EnableSingleCipher(cipher);
Connect();
-
EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0));
SendReceive();
}
@@ -129,5 +853,7 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) {
INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus,
TlsConnectTestBase::kTlsV12Plus);
+INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13,
+ TlsConnectTestBase::kTlsV11V12);
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
index 1e406b6c2..e0f8b1f55 100644
--- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -193,7 +193,9 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
public:
- TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {}
+ TlsKeyExchangeGroupCapture()
+ : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}),
+ group_(ssl_grp_none) {}
SSLNamedGroup group() const { return group_; }
@@ -201,10 +203,6 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
uint32_t value = 0;
EXPECT_TRUE(input.Read(0, 1, &value));
EXPECT_EQ(3U, value) << "curve type has to be 3";
@@ -518,16 +516,12 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
// Replace the point in the client key exchange message with an empty one
class ECCClientKEXFilter : public TlsHandshakeFilter {
public:
- ECCClientKEXFilter() {}
+ ECCClientKEXFilter() : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
- return KEEP;
- }
-
// Replace the client key exchange message with an empty point
output->Allocate(1);
output->Write(0, 0U, 1); // set point length 0
@@ -538,20 +532,16 @@ class ECCClientKEXFilter : public TlsHandshakeFilter {
// Replace the point in the server key exchange message with an empty one
class ECCServerKEXFilter : public TlsHandshakeFilter {
public:
- ECCServerKEXFilter() {}
+ ECCServerKEXFilter() : TlsHandshakeFilter({kTlsHandshakeServerKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Replace the server key exchange message with an empty point
output->Allocate(4);
output->Write(0, 3U, 1); // named curve
- uint32_t curve;
+ uint32_t curve = 0;
EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id
output->Write(1, curve, 2); // write curve id
output->Write(3, 0U, 1); // point length 0
diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
index be407b42e..c42883eb7 100644
--- a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -118,7 +118,6 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
TEST_P(TlsConnectTls13, EarlyExporter) {
SetupForZeroRtt();
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
index d15139419..4142ab07a 100644
--- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -61,60 +61,14 @@ class TlsExtensionDamager : public TlsExtensionFilter {
size_t index_;
};
-class TlsExtensionInjector : public TlsHandshakeFilter {
- public:
- TlsExtensionInjector(uint16_t ext, DataBuffer& data)
- : extension_(ext), data_(data) {}
-
- virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) {
- TlsParser parser(input);
- if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
- return KEEP;
- }
- size_t offset = parser.consumed();
-
- *output = input;
-
- // Increase the size of the extensions.
- uint16_t ext_len;
- memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
- ext_len = htons(ntohs(ext_len) + data_.len() + 4);
- memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
-
- // Insert the extension type and length.
- DataBuffer type_length;
- type_length.Allocate(4);
- type_length.Write(0, extension_, 2);
- type_length.Write(2, data_.len(), 2);
- output->Splice(type_length, offset + 2);
-
- // Insert the payload.
- if (data_.len() > 0) {
- output->Splice(data_, offset + 6);
- }
-
- return CHANGE;
- }
-
- private:
- const uint16_t extension_;
- const DataBuffer data_;
-};
-
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
- : handshake_type_(handshake_type), extension_(ext), data_(data) {}
+ : TlsHandshakeFilter({handshake_type}), extension_(ext), data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() != handshake_type_) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
return KEEP;
@@ -159,7 +113,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter {
return true;
}
- const uint8_t handshake_type_;
const uint16_t extension_;
const DataBuffer data_;
};
@@ -200,8 +153,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
client_->ConfigNamedGroups(client_groups);
server_->ConfigNamedGroups(server_groups);
EnsureTlsSetup();
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
@@ -1009,7 +961,6 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
std::make_shared<TlsExtensionAppender>(message, extension, empty);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
server_->SetTlsRecordFilter(filter);
- filter->EnableDecryption();
} else {
server_->SetPacketFilter(filter);
}
@@ -1032,17 +983,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
protected:
void ConnectAndFail(uint8_t message) override {
- if (message == kTlsHandshakeHelloRetryRequest) {
+ if (message != kTlsHandshakeServerHello) {
ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
return;
}
- client_->StartConnect();
- server_->StartConnect();
+ FailWithAlert(kTlsAlertUnsupportedExtension);
+ }
+
+ void FailWithAlert(uint8_t alert) {
+ StartConnect();
client_->Handshake(); // ClientHello
server_->Handshake(); // ServerHello
- client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ client_->ExpectSendAlert(alert);
client_->Handshake();
if (variant_ == ssl_variant_stream) {
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -1067,9 +1021,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
Run(kTlsHandshakeCertificate);
}
+// It's perfectly valid to set unknown extensions in CertificateRequest.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
server_->RequestClientAuth(false);
- Run(kTlsHandshakeCertificateRequest);
+ AddFilter(kTlsHandshakeCertificateRequest, 0xff);
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
@@ -1079,10 +1036,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
Run(kTlsHandshakeHelloRetryRequest);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) {
- Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn);
-}
-
TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) {
Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
}
@@ -1096,13 +1049,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) {
Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) {
- static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
- server_->ConfigNamedGroups(groups);
-
- Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn);
-}
-
// NewSessionTicket allows unknown extensions AND it isn't protected by the
// Finished. So adding an unknown extension doesn't cause an error.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
index 44cacce46..64b824786 100644
--- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -51,10 +51,16 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsHandshakeFilter::HandshakeHeader handshake_header;
DataBuffer handshake_body;
- if (!handshake_header.Parse(&parser, record_header, &handshake_body)) {
+ bool complete = false;
+ if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
+ &handshake_body, &complete)) {
ADD_FAILURE() << "couldn't parse handshake header";
return false;
}
+ if (!complete) {
+ ADD_FAILURE() << "don't want to deal with fragmented messages";
+ return false;
+ }
DataBuffer record_fragment;
// We can't fragment handshake records that are too small.
@@ -82,7 +88,7 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(&parser, &record)) {
+ if (!header.Parse(0, &parser, &record)) {
ADD_FAILURE() << "bad record header";
return false;
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
index 1587b66de..ab4c0eab7 100644
--- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -47,9 +47,9 @@ class TlsApplicationDataRecorder : public TlsRecordFilter {
// Ensure that ssl_Time() returns a constant value.
FUZZ_F(TlsFuzzTest, SSL_Time_Constant) {
- PRUint32 now = ssl_Time();
+ PRUint32 now = ssl_TimeSec();
PR_Sleep(PR_SecondsToInterval(2));
- EXPECT_EQ(ssl_Time(), now);
+ EXPECT_EQ(ssl_TimeSec(), now);
}
// Check that due to the deterministic PRNG we derive
@@ -215,58 +215,6 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) {
SendReceive();
}
-class TlsSessionTicketMacDamager : public TlsExtensionFilter {
- public:
- TlsSessionTicketMacDamager() {}
- virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
- const DataBuffer& input,
- DataBuffer* output) {
- if (extension_type != ssl_session_ticket_xtn &&
- extension_type != ssl_tls13_pre_shared_key_xtn) {
- return KEEP;
- }
-
- *output = input;
-
- // Handle everything before TLS 1.3.
- if (extension_type == ssl_session_ticket_xtn) {
- // Modify the last byte of the MAC.
- output->data()[output->len() - 1] ^= 0xff;
- }
-
- // Handle TLS 1.3.
- if (extension_type == ssl_tls13_pre_shared_key_xtn) {
- TlsParser parser(input);
-
- uint32_t ids_len;
- EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0);
-
- uint32_t ticket_len;
- EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0);
-
- // Modify the last byte of the MAC.
- output->data()[2 + 2 + ticket_len - 1] ^= 0xff;
- }
-
- return CHANGE;
- }
-};
-
-// Check that session ticket resumption works with a bad MAC.
-FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) {
- ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- Connect();
- SendReceive();
-
- Reset();
- ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- ExpectResumption(RESUME_TICKET);
-
- client_->SetPacketFilter(std::make_shared<TlsSessionTicketMacDamager>());
- Connect();
- SendReceive();
-}
-
// Check that session tickets are not encrypted.
FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
@@ -276,10 +224,13 @@ FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
server_->SetPacketFilter(i1);
Connect();
+ std::cerr << "ticket" << i1->buffer() << std::endl;
size_t offset = 4; /* lifetime */
if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
- offset += 1 + 1 + /* ke_modes */
- 1 + 1; /* auth_modes */
+ offset += 4; /* ticket_age_add */
+ uint32_t nonce_len = 0;
+ EXPECT_TRUE(i1->buffer().Read(offset, 1, &nonce_len));
+ offset += 1 + nonce_len;
}
offset += 2 + /* ticket length */
2; /* TLS_EX_SESS_TICKET_VERSION */
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
index cd10076b8..2fff9d7cb 100644
--- a/security/nss/gtests/ssl_gtest/ssl_gtest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
@@ -6,6 +6,7 @@
#include <cstdlib>
#include "test_io.h"
+#include "databuffer.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
@@ -28,6 +29,7 @@ int main(int argc, char** argv) {
++i;
} else if (!strcmp(argv[i], "-v")) {
g_ssl_gtest_verbose = true;
+ nss_test::DataBuffer::SetLogLimit(16384);
}
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
index 8cd7d1009..e2a8d830a 100644
--- a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
@@ -11,6 +11,7 @@
'target_name': 'ssl_gtest',
'type': 'executable',
'sources': [
+ 'bloomfilter_unittest.cc',
'libssl_internals.c',
'selfencrypt_unittest.cc',
'ssl_0rtt_unittest.cc',
@@ -18,6 +19,7 @@
'ssl_auth_unittest.cc',
'ssl_cert_ext_unittest.cc',
'ssl_ciphersuite_unittest.cc',
+ 'ssl_custext_unittest.cc',
'ssl_damage_unittest.cc',
'ssl_dhe_unittest.cc',
'ssl_drop_unittest.cc',
@@ -30,11 +32,16 @@
'ssl_gather_unittest.cc',
'ssl_gtest.cc',
'ssl_hrr_unittest.cc',
+ 'ssl_keylog_unittest.cc',
+ 'ssl_keyupdate_unittest.cc',
'ssl_loopback_unittest.cc',
+ 'ssl_misc_unittest.cc',
'ssl_record_unittest.cc',
'ssl_resumption_unittest.cc',
+ 'ssl_renegotiation_unittest.cc',
'ssl_skip_unittest.cc',
'ssl_staticrsa_unittest.cc',
+ 'ssl_tls13compat_unittest.cc',
'ssl_v2_client_hello_unittest.cc',
'ssl_version_unittest.cc',
'ssl_versionpolicy_unittest.cc',
diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
index 39055f641..93e19a720 100644
--- a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -187,6 +187,590 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
}
+// Here we modify the second ClientHello so that the client retries with the
+// same shares, even though the server wanted something else.
+TEST_P(TlsConnectTls13, RetryWithTwoShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+ client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackAccept) {
+ EnsureTlsSetup();
+
+ auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ accept_hello, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(capture);
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_run = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), accept_hello_twice, &cb_run));
+ Connect();
+ EXPECT_EQ(2U, cb_run);
+ EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackFail) {
+ EnsureTlsSetup();
+
+ auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_fail;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ fail_hello, &cb_run));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT);
+ EXPECT_TRUE(cb_run);
+}
+
+// Asking for retry twice isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ return ssl_hello_retry_request;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// Accepting the CH and modifying the token isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// As above, but with reject.
+TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_fail;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// This is a (pretend) buffer overflow.
+TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = appTokenMax + 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+SSLHelloRetryRequestAction RetryHello(PRBool firstHello,
+ const PRUint8* clientToken,
+ unsigned int clientTokenLen,
+ PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetry) {
+ EnsureTlsSetup();
+
+ auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ ssl_hs_hello_retry_request);
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr,
+ capture_key_share};
+ server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(chain));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected";
+ EXPECT_FALSE(capture_key_share->captured())
+ << "no key_share extension expected";
+
+ auto capture_cookie =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ client_->SetPacketFilter(capture_cookie);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie";
+}
+
+static size_t CountShares(const DataBuffer& key_share) {
+ size_t count = 0;
+ uint32_t len = 0;
+ size_t offset = 2;
+
+ EXPECT_TRUE(key_share.Read(0, 2, &len));
+ EXPECT_EQ(key_share.len() - 2, len);
+ while (offset < key_share.len()) {
+ offset += 2; // Skip KeyShareEntry.group
+ EXPECT_TRUE(key_share.Read(offset, 2, &len));
+ offset += 2 + len; // Skip KeyShareEntry.key_exchange
+ ++count;
+ }
+ return count;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(capture_server);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_FALSE(capture_server->captured())
+ << "no key_share extension expected from server";
+
+ auto capture_client_2nd =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ client_->SetPacketFilter(capture_client_2nd);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share";
+ EXPECT_EQ(2U, CountShares(capture_client_2nd->extension()))
+ << "client should still send two shares";
+}
+
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants a P-384 key share and the client didn't offer one.
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto capture_cookie =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
+ capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_cookie, capture_key_share}));
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "cookie expected";
+ EXPECT_TRUE(capture_key_share->captured()) << "key_share expected";
+}
+
+static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00};
+
+SSLHelloRetryRequestAction RetryHelloWithToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ if (firstHello) {
+ memcpy(appToken, kApplicationToken, sizeof(kApplicationToken));
+ *appTokenLen = sizeof(kApplicationToken);
+ return ssl_hello_retry_request;
+ }
+
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) {
+ EnsureTlsSetup();
+
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(capture_key_share);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_FALSE(capture_key_share->captured()) << "no key share expected";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) {
+ EnsureTlsSetup();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(capture_key_share);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_key_share->captured()) << "key share expected";
+}
+
+SSLHelloRetryRequestAction CheckTicketToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+// Stream because SSL_SendSessionTicket only supports that.
+TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken,
+ sizeof(kApplicationToken)));
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), CheckTicketToken, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server) {
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Start the handshake.
+ client->StartConnect();
+ server->StartConnect();
+ client->Handshake();
+ server->Handshake();
+ EXPECT_EQ(1U, cb_called);
+}
+
+TEST_P(TlsConnectTls13, RetryStateless) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ Handshake();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, RetryStatefulDropCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ client_->SetPacketFilter(
+ std::make_shared<TlsExtensionDropper>(ssl_tls13_cookie_xtn));
+
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION);
+}
+
+// Stream only because DTLS drops bad packets.
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
+ client_->SetPacketFilter(damage_ch);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ auto damage_ch = std::make_shared<TlsExtensionInjector>(0xfff3, DataBuffer());
+ client_->SetPacketFilter(damage_ch);
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+// Read the cipher suite from the HRR and disable it on the identified agent.
+static void DisableSuiteFromHrr(
+ std::shared_ptr<TlsAgent>& agent,
+ std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture_hrr) {
+ uint32_t tmp;
+ size_t offset = 2 + 32; // skip version + server_random
+ ASSERT_TRUE(
+ capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length
+ EXPECT_EQ(0U, tmp);
+ offset += 1 + tmp;
+ ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite
+ EXPECT_EQ(
+ SECSuccess,
+ SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE));
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ ssl_hs_hello_retry_request);
+ server_->SetPacketFilter(capture_hrr);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(client_, capture_hrr);
+
+ // The client thinks that the HelloRetryRequest is bad, even though its
+ // because it changed its mind about the cipher suite.
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ ssl_hs_hello_retry_request);
+ server_->SetPacketFilter(capture_hrr);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(server_, capture_hrr);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(groups);
+
+ // We're into undefined behavior on the client side, but - at the point this
+ // test was written - the client here doesn't amend its key shares because the
+ // server doesn't ask it to. The server notices that the key share (x25519)
+ // doesn't match the negotiated group (P-384) and objects.
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessBadCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+
+ // Now replace the self-encrypt MAC key with a garbage key.
+ static const uint8_t bad_hmac_key[32] = {0};
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key),
+ sizeof(bad_hmac_key)};
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey* hmac_key =
+ PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap,
+ CKA_SIGN, &key_item, nullptr);
+ ASSERT_NE(nullptr, hmac_key);
+ SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership.
+
+ MakeNewServer();
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Stream because the server doesn't consume the alert and terminate.
+TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
+ EnsureTlsSetup();
+ // Force a HelloRetryRequest.
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+ // Then switch out the default suite (TLS_AES_128_GCM_SHA256).
+ server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ TLS_CHACHA20_POLY1305_SHA256));
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
// This tests that the second attempt at sending a ClientHello (after receiving
// a HelloRetryRequest) is correctly retransmitted.
TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
@@ -233,6 +817,54 @@ TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) {
CheckKEXDetails(client_groups, client_groups);
}
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants an X25519 key share and the client didn't offer one.
+TEST_P(TlsKeyExchange13,
+ RetryCallbackRetryWithGroupMismatchAndAdditionalShares) {
+ EnsureKeyShareSetup();
+
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519};
+ server_->ConfigNamedGroups(server_groups);
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_hrr_, capture_server}));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_TRUE(capture_server->captured()) << "key_share extension expected";
+
+ uint32_t server_group = 0;
+ EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group));
+ EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group));
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares";
+
+ CheckKeys();
+ static const std::vector<SSLNamedGroup> client_shares(
+ client_groups.begin(), client_groups.begin() + 2);
+ CheckKEXDetails(client_groups, client_shares, server_groups[0]);
+}
+
TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -245,8 +877,7 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
static const std::vector<SSLNamedGroup> server_groups = {
ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(server_groups);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
@@ -276,15 +907,30 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient {
void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
uint32_t seq_num = 0) const {
DataBuffer hrr_data;
- hrr_data.Allocate(len + 4);
+ const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+
+ hrr_data.Allocate(len + 6);
size_t i = 0;
+ i = hrr_data.Write(i, 0x0303, 2);
+ i = hrr_data.Write(i, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random));
+ i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id
+ i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2);
+ i = hrr_data.Write(i, ssl_compression_null, 1);
+ // Add extensions. First a length, which includes the supported version.
+ i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2);
+ // Now the supported version.
+ i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2);
+ i = hrr_data.Write(i, 2, 2);
i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2);
- i = hrr_data.Write(i, static_cast<uint32_t>(len), 2);
if (len) {
hrr_data.Write(i, body, len);
}
DataBuffer hrr;
- MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(),
+ MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(),
hrr_data.len(), &hrr, seq_num);
MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(),
hrr.len(), hrr_record, seq_num);
@@ -334,28 +980,6 @@ TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) {
SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
}
-TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) {
- const uint8_t canned_cookie_hrr[] = {
- static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8),
- static_cast<uint8_t>(ssl_tls13_cookie_xtn),
- 0,
- 5, // length of cookie extension
- 0,
- 3, // cookie value length
- 0xc0,
- 0x0c,
- 0x13};
- DataBuffer hrr;
- MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
- agent_->SetPacketFilter(capture);
- ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
- const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len
- DataBuffer cookie(canned_cookie_hrr + cookie_pos,
- sizeof(canned_cookie_hrr) - cookie_pos);
- EXPECT_EQ(cookie, capture->extension());
-}
-
INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest,
::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV13));
diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
new file mode 100644
index 000000000..8ed342305
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
@@ -0,0 +1,118 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifdef NSS_ALLOW_SSLKEYLOGFILE
+
+#include <cstdlib>
+#include <fstream>
+#include <sstream>
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static const std::string keylog_file_path = "keylog.txt";
+static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path;
+
+class KeyLogFileTest : public TlsConnectGeneric {
+ public:
+ void SetUp() {
+ TlsConnectTestBase::SetUp();
+ // Remove previous results (if any).
+ (void)remove(keylog_file_path.c_str());
+ PR_SetEnv(keylog_env.c_str());
+ }
+
+ void CheckKeyLog() {
+ std::ifstream f(keylog_file_path);
+ std::map<std::string, size_t> labels;
+ std::set<std::string> client_randoms;
+ for (std::string line; std::getline(f, line);) {
+ if (line[0] == '#') {
+ continue;
+ }
+
+ std::istringstream iss(line);
+ std::string label, client_random, secret;
+ iss >> label >> client_random >> secret;
+
+ ASSERT_EQ(64U, client_random.size());
+ client_randoms.insert(client_random);
+ labels[label]++;
+ }
+
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(1U, client_randoms.size());
+ } else {
+ /* two handshakes for 0-RTT */
+ ASSERT_EQ(2U, client_randoms.size());
+ }
+
+ // Every entry occurs twice (one log from server, one from client).
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(2U, labels["CLIENT_RANDOM"]);
+ } else {
+ ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]);
+ ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["EXPORTER_SECRET"]);
+ }
+ }
+
+ void ConnectAndCheck() {
+ // This is a child process, ensure that error messages immediately
+ // propagate or else it will not be visible.
+ ::testing::GTEST_FLAG(throw_on_failure) = true;
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ } else {
+ Connect();
+ }
+ CheckKeyLog();
+ _exit(0);
+ }
+};
+
+// Tests are run in a separate process to ensure that NSS is not initialized yet
+// and can process the SSLKEYLOGFILE environment variable.
+
+TEST_P(KeyLogFileTest, KeyLogFile) {
+ testing::GTEST_FLAG(death_test_style) = "threadsafe";
+
+ ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), "");
+}
+
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileDTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileTLS13, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
+
+#endif // NSS_ALLOW_SSLKEYLOGFILE
diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
new file mode 100644
index 000000000..d03775c25
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
@@ -0,0 +1,178 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// All stream only tests; DTLS isn't supported yet.
+
+TEST_F(TlsConnectTest, KeyUpdateClient) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 3);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ // SendReceive() only gives each peer one chance to read. This isn't enough
+ // when the read on one side generates another handshake message. A second
+ // read gives each peer an extra chance to consume the KeyUpdate.
+ SendReceive(50);
+ SendReceive(60); // Cumulative count.
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServer) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(3, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // The server should have updated twice, but the client should have declined
+ // to respond to the second request from the server, since it doesn't send
+ // anything in between those two requests.
+ CheckEpochs(4, 5);
+}
+
+// Check that a local update can be immediately followed by a remotely triggered
+// update even if there is no use of the keys.
+TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // This should trigger an update on the client.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ // The client should update for the first request.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ // ...but not the second.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // Both should have updated twice.
+ CheckEpochs(5, 5);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateMultiple) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 6);
+}
+
+// Both ask the other for an update, and both should react.
+TEST_F(TlsConnectTest, KeyUpdateBothRequest) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 5);
+}
+
+// If the sequence number exceeds the number of writes before an automatic
+// update (currently 3/4 of the max records for the cipher suite), then the
+// stack should send an update automatically (but not request one).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Set this to one below the write threshold.
+ uint64_t threshold = (0x5aULL << 28) * 3 / 4;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should be OK.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // This should cause the client to update.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ SendReceive(100);
+ CheckEpochs(4, 3);
+}
+
+// If the sequence number exceeds a certain number of reads (currently 7/8 of
+// the max records for the cipher suite), then the stack should send AND request
+// an update automatically. However, the sender (client) will be above its
+// automatic update threshold, so the KeyUpdate - that it sends with the old
+// cipher spec - will exceed the receiver (server) automatic update threshold.
+// The receiver gets a packet with a sequence number over its automatic read
+// update threshold. Even though the sender has updated, the code that checks
+// the sequence numbers at the receiver doesn't know this and it will request an
+// update. This causes two updates: one from the sender (without requesting a
+// response) and one from the receiver (which does request a response).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Move to right at the read threshold. Unlike the write test, we can't send
+ // packets because that would cause the client to update, which would spoil
+ // the test.
+ uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should cause the client to update, but not early enough to prevent the
+ // server from updating also.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // Need two SendReceive() calls to ensure that the update that the server
+ // requested is properly generated and consumed.
+ SendReceive(70);
+ SendReceive(80);
+ CheckEpochs(5, 4);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
index 77703dd8e..4bc6e60ab 100644
--- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -6,6 +6,7 @@
#include <functional>
#include <memory>
+#include <vector>
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
@@ -84,13 +85,13 @@ class TlsAlertRecorder : public TlsRecordFilter {
};
class HelloTruncator : public TlsHandshakeFilter {
+ public:
+ HelloTruncator()
+ : TlsHandshakeFilter(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeClientHello &&
- header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
output->Assign(input.data(), input.len() - 1);
return CHANGE;
}
@@ -102,9 +103,9 @@ TEST_P(TlsConnectGeneric, CaptureAlertServer) {
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
server_->SetPacketFilter(alert_recorder);
- ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
}
TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
@@ -123,8 +124,7 @@ TEST_P(TlsConnectTls13, CaptureAlertClient) {
auto alert_recorder = std::make_shared<TlsAlertRecorder>();
client_->SetPacketFilter(alert_recorder);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake();
client_->ExpectSendAlert(kTlsAlertDecodeError);
@@ -166,26 +166,107 @@ TEST_P(TlsConnectDatagram, ConnectSrtp) {
SendReceive();
}
-// 1.3 is disabled in the next few tests because we don't
-// presently support resumption in 1.3.
-TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) {
+TEST_P(TlsConnectGeneric, ConnectSendReceive) {
Connect();
- server_->PrepareForRenegotiate();
- client_->StartRenegotiate();
- Handshake();
- CheckConnected();
+ SendReceive();
+}
+
+class SaveTlsRecord : public TlsRecordFilter {
+ public:
+ SaveTlsRecord(size_t index) : index_(index), count_(0), contents_() {}
+
+ const DataBuffer& contents() const { return contents_; }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ contents_ = data;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+ DataBuffer contents_;
+};
+
+// Check that decrypting filters work and can read any record.
+// This test (currently) only works in TLS 1.3 where we can decrypt.
+TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
+ auto saved = std::make_shared<SaveTlsRecord>(3);
+ client_->SetTlsRecordFilter(saved);
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xdc};
+ DataBuffer buf(data, sizeof(data));
+ client_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
}
-TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) {
+TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
+ EnsureTlsSetup();
+ // Disable tickets so that we are sure to not get NewSessionTicket.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+ // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
+ auto saved = std::make_shared<SaveTlsRecord>(3);
+ server_->SetTlsRecordFilter(saved);
Connect();
- client_->PrepareForRenegotiate();
- server_->StartRenegotiate();
- Handshake();
- CheckConnected();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xd5};
+ DataBuffer buf(data, sizeof(data));
+ server_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
}
-TEST_P(TlsConnectGeneric, ConnectSendReceive) {
+class DropTlsRecord : public TlsRecordFilter {
+ public:
+ DropTlsRecord(size_t index) : index_(index), count_(0) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ return DROP;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+};
+
+// Test that decrypting filters work correctly and are able to drop records.
+TEST_F(TlsConnectStreamTls13, DropRecordServer) {
+ EnsureTlsSetup();
+ // Disable session tickets so that the server doesn't send an extra record.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ // 0 = ServerHello, 1 = other handshake, 2 = first write
+ server_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
+ Connect();
+ server_->SendData(23, 23); // This should be dropped, so it won't be counted.
+ server_->ResetSentBytes();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13, DropRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = first write
+ client_->SetTlsRecordFilter(std::make_shared<DropTlsRecord>(2));
Connect();
+ client_->SendData(26, 26); // This should be dropped, so it won't be counted.
+ client_->ResetSentBytes();
SendReceive();
}
@@ -224,29 +305,70 @@ TEST_P(TlsConnectStream, ShortRead) {
ASSERT_EQ(50U, client_->received_bytes());
}
-TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) {
+// We enable compression via the API but it's disabled internally,
+// so we should never get it.
+TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) {
EnsureTlsSetup();
- client_->EnableCompression();
- server_->EnableCompression();
+ client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
Connect();
- EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
- variant_ != ssl_variant_datagram,
- client_->is_compressed());
+ EXPECT_FALSE(client_->is_compressed());
SendReceive();
}
-TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
+class TlsHolddownTest : public TlsConnectDatagram {
+ protected:
+ // This causes all timers to run to completion. It advances the clock and
+ // handshakes on both peers until both peers have no more timers pending,
+ // which should happen at the end of a handshake. This is necessary to ensure
+ // that the relatively long holddown timer expires, but that any other timers
+ // also expire and run correctly.
+ void RunAllTimersDown() {
+ while (true) {
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ break; // Neither peer has an outstanding timer.
+ }
+ }
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Shifting timers" << std::endl;
+ }
+ ShiftDtlsTimers();
+ Handshake();
+ }
+ }
+};
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) {
Connect();
- std::cerr << "Expiring holddown timer\n";
- SSLInt_ForceTimerExpiry(client_->ssl_fd());
- SSLInt_ForceTimerExpiry(server_->ssl_fd());
+ std::cerr << "Expiring holddown timer" << std::endl;
+ RunAllTimersDown();
SendReceive();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
// One for send, one for receive.
- EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd()));
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
}
}
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ RunAllTimersDown();
+ SendReceive();
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+}
+
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
TlsPreCCSHeaderInjector() {}
@@ -274,8 +396,7 @@ TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
@@ -306,21 +427,65 @@ TEST_P(TlsConnectTls13, AlertWrongLevel) {
TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
EnsureTlsSetup();
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake(); // Send first flight.
- client_->adapter()->CloseWrites();
+ client_->adapter()->SetWriteError(PR_IO_ERROR);
client_->Handshake(); // This will get an error, but shouldn't crash.
client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
}
-TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) {
- client_->SetShortHeadersEnabled();
- server_->SetShortHeadersEnabled();
- client_->ExpectShortHeaders();
- server_->ExpectShortHeaders();
+TEST_P(TlsConnectDatagram, BlockedWrite) {
+ Connect();
+
+ // Mark the socket as blocked.
+ client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR);
+ static const uint8_t data[] = {1, 2, 3};
+ int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Remove the write error and though the previous write failed, future reads
+ // and writes should just work as if it never happened.
+ client_->adapter()->SetWriteError(0);
+ SendReceive();
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) {
+ // MAC-then-Encrypt expansion formula:
+ return ((in + hmac + (block - 1)) / block) * block;
+}
+
+TEST_F(TlsConnectTest, OneNRecordSplitting) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
+ EnsureTlsSetup();
+ ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
+ auto records = std::make_shared<TlsRecordRecorder>();
+ server_->SetPacketFilter(records);
+ // This should be split into 1, 16384 and 20.
+ DataBuffer big_buffer;
+ big_buffer.Allocate(1 + 16384 + 20);
+ server_->SendBuffer(big_buffer);
+ ASSERT_EQ(3U, records->count());
+ EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len());
}
INSTANTIATE_TEST_CASE_P(
@@ -336,6 +501,8 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream,
TlsConnectTestBase::kTlsVAll);
INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram,
TlsConnectTestBase::kTlsV11Plus);
+INSTANTIATE_TEST_CASE_P(DatagramHolddown, TlsHolddownTest,
+ TlsConnectTestBase::kTlsV11Plus);
INSTANTIATE_TEST_CASE_P(
Pre12Stream, TlsConnectPre12,
diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
new file mode 100644
index 000000000..2b1b92dcd
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
@@ -0,0 +1,20 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslexp.h"
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+class MiscTest : public ::testing::Test {};
+
+TEST_F(MiscTest, NonExistentExperimentalAPI) {
+ EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah"));
+ EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
index ef81b222c..d1d496f49 100644
--- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -10,6 +10,8 @@
#include "databuffer.h"
#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
namespace nss_test {
@@ -51,8 +53,8 @@ class TlsPaddingTest
<< " total length=" << plaintext_.len() << std::endl;
std::cerr << "Plaintext: " << plaintext_ << std::endl;
sslBuffer s;
- s.buf = const_cast<unsigned char *>(
- static_cast<const unsigned char *>(plaintext_.data()));
+ s.buf = const_cast<unsigned char*>(
+ static_cast<const unsigned char*>(plaintext_.data()));
s.len = plaintext_.len();
SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
if (expect_success) {
@@ -99,6 +101,73 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
}
}
+class RecordReplacer : public TlsRecordFilter {
+ public:
+ RecordReplacer(size_t size)
+ : TlsRecordFilter(), enabled_(false), size_(size) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (!enabled_) {
+ return KEEP;
+ }
+
+ EXPECT_EQ(kTlsApplicationDataType, header.content_type());
+ changed->Allocate(size_);
+
+ for (size_t i = 0; i < size_; ++i) {
+ changed->data()[i] = i & 0xff;
+ }
+
+ enabled_ = false;
+ return CHANGE;
+ }
+
+ void Enable() { enabled_ = true; }
+
+ private:
+ bool enabled_;
+ size_t size_;
+};
+
+TEST_F(TlsConnectStreamTls13, LargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = std::make_shared<RecordReplacer>(record_limit);
+ client_->SetTlsRecordFilter(replacer);
+ Connect();
+
+ replacer->Enable();
+ client_->SendData(10);
+ WAIT_(server_->received_bytes() == record_limit, 2000);
+ ASSERT_EQ(record_limit, server_->received_bytes());
+}
+
+TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = std::make_shared<RecordReplacer>(record_limit + 1);
+ client_->SetTlsRecordFilter(replacer);
+ Connect();
+
+ replacer->Enable();
+ ExpectAlert(server_, kTlsAlertRecordOverflow);
+ client_->SendData(10); // This is expanded.
+
+ uint8_t buf[record_limit + 2];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError());
+
+ // Read the server alert.
+ rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
+}
+
const static size_t kContentSizesArr[] = {
1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
new file mode 100644
index 000000000..a902a5f7f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
@@ -0,0 +1,212 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// 1.3 is disabled in the next few tests because we don't
+// presently support resumption in 1.3.
+TEST_P(TlsConnectStreamPre13, RenegotiateClient) {
+ Connect();
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectStreamPre13, RenegotiateServer) {
+ Connect();
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen.
+TEST_F(TlsConnectTest, RenegotiationConfigTls13) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED);
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ Connect();
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ client_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ server_->StartRenegotiate();
+
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ server_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ client_->StartRenegotiate();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(SECFailure, rv);
+ return;
+ }
+ ASSERT_EQ(SECSuccess, rv);
+
+ // Now, before handshaking, tweak the server configuration.
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+
+ // The server should catch the own error.
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE));
+}
+
+TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the client set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ server_->ResetPreliminaryInfo();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // The client will refuse to renegotiate down.
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE));
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
index ce0e3ca8d..a413caf2c 100644
--- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -9,6 +9,7 @@
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
extern "C" {
@@ -246,8 +247,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
: ssl_session_ticket_xtn;
auto capture = std::make_shared<TlsExtensionCapture>(xtn);
client_->SetPacketFilter(capture);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
EXPECT_LT(0U, capture->extension().len());
@@ -355,10 +355,7 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
// This test parses the ServerKeyExchange, which isn't in 1.3
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
- server_->EnsureTlsSetup();
- SECStatus rv =
- SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+ server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i1);
@@ -369,9 +366,7 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
// Restart
Reset();
- server_->EnsureTlsSetup();
- rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+ server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i2);
@@ -401,7 +396,8 @@ TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) {
client_->ConfigNamedGroups(kFFDHEGroups);
server_->ConfigNamedGroups(kFFDHEGroups);
Connect();
- CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
}
// We need to enable different cipher suites at different times in the following
@@ -461,36 +457,6 @@ TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) {
CheckKeys();
}
-class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
- public:
- SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}
-
- protected:
- PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
-
- *output = input;
- uint32_t temp = 0;
- EXPECT_TRUE(input.Read(0, 2, &temp));
- // Cipher suite is after version(2) and random(32).
- size_t pos = 34;
- if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
- // In old versions, we have to skip a session_id too.
- EXPECT_TRUE(input.Read(pos, 1, &temp));
- pos += 1 + temp;
- }
- output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
- return CHANGE;
- }
-
- private:
- uint16_t cipher_suite_;
-};
-
// Test that the client doesn't tolerate the server picking a different cipher
// suite for resumption.
TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
@@ -524,16 +490,13 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
class SelectedVersionReplacer : public TlsHandshakeFilter {
public:
- SelectedVersionReplacer(uint16_t version) : version_(version) {}
+ SelectedVersionReplacer(uint16_t version)
+ : TlsHandshakeFilter({kTlsHandshakeServerHello}), version_(version) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
-
*output = input;
output->Write(0, static_cast<uint32_t>(version_), 2);
return CHANGE;
@@ -609,7 +572,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_none);
+ ssl_sig_rsa_pss_sha256);
// The filter will go away when we reset, so save the captured extension.
DataBuffer initialTicket(c1->extension());
ASSERT_LT(0U, initialTicket.len());
@@ -627,7 +590,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_none);
+ ssl_sig_rsa_pss_sha256);
ASSERT_LT(0U, c2->extension().len());
ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
@@ -652,18 +615,158 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
// Clear the session ticket keys to invalidate the old ticket.
SSLInt_ClearSelfEncryptKey();
- SSLInt_SendNewSessionTicket(server_->ssl_fd());
+ SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0);
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+// Check that the value captured in a NewSessionTicket message matches the value
+// captured from a pre_shared_key extension.
+void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) {
+ uint32_t len;
+
+ size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add.
+ ASSERT_TRUE(nst.Read(offset, 1, &len));
+ offset += 1 + len; // Skip ticket_nonce.
+
+ ASSERT_TRUE(nst.Read(offset, 2, &len));
+ offset += 2; // Skip the ticket length.
+ ASSERT_LE(offset + len, nst.len());
+ DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len));
+
+ offset = 2; // Skip the identities length.
+ ASSERT_TRUE(psk.Read(offset, 2, &len));
+ offset += 2; // Skip the identity length.
+ ASSERT_LE(offset + len, psk.len());
+ DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len));
+
+ EXPECT_EQ(nst_ticket, psk_ticket);
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ ssl_hs_new_session_ticket);
+ server_->SetTlsRecordFilter(nst_capture);
+ Connect();
+
+ // Clear the session ticket keys to invalidate the old ticket.
+ SSLInt_ClearSelfEncryptKey();
+ nst_capture->Reset();
+ uint8_t token[] = {0x20, 0x20, 0xff, 0x00};
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token)));
SendReceive(); // Need to read so that we absorb the session tickets.
CheckKeys();
+ EXPECT_LT(0U, nst_capture->buffer().len());
// Resume the connection.
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
+ client_->SetPacketFilter(psk_capture);
Connect();
SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent
+// by invoking SSL_SendSessionTicket directly (and that the ticket is usable).
+TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ auto nst_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ ssl_hs_new_session_ticket);
+ server_->SetTlsRecordFilter(nst_capture);
+ Connect();
+
+ EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet";
+
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now";
+
+ SendReceive(); // Ensure that the client reads the ticket.
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
+ client_->SetPacketFilter(psk_capture);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Test calling SSL_SendSessionTicket in inappropriate conditions.
+TEST_F(TlsConnectTest, SendSessionTicketInappropriate) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0))
+ << "clients can't send tickets";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ StartConnect();
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no ticket before the handshake has started";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ Handshake();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no special tickets in TLS 1.2";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // It should be safe to set length with a NULL token because the length should
+ // be checked before reading token.
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff))
+ << "this is clearly too big";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ static const uint8_t big_token[0xffff] = {1};
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token,
+ sizeof(big_token)))
+ << "this is too big, but that's not immediately obvious";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no extra tickets in DTLS until we have Ack support";
+ EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError());
}
TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) {
@@ -719,13 +822,84 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
filters.push_back(
std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2));
+
+ // Drop a bunch of extensions so that we get past the SH processing. The
+ // version extension says TLS 1.3, which is counter to our goal, the others
+ // are not permitted in TLS 1.2 handshakes.
+ filters.push_back(
+ std::make_shared<TlsExtensionDropper>(ssl_tls13_supported_versions_xtn));
+ filters.push_back(
+ std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ filters.push_back(
+ std::make_shared<TlsExtensionDropper>(ssl_tls13_pre_shared_key_xtn));
server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters));
- client_->ExpectSendAlert(kTlsAlertDecodeError);
+ // The client here generates an unexpected_message alert when it receives an
+ // encrypted handshake message from the server (EncryptedExtension). The
+ // client expects to receive an unencrypted TLS 1.2 Certificate message.
+ // The server can't decrypt the alert.
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read
ConnectExpectFail();
- client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
}
+TEST_P(TlsConnectGeneric, ReConnectTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, ReConnectCache) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ // Resume
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsConnectGeneric, ReConnectAgainTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+ // Resume connection again
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET, 2);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+}
+
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
index a130ef77f..335bfecfa 100644
--- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -43,7 +43,14 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
size_t start = parser.consumed();
TlsHandshakeFilter::HandshakeHeader header;
DataBuffer ignored;
- if (!header.Parse(&parser, record_header, &ignored)) {
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, DataBuffer(), &ignored,
+ &complete)) {
+ ADD_FAILURE() << "Error parsing handshake header";
+ return KEEP;
+ }
+ if (!complete) {
+ ADD_FAILURE() << "Don't want to deal with fragmented input";
return KEEP;
}
@@ -101,26 +108,15 @@ class Tls13SkipTest : public TlsConnectTestBase,
void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
EnsureTlsSetup();
server_->SetTlsRecordFilter(filter);
- filter->EnableDecryption();
- client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
- if (variant_ == ssl_variant_stream) {
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
- ConnectExpectFail();
- } else {
- ConnectExpectFailOneSide(TlsAgent::CLIENT);
- }
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
client_->CheckErrorCode(error);
- if (variant_ == ssl_variant_stream) {
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- } else {
- ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
- }
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
}
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
EnsureTlsSetup();
client_->SetTlsRecordFilter(filter);
- filter->EnableDecryption();
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
@@ -171,11 +167,10 @@ TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit{
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ kTlsHandshakeServerKeyExchange)});
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
index 8db1f30e1..e7fe44d92 100644
--- a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -71,7 +71,7 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- server_->DisableRollbackDetection();
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
@@ -102,7 +102,7 @@ TEST_P(TlsConnectStreamPre13,
EnableExtendedMasterSecret();
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- server_->DisableRollbackDetection();
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
new file mode 100644
index 000000000..75cee52fc
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -0,0 +1,337 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+#include <vector>
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class Tls13CompatTest : public TlsConnectStreamTls13 {
+ protected:
+ void EnableCompatMode() {
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ }
+
+ void InstallFilters() {
+ EnsureTlsSetup();
+ client_recorders_.Install(client_);
+ server_recorders_.Install(server_);
+ }
+
+ void CheckRecordVersions() {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0,
+ client_recorders_.records_->record(0).header.version());
+ CheckRecordsAreTls12("client", client_recorders_.records_, 1);
+ CheckRecordsAreTls12("server", server_recorders_.records_, 0);
+ }
+
+ void CheckHelloVersions() {
+ uint32_t ver;
+ ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ }
+
+ void CheckForCCS(bool expected_client, bool expected_server) {
+ client_recorders_.CheckForCCS(expected_client);
+ server_recorders_.CheckForCCS(expected_server);
+ }
+
+ void CheckForRegularHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(0U, client_recorders_.session_id_length());
+ EXPECT_EQ(0U, server_recorders_.session_id_length());
+ CheckForCCS(false, false);
+ }
+
+ void CheckForCompatHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(32U, client_recorders_.session_id_length());
+ EXPECT_EQ(32U, server_recorders_.session_id_length());
+ CheckForCCS(true, true);
+ }
+
+ private:
+ struct Recorders {
+ Recorders()
+ : records_(new TlsRecordRecorder()),
+ hello_(new TlsInspectorRecordHandshakeMessage(std::set<uint8_t>(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello}))) {}
+
+ uint8_t session_id_length() const {
+ // session_id is always after version (2) and random (32).
+ uint32_t len = 0;
+ EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len));
+ return static_cast<uint8_t>(len);
+ }
+
+ void CheckForCCS(bool expected) const {
+ EXPECT_LT(0U, records_->count());
+ for (size_t i = 0; i < records_->count(); ++i) {
+ // Only the second record can be a CCS.
+ bool expected_match = expected && (i == 1);
+ EXPECT_EQ(expected_match,
+ kTlsChangeCipherSpecType ==
+ records_->record(i).header.content_type());
+ }
+ }
+
+ void Install(std::shared_ptr<TlsAgent>& agent) {
+ agent->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, hello_})));
+ }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsInspectorRecordHandshakeMessage> hello_;
+ };
+
+ void CheckRecordsAreTls12(const std::string& agent,
+ const std::shared_ptr<TlsRecordRecorder>& records,
+ size_t start) {
+ EXPECT_LE(start, records->count());
+ for (size_t i = start; i < records->count(); ++i) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2,
+ records->record(i).header.version())
+ << agent << ": record " << i << " has wrong version";
+ }
+ }
+
+ Recorders client_recorders_;
+ Recorders server_recorders_;
+};
+
+TEST_F(Tls13CompatTest, Disabled) {
+ InstallFilters();
+ Connect();
+ CheckForRegularHandshake();
+}
+
+TEST_F(Tls13CompatTest, Enabled) {
+ EnableCompatMode();
+ InstallFilters();
+ Connect();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ CheckForCCS(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest. The server sends CCS immediately.
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckForCCS(false, true);
+
+ Handshake();
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledStatelessHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckForCCS(false, true);
+
+ // A new server should just work, but not send another CCS.
+ MakeNewServer();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ Handshake();
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ // With 0-RTT, the client sends CCS immediately. With HRR, the server sends
+ // CCS immediately too.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ CheckForCCS(true, true);
+
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+static const uint8_t kCannedCcs[] = {
+ kTlsChangeCipherSpecType,
+ SSL_LIBRARY_VERSION_TLS_1_2 >> 8,
+ SSL_LIBRARY_VERSION_TLS_1_2 & 0xff,
+ 0,
+ 1, // length
+ 1 // change_cipher_spec_choice
+};
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// If we negotiate 1.2, we abort.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ auto client_records = std::make_shared<TlsRecordRecorder>();
+ client_->SetPacketFilter(client_records);
+ auto server_records = std::make_shared<TlsRecordRecorder>();
+ server_->SetPacketFilter(server_records);
+ Connect();
+
+ ASSERT_EQ(2U, client_records->count()); // CH, Fin
+ EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type());
+ EXPECT_EQ(kTlsApplicationDataType,
+ client_records->record(1).header.content_type());
+
+ ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack
+ EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kTlsApplicationDataType,
+ server_records->record(i).header.content_type());
+ }
+}
+
+class AddSessionIdFilter : public TlsHandshakeFilter {
+ public:
+ AddSessionIdFilter() : TlsHandshakeFilter({ssl_hs_client_hello}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+ uint8_t session_id[33] = {32}; // 32 for length, the rest zero.
+ *output = input;
+ output->Splice(session_id, sizeof(session_id), 34, 1);
+ return CHANGE;
+ }
+};
+
+// Adding a session ID to a DTLS ClientHello should not trigger compatibility
+// mode. It should be ignored instead.
+TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
+ EnsureTlsSetup();
+ auto client_records = std::make_shared<TlsRecordRecorder>();
+ client_->SetPacketFilter(
+ std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit(
+ {client_records, std::make_shared<AddSessionIdFilter>()})));
+ auto server_hello = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerHello);
+ auto server_records = std::make_shared<TlsRecordRecorder>();
+ server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({server_records, server_hello})));
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // The client will consume the ServerHello, but discard everything else
+ // because it doesn't decrypt. And don't wait around for the client to ACK.
+ client_->Handshake();
+
+ ASSERT_EQ(1U, client_records->count());
+ EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type());
+
+ ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin
+ EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kTlsApplicationDataType,
+ server_records->record(i).header.content_type());
+ }
+
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+}
+
+} // nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
index 110e3e0b6..2f8ddd6fe 100644
--- a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -153,13 +153,6 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
client_->SetPacketFilter(filter_);
}
- void RequireSafeRenegotiation() {
- server_->EnsureTlsSetup();
- SECStatus rv =
- SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
- EXPECT_EQ(rv, SECSuccess);
- }
-
void SetExpectedVersion(uint16_t version) {
TlsConnectTestBase::SetExpectedVersion(version);
filter_->SetVersion(version);
@@ -319,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
// Connection must fail if we require safe renegotiation but the client doesn't
// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
- RequireSafeRenegotiation();
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code());
@@ -328,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
// Connection must succeed when requiring safe renegotiation and the client
// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) {
- RequireSafeRenegotiation();
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
TLS_EMPTY_RENEGOTIATION_INFO_SCSV};
SetAvailableCipherSuites(cipher_suites);
diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
index 379a67e35..9db293b07 100644
--- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -128,12 +128,12 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) {
#endif
TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) {
- client_->SetFallbackSCSVEnabled(true);
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
Connect();
}
TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) {
- client_->SetFallbackSCSVEnabled(true);
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
server_->SetVersionRange(version_, version_ + 1);
ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT);
@@ -155,107 +155,10 @@ TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) {
EXPECT_EQ(SECFailure, rv);
}
-TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
- if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
- return;
- }
- // Set the client so it will accept any version from 1.0
- // to |version_|.
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
- SSL_LIBRARY_VERSION_TLS_1_0);
- // Reset version so that the checks succeed.
- uint16_t test_version = version_;
- version_ = SSL_LIBRARY_VERSION_TLS_1_0;
- Connect();
-
- // Now renegotiate, with the server being set to do
- // |version_|.
- client_->PrepareForRenegotiate();
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
- // Reset version and cipher suite so that the preinfo callback
- // doesn't fail.
- server_->ResetPreliminaryInfo();
- server_->StartRenegotiate();
-
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- ExpectAlert(server_, kTlsAlertUnexpectedMessage);
- } else {
- ExpectAlert(client_, kTlsAlertIllegalParameter);
- }
-
- Handshake();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- // In TLS 1.3, the server detects this problem.
- client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
- server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
- } else {
- client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
- }
-}
-
-TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
- if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
- return;
- }
- // Set the client so it will accept any version from 1.0
- // to |version_|.
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
- SSL_LIBRARY_VERSION_TLS_1_0);
- // Reset version so that the checks succeed.
- uint16_t test_version = version_;
- version_ = SSL_LIBRARY_VERSION_TLS_1_0;
- Connect();
-
- // Now renegotiate, with the server being set to do
- // |version_|.
- server_->PrepareForRenegotiate();
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
- // Reset version and cipher suite so that the preinfo callback
- // doesn't fail.
- server_->ResetPreliminaryInfo();
- client_->StartRenegotiate();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- ExpectAlert(server_, kTlsAlertUnexpectedMessage);
- } else {
- ExpectAlert(client_, kTlsAlertIllegalParameter);
- }
- Handshake();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- // In TLS 1.3, the server detects this problem.
- client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
- server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
- } else {
- client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
- }
-}
-
-TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
- EnsureTlsSetup();
- ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- Connect();
- SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
- EXPECT_EQ(SECFailure, rv);
- EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
-}
-
-TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
- EnsureTlsSetup();
- ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- Connect();
- SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
- EXPECT_EQ(SECFailure, rv);
- EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
-}
-
TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
EnsureTlsSetup();
client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello.
static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
kTlsAlertUnrecognizedName};
@@ -314,20 +217,20 @@ TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
client_->SetPacketFilter(
std::make_shared<TlsInspectorClientHelloVersionSetter>(
SSL_LIBRARY_VERSION_TLS_1_3 + 1));
- auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_supported_versions_xtn);
server_->SetPacketFilter(capture);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- const DataBuffer& server_hello = capture->buffer();
- ASSERT_GT(server_hello.len(), 2U);
- uint32_t ver;
- ASSERT_TRUE(server_hello.Read(0, 2, &ver));
+
+ ASSERT_EQ(2U, capture->extension().len());
+ uint32_t version = 0;
+ ASSERT_TRUE(capture->extension().Read(0, 2, &version));
// This way we don't need to change with new draft version.
- ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver);
+ ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version);
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc
index b9f0c672e..adcdbfbaf 100644
--- a/security/nss/gtests/ssl_gtest/test_io.cc
+++ b/security/nss/gtests/ssl_gtest/test_io.cc
@@ -98,8 +98,13 @@ int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen,
}
int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
+ if (write_error_) {
+ PR_SetError(write_error_, 0);
+ return -1;
+ }
+
auto peer = peer_.lock();
- if (!peer || !writeable_) {
+ if (!peer) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
@@ -109,7 +114,7 @@ int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
DataBuffer filtered;
PacketFilter::Action action = PacketFilter::KEEP;
if (filter_) {
- action = filter_->Filter(packet, &filtered);
+ action = filter_->Process(packet, &filtered);
}
switch (action) {
case PacketFilter::CHANGE:
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h
index ac2497222..469d90a7c 100644
--- a/security/nss/gtests/ssl_gtest/test_io.h
+++ b/security/nss/gtests/ssl_gtest/test_io.h
@@ -33,9 +33,18 @@ class PacketFilter {
CHANGE, // change the packet to a different value
DROP // drop the packet
};
-
+ PacketFilter(bool enabled = true) : enabled_(enabled) {}
virtual ~PacketFilter() {}
+ virtual Action Process(const DataBuffer& input, DataBuffer* output) {
+ if (!enabled_) {
+ return KEEP;
+ }
+ return Filter(input, output);
+ }
+ void Enable() { enabled_ = true; }
+ void Disable() { enabled_ = false; }
+
// The packet filter takes input and has the option of mutating it.
//
// A filter that modifies the data places the modified data in *output and
@@ -43,6 +52,9 @@ class PacketFilter {
// case the value in *output is ignored. A Filter can return DROP, in which
// case the packet is dropped (and *output is ignored).
virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
+
+ private:
+ bool enabled_;
};
class DummyPrSocket : public DummyIOLayerMethods {
@@ -53,7 +65,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
peer_(),
input_(),
filter_(nullptr),
- writeable_(true) {}
+ write_error_(0) {}
virtual ~DummyPrSocket() {}
// Create a file descriptor that will reference this object. The fd must not
@@ -71,7 +83,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
PRIntervalTime to) override;
int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
- void CloseWrites() { writeable_ = false; }
+ void SetWriteError(PRErrorCode code) { write_error_ = code; }
SSLProtocolVariant variant() const { return variant_; }
bool readable() const { return !input_.empty(); }
@@ -98,7 +110,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
std::weak_ptr<DummyPrSocket> peer_;
std::queue<Packet> input_;
std::shared_ptr<PacketFilter> filter_;
- bool writeable_;
+ PRErrorCode write_error_;
};
// Marker interface.
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
index d6d91f7f7..3b939bba8 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -10,6 +10,7 @@
#include "pk11func.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
#include "tls_parser.h"
@@ -35,7 +36,6 @@ const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
-const std::string TlsAgent::kServerRsaChain = "rsa_chain";
const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
@@ -73,7 +73,6 @@ TlsAgent::TlsAgent(const std::string& name, Role role,
handshake_callback_(),
auth_certificate_callback_(),
sni_callback_(),
- expect_short_headers_(false),
skip_version_checks_(false) {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
@@ -93,11 +92,11 @@ TlsAgent::~TlsAgent() {
// Add failures manually, if any, so we don't throw in a destructor.
if (expected_received_alert_ != kTlsAlertCloseNotify ||
expected_received_alert_level_ != kTlsAlertWarning) {
- ADD_FAILURE() << "Wrong expected_received_alert status";
+ ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
}
if (expected_sent_alert_ != kTlsAlertCloseNotify ||
expected_sent_alert_level_ != kTlsAlertWarning) {
- ADD_FAILURE() << "Wrong expected_sent_alert status";
+ ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
}
}
@@ -258,13 +257,10 @@ void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) {
}
void TlsAgent::RequestClientAuth(bool requireAuth) {
- EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(SERVER, role_);
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE,
- requireAuth ? PR_TRUE : PR_FALSE));
+ SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
ssl_fd(), &TlsAgent::ClientAuthenticated, this));
@@ -376,42 +372,8 @@ void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
EXPECT_EQ(SECSuccess, rv);
}
-void TlsAgent::SetSessionTicketsEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
- en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetSessionCacheEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
void TlsAgent::Set0RttEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv =
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetFallbackSCSVEnabled(bool en) {
- EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV,
- en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetShortHeadersEnabled() {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd());
- EXPECT_EQ(SECSuccess, rv);
+ SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
}
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
@@ -437,8 +399,6 @@ void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
-void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
-
void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
@@ -517,6 +477,12 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
}
}
+void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_group, info_.originalKeaGroup);
+ }
+}
+
void TlsAgent::CheckAuthType(SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
EXPECT_EQ(STATE_CONNECTED, state_);
@@ -569,8 +535,7 @@ void TlsAgent::EnableFalseStart() {
falsestart_enabled_ = true;
EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
ssl_fd(), CanFalseStartCallback, this));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE));
+ SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
}
void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
@@ -578,7 +543,7 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE));
+ SetOption(SSL_ENABLE_ALPN, PR_TRUE);
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
@@ -622,12 +587,8 @@ void TlsAgent::CheckErrorCode(int32_t expected) const {
}
static uint8_t GetExpectedAlertLevel(uint8_t alert) {
- switch (alert) {
- case kTlsAlertCloseNotify:
- case kTlsAlertEndOfEarlyData:
- return kTlsAlertWarning;
- default:
- break;
+ if (alert == kTlsAlertCloseNotify) {
+ return kTlsAlertWarning;
}
return kTlsAlertFatal;
}
@@ -730,6 +691,50 @@ void TlsAgent::ResetPreliminaryInfo() {
expected_cipher_suite_ = 0;
}
+void TlsAgent::ValidateCipherSpecs() {
+ PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
+ // We use one ciphersuite in each direction.
+ PRInt32 expected = 2;
+ if (variant_ == ssl_variant_datagram) {
+ // For DTLS 1.3, the client retains the cipher spec for early data and the
+ // handshake so that it can retransmit EndOfEarlyData and its final flight.
+ // It also retains the handshake read cipher spec so that it can read ACKs
+ // from the server. The server retains the handshake read cipher spec so it
+ // can read the client's retransmitted Finished.
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ if (role_ == CLIENT) {
+ expected = info_.earlyDataAccepted ? 5 : 4;
+ } else {
+ expected = 3;
+ }
+ } else {
+ // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
+ // until the holddown timer runs down.
+ if (expect_resumption_) {
+ if (role_ == CLIENT) {
+ expected = 3;
+ }
+ } else {
+ if (role_ == SERVER) {
+ expected = 3;
+ }
+ }
+ }
+ }
+ // This function will be run before the handshake completes if false start is
+ // enabled. In that case, the client will still be reading cleartext, but
+ // will have a spec prepared for reading ciphertext. With DTLS, the client
+ // will also have a spec retained for retransmission of handshake messages.
+ if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
+ EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
+ expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
+ }
+ EXPECT_EQ(expected, cipherSpecs);
+ if (expected != cipherSpecs) {
+ SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
+ }
+}
+
void TlsAgent::Connected() {
if (state_ == STATE_CONNECTED) {
return;
@@ -743,6 +748,8 @@ void TlsAgent::Connected() {
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
+ EXPECT_EQ(expect_resumption_, info_.resumed == PR_TRUE);
+
// Preliminary values are exposed through callbacks during the handshake.
// If either expected values were set or the callbacks were called, check
// that the final values are correct.
@@ -753,32 +760,13 @@ void TlsAgent::Connected() {
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
- if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd());
- // We use one ciphersuite in each direction, plus one that's kept around
- // by DTLS for retransmission.
- PRInt32 expected =
- ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2;
- EXPECT_EQ(expected, cipherSuites);
- if (expected != cipherSuites) {
- SSLInt_PrintTls13CipherSpecs(ssl_fd());
- }
- }
+ ValidateCipherSpecs();
- PRBool short_headers;
- rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers);
- EXPECT_EQ(SECSuccess, rv);
- EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
SetState(STATE_CONNECTED);
}
void TlsAgent::EnableExtendedMasterSecret() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv =
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
-
- ASSERT_EQ(SECSuccess, rv);
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
}
void TlsAgent::CheckExtendedMasterSecret(bool expected) {
@@ -801,21 +789,6 @@ void TlsAgent::CheckSecretsDestroyed() {
ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
-void TlsAgent::DisableRollbackDetection() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE);
-
- ASSERT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::EnableCompression() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE);
- ASSERT_EQ(SECSuccess, rv);
-}
-
void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
ASSERT_TRUE(EnsureTlsSetup());
@@ -883,6 +856,14 @@ void TlsAgent::SendDirect(const DataBuffer& buf) {
}
}
+void TlsAgent::SendRecordDirect(const TlsRecord& record) {
+ DataBuffer buf;
+
+ auto rv = record.header.Write(&buf, 0, record.buffer);
+ EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
+ SendDirect(buf);
+}
+
static bool ErrorIsNonFatal(PRErrorCode code) {
return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
}
@@ -918,6 +899,27 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
}
}
+bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint16_t wireVersion, uint64_t seq,
+ uint8_t ct, const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ // Ensure we are a TLS 1.3 cipher agent.
+ EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
+ TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq);
+ DataBuffer padded = buf;
+ padded.Write(padded.len(), ct, 1);
+ DataBuffer ciphertext;
+ if (!spec->Protect(header, padded, &ciphertext)) {
+ return false;
+ }
+
+ DataBuffer record;
+ auto rv = header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(header.header_length() + ciphertext.len(), rv);
+ SendDirect(record);
+ return true;
+}
+
void TlsAgent::ReadBytes(size_t amount) {
uint8_t block[16384];
@@ -951,23 +953,20 @@ void TlsAgent::ReadBytes(size_t amount) {
void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
-void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE,
- mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
- EXPECT_EQ(SECSuccess, rv);
+void TlsAgent::SetOption(int32_t option, int value) {
+ ASSERT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
+}
- rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
- mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ SetOption(SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
}
void TlsAgent::DisableECDHEServerKeyReuse() {
- ASSERT_TRUE(EnsureTlsSetup());
ASSERT_EQ(TlsAgent::SERVER, role_);
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+ SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
}
static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h
index 4bccb9a84..b3fd892ae 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.h
+++ b/security/nss/gtests/ssl_gtest/tls_agent.h
@@ -66,7 +66,6 @@ class TlsAgent : public PollTarget {
static const std::string kServerRsaSign;
static const std::string kServerRsaPss;
static const std::string kServerRsaDecrypt;
- static const std::string kServerRsaChain; // A cert that requires a chain.
static const std::string kServerEcdsa256;
static const std::string kServerEcdsa384;
static const std::string kServerEcdsa521;
@@ -81,9 +80,11 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}
+ // Set a filter that can access plaintext (TLS 1.3 only).
void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
filter->SetAgent(this);
adapter_->SetPacketFilter(filter);
+ filter->EnableDecryption();
}
void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
@@ -95,6 +96,7 @@ class TlsAgent : public PollTarget {
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
+ void CheckOriginalKEA(SSLNamedGroup kea_group) const;
void CheckAuthType(SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const;
@@ -121,12 +123,10 @@ class TlsAgent : public PollTarget {
void SetupClientAuth();
void RequestClientAuth(bool requireAuth);
+ void SetOption(int32_t option, int value);
void ConfigureSessionCache(SessionResumptionMode mode);
- void SetSessionTicketsEnabled(bool en);
- void SetSessionCacheEnabled(bool en);
void Set0RttEnabled(bool en);
void SetFallbackSCSVEnabled(bool en);
- void SetShortHeadersEnabled();
void SetVersionRange(uint16_t minver, uint16_t maxver);
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
void CheckPreliminaryInfo();
@@ -136,7 +136,6 @@ class TlsAgent : public PollTarget {
void ExpectReadWriteError();
void EnableFalseStart();
void ExpectResumption();
- void ExpectShortHeaders();
void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
@@ -149,15 +148,17 @@ class TlsAgent : public PollTarget {
// Send data on the socket, encrypting it.
void SendData(size_t bytes, size_t blocksize = 1024);
void SendBuffer(const DataBuffer& buf);
+ bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint16_t wireVersion, uint64_t seq, uint8_t ct,
+ const DataBuffer& buf);
// Send data directly to the underlying socket, skipping the TLS layer.
void SendDirect(const DataBuffer& buf);
+ void SendRecordDirect(const TlsRecord& record);
void ReadBytes(size_t max = 16384U);
void ResetSentBytes(); // Hack to test drops.
void EnableExtendedMasterSecret();
void CheckExtendedMasterSecret(bool expected);
void CheckEarlyDataAccepted(bool expected);
- void DisableRollbackDetection();
- void EnableCompression();
void SetDowngradeCheckVersion(uint16_t version);
void CheckSecretsDestroyed();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
@@ -170,6 +171,8 @@ class TlsAgent : public PollTarget {
Role role() const { return role_; }
std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+ SSLProtocolVariant variant() const { return variant_; }
+
State state() const { return state_; }
const CERTCertificate* peer_cert() const {
@@ -253,6 +256,7 @@ class TlsAgent : public PollTarget {
const static char* states[];
void SetState(State state);
+ void ValidateCipherSpecs();
// Dummy auth certificate hook.
static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
@@ -388,7 +392,6 @@ class TlsAgent : public PollTarget {
HandshakeCallbackFunction handshake_callback_;
AuthCertificateCallbackFunction auth_certificate_callback_;
SniCallbackFunction sni_callback_;
- bool expect_short_headers_;
bool skip_version_checks_;
};
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
index c8de5a1fe..0af5123e9 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -5,6 +5,7 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "tls_connect.h"
+#include "sslexp.h"
extern "C" {
#include "libssl_internals.h"
}
@@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) {
switch (version) {
case 0:
return "(no version)";
+ case SSL_LIBRARY_VERSION_3_0:
+ return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_0:
return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_1:
@@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
+ expected_resumptions_(0),
session_ids_(),
expect_extended_master_secret_(false),
expect_early_data_accepted_(false),
@@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares(
EXPECT_EQ(shares.len(), i);
}
+void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
+ uint16_t server_epoch) const {
+ uint16_t read_epoch = 0;
+ uint16_t write_epoch = 0;
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(server_epoch, read_epoch) << "client read epoch";
+ EXPECT_EQ(client_epoch, write_epoch) << "client write epoch";
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(client_epoch, read_epoch) << "server read epoch";
+ EXPECT_EQ(server_epoch, write_epoch) << "server write epoch";
+}
+
void TlsConnectTestBase::ClearStats() {
// Clear statistics.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -178,6 +198,7 @@ void TlsConnectTestBase::SetUp() {
SSLInt_ClearSelfEncryptKey();
SSLInt_SetTicketLifetime(30);
SSLInt_SetMaxEarlyDataSize(1024);
+ SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3);
ClearStats();
Init();
}
@@ -219,12 +240,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name,
Init();
}
-void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
+void TlsConnectTestBase::MakeNewServer() {
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ if (version_) {
+ server_->SetVersionRange(version_, version_);
+ }
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->StartConnect();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumptions) {
expected_resumption_mode_ = expected;
if (expected != RESUME_NONE) {
client_->ExpectResumption();
server_->ExpectResumption();
+ expected_resumptions_ = num_resumptions;
}
+ EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
}
void TlsConnectTestBase::EnsureTlsSetup() {
@@ -258,6 +294,11 @@ void TlsConnectTestBase::Connect() {
CheckConnected();
}
+void TlsConnectTestBase::StartConnect() {
+ server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
+ client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+}
+
void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
EnsureTlsSetup();
client_->EnableSingleCipher(cipher_suite);
@@ -274,6 +315,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
+ // Have the client read handshake twice to make sure we get the
+ // NST and the ACK.
+ if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ variant_ == ssl_variant_datagram) {
+ client_->Handshake();
+ client_->Handshake();
+ auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
+ // Verify that we dropped the client's retransmission cipher suites.
+ EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
+ if (suites != 2) {
+ SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
+ }
+ }
EXPECT_EQ(client_->version(), server_->version());
if (!skip_version_checks_) {
// Check the version is as expected
@@ -314,10 +368,12 @@ void TlsConnectTestBase::CheckConnected() {
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
- client_->CheckKEA(kea_type, kea_group);
- server_->CheckKEA(kea_type, kea_group);
- client_->CheckAuthType(auth_type, sig_scheme);
+ if (kea_group != ssl_grp_none) {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ }
server_->CheckAuthType(auth_type, sig_scheme);
+ client_->CheckAuthType(auth_type, sig_scheme);
}
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
@@ -372,9 +428,19 @@ void TlsConnectTestBase::CheckKeys() const {
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}
+void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
+ SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) {
+ CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
+ EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
+ client_->CheckOriginalKEA(original_kea_group);
+ server_->CheckOriginalKEA(original_kea_group);
+}
+
void TlsConnectTestBase::ConnectExpectFail() {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
Handshake();
ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
@@ -395,8 +461,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
}
void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->SetServerKeyBits(server_->server_key_bits());
client_->Handshake();
server_->Handshake();
@@ -455,29 +520,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() {
}
}
+void TlsConnectTestBase::ConfigureSelfEncrypt() {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(
+ TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+}
+
void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server) {
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- ScopedCERTCertificate cert;
- ScopedSECKEYPrivateKey privKey;
- ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
- &privKey));
-
- ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
- ASSERT_TRUE(pubKey);
-
- EXPECT_EQ(SECSuccess,
- SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+ ConfigureSelfEncrypt();
}
}
void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
EXPECT_NE(RESUME_BOTH, expected);
- int resume_count = expected ? 1 : 0;
- int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
+ int resume_count = expected ? expected_resumptions_ : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
// Note: hch == server counter; hsh == client counter.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -490,7 +559,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
if (expected != RESUME_NONE) {
if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
// Check that the last two session ids match.
- ASSERT_EQ(2U, session_ids_.size());
+ ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
EXPECT_EQ(session_ids_[session_ids_.size() - 1],
session_ids_[session_ids_.size() - 2]);
} else {
@@ -540,31 +609,28 @@ void TlsConnectTestBase::CheckSrtp() const {
server_->CheckSrtp();
}
-void TlsConnectTestBase::SendReceive() {
- client_->SendData(50);
- server_->SendData(50);
- Receive(50);
+void TlsConnectTestBase::SendReceive(size_t total) {
+ ASSERT_GT(total, client_->received_bytes());
+ ASSERT_GT(total, server_->received_bytes());
+ client_->SendData(total - server_->received_bytes());
+ server_->SendData(total - client_->received_bytes());
+ Receive(total); // Receive() is cumulative
}
// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
+ // If we don't do this, then all 0-RTT attempts will be rejected.
+ SSLInt_RolloverAntiReplay();
+
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys();
Reset();
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
}
// Do a first connection so we can do resumption
@@ -584,10 +650,6 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
- if (expect_writable && expect_readable) {
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
- }
-
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -599,7 +661,7 @@ void TlsConnectTestBase::ZeroRttSendReceive(
} else {
EXPECT_EQ(SECFailure, rv);
}
- server_->Handshake(); // Consume ClientHello, EE, Finished.
+ server_->Handshake(); // Consume ClientHello
std::vector<uint8_t> buf(k0RttDataLen);
rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
@@ -653,6 +715,30 @@ void TlsConnectTestBase::SkipVersionChecks() {
server_->SkipVersionChecks();
}
+// Shift the DTLS timers, to the minimum time necessary to let the next timer
+// run on either client or server. This allows tests to skip waiting without
+// having timers run out of order.
+void TlsConnectTestBase::ShiftDtlsTimers() {
+ PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv == SECSuccess) {
+ time_shift = time;
+ }
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv == SECSuccess &&
+ (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
+ time_shift = time;
+ }
+
+ if (time_shift == PR_INTERVAL_NO_TIMEOUT) {
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
+}
+
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -691,11 +777,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
EXPECT_TRUE(ext.len() % 2 == 0);
+
std::vector<SSLNamedGroup> groups;
for (size_t i = 1; i < ext.len() / 2; i += 1) {
EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
@@ -705,10 +795,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+
std::vector<SSLNamedGroup> shares;
size_t i = 2;
while (i < ext.len()) {
@@ -724,17 +818,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
void TlsKeyExchangeTest::CheckKEXDetails(
const std::vector<SSLNamedGroup>& expected_groups,
const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
- std::vector<SSLNamedGroup> groups =
- GetGroupDetails(groups_capture_->extension());
+ std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
EXPECT_EQ(expected_groups, groups);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
ASSERT_LT(0U, expected_shares.size());
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture_->extension());
+ std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
EXPECT_EQ(expected_shares, shares);
} else {
- EXPECT_EQ(0U, shares_capture_->extension().len());
+ EXPECT_FALSE(shares_capture_->captured());
}
EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
@@ -756,8 +848,6 @@ void TlsKeyExchangeTest::CheckKEXDetails(
EXPECT_NE(expected_share2, it);
}
std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture2_->extension());
- EXPECT_EQ(expected_shares2, shares);
+ EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h
index 73e8dc81a..c650dda1d 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.h
+++ b/security/nss/gtests/ssl_gtest/tls_connect.h
@@ -61,7 +61,11 @@ class TlsConnectTestBase : public ::testing::Test {
// Reset, and update the certificate names on both peers
void Reset(const std::string& server_name,
const std::string& client_name = "client");
+ // Replace the server.
+ void MakeNewServer();
+ // Set up
+ void StartConnect();
// Run the handshake.
void Handshake();
// Connect and check that it works.
@@ -81,20 +85,28 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
// This version assumes defaults.
void CheckKeys() const;
+ // Check that keys on resumed sessions.
+ void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme);
void CheckGroups(const DataBuffer& groups,
std::function<void(SSLNamedGroup)> check_group);
void CheckShares(const DataBuffer& shares,
std::function<void(SSLNamedGroup)> check_group);
+ void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
void ConfigureVersion(uint16_t version);
void SetExpectedVersion(uint16_t version);
// Expect resumption of a particular type.
- void ExpectResumption(SessionResumptionMode expected);
+ void ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumed = 1);
void DisableAllCiphers();
void EnableOnlyStaticRsaCiphers();
void EnableOnlyDheCiphers();
void EnableSomeEcdhCiphers();
void EnableExtendedMasterSecret();
+ void ConfigureSelfEncrypt();
void ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server);
void EnableAlpn();
@@ -103,7 +115,7 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckAlpn(const std::string& val);
void EnableSrtp();
void CheckSrtp() const;
- void SendReceive();
+ void SendReceive(size_t total = 50);
void SetupForZeroRtt();
void SetupForResume();
void ZeroRttSendReceive(
@@ -115,6 +127,9 @@ class TlsConnectTestBase : public ::testing::Test {
void DisableECDHEServerKeyReuse();
void SkipVersionChecks();
+ // Move the DTLS timers for both endpoints to pop the next timer.
+ void ShiftDtlsTimers();
+
protected:
SSLProtocolVariant variant_;
std::shared_ptr<TlsAgent> client_;
@@ -123,6 +138,7 @@ class TlsConnectTestBase : public ::testing::Test {
std::unique_ptr<TlsAgent> server_model_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
+ uint8_t expected_resumptions_;
std::vector<std::vector<uint8_t>> session_ids_;
// A simple value of "a", "b". Note that the preferred value of "a" is placed
@@ -244,6 +260,11 @@ class TlsConnectDatagram13 : public TlsConnectTestBase {
: TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
+class TlsConnectDatagramPre13 : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagramPre13() {}
+};
+
// A variant that is used only with Pre13.
class TlsConnectGenericPre13 : public TlsConnectGeneric {};
@@ -256,8 +277,10 @@ class TlsKeyExchangeTest : public TlsConnectGeneric {
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
- std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext);
- std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext);
+ std::vector<SSLNamedGroup> GetGroupDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ std::vector<SSLNamedGroup> GetShareDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
const std::vector<SSLNamedGroup>& expectedShares);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
index 76d9aaaff..89f201295 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -12,6 +12,7 @@ extern "C" {
#include "libssl_internals.h"
}
+#include <cassert>
#include <iostream>
#include "gtest_utils.h"
#include "tls_agent.h"
@@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
if (g_ssl_gtest_verbose) {
- std::cerr << "Cipher spec changed. Role="
- << (isServer ? "server" : "client")
- << " direction=" << (sending ? "send" : "receive") << std::endl;
+ std::cerr << (isServer ? "server" : "client") << ": "
+ << (sending ? "send" : "receive")
+ << " cipher spec changed: " << newSpec->epoch << " ("
+ << newSpec->phase << ")" << std::endl;
+ }
+ if (!sending) {
+ return;
}
- if (!sending) return;
+ self->in_sequence_number_ = 0;
+ self->out_sequence_number_ = 0;
+ self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
- bool ret =
- self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
- SSLInt_CipherSpecToKey(isServer, newSpec),
- SSLInt_CipherSpecToIv(isServer, newSpec));
+ bool ret = self->cipher_spec_->Init(
+ SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec),
+ SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec));
EXPECT_EQ(true, ret);
}
@@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(&parser, &record)) {
+ if (!header.Parse(in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
+ // Track the sequence number, which is necessary for stream mode (the
+ // sequence number is in the header for datagram).
+ //
+ // This isn't perfectly robust. If there is a change from an active cipher
+ // spec to another active cipher spec (KeyUpdate for instance) AND writes
+ // are consolidated across that change AND packets were dropped from the
+ // older epoch, we will not correctly re-encrypt records in the old epoch to
+ // update their sequence numbers.
+ if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
+ ++in_sequence_number_;
+ }
+
if (FilterRecord(header, record, &offset, output) != KEEP) {
changed = true;
} else {
@@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
header.sequence_number()};
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
+ // In stream mode, even if something doesn't change we need to re-encrypt if
+ // previous packets were dropped.
if (action == KEEP) {
- return KEEP;
+ if (header.is_dtls() || !dropped_record_) {
+ return KEEP;
+ }
+ filtered = plaintext;
}
if (action == DROP) {
- std::cerr << "record drop: " << record << std::endl;
+ std::cerr << "record drop: " << header << ":" << record << std::endl;
+ dropped_record_ = true;
return DROP;
}
EXPECT_GT(0x10000U, filtered.len());
- std::cerr << "record old: " << plaintext << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
+ if (action != KEEP) {
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ }
+
+ uint64_t seq_num;
+ if (header.is_dtls() || !cipher_spec_ ||
+ header.content_type() != kTlsApplicationDataType) {
+ seq_num = header.sequence_number();
+ } else {
+ seq_num = out_sequence_number_++;
+ }
+ TlsRecordHeader out_header = {header.version(), header.content_type(),
+ seq_num};
DataBuffer ciphertext;
- bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
+ bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
EXPECT_TRUE(rv);
if (!rv) {
return KEEP;
}
- *offset = header.Write(output, *offset, ciphertext);
+ *offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
-bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
@@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
}
version_ = version;
- sequence_number_ = 0;
+ // If this is DTLS, overwrite the sequence number.
if (IsDtls(version)) {
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
@@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return false;
}
sequence_number_ |= static_cast<uint64_t>(tmp);
+ } else {
+ sequence_number_ = sequence_number;
}
return parser->ReadVariable(body, 2);
}
@@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
return true;
}
- if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;
+ if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) {
+ return false;
+ }
size_t len = plaintext->len();
while (len > 0 && !plaintext->data()[len - 1]) {
@@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
*inner_content_type = plaintext->data()[len - 1];
plaintext->Truncate(len - 1);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "unprotect: " << std::hex << header.sequence_number()
+ << std::dec << " type=" << static_cast<int>(*inner_content_type)
+ << " " << *plaintext << std::endl;
+ }
return true;
}
@@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
*ciphertext = plaintext;
return true;
}
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "protect: " << header.sequence_number() << std::endl;
+ }
DataBuffer padded = plaintext;
padded.Write(padded.len(), inner_content_type, 1);
return cipher_spec_->Protect(header, padded, ciphertext);
}
+bool IsHelloRetry(const DataBuffer& body) {
+ static const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+ return memcmp(body.data() + 2, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random)) == 0;
+}
+
+bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& body) {
+ if (handshake_types_.empty()) {
+ return true;
+ }
+
+ uint8_t type = header.handshake_type();
+ if (type == kTlsHandshakeServerHello) {
+ if (IsHelloRetry(body)) {
+ type = kTlsHandshakeHelloRetryRequest;
+ }
+ }
+ return handshake_types_.count(type) > 0U;
+}
+
PacketFilter::Action TlsHandshakeFilter::FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
// Check that the first byte is as requested.
- if (record_header.content_type() != kTlsHandshakeType) {
+ if ((record_header.content_type() != kTlsHandshakeType) &&
+ (record_header.content_type() != kTlsAltHandshakeType)) {
return KEEP;
}
@@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
while (parser.remaining()) {
HandshakeHeader header;
DataBuffer handshake;
- if (!header.Parse(&parser, record_header, &handshake)) {
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake,
+ &complete)) {
return KEEP;
}
+ if (!complete) {
+ EXPECT_TRUE(record_header.is_dtls());
+ // Save the fragment and drop it from this record. Fragments are
+ // coalesced with the last fragment of the handshake message.
+ changed = true;
+ preceding_fragment_.Assign(handshake);
+ continue;
+ }
+ preceding_fragment_.Truncate(0);
+
DataBuffer filtered;
- PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+ PacketFilter::Action action;
+ if (!IsFilteredType(header, handshake)) {
+ action = KEEP;
+ } else {
+ action = FilterHandshake(header, handshake, &filtered);
+ }
if (action == DROP) {
changed = true;
std::cerr << "handshake drop: " << handshake << std::endl;
@@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
std::cerr << "handshake old: " << handshake << std::endl;
std::cerr << "handshake new: " << filtered << std::endl;
source = &filtered;
+ } else if (preceding_fragment_.len()) {
+ changed = true;
}
offset = header.Write(output, offset, *source);
@@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
}
bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
- TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) {
- if (!parser->Read(length, 3)) {
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset,
+ uint32_t* length, bool* last_fragment) {
+ uint32_t message_length;
+ if (!parser->Read(&message_length, 3)) {
return false; // malformed
}
if (!header.is_dtls()) {
+ *last_fragment = true;
+ *length = message_length;
return true; // nothing left to do
}
@@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
}
message_seq_ = message_seq_tmp;
- uint32_t fragment_offset;
- if (!parser->Read(&fragment_offset, 3)) {
+ uint32_t offset = 0;
+ if (!parser->Read(&offset, 3)) {
+ return false;
+ }
+ // We only parse if the fragments are all complete and in order.
+ if (offset != expected_offset) {
+ EXPECT_NE(0U, header.epoch())
+ << "Received out of order handshake fragment for epoch 0";
return false;
}
- uint32_t fragment_length;
- if (!parser->Read(&fragment_length, 3)) {
+ // For DTLS, we return the length of just this fragment.
+ if (!parser->Read(length, 3)) {
return false;
}
- // All current tests where we are using this code don't fragment.
- return (fragment_offset == 0 && fragment_length == *length);
+ // It's a fragment if the entire message is longer than what we have.
+ *last_fragment = message_length == (*length + offset);
+ return true;
}
bool TlsHandshakeFilter::HandshakeHeader::Parse(
- TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) {
+ TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
+ *complete = false;
+
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
}
+
uint32_t length;
- if (!ReadLength(parser, record_header, &length)) {
+ if (!ReadLength(parser, record_header, preceding_fragment.len(), &length,
+ complete)) {
return false;
}
- return parser->Read(body, length);
+ if (!parser->Read(body, length)) {
+ return false;
+ }
+ if (preceding_fragment.len()) {
+ body->Splice(preceding_fragment, 0);
+ }
+ return true;
}
size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
@@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
return KEEP;
}
- if (header.handshake_type() == handshake_type_) {
- buffer_ = input;
- }
+ buffer_ = input;
return KEEP;
}
PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == handshake_type_) {
- *output = buffer_;
- return CHANGE;
- }
+ *output = buffer_;
+ return CHANGE;
+}
+PacketFilter::Action TlsRecordRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (!filter_ || (header.content_type() == ct_)) {
+ records_.push_back({header, input});
+ }
return KEEP;
}
@@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord(
return KEEP;
}
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(header);
+ return KEEP;
+}
+
+const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) {
+ if (index > headers_.size() + 1) {
+ return nullptr;
+ }
+ return &headers_[index];
+}
+
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
bool changed = false;
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- PacketFilter::Action action = (*it)->Filter(in, output);
+ PacketFilter::Action action = (*it)->Process(in, output);
if (action == DROP) {
return DROP;
}
+
if (action == CHANGE) {
in = *output;
changed = true;
@@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
-static bool FindHelloRetryExtensions(TlsParser* parser,
- const TlsVersioned& header) {
- // TODO for -19 add cipher suite
- if (!parser->Skip(2)) { // version
- return false;
- }
- return true;
-}
-
bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
@@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser,
if (!parser->SkipVariable(1)) { // request context
return false;
}
- // TODO remove the next two for -19
- if (!parser->SkipVariable(2)) { // signature_algorithms
- return false;
- }
- if (!parser->SkipVariable(2)) { // certificate_authorities
- return false;
- }
return true;
}
@@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
if (!parser->Skip(8)) { // lifetime, age add
return false;
}
+ if (!parser->SkipVariable(1)) { // ticket_nonce
+ return false;
+ }
if (!parser->SkipVariable(2)) { // ticket
return false;
}
@@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
{kTlsHandshakeClientHello, FindClientHelloExtensions},
{kTlsHandshakeServerHello, FindServerHelloExtensions},
- {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
{kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
{kTlsHandshakeCertificateRequest, FindCertReqExtensions},
{kTlsHandshakeCertificate, FindCertificateExtensions},
@@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
PacketFilter::Action TlsExtensionFilter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (handshake_types_.count(header.handshake_type()) == 0) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!FindExtensions(&parser, header)) {
return KEEP;
@@ -610,6 +725,38 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension(
return KEEP;
}
+PacketFilter::Action TlsExtensionInjector::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ size_t offset = parser.consumed();
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+}
+
PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
@@ -628,10 +775,8 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
- EXPECT_EQ(SECSuccess,
- SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
- }
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
return KEEP;
}
@@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
return ((1 << counter_++) & pattern_) ? DROP : KEEP;
}
+PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& data,
+ DataBuffer* changed) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+/* static */ uint32_t SelectiveRecordDropFilter::ToPattern(
+ std::initializer_list<size_t> records) {
+ uint32_t pattern = 0;
+ for (auto it = records.begin(); it != records.end(); ++it) {
+ EXPECT_GT(32U, *it);
+ assert(*it < 32U);
+ pattern |= 1 << *it;
+ }
+ return pattern;
+}
+
PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- *output = input;
- output->Write(0, version_, 2);
- return CHANGE;
- }
- return KEEP;
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+}
+
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ // Cipher suite is after version(2) and random(32).
+ size_t pos = 34;
+ if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In old versions, we have to skip a session_id too.
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+ }
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index e4030e23f..1db3b90f6 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned {
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
- size_t header_length() const { return is_dtls() ? 11 : 3; }
+ uint16_t epoch() const {
+ return static_cast<uint16_t>(sequence_number_ >> 48);
+ }
+ size_t header_length() const { return is_dtls() ? 13 : 5; }
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(TlsParser* parser, DataBuffer* body);
+ bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
@@ -63,10 +66,21 @@ class TlsRecordHeader : public TlsVersioned {
uint64_t sequence_number_;
};
+struct TlsRecord {
+ const TlsRecordHeader header;
+ const DataBuffer buffer;
+};
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
+ TlsRecordFilter()
+ : agent_(nullptr),
+ count_(0),
+ cipher_spec_(),
+ dropped_record_(false),
+ in_sequence_number_(0),
+ out_sequence_number_(0) {}
void SetAgent(const TlsAgent* agent) { agent_ = agent; }
const TlsAgent* agent() const { return agent_; }
@@ -115,14 +129,21 @@ class TlsRecordFilter : public PacketFilter {
const TlsAgent* agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
+ // Whether we dropped a record since the cipher spec changed.
+ bool dropped_record_;
+ // The sequence number we use for reading records as they are written.
+ uint64_t in_sequence_number_;
+ // The sequence number we use for writing modified records.
+ uint64_t out_sequence_number_;
};
-inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
+inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
v.WriteStream(stream);
return stream;
}
-inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsRecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
@@ -133,13 +154,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
stream << "Alert";
break;
case kTlsHandshakeType:
+ case kTlsAltHandshakeType:
stream << "Handshake";
break;
case kTlsApplicationDataType:
stream << "Data";
break;
+ case kTlsAckType:
+ stream << "ACK";
+ break;
default:
- stream << '<' << hdr.content_type() << '>';
+ stream << '<' << static_cast<int>(hdr.content_type()) << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
@@ -150,7 +175,16 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() {}
+ TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::set<uint8_t>& types)
+ : handshake_types_(types), preceding_fragment_() {}
+
+ // This filter can be set to be selective based on handshake message type. If
+ // this function isn't used (or the set is empty), then all handshake messages
+ // will be filtered.
+ void SetHandshakeTypes(const std::set<uint8_t>& types) {
+ handshake_types_ = types;
+ }
class HandshakeHeader : public TlsVersioned {
public:
@@ -158,7 +192,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
uint8_t handshake_type() const { return handshake_type_; }
bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
- DataBuffer* body);
+ const DataBuffer& preceding_fragment, DataBuffer* body,
+ bool* complete);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
@@ -169,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
- uint32_t* length);
+ uint32_t expected_offset, uint32_t* length,
+ bool* last_fragment);
uint8_t handshake_type_;
uint16_t message_seq_;
@@ -185,22 +221,30 @@ class TlsHandshakeFilter : public TlsRecordFilter {
DataBuffer* output) = 0;
private:
+ bool IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& handshake);
+
+ std::set<uint8_t> handshake_types_;
+ DataBuffer preceding_fragment_;
};
// Make a copy of the first instance of a handshake message.
class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : handshake_type_(handshake_type), buffer_() {}
+ : TlsHandshakeFilter({handshake_type}), buffer_() {}
+ TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
+ void Reset() { buffer_.Truncate(0); }
+
const DataBuffer& buffer() const { return buffer_; }
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
@@ -209,17 +253,39 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
const DataBuffer& replacement)
- : handshake_type_(handshake_type), buffer_(replacement) {}
+ : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
+// Make a copy of each record of a given type.
+class TlsRecordRecorder : public TlsRecordFilter {
+ public:
+ TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder()
+ : filter_(false),
+ ct_(content_handshake), // dummy (<optional> is C++14)
+ records_() {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ size_t count() const { return records_.size(); }
+ void Clear() { records_.clear(); }
+
+ const TlsRecord& record(size_t i) const { return records_[i]; }
+
+ private:
+ bool filter_;
+ uint8_t ct_;
+ std::vector<TlsRecord> records_;
+};
+
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
@@ -230,15 +296,31 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer* output);
private:
- DataBuffer& buffer_;
+ DataBuffer buffer_;
};
+// Make a copy of the records
+class TlsHeaderRecorder : public TlsRecordFilter {
+ public:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ const TlsRecordHeader* header(size_t index);
+
+ private:
+ std::vector<TlsRecordHeader> headers_;
+};
+
+typedef std::initializer_list<std::shared_ptr<PacketFilter>>
+ ChainedPacketFilterInit;
+
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
+ ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
@@ -256,13 +338,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter() : handshake_types_() {
- handshake_types_.insert(kTlsHandshakeClientHello);
- handshake_types_.insert(kTlsHandshakeServerHello);
- }
+ TlsExtensionFilter()
+ : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ kTlsHandshakeHelloRetryRequest,
+ kTlsHandshakeEncryptedExtensions}) {}
TlsExtensionFilter(const std::set<uint8_t>& types)
- : handshake_types_(types) {}
+ : TlsHandshakeFilter(types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -279,8 +361,6 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
-
- std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
@@ -326,6 +406,21 @@ class TlsExtensionDropper : public TlsExtensionFilter {
uint16_t extension_;
};
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
class TlsAgent;
typedef std::function<void(void)> VoidFunction;
@@ -352,7 +447,7 @@ class AfterRecordN : public TlsRecordFilter {
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : server_(server) {}
+ : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -377,10 +472,47 @@ class SelectiveDropFilter : public PacketFilter {
uint8_t counter_;
};
+// This class selectively drops complete records. The difference from
+// SelectiveDropFilter is that if multiple DTLS records are in the same
+// datagram, we just drop one.
+class SelectiveRecordDropFilter : public TlsRecordFilter {
+ public:
+ SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
+ : pattern_(pattern), counter_(0) {
+ if (!enabled) {
+ Disable();
+ }
+ }
+ SelectiveRecordDropFilter(std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(ToPattern(records), true) {}
+
+ void Reset(uint32_t pattern) {
+ counter_ = 0;
+ PacketFilter::Enable();
+ pattern_ = pattern;
+ }
+
+ void Reset(std::initializer_list<size_t> records) {
+ Reset(ToPattern(records));
+ }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override;
+
+ private:
+ static uint32_t ToPattern(std::initializer_list<size_t> records);
+
+ uint32_t pattern_;
+ uint8_t counter_;
+};
+
// Set the version number in the ClientHello.
class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
+ TlsInspectorClientHelloVersionSetter(uint16_t version)
+ : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -411,6 +543,20 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
uint8_t type_;
};
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(uint16_t suite)
+ : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t cipher_suite_;
+};
+
} // namespace nss_test
#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
index 51ff938b1..45f6cf2bd 100644
--- a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
@@ -241,13 +241,13 @@ TEST_P(TlsHkdfTest, HkdfExpandLabel) {
{/* ssl_hash_md5 */},
{/* ssl_hash_sha1 */},
{/* ssl_hash_sha224 */},
- {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b,
- 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13,
- 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c},
- {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01,
- 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21,
- 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9,
- 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}};
+ {0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59,
+ 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc,
+ 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1},
+ {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8,
+ 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54,
+ 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7,
+ 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}};
const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_],
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc
index efcd89e14..6c945f66e 100644
--- a/security/nss/gtests/ssl_gtest/tls_protect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_protect.cc
@@ -32,7 +32,6 @@ void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) {
}
DataBuffer d(nonce, 12);
- std::cerr << "Nonce " << d << std::endl;
}
bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length,
@@ -92,8 +91,9 @@ bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq,
in, inlen, out, outlen, maxlen);
}
-bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key,
- const uint8_t *iv) {
+bool TlsCipherSpec::Init(uint16_t epoch, SSLCipherAlgorithm cipher,
+ PK11SymKey *key, const uint8_t *iv) {
+ epoch_ = epoch;
switch (cipher) {
case ssl_calg_aes_gcm:
aead_.reset(new AeadCipherAesGcm());
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h
index 4efbd6e6b..93ffd6322 100644
--- a/security/nss/gtests/ssl_gtest/tls_protect.h
+++ b/security/nss/gtests/ssl_gtest/tls_protect.h
@@ -20,7 +20,7 @@ class TlsRecordHeader;
class AeadCipher {
public:
AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {}
- ~AeadCipher();
+ virtual ~AeadCipher();
bool Init(PK11SymKey *key, const uint8_t *iv);
virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen,
@@ -58,16 +58,19 @@ class AeadCipherAesGcm : public AeadCipher {
// Our analog of ssl3CipherSpec
class TlsCipherSpec {
public:
- TlsCipherSpec() : aead_() {}
+ TlsCipherSpec() : epoch_(0), aead_() {}
- bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv);
+ bool Init(uint16_t epoch, SSLCipherAlgorithm cipher, PK11SymKey *key,
+ const uint8_t *iv);
bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext,
DataBuffer *ciphertext);
bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext,
DataBuffer *plaintext);
+ uint16_t epoch() const { return epoch_; }
private:
+ uint16_t epoch_;
std::unique_ptr<AeadCipher> aead_;
};