summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
-rw-r--r--security/nss/gtests/ssl_gtest/Makefile4
-rw-r--r--security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc108
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.c143
-rw-r--r--security/nss/gtests/ssl_gtest/libssl_internals.h21
-rw-r--r--security/nss/gtests/ssl_gtest/manifest.mn7
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc343
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc22
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc204
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc35
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc65
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc498
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc46
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc134
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc768
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc78
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc1
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc298
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc14
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc104
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.cc2
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_gtest.gyp7
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc695
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc118
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc178
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc304
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc20
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_record_unittest.cc75
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc212
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc492
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc117
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc28
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc363
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc19
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_version_unittest.cc152
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc6
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.cc13
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.h26
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc258
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.h53
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc226
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.h85
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc329
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h281
-rw-r--r--security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc14
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.cc6
-rw-r--r--security/nss/gtests/ssl_gtest/tls_protect.h9
46 files changed, 5531 insertions, 1450 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile
index a9a9290e0..95c111aeb 100644
--- a/security/nss/gtests/ssl_gtest/Makefile
+++ b/security/nss/gtests/ssl_gtest/Makefile
@@ -29,10 +29,6 @@ include ../common/gtest.mk
CFLAGS += -I$(CORE_DEPTH)/lib/ssl
-ifdef NSS_SSL_ENABLE_ZLIB
-include $(CORE_DEPTH)/coreconf/zlib.mk
-endif
-
ifdef NSS_DISABLE_TLS_1_3
NSS_DISABLE_TLS_1_3=1
# Run parameterized tests only, for which we can easily exclude TLS 1.3
diff --git a/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
new file mode 100644
index 000000000..6efe06ec7
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/bloomfilter_unittest.cc
@@ -0,0 +1,108 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+extern "C" {
+#include "sslbloom.h"
+}
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+// Some random-ish inputs to test with. These don't result in collisions in any
+// of the configurations that are tested below.
+static const uint8_t kHashes1[] = {
+ 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8,
+ 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34,
+ 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48};
+static const uint8_t kHashes2[] = {
+ 0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59,
+ 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc,
+ 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1};
+
+typedef struct {
+ unsigned int k;
+ unsigned int bits;
+} BloomFilterConfig;
+
+class BloomFilterTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<BloomFilterConfig> {
+ public:
+ BloomFilterTest() : filter_() {}
+
+ void SetUp() { Init(); }
+
+ void TearDown() { sslBloom_Destroy(&filter_); }
+
+ protected:
+ void Init() {
+ if (filter_.filter) {
+ sslBloom_Destroy(&filter_);
+ }
+ ASSERT_EQ(SECSuccess,
+ sslBloom_Init(&filter_, GetParam().k, GetParam().bits));
+ }
+
+ bool Check(const uint8_t* hashes) {
+ return sslBloom_Check(&filter_, hashes) ? true : false;
+ }
+
+ void Add(const uint8_t* hashes, bool expect_collision = false) {
+ EXPECT_EQ(expect_collision, sslBloom_Add(&filter_, hashes) ? true : false);
+ EXPECT_TRUE(Check(hashes));
+ }
+
+ sslBloomFilter filter_;
+};
+
+TEST_P(BloomFilterTest, InitOnly) {}
+
+TEST_P(BloomFilterTest, AddToEmpty) {
+ EXPECT_FALSE(Check(kHashes1));
+ Add(kHashes1);
+}
+
+TEST_P(BloomFilterTest, AddTwo) {
+ Add(kHashes1);
+ Add(kHashes2);
+}
+
+TEST_P(BloomFilterTest, AddOneTwice) {
+ Add(kHashes1);
+ Add(kHashes1, true);
+}
+
+TEST_P(BloomFilterTest, Zero) {
+ Add(kHashes1);
+ sslBloom_Zero(&filter_);
+ EXPECT_FALSE(Check(kHashes1));
+ EXPECT_FALSE(Check(kHashes2));
+}
+
+TEST_P(BloomFilterTest, Fill) {
+ sslBloom_Fill(&filter_);
+ EXPECT_TRUE(Check(kHashes1));
+ EXPECT_TRUE(Check(kHashes2));
+}
+
+static const BloomFilterConfig kBloomFilterConfigurations[] = {
+ {1, 1}, // 1 hash, 1 bit input - high chance of collision.
+ {1, 2}, // 1 hash, 2 bits - smaller than the basic unit size.
+ {1, 3}, // 1 hash, 3 bits - same as basic unit size.
+ {1, 4}, // 1 hash, 4 bits - 2 octets each.
+ {3, 10}, // 3 hashes over a reasonable number of bits.
+ {3, 3}, // Test that we can read multiple bits.
+ {4, 15}, // A credible filter.
+ {2, 18}, // A moderately large allocation.
+ {16, 16}, // Insane, use all of the bits from the hashes.
+ {16, 9}, // This also uses all of the bits from the hashes.
+};
+
+INSTANTIATE_TEST_CASE_P(BloomFilterConfigurations, BloomFilterTest,
+ ::testing::ValuesIn(kBloomFilterConfigurations));
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c
index 97b8354ae..17b4ffe49 100644
--- a/security/nss/gtests/ssl_gtest/libssl_internals.c
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.c
@@ -34,18 +34,17 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
return SECFailure;
}
- ssl3_InitState(ss);
ssl3_RestartHandshakeHashes(ss);
// Ensure we don't overrun hs.client_random.
rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
- // Zero the client_random struct.
- PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
+ // Zero the client_random.
+ PORT_Memset(ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
// Copy over the challenge bytes.
size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
- PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len);
+ PORT_Memcpy(ss->ssl3.hs.client_random + offset, rnd, rnd_len);
// Rehash the SSLv2 client hello message.
return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
@@ -73,10 +72,11 @@ SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
return SECFailure;
}
ss->ssl3.mtu = mtu;
+ ss->ssl3.hs.rtRetries = 0; /* Avoid DTLS shrinking the MTU any more. */
return SECSuccess;
}
-PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd) {
PRCList *cur_p;
PRInt32 ct = 0;
@@ -92,7 +92,7 @@ PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
return ct;
}
-void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd) {
PRCList *cur_p;
sslSocket *ss = ssl_FindSocket(fd);
@@ -100,27 +100,31 @@ void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
return;
}
- fprintf(stderr, "Cipher specs\n");
+ fprintf(stderr, "Cipher specs for %s\n", label);
for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
- fprintf(stderr, " %s\n", spec->phase);
+ fprintf(stderr, " %s spec epoch=%d (%s) refct=%d\n", SPEC_DIR(spec),
+ spec->epoch, spec->phase, spec->refCt);
}
}
-/* Force a timer expiry by backdating when the timer was started.
- * We could set the remaining time to 0 but then backoff would not
- * work properly if we decide to test it. */
-void SSLInt_ForceTimerExpiry(PRFileDesc *fd) {
+/* Force a timer expiry by backdating when all active timers were started. We
+ * could set the remaining time to 0 but then backoff would not work properly if
+ * we decide to test it. */
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift) {
+ size_t i;
sslSocket *ss = ssl_FindSocket(fd);
if (!ss) {
- return;
+ return SECFailure;
}
- if (!ss->ssl3.hs.rtTimerCb) return;
-
- ss->ssl3.hs.rtTimerStarted =
- PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1);
+ for (i = 0; i < PR_ARRAY_SIZE(ss->ssl3.hs.timers); ++i) {
+ if (ss->ssl3.hs.timers[i].cb) {
+ ss->ssl3.hs.timers[i].started -= shift;
+ }
+ }
+ return SECSuccess;
}
#define CHECK_SECRET(secret) \
@@ -136,7 +140,6 @@ PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
}
CHECK_SECRET(currentSecret);
- CHECK_SECRET(resumptionMasterSecret);
CHECK_SECRET(dheSecret);
CHECK_SECRET(clientEarlyTrafficSecret);
CHECK_SECRET(clientHsTrafficSecret);
@@ -226,28 +229,7 @@ PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
return PR_TRUE;
}
-PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) {
- sslSocket *ss = ssl_FindSocket(fd);
- if (!ss) {
- return PR_FALSE;
- }
-
- ssl_GetSSL3HandshakeLock(ss);
- ssl_GetXmitBufLock(ss);
-
- SECStatus rv = tls13_SendNewSessionTicket(ss);
- if (rv == SECSuccess) {
- rv = ssl3_FlushHandshake(ss, 0);
- }
-
- ssl_ReleaseXmitBufLock(ss);
- ssl_ReleaseSSL3HandshakeLock(ss);
-
- return rv == SECSuccess;
-}
-
SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
- PRUint64 epoch;
sslSocket *ss;
ssl3CipherSpec *spec;
@@ -255,43 +237,40 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
if (!ss) {
return SECFailure;
}
- if (to >= (1ULL << 48)) {
+ if (to >= RECORD_SEQ_MAX) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
spec = ss->ssl3.crSpec;
- epoch = spec->read_seq_num >> 48;
- spec->read_seq_num = (epoch << 48) | to;
+ spec->seqNum = to;
/* For DTLS, we need to fix the record sequence number. For this, we can just
* scrub the entire structure on the assumption that the new sequence number
* is far enough past the last received sequence number. */
- if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ if (spec->seqNum <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
- dtls_RecordSetRecvd(&spec->recvdRecords, to);
+ dtls_RecordSetRecvd(&spec->recvdRecords, spec->seqNum);
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
- PRUint64 epoch;
sslSocket *ss;
ss = ssl_FindSocket(fd);
if (!ss) {
return SECFailure;
}
- if (to >= (1ULL << 48)) {
+ if (to >= RECORD_SEQ_MAX) {
PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
- epoch = ss->ssl3.cwSpec->write_seq_num >> 48;
- ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to;
+ ss->ssl3.cwSpec->seqNum = to;
ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
@@ -305,9 +284,9 @@ SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
return SECFailure;
}
ssl_GetSpecReadLock(ss);
- to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra;
+ to = ss->ssl3.cwSpec->seqNum + DTLS_RECVD_RECORDS_WINDOW + extra;
ssl_ReleaseSpecReadLock(ss);
- return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX);
+ return SSLInt_AdvanceWriteSeqNum(fd, to);
}
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
@@ -333,56 +312,26 @@ SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
return SECSuccess;
}
-static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer,
- ssl3CipherSpec *spec) {
- return isServer ? &spec->server : &spec->client;
+PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec) {
+ return spec->keyMaterial.key;
}
-PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) {
- return GetKeyingMaterial(isServer, spec)->write_key;
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec) {
+ return spec->cipherDef->calg;
}
-SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
- ssl3CipherSpec *spec) {
- return spec->cipher_def->calg;
+const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec) {
+ return spec->keyMaterial.iv;
}
-unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) {
- return GetKeyingMaterial(isServer, spec)->write_iv;
-}
-
-SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) {
- sslSocket *ss;
-
- ss = ssl_FindSocket(fd);
- if (!ss) {
- return SECFailure;
- }
-
- ss->opt.enableShortHeaders = PR_TRUE;
- return SECSuccess;
-}
-
-SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) {
- sslSocket *ss;
-
- ss = ssl_FindSocket(fd);
- if (!ss) {
- return SECFailure;
- }
-
- *result = ss->ssl3.hs.shortHeaders;
- return SECSuccess;
+PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec) {
+ return spec->epoch;
}
void SSLInt_SetTicketLifetime(uint32_t lifetime) {
ssl_ticket_lifetime = lifetime;
}
-void SSLInt_SetMaxEarlyDataSize(uint32_t size) {
- ssl_max_early_data_size = size;
-}
-
SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
sslSocket *ss;
@@ -405,3 +354,21 @@ SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
return SECSuccess;
}
+
+void SSLInt_RolloverAntiReplay(void) {
+ tls13_AntiReplayRollover(ssl_TimeUsec());
+}
+
+SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch,
+ PRUint16 *writeEpoch) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss || !readEpoch || !writeEpoch) {
+ return SECFailure;
+ }
+
+ ssl_GetSpecReadLock(ss);
+ *readEpoch = ss->ssl3.crSpec->epoch;
+ *writeEpoch = ss->ssl3.cwSpec->epoch;
+ ssl_ReleaseSpecReadLock(ss);
+ return SECSuccess;
+}
diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h
index 33709c4b4..3efb362c2 100644
--- a/security/nss/gtests/ssl_gtest/libssl_internals.h
+++ b/security/nss/gtests/ssl_gtest/libssl_internals.h
@@ -24,9 +24,9 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext);
void SSLInt_ClearSelfEncryptKey();
void SSLInt_SetSelfEncryptMacKey(PK11SymKey *key);
-PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd);
-void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd);
-void SSLInt_ForceTimerExpiry(PRFileDesc *fd);
+PRInt32 SSLInt_CountCipherSpecs(PRFileDesc *fd);
+void SSLInt_PrintCipherSpecs(const char *label, PRFileDesc *fd);
+SECStatus SSLInt_ShiftDtlsTimers(PRFileDesc *fd, PRIntervalTime shift);
SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu);
PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd);
PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd);
@@ -35,23 +35,22 @@ PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd);
SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len);
PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType);
PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type);
-PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd);
SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);
+SECStatus SSLInt_GetEpochs(PRFileDesc *fd, PRUint16 *readEpoch,
+ PRUint16 *writeEpoch);
SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
sslCipherSpecChangedFunc func,
void *arg);
-PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec);
-SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
- ssl3CipherSpec *spec);
-unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec);
-SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd);
-SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result);
+PRUint16 SSLInt_CipherSpecToEpoch(const ssl3CipherSpec *spec);
+PK11SymKey *SSLInt_CipherSpecToKey(const ssl3CipherSpec *spec);
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(const ssl3CipherSpec *spec);
+const PRUint8 *SSLInt_CipherSpecToIv(const ssl3CipherSpec *spec);
void SSLInt_SetTicketLifetime(uint32_t lifetime);
-void SSLInt_SetMaxEarlyDataSize(uint32_t size);
SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size);
+void SSLInt_RolloverAntiReplay(void);
#endif // ndef libssl_internals_h_
diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn
index cc729c0f1..5d893bab3 100644
--- a/security/nss/gtests/ssl_gtest/manifest.mn
+++ b/security/nss/gtests/ssl_gtest/manifest.mn
@@ -12,11 +12,13 @@ CSRCS = \
$(NULL)
CPPSRCS = \
+ bloomfilter_unittest.cc \
ssl_0rtt_unittest.cc \
ssl_agent_unittest.cc \
ssl_auth_unittest.cc \
ssl_cert_ext_unittest.cc \
ssl_ciphersuite_unittest.cc \
+ ssl_custext_unittest.cc \
ssl_damage_unittest.cc \
ssl_dhe_unittest.cc \
ssl_drop_unittest.cc \
@@ -29,11 +31,16 @@ CPPSRCS = \
ssl_gather_unittest.cc \
ssl_gtest.cc \
ssl_hrr_unittest.cc \
+ ssl_keylog_unittest.cc \
+ ssl_keyupdate_unittest.cc \
ssl_loopback_unittest.cc \
+ ssl_misc_unittest.cc \
ssl_record_unittest.cc \
ssl_resumption_unittest.cc \
+ ssl_renegotiation_unittest.cc \
ssl_skip_unittest.cc \
ssl_staticrsa_unittest.cc \
+ ssl_tls13compat_unittest.cc \
ssl_v2_client_hello_unittest.cc \
ssl_version_unittest.cc \
ssl_versionpolicy_unittest.cc \
diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index 85b7011a1..08781af71 100644
--- a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -7,6 +7,7 @@
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
extern "C" {
@@ -44,6 +45,92 @@ TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) {
SendReceive();
}
+TEST_P(TlsConnectTls13, ZeroRttApparentReplayAfterRestart) {
+ // The test fixtures call SSL_SetupAntiReplay() in SetUp(). This results in
+ // 0-RTT being rejected until at least one window passes. SetupFor0Rtt()
+ // forces a rollover of the anti-replay filters, which clears this state.
+ // Here, we do the setup manually here without that forced rollover.
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ StartConnect();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+class TlsZeroRttReplayTest : public TlsConnectTls13 {
+ private:
+ class SaveFirstPacket : public PacketFilter {
+ public:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!packet_.len() && input.len()) {
+ packet_ = input;
+ }
+ return KEEP;
+ }
+
+ const DataBuffer& packet() const { return packet_; }
+
+ private:
+ DataBuffer packet_;
+ };
+
+ protected:
+ void RunTest(bool rollover) {
+ // Run the initial handshake
+ SetupForZeroRtt();
+
+ // Now run a true 0-RTT handshake, but capture the first packet.
+ auto first_packet = std::make_shared<SaveFirstPacket>();
+ client_->SetFilter(first_packet);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ EXPECT_LT(0U, first_packet->packet().len());
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+
+ if (rollover) {
+ SSLInt_RolloverAntiReplay();
+ }
+
+ // Now replay that packet against the server.
+ Reset();
+ server_->StartConnect();
+ server_->Set0RttEnabled(true);
+
+ // Capture the early_data extension, which should not appear.
+ auto early_data_ext =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_early_data_xtn);
+
+ // Finally, replay the ClientHello and force the server to consume it. Stop
+ // after the server sends its first flight; the client will not be able to
+ // complete this handshake.
+ server_->adapter()->PacketReceived(first_packet->packet());
+ server_->Handshake();
+ EXPECT_FALSE(early_data_ext->captured());
+ }
+};
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplay) { RunTest(false); }
+
+TEST_P(TlsZeroRttReplayTest, ZeroRttReplayAfterRollover) { RunTest(true); }
+
// Test that we don't try to send 0-RTT data when the server sent
// us a ticket without the 0-RTT flags.
TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
@@ -52,8 +139,7 @@ TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
Reset();
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
// Now turn on 0-RTT but too late for the ticket.
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
@@ -80,8 +166,7 @@ TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) {
TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
ExpectResumption(RESUME_NONE);
server_->Set0RttEnabled(true);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
// Client sends ordinary ClientHello.
client_->Handshake();
@@ -99,6 +184,61 @@ TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
CheckKeys();
}
+// A small sleep after sending the ClientHello means that the ticket age that
+// arrives at the server is too low. With a small tolerance for variation in
+// ticket age (which is determined by the |window| parameter that is passed to
+// SSL_SetupAntiReplay()), the server then rejects early data.
+TEST_P(TlsConnectTls13, ZeroRttRejectOldTicket) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3));
+ SSLInt_RolloverAntiReplay(); // Make sure to flush replay state.
+ SSLInt_RolloverAntiReplay();
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, []() {
+ PR_Sleep(PR_MillisecondsToInterval(10));
+ return true;
+ });
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ SendReceive();
+}
+
+// In this test, we falsely inflate the estimate of the RTT by delaying the
+// ServerHello on the first handshake. This results in the server estimating a
+// higher value of the ticket age than the client ultimately provides. Add a
+// small tolerance for variation in ticket age and the ticket will appear to
+// arrive prematurely, causing the server to reject early data.
+TEST_P(TlsConnectTls13, ZeroRttRejectPrematureTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true);
+ StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+ PR_Sleep(PR_MillisecondsToInterval(10));
+ Handshake(); // Remainder of handshake
+ CheckConnected();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EXPECT_EQ(SECSuccess, SSL_SetupAntiReplay(1, 1, 3));
+ SSLInt_RolloverAntiReplay(); // Make sure to flush replay state.
+ SSLInt_RolloverAntiReplay();
+ ExpectResumption(RESUME_TICKET);
+ ExpectEarlyDataAccepted(false);
+ StartConnect();
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
EnableAlpn();
SetupForZeroRtt();
@@ -117,6 +257,14 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
CheckAlpn("a");
}
+// NOTE: In this test and those below, the client always sends
+// post-ServerHello alerts with the handshake keys, even if the server
+// has accepted 0-RTT. In some cases, as with errors in
+// EncryptedExtensions, the client can't know the server's behavior,
+// and in others it's just simpler. What the server is expecting
+// depends on whether it accepted 0-RTT or not. Eventually, we may
+// make the server trial decrypt.
+//
// Have the server negotiate a different ALPN value, and therefore
// reject 0-RTT.
TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) {
@@ -155,12 +303,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) {
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
- ExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
return true;
});
- Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
}
// Set up with no ALPN and then set the client so it thinks it has ALPN.
@@ -175,12 +328,17 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
PRUint8 b[] = {'b'};
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
- ExpectAlert(client_, kTlsAlertIllegalParameter);
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
return true;
});
- Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ client_->Handshake();
+ }
client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
}
// Remove the old ALPN value and so the client will not offer early data.
@@ -218,9 +376,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
- client_->StartConnect();
- server_->StartConnect();
-
+ StartConnect();
// We will send the early data xtn without sending actual early data. Thus
// a 1.2 server shouldn't fail until the client sends an alert because the
// client sends end_of_early_data only after reading the server's flight.
@@ -248,6 +404,9 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) {
// The client should abort the connection when sending a 0-rtt handshake but
// the servers responds with a TLS 1.2 ServerHello. (with app data)
TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
server_->Set0RttEnabled(true); // set ticket_allow_early_data
Connect();
@@ -261,33 +420,32 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
- client_->StartConnect();
- server_->StartConnect();
-
+ StartConnect();
// Send the early data xtn in the CH, followed by early app data. The server
// will fail right after sending its flight, when receiving the early data.
client_->Set0RttEnabled(true);
- ZeroRttSendReceive(true, false, [this]() {
- client_->ExpectSendAlert(kTlsAlertIllegalParameter);
- if (variant_ == ssl_variant_stream) {
- server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
- }
- return true;
- });
-
- client_->Handshake();
- server_->Handshake();
- ASSERT_TRUE_WAIT(
- (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000);
+ client_->Handshake(); // Send ClientHello.
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
- // DTLS will timeout as we bump the epoch when installing the early app data
- // cipher suite. Thus the encrypted alert will be ignored.
if (variant_ == ssl_variant_stream) {
- // The server sends an alert when receiving the early app data record.
- ASSERT_TRUE_WAIT(
- (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA),
- 2000);
+ // When the server receives the early data, it will fail.
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ server_->Handshake(); // Consume ClientHello
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
+ } else {
+ // If it's datagram, we just discard the early data.
+ server_->Handshake(); // Consume ClientHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
}
+
+ // The client now reads the ServerHello and fails.
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA);
}
static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
@@ -300,17 +458,19 @@ static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
}
TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
+ EnsureTlsSetup();
const char* big_message = "0123456789abcdef";
const size_t short_size = strlen(big_message) - 1;
const PRInt32 short_length = static_cast<PRInt32>(short_size);
- SSLInt_SetMaxEarlyDataSize(static_cast<PRUint32>(short_size));
+ EXPECT_EQ(SECSuccess,
+ SSL_SetMaxEarlyDataSize(server_->ssl_fd(),
+ static_cast<PRUint32>(short_size)));
SetupForZeroRtt();
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
client_->Handshake();
CheckEarlyDataLimit(client_, short_size);
@@ -356,18 +516,21 @@ TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
}
TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
+ EnsureTlsSetup();
+
const size_t limit = 5;
- SSLInt_SetMaxEarlyDataSize(limit);
+ EXPECT_EQ(SECSuccess, SSL_SetMaxEarlyDataSize(server_->ssl_fd(), limit));
SetupForZeroRtt();
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
- client_->ExpectSendAlert(kTlsAlertEndOfEarlyData);
client_->Handshake(); // Send ClientHello
CheckEarlyDataLimit(client_, limit);
+ server_->Handshake(); // Process ClientHello, send server flight.
+
// Lift the limit on the client.
EXPECT_EQ(SECSuccess,
SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000));
@@ -381,22 +544,114 @@ TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
// This error isn't fatal for DTLS.
ExpectAlert(server_, kTlsAlertUnexpectedMessage);
}
- server_->Handshake(); // Process ClientHello, send server flight.
- server_->Handshake(); // Just to make sure that we don't read ahead.
+
+ server_->Handshake(); // This reads the early data and maybe throws an error.
+ if (variant_ == ssl_variant_stream) {
+ server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+ } else {
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ }
CheckEarlyDataLimit(server_, limit);
- // Attempt to read early data.
+ // Attempt to read early data. This will get an error.
std::vector<uint8_t> buf(strlen(message) + 1);
EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()));
if (variant_ == ssl_variant_stream) {
- server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+ EXPECT_EQ(SSL_ERROR_HANDSHAKE_FAILED, PORT_GetError());
+ } else {
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}
- client_->Handshake(); // Process the handshake.
- client_->Handshake(); // Process the alert.
+ client_->Handshake(); // Process the server's first flight.
if (variant_ == ssl_variant_stream) {
+ client_->Handshake(); // Process the alert.
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ } else {
+ server_->Handshake(); // Finish connecting.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
}
}
+class PacketCoalesceFilter : public PacketFilter {
+ public:
+ PacketCoalesceFilter() : packet_data_() {}
+
+ void SendCoalesced(std::shared_ptr<TlsAgent> agent) {
+ agent->SendDirect(packet_data_);
+ }
+
+ protected:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ packet_data_.Write(packet_data_.len(), input);
+ return DROP;
+ }
+
+ private:
+ DataBuffer packet_data_;
+};
+
+TEST_P(TlsConnectTls13, ZeroRttOrdering) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ // Send out the ClientHello.
+ client_->Handshake();
+
+ // Now, coalesce the next three things from the client: early data, second
+ // flight and 1-RTT data.
+ auto coalesce = std::make_shared<PacketCoalesceFilter>();
+ client_->SetFilter(coalesce);
+
+ // Send (and hold) early data.
+ static const std::vector<uint8_t> early_data = {3, 2, 1};
+ EXPECT_EQ(static_cast<PRInt32>(early_data.size()),
+ PR_Write(client_->ssl_fd(), early_data.data(), early_data.size()));
+
+ // Send (and hold) the second client handshake flight.
+ // The client sends EndOfEarlyData after seeing the server Finished.
+ server_->Handshake();
+ client_->Handshake();
+
+ // Send (and hold) 1-RTT data.
+ static const std::vector<uint8_t> late_data = {7, 8, 9, 10};
+ EXPECT_EQ(static_cast<PRInt32>(late_data.size()),
+ PR_Write(client_->ssl_fd(), late_data.data(), late_data.size()));
+
+ // Now release them all at once.
+ coalesce->SendCoalesced(client_);
+
+ // Now ensure that the three steps are exposed in the right order on the
+ // server: delivery of early data, handshake callback, delivery of 1-RTT.
+ size_t step = 0;
+ server_->SetHandshakeCallback([&step](TlsAgent*) {
+ EXPECT_EQ(1U, step);
+ ++step;
+ });
+
+ std::vector<uint8_t> buf(10);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(early_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(early_data, buf);
+ EXPECT_EQ(0U, step);
+ ++step;
+
+ // The third read should be after the handshake callback and should return the
+ // data that was sent after the handshake completed.
+ buf.resize(10);
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.size());
+ ASSERT_EQ(static_cast<PRInt32>(late_data.size()), read);
+ buf.resize(read);
+ EXPECT_EQ(late_data, buf);
+ EXPECT_EQ(2U, step);
+}
+
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(Tls13ZeroRttReplayTest, TlsZeroRttReplayTest,
+ TlsConnectTestBase::kTlsVariantsAll);
+#endif
+
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
index 5035a338d..f0c57e8b1 100644
--- a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -31,7 +31,7 @@ const static uint8_t kCannedTls13ClientHello[] = {
0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00,
0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01,
- 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00,
+ 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x33, 0x00,
0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc,
0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65,
0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a,
@@ -44,13 +44,14 @@ const static uint8_t kCannedTls13ClientHello[] = {
0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02};
const static uint8_t kCannedTls13ServerHello[] = {
- 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0,
- 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5,
- 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01,
- 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf,
- 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65,
- 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a,
- 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19};
+ 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
+ 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
+ 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
+ 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
+ 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
+ 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
+ 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13};
static const char *k0RttData = "ABCDEF";
TEST_P(TlsAgentTest, EarlyFinished) {
@@ -159,9 +160,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
agent_->Set0RttEnabled(true);
- auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeClientHello);
- agent_->SetPacketFilter(filter);
+ auto filter =
+ MakeTlsFilter<TlsHandshakeRecorder>(agent_, kTlsHandshakeClientHello);
PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
EXPECT_EQ(-1, rv);
int32_t err = PORT_GetError();
diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
index dbcbc9aa3..7f2b2840d 100644
--- a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -29,7 +29,25 @@ TEST_P(TlsConnectGeneric, ServerAuthBigRsa) {
}
TEST_P(TlsConnectGeneric, ServerAuthRsaChain) {
- Reset(TlsAgent::kServerRsaChain);
+ Reset("rsa_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaPssChain) {
+ Reset("rsa_pss_chain");
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaCARsaPssChain) {
+ Reset("rsa_ca_rsa_pss_chain");
Connect();
CheckKeys();
size_t chain_length;
@@ -77,10 +95,9 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
}
// Offset is the position in the captured buffer where the signature sits.
-static void CheckSigScheme(
- std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset,
- std::shared_ptr<TlsAgent>& peer, uint16_t expected_scheme,
- size_t expected_size) {
+static void CheckSigScheme(std::shared_ptr<TlsHandshakeRecorder>& capture,
+ size_t offset, std::shared_ptr<TlsAgent>& peer,
+ uint16_t expected_scheme, size_t expected_size) {
EXPECT_LT(offset + 2U, capture->buffer().len());
uint32_t scheme = 0;
@@ -96,9 +113,8 @@ static void CheckSigScheme(
// in the default certificate.
TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(capture_ske);
+ auto capture_ske = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
Connect();
CheckKeys();
@@ -109,15 +125,14 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve";
EXPECT_EQ(ssl_grp_ec_curve25519, tmp);
EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint";
- CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_sha256, 1024);
+ CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_rsae_sha256,
+ 1024);
}
TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
@@ -128,26 +143,23 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
Connect();
CheckKeys();
- CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048);
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_rsae_sha256,
+ 2048);
}
class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
public:
+ TlsZeroCertificateRequestSigAlgsFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeCertificateRequest}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeCertificateRequest) {
- return KEEP;
- }
-
TlsParser parser(input);
std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl;
@@ -189,12 +201,9 @@ class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
// supported_signature_algorithms in the CertificateRequest message.
TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) {
EnsureTlsSetup();
- auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>();
- server_->SetPacketFilter(filter);
- auto capture_cert_verify =
- std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeCertificateVerify);
- client_->SetPacketFilter(capture_cert_verify);
+ MakeTlsFilter<TlsZeroCertificateRequestSigAlgsFilter>(server_);
+ auto capture_cert_verify = MakeTlsFilter<TlsHandshakeRecorder>(
+ client_, kTlsHandshakeCertificateVerify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -342,8 +351,7 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
// The signature_algorithms extension is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn);
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
@@ -352,8 +360,7 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and
// only fails when the Finished is checked.
TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_signature_algorithms_xtn);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -371,11 +378,11 @@ class BeforeFinished : public TlsRecordFilter {
enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
public:
- BeforeFinished(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server, VoidFunction before_ccs,
- VoidFunction before_finished)
- : client_(client),
- server_(server),
+ BeforeFinished(const std::shared_ptr<TlsAgent>& server,
+ const std::shared_ptr<TlsAgent>& client,
+ VoidFunction before_ccs, VoidFunction before_finished)
+ : TlsRecordFilter(server),
+ client_(client),
before_ccs_(before_ccs),
before_finished_(before_finished),
state_(BEFORE_CCS) {}
@@ -395,7 +402,7 @@ class BeforeFinished : public TlsRecordFilter {
// but that means that they both get processed together.
DataBuffer ccs;
header.Write(&ccs, 0, body);
- server_.lock()->SendDirect(ccs);
+ agent()->SendDirect(ccs);
client_.lock()->Handshake();
state_ = AFTER_CCS;
// Request that the original record be dropped by the filter.
@@ -420,7 +427,6 @@ class BeforeFinished : public TlsRecordFilter {
private:
std::weak_ptr<TlsAgent> client_;
- std::weak_ptr<TlsAgent> server_;
VoidFunction before_ccs_;
VoidFunction before_finished_;
HandshakeState state_;
@@ -445,11 +451,11 @@ class BeforeFinished13 : public PacketFilter {
};
public:
- BeforeFinished13(std::shared_ptr<TlsAgent>& client,
- std::shared_ptr<TlsAgent>& server,
+ BeforeFinished13(const std::shared_ptr<TlsAgent>& server,
+ const std::shared_ptr<TlsAgent>& client,
VoidFunction before_finished)
- : client_(client),
- server_(server),
+ : server_(server),
+ client_(client),
before_finished_(before_finished),
records_(0) {}
@@ -481,8 +487,8 @@ class BeforeFinished13 : public PacketFilter {
}
private:
- std::weak_ptr<TlsAgent> client_;
std::weak_ptr<TlsAgent> server_;
+ std::weak_ptr<TlsAgent> client_;
VoidFunction before_finished_;
size_t records_;
};
@@ -496,11 +502,9 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
// processed by the client, SSL_AuthCertificateComplete() is called.
TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(
- std::make_shared<BeforeFinished13>(client_, server_, [this]() {
- EXPECT_EQ(SECSuccess,
- SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
- }));
+ MakeTlsFilter<BeforeFinished13>(server_, client_, [this]() {
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ });
Connect();
}
@@ -528,13 +532,13 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
client_->EnableFalseStart();
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
- client_, server_,
+ MakeTlsFilter<BeforeFinished>(
+ server_, client_,
[this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
[this]() {
// Write something, which used to fail: bug 1235366.
client_->SendData(10);
- }));
+ });
Connect();
server_->SendData(10);
@@ -544,8 +548,8 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
client_->EnableFalseStart();
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(std::make_shared<BeforeFinished>(
- client_, server_,
+ MakeTlsFilter<BeforeFinished>(
+ server_, client_,
[]() {
// Do nothing before CCS
},
@@ -556,7 +560,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
EXPECT_TRUE(client_->can_falsestart_hook_called());
client_->SendData(10);
- }));
+ });
Connect();
server_->SendData(10);
@@ -581,8 +585,7 @@ class EnforceNoActivity : public PacketFilter {
TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send ServerHello
client_->Handshake(); // Send ClientKeyExchange and Finished
@@ -591,7 +594,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
@@ -601,8 +604,33 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
- // Remove this before closing or the close_notify alert will trigger it.
- client_->DeletePacketFilter();
+ // Remove filter before closing or the close_notify alert will trigger it.
+ client_->ClearFilter();
+}
+
+TEST_P(TlsConnectGenericPre13, AuthCompleteFailDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ client_->Handshake(); // Send ClientKeyExchange and Finished
+ server_->Handshake(); // Send Finished
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // The client should send nothing from here on.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // Report failure.
+ client_->ClearFilter();
+ client_->ExpectSendAlert(kTlsAlertBadCertificate);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
+ SSL_ERROR_BAD_CERTIFICATE));
+ client_->Handshake(); // Fail
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
}
// TLS 1.3 handles a delayed AuthComplete callback differently since the
@@ -610,20 +638,19 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send ServerHello
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// This should allow the handshake to complete now.
- client_->DeletePacketFilter();
+ client_->ClearFilter();
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
client_->Handshake(); // Send Finished
server_->Handshake(); // Transition to connected and send NewSessionTicket
@@ -631,6 +658,44 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
}
+TEST_P(TlsConnectTls13, AuthCompleteFailDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+
+ // The client will send nothing until AuthCertificateComplete is called.
+ client_->SetFilter(std::make_shared<EnforceNoActivity>());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // Report failure.
+ client_->ClearFilter();
+ ExpectAlert(client_, kTlsAlertBadCertificate);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(),
+ SSL_ERROR_BAD_CERTIFICATE));
+ client_->Handshake(); // This should now fail.
+ server_->Handshake(); // Get the error.
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+}
+
+static SECStatus AuthCompleteFail(TlsAgent*, PRBool, PRBool) {
+ PORT_SetError(SSL_ERROR_BAD_CERTIFICATE);
+ return SECFailure;
+}
+
+TEST_P(TlsConnectGeneric, AuthFailImmediate) {
+ client_->SetAuthCertificateCallback(AuthCompleteFail);
+
+ StartConnect();
+ ConnectExpectAlert(client_, kTlsAlertBadCertificate);
+ client_->CheckErrorCode(SSL_ERROR_BAD_CERTIFICATE);
+}
+
static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = {
ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr};
static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = {
@@ -753,8 +818,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
Reset(certificate_);
auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
- client_->SetPacketFilter(capture);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
TestSignatureSchemeConfig(client_);
const DataBuffer& ext = capture->extension();
@@ -782,8 +846,8 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(TlsAgent::kServerRsaSign),
::testing::Values(ssl_auth_rsa_sign),
::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
- ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_sha256,
- ssl_sig_rsa_pss_sha384)));
+ ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_rsae_sha384)));
// PSS with SHA-512 needs a bigger key to work.
INSTANTIATE_TEST_CASE_P(
SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration,
@@ -791,7 +855,7 @@ INSTANTIATE_TEST_CASE_P(
TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kRsa2048),
::testing::Values(ssl_auth_rsa_sign),
- ::testing::Values(ssl_sig_rsa_pss_sha512)));
+ ::testing::Values(ssl_sig_rsa_pss_rsae_sha512)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration,
::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
@@ -828,4 +892,4 @@ INSTANTIATE_TEST_CASE_P(
TlsAgent::kServerEcdsa384),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_sha1)));
-}
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
index 3463782e0..573c69c75 100644
--- a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -82,9 +82,8 @@ TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) {
ssl_kea_rsa));
EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
&kSctItem, ssl_kea_rsa));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -96,9 +95,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) {
EnsureTlsSetup();
EXPECT_TRUE(
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -120,9 +117,7 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) {
TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
- PR_TRUE));
+ client_->SetOption(SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -173,23 +168,20 @@ TEST_P(TlsConnectGeneric, OcspNotRequested) {
// Even if the client asks, the server has nothing unless it is configured.
TEST_P(TlsConnectGeneric, OcspNotProvided) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
client_->SetAuthCertificateCallback(CheckNoOCSP);
Connect();
}
TEST_P(TlsConnectGenericPre13, OcspMangled) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
EXPECT_TRUE(
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
static const uint8_t val[] = {1};
- auto replacer = std::make_shared<TlsExtensionReplacer>(
- ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
- server_->SetPacketFilter(replacer);
+ auto replacer = MakeTlsFilter<TlsExtensionReplacer>(
+ server_, ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -197,11 +189,9 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
auto capture_ocsp =
- std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
- server_->SetPacketFilter(capture_ocsp);
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_cert_status_xtn);
// The value should be available during the AuthCertificateCallback
client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig,
@@ -225,8 +215,7 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetOption(SSL_ENABLE_OCSP_STAPLING, PR_TRUE);
uint8_t hugeOcspValue[16385];
memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue));
@@ -254,4 +243,4 @@ TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index 85c30b2bf..fa2238be7 100644
--- a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -31,11 +31,11 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
public:
TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version,
uint16_t cipher_suite, SSLNamedGroup group,
- SSLSignatureScheme signature_scheme)
+ SSLSignatureScheme sig_scheme)
: TlsConnectTestBase(variant, version),
cipher_suite_(cipher_suite),
group_(group),
- signature_scheme_(signature_scheme),
+ sig_scheme_(sig_scheme),
csinfo_({0}) {
SECStatus rv =
SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_));
@@ -60,26 +60,26 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
server_->ConfigNamedGroups(groups);
kea_type_ = SSLInt_GetKEAType(group_);
- client_->SetSignatureSchemes(&signature_scheme_, 1);
- server_->SetSignatureSchemes(&signature_scheme_, 1);
+ client_->SetSignatureSchemes(&sig_scheme_, 1);
+ server_->SetSignatureSchemes(&sig_scheme_, 1);
}
}
virtual void SetupCertificate() {
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- switch (signature_scheme_) {
+ switch (sig_scheme_) {
case ssl_sig_rsa_pkcs1_sha256:
case ssl_sig_rsa_pkcs1_sha384:
case ssl_sig_rsa_pkcs1_sha512:
Reset(TlsAgent::kServerRsaSign);
auth_type_ = ssl_auth_rsa_sign;
break;
- case ssl_sig_rsa_pss_sha256:
- case ssl_sig_rsa_pss_sha384:
+ case ssl_sig_rsa_pss_rsae_sha256:
+ case ssl_sig_rsa_pss_rsae_sha384:
Reset(TlsAgent::kServerRsaSign);
auth_type_ = ssl_auth_rsa_sign;
break;
- case ssl_sig_rsa_pss_sha512:
+ case ssl_sig_rsa_pss_rsae_sha512:
// You can't fit SHA-512 PSS in a 1024-bit key.
Reset(TlsAgent::kRsa2048);
auth_type_ = ssl_auth_rsa_sign;
@@ -93,8 +93,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
auth_type_ = ssl_auth_ecdsa;
break;
default:
- ASSERT_TRUE(false) << "Unsupported signature scheme: "
- << signature_scheme_;
+ ADD_FAILURE() << "Unsupported signature scheme: " << sig_scheme_;
break;
}
} else {
@@ -187,7 +186,7 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
SSLAuthType auth_type_;
SSLKEAType kea_type_;
SSLNamedGroup group_;
- SSLSignatureScheme signature_scheme_;
+ SSLSignatureScheme sig_scheme_;
SSLCipherSuiteInfo csinfo_;
};
@@ -236,27 +235,29 @@ TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
ConnectAndCheckCipherSuite();
}
-// This only works for stream ciphers because we modify the sequence number -
-// which is included explicitly in the DTLS record header - and that trips a
-// different error code. Note that the message that the client sends would not
-// decrypt (the nonce/IV wouldn't match), but the record limit is hit before
-// attempting to decrypt a record.
TEST_P(TlsCipherSuiteTest, ReadLimit) {
SetupCertificate();
EnableSingleCipher();
ConnectAndCheckCipherSuite();
- EXPECT_EQ(SECSuccess,
- SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write()));
- EXPECT_EQ(SECSuccess,
- SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write()));
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ uint64_t last = last_safe_write();
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
- client_->SendData(10, 10);
- server_->ReadBytes(); // This should be OK.
+ client_->SendData(10, 10);
+ server_->ReadBytes(); // This should be OK.
+ } else {
+ // In TLS 1.3, reading or writing triggers a KeyUpdate. That would mean
+ // that the sequence numbers would reset and we wouldn't hit the limit. So
+ // we move the sequence number to one less than the limit directly and don't
+ // test sending and receiving just before the limit.
+ uint64_t last = record_limit() - 1;
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last));
+ }
- // The payload needs to be big enough to pass for encrypted. In the extreme
- // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for
- // authentication tag.
- static const uint8_t payload[18] = {6};
+ // The payload needs to be big enough to pass for encrypted. The code checks
+ // the limit before it tries to decrypt.
+ static const uint8_t payload[32] = {6};
DataBuffer record;
uint64_t epoch;
if (variant_ == ssl_variant_datagram) {
@@ -271,13 +272,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_,
payload, sizeof(payload), &record,
(epoch << 48) | record_limit());
- server_->adapter()->PacketReceived(record);
+ client_->SendDirect(record);
server_->ExpectReadWriteError();
server_->ReadBytes();
EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code());
}
TEST_P(TlsCipherSuiteTest, WriteLimit) {
+ // This asserts in TLS 1.3 because we expect an automatic update.
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
SetupCertificate();
EnableSingleCipher();
ConnectAndCheckCipherSuite();
@@ -308,8 +313,8 @@ static const auto kDummySignatureSchemesParams =
static SSLSignatureScheme kSignatureSchemesParamsArr[] = {
ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256,
- ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_sha256,
- ssl_sig_rsa_pss_sha384, ssl_sig_rsa_pss_sha512,
+ ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_rsae_sha256,
+ ssl_sig_rsa_pss_rsae_sha384, ssl_sig_rsa_pss_rsae_sha512,
};
#endif
@@ -461,4 +466,4 @@ static const SecStatusParams kSecStatusTestValuesArr[] = {
INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest,
::testing::ValuesIn(kSecStatusTestValuesArr));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
new file mode 100644
index 000000000..c2f582a93
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_custext_unittest.cc
@@ -0,0 +1,498 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "sslexp.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static void IncrementCounterArg(void *arg) {
+ if (arg) {
+ auto *called = reinterpret_cast<size_t *>(arg);
+ ++*called;
+ }
+}
+
+PRBool NoopExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_FALSE;
+}
+
+PRBool EmptyExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ IncrementCounterArg(arg);
+ return PR_TRUE;
+}
+
+SECStatus NoopExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECSuccess;
+}
+
+// All of the (current) set of supported extensions, plus a few extra.
+static const uint16_t kManyExtensions[] = {
+ ssl_server_name_xtn,
+ ssl_cert_status_xtn,
+ ssl_supported_groups_xtn,
+ ssl_ec_point_formats_xtn,
+ ssl_signature_algorithms_xtn,
+ ssl_signature_algorithms_cert_xtn,
+ ssl_use_srtp_xtn,
+ ssl_app_layer_protocol_xtn,
+ ssl_signed_cert_timestamp_xtn,
+ ssl_padding_xtn,
+ ssl_extended_master_secret_xtn,
+ ssl_session_ticket_xtn,
+ ssl_tls13_key_share_xtn,
+ ssl_tls13_pre_shared_key_xtn,
+ ssl_tls13_early_data_xtn,
+ ssl_tls13_supported_versions_xtn,
+ ssl_tls13_cookie_xtn,
+ ssl_tls13_psk_key_exchange_modes_xtn,
+ ssl_tls13_ticket_early_data_info_xtn,
+ ssl_tls13_certificate_authorities_xtn,
+ ssl_next_proto_nego_xtn,
+ ssl_renegotiation_info_xtn,
+ ssl_tls13_short_header_xtn,
+ 1,
+ 0xffff};
+// The list here includes all extensions we expect to use (SSL_MAX_EXTENSIONS),
+// plus the deprecated values (see sslt.h), and two extra dummy values.
+PR_STATIC_ASSERT((SSL_MAX_EXTENSIONS + 5) == PR_ARRAY_SIZE(kManyExtensions));
+
+void InstallManyWriters(std::shared_ptr<TlsAgent> agent,
+ SSLExtensionWriter writer, size_t *installed = nullptr,
+ size_t *called = nullptr) {
+ for (size_t i = 0; i < PR_ARRAY_SIZE(kManyExtensions); ++i) {
+ SSLExtensionSupport support = ssl_ext_none;
+ SECStatus rv = SSL_GetExtensionSupport(kManyExtensions[i], &support);
+ ASSERT_EQ(SECSuccess, rv) << "SSL_GetExtensionSupport cannot fail";
+
+ rv = SSL_InstallExtensionHooks(agent->ssl_fd(), kManyExtensions[i], writer,
+ called, NoopExtensionHandler, nullptr);
+ if (support == ssl_ext_native_only) {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ } else {
+ if (installed) {
+ ++*installed;
+ }
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopClient) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(client_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ EXPECT_EQ(installed, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionAllNoopServer) {
+ EnsureTlsSetup();
+ size_t installed = 0;
+ size_t called = 0;
+ InstallManyWriters(server_, NoopExtensionWriter, &installed, &called);
+ EXPECT_LT(0U, installed);
+ Connect();
+ // Extension writers are all called for each of ServerHello,
+ // EncryptedExtensions, and Certificate.
+ EXPECT_EQ(installed * 3, called);
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterClient) {
+ EnsureTlsSetup();
+ InstallManyWriters(client_, EmptyExtensionWriter);
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionEmptyWriterServer) {
+ EnsureTlsSetup();
+ InstallManyWriters(server_, EmptyExtensionWriter);
+ // Sending extensions that the client doesn't expect leads to extensions
+ // appearing even if the client didn't send one, or in the wrong messages.
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Install an writer to disable sending of a natively-supported extension.
+TEST_F(TlsConnectStreamTls13, CustomExtensionWriterDisable) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that doesn't do anything. You have to specify
+ // something; passing all nullptr values removes an existing handler.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NoopExtensionWriter,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+
+ Connect();
+ // So nothing will be sent.
+ EXPECT_FALSE(capture->captured());
+}
+
+// An extension that is unlikely to be parsed as valid.
+static uint8_t kNonsenseExtension[] = {91, 82, 73, 64, 55, 46, 37, 28, 19};
+
+static PRBool NonsenseExtensionWriter(PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_NE(nullptr, agent);
+ EXPECT_NE(nullptr, data);
+ EXPECT_NE(nullptr, len);
+ EXPECT_EQ(0U, *len);
+ EXPECT_LT(0U, maxLen);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+
+ if (message != ssl_hs_client_hello && message != ssl_hs_server_hello &&
+ message != ssl_hs_encrypted_extensions) {
+ return PR_FALSE;
+ }
+
+ *len = static_cast<unsigned int>(sizeof(kNonsenseExtension));
+ EXPECT_GE(maxLen, *len);
+ if (maxLen < *len) {
+ return PR_FALSE;
+ }
+ PORT_Memcpy(data, kNonsenseExtension, *len);
+ return PR_TRUE;
+}
+
+// Override the extension handler for an natively-supported and produce
+// nonsense, which results in a handshake failure.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverride) {
+ EnsureTlsSetup();
+
+ // This option enables sending the extension via the native support.
+ SECStatus rv = SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // This installs an override that sends nonsense.
+ rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), ssl_signed_cert_timestamp_xtn, NonsenseExtensionWriter,
+ client_.get(), NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_signed_cert_timestamp_xtn);
+
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static SECStatus NonsenseExtensionHandler(PRFileDesc *fd,
+ SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert,
+ void *arg) {
+ TlsAgent *agent = reinterpret_cast<TlsAgent *>(arg);
+ EXPECT_EQ(agent->ssl_fd(), fd);
+ if (agent->role() == TlsAgent::SERVER) {
+ EXPECT_EQ(ssl_hs_client_hello, message);
+ } else {
+ EXPECT_TRUE(message == ssl_hs_server_hello ||
+ message == ssl_hs_encrypted_extensions);
+ }
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ DataBuffer(data, len));
+ EXPECT_NE(nullptr, alert);
+ return SECSuccess;
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientToServer) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe5;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, NonsenseExtensionWriter, client_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, extension_code);
+
+ // Handle it so that the handshake completes.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ NonsenseExtensionHandler, server_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterSH(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_server_hello) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in ServerHello.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientSH) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterSH, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the ServerHello only and check it.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeServerHello});
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+static PRBool NonsenseExtensionWriterEE(PRFileDesc *fd,
+ SSLHandshakeType message, PRUint8 *data,
+ unsigned int *len, unsigned int maxLen,
+ void *arg) {
+ if (message == ssl_hs_encrypted_extensions) {
+ return NonsenseExtensionWriter(fd, message, data, len, maxLen, arg);
+ }
+ return PR_FALSE;
+}
+
+// Send nonsense in an extension from server to client, in EncryptedExtensions.
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerToClientEE) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ client_->ssl_fd(), extension_code, EmptyExtensionWriter, nullptr,
+ NonsenseExtensionHandler, client_.get());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NonsenseExtensionWriterEE, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture the extension from the EncryptedExtensions only and check it.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+ capture->SetHandshakeTypes({kTlsHandshakeEncryptedExtensions});
+ capture->EnableDecryption();
+
+ Connect();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionUnsolicitedServer) {
+ EnsureTlsSetup();
+
+ const uint16_t extension_code = 0xff5e;
+ SECStatus rv = SSL_InstallExtensionHooks(
+ server_->ssl_fd(), extension_code, NonsenseExtensionWriter, server_.get(),
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Capture it to see what we got.
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(server_, extension_code);
+
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(DataBuffer(kNonsenseExtension, sizeof(kNonsenseExtension)),
+ capture->extension());
+}
+
+SECStatus RejectExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffe7;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientReject) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff58;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ RejectExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kTlsAlertHandshakeFailure);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+static const uint8_t kCustomAlert = 0xf6;
+
+SECStatus AlertExtensionHandler(PRFileDesc *fd, SSLHandshakeType message,
+ const PRUint8 *data, unsigned int len,
+ SSLAlertDescription *alert, void *arg) {
+ *alert = kCustomAlert;
+ return SECFailure;
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionServerRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nonsense.
+ const uint16_t extension_code = 0xffea;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Reject the extension for no good reason.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ NoopExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ ConnectExpectAlert(server_, kCustomAlert);
+}
+
+// Send nonsense in an extension from client to server.
+TEST_F(TlsConnectStreamTls13, CustomExtensionClientRejectAlert) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ const uint16_t extension_code = 0xff5a;
+ SECStatus rv = SSL_InstallExtensionHooks(client_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ AlertExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ // Have the server send nonsense.
+ rv = SSL_InstallExtensionHooks(server_->ssl_fd(), extension_code,
+ EmptyExtensionWriter, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+
+ client_->ExpectSendAlert(kCustomAlert);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+}
+
+// Configure a custom extension hook badly.
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyWriter) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6c, EmptyExtensionWriter,
+ nullptr, nullptr, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOnlyHandler) {
+ EnsureTlsSetup();
+
+ // This installs an override that sends nothing but expects nonsense.
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff6d, nullptr, nullptr,
+ NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectStreamTls13, CustomExtensionOverrunBuffer) {
+ EnsureTlsSetup();
+ // This doesn't actually overrun the buffer, but it says that it does.
+ auto overrun_writer = [](PRFileDesc *fd, SSLHandshakeType message,
+ PRUint8 *data, unsigned int *len,
+ unsigned int maxLen, void *arg) -> PRBool {
+ *len = maxLen + 1;
+ return PR_TRUE;
+ };
+ SECStatus rv =
+ SSL_InstallExtensionHooks(client_->ssl_fd(), 0xff71, overrun_writer,
+ nullptr, NoopExtensionHandler, nullptr);
+ EXPECT_EQ(SECSuccess, rv);
+ client_->StartConnect();
+ client_->Handshake();
+ client_->CheckErrorCode(SEC_ERROR_APPLICATION_CALLBACK_ERROR);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
index 69fd00331..b8836d7fc 100644
--- a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -29,8 +29,7 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
std::cerr << "Damaging HS secret" << std::endl;
@@ -51,23 +50,19 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- client_->ExpectSendAlert(kTlsAlertDecryptError);
- // The server can't read the client's alert, so it also sends an alert.
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
- server_->SetPacketFilter(std::make_shared<AfterRecordN>(
+ MakeTlsFilter<AfterRecordN>(
server_, client_,
0, // ServerHello.
- [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
- ConnectExpectFail();
+ [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); });
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
}
TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange);
- server_->SetTlsRecordFilter(filter);
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ server_, kTlsHandshakeServerKeyExchange);
+ filter->EnableDecryption();
ExpectAlert(client_, kTlsAlertDecryptError);
ConnectExpectFail();
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
@@ -76,19 +71,10 @@ TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
TEST_P(TlsConnectTls13, DamageServerSignature) {
EnsureTlsSetup();
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- server_->SetTlsRecordFilter(filter);
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ server_, kTlsHandshakeCertificateVerify);
filter->EnableDecryption();
- client_->ExpectSendAlert(kTlsAlertDecryptError);
- // The server can't read the client's alert, so it also sends an alert.
- if (variant_ == ssl_variant_stream) {
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
- ConnectExpectFail();
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- } else {
- ConnectExpectFailOneSide(TlsAgent::CLIENT);
- }
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
@@ -96,15 +82,13 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- auto filter =
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
- client_->SetTlsRecordFilter(filter);
- server_->ExpectSendAlert(kTlsAlertDecryptError);
+ auto filter = MakeTlsFilter<TlsLastByteDamager>(
+ client_, kTlsHandshakeCertificateVerify);
filter->EnableDecryption();
+ server_->ExpectSendAlert(kTlsAlertDecryptError);
// Do these handshakes by hand to avoid race condition on
// the client processing the server's alert.
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
client_->Handshake();
@@ -116,4 +100,4 @@ TEST_P(TlsConnectGeneric, DamageClientSignature) {
server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
index 97943303a..cdafa7a84 100644
--- a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -24,7 +24,7 @@ TEST_P(TlsConnectGeneric, ConnectDhe) {
EnableOnlyDheCiphers();
Connect();
CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
@@ -32,12 +32,12 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
client_->ConfigNamedGroups(kAllDHEGroups);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -59,15 +59,14 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
auto groups_capture =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
auto shares_capture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
shares_capture};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -90,8 +89,7 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
// because the client automatically sends the supported groups extension.
TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ server_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
Connect();
@@ -105,14 +103,11 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
public:
- TlsDheServerKeyExchangeDamager() {}
+ TlsDheServerKeyExchangeDamager(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Damage the first octet of dh_p. Anything other than the known prime will
// be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled.
*output = input;
@@ -126,9 +121,8 @@ class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
// the signature until everything else has been checked.
TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
- server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
+ MakeTlsFilter<TlsDheServerKeyExchangeDamager>(server_);
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -147,7 +141,9 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
kYZeroPad
};
- TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {}
+ TlsDheSkeChangeY(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, ChangeYTo change)
+ : TlsHandshakeFilter(agent, {handshake_type}), change_Y_(change) {}
protected:
void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
@@ -212,8 +208,11 @@ class TlsDheSkeChangeY : public TlsHandshakeFilter {
class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
public:
- TlsDheSkeChangeYServer(ChangeYTo change, bool modify)
- : TlsDheSkeChangeY(change), modify_(modify), p_() {}
+ TlsDheSkeChangeYServer(const std::shared_ptr<TlsAgent>& agent,
+ ChangeYTo change, bool modify)
+ : TlsDheSkeChangeY(agent, kTlsHandshakeServerKeyExchange, change),
+ modify_(modify),
+ p_() {}
const DataBuffer& prime() const { return p_; }
@@ -221,10 +220,6 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
size_t offset = 2;
// Read dh_p
uint32_t dh_len = 0;
@@ -252,18 +247,15 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
public:
TlsDheSkeChangeYClient(
- ChangeYTo change,
+ const std::shared_ptr<TlsAgent>& agent, ChangeYTo change,
std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
- : TlsDheSkeChangeY(change), server_filter_(server_filter) {}
+ : TlsDheSkeChangeY(agent, kTlsHandshakeClientKeyExchange, change),
+ server_filter_(server_filter) {}
protected:
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
- return KEEP;
- }
-
ChangeY(input, output, 0, server_filter_->prime());
return CHANGE;
}
@@ -289,12 +281,10 @@ class TlsDamageDHYTest
TEST_P(TlsDamageDHYTest, DamageServerY) {
EnableOnlyDheCiphers();
if (std::get<3>(GetParam())) {
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- server_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYServer>(change, true));
+ MakeTlsFilter<TlsDheSkeChangeYServer>(server_, change, true);
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(client_, kTlsAlertDecryptError);
@@ -320,18 +310,15 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
TEST_P(TlsDamageDHYTest, DamageClientY) {
EnableOnlyDheCiphers();
if (std::get<3>(GetParam())) {
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
}
// The filter on the server is required to capture the prime.
- auto server_filter =
- std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false);
- server_->SetPacketFilter(server_filter);
+ auto server_filter = MakeTlsFilter<TlsDheSkeChangeYServer>(
+ server_, TlsDheSkeChangeY::kYZero, false);
// The client filter does the damage.
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- client_->SetPacketFilter(
- std::make_shared<TlsDheSkeChangeYClient>(change, server_filter));
+ MakeTlsFilter<TlsDheSkeChangeYClient>(client_, change, server_filter);
if (change == TlsDheSkeChangeY::kYZeroPad) {
ExpectAlert(server_, kTlsAlertDecryptError);
@@ -370,13 +357,12 @@ INSTANTIATE_TEST_CASE_P(
class TlsDheSkeMakePEven : public TlsHandshakeFilter {
public:
+ TlsDheSkeMakePEven(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Find the end of dh_p
uint32_t dh_len = 0;
EXPECT_TRUE(input.Read(0, 2, &dh_len));
@@ -394,7 +380,7 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter {
// Even without requiring named groups, an even value for p is bad news.
TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>());
+ MakeTlsFilter<TlsDheSkeMakePEven>(server_);
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
@@ -404,13 +390,12 @@ TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
public:
+ TlsDheSkeZeroPadP(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}) {}
+
virtual PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
*output = input;
uint32_t dh_len = 0;
EXPECT_TRUE(input.Read(0, 2, &dh_len));
@@ -425,7 +410,7 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
// Zero padding only causes signature failure.
TEST_P(TlsConnectGenericPre13, PadDheP) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>());
+ MakeTlsFilter<TlsDheSkeZeroPadP>(server_);
ConnectExpectAlert(client_, kTlsAlertDecryptError);
@@ -445,8 +430,7 @@ TEST_P(TlsConnectGenericPre13, PadDheP) {
// Note: This test case can take ages to generate the weak DH key.
TEST_P(TlsConnectGenericPre13, WeakDHGroup) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
EXPECT_EQ(SECSuccess,
SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
@@ -474,7 +458,7 @@ TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) {
Connect();
CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
// Same test but for TLS 1.3. This has to fail.
@@ -496,8 +480,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) {
// custom group in contrast to the previous test.
TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
ssl_grp_ffdhe_2048};
@@ -519,14 +502,13 @@ TEST_P(TlsConnectGenericPre13, PreferredFfdhe) {
Connect();
client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
- client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
- server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+ client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+ server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
}
TEST_P(TlsConnectGenericPre13, MismatchDHE) {
EnableOnlyDheCiphers();
- EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
- SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ client_->SetOption(SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE);
static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group};
EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups,
PR_ARRAY_SIZE(serverGroups)));
@@ -544,37 +526,37 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
- CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
EnableOnlyDheCiphers();
auto clientCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(clientCapture);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
auto serverCapture =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- server_->SetPacketFilter(serverCapture);
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_pre_shared_key_xtn);
ExpectResumption(RESUME_TICKET);
Connect();
- CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
ASSERT_LT(0UL, clientCapture->extension().len());
ASSERT_LT(0UL, serverCapture->extension().len());
}
class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
public:
- TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len)
- : version_(version), data_(data), len_(len) {}
+ TlsDheSkeChangeSignature(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version, const uint8_t* data, size_t len)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
+ version_(version),
+ data_(data),
+ len_(len) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
TlsParser parser(input);
EXPECT_TRUE(parser.SkipVariable(2)); // dh_p
EXPECT_TRUE(parser.SkipVariable(2)); // dh_g
@@ -615,8 +597,8 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
client_->ConfigNamedGroups(client_groups);
- server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>(
- version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
+ MakeTlsFilter<TlsDheSkeChangeSignature>(server_, version_, kBogusDheSignature,
+ sizeof(kBogusDheSignature));
ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
index 3cc3b0e62..ee8906deb 100644
--- a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -6,6 +6,7 @@
#include "secerr.h"
#include "ssl.h"
+#include "sslexp.h"
extern "C" {
// This is not something that should make you happy.
@@ -20,14 +21,14 @@ extern "C" {
namespace nss_test {
-TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+TEST_P(TlsConnectDatagramPre13, DropClientFirstFlightOnce) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
-TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightOnce) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
@@ -35,36 +36,770 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
// This drops the first transmission from both the client and server of all
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
-TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
+TEST_P(TlsConnectDatagramPre13, DropAllFirstTransmissions) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x15));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}
// This drops the server's first flight three times.
-TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
+TEST_P(TlsConnectDatagramPre13, DropServerFirstFlightThrice) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}
// This drops the client's second flight once
-TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightOnce) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
// This drops the client's second flight three times.
-TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) {
- client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+TEST_P(TlsConnectDatagramPre13, DropClientSecondFlightThrice) {
+ client_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
// This drops the server's second flight three times.
-TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) {
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
+TEST_P(TlsConnectDatagramPre13, DropServerSecondFlightThrice) {
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
+class TlsDropDatagram13 : public TlsConnectDatagram13 {
+ public:
+ TlsDropDatagram13()
+ : client_filters_(),
+ server_filters_(),
+ expected_client_acks_(0),
+ expected_server_acks_(1) {}
+
+ void SetUp() override {
+ TlsConnectDatagram13::SetUp();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ SetFilters();
+ }
+
+ void SetFilters() {
+ EnsureTlsSetup();
+ client_filters_.Init(client_);
+ server_filters_.Init(server_);
+ }
+
+ void HandshakeAndAck(const std::shared_ptr<TlsAgent>& agent) {
+ agent->Handshake(); // Read flight.
+ ShiftDtlsTimers();
+ agent->Handshake(); // Generate ACK.
+ }
+
+ void ShrinkPostServerHelloMtu() {
+ // Abuse the custom extension mechanism to modify the MTU so that the
+ // Certificate message is split into two pieces.
+ ASSERT_EQ(
+ SECSuccess,
+ SSL_InstallExtensionHooks(
+ server_->ssl_fd(), 1,
+ [](PRFileDesc* fd, SSLHandshakeType message, PRUint8* data,
+ unsigned int* len, unsigned int maxLen, void* arg) -> PRBool {
+ SSLInt_SetMTU(fd, 500); // Splits the certificate.
+ return PR_FALSE;
+ },
+ nullptr,
+ [](PRFileDesc* fd, SSLHandshakeType message, const PRUint8* data,
+ unsigned int len, SSLAlertDescription* alert,
+ void* arg) -> SECStatus { return SECSuccess; },
+ nullptr));
+ }
+
+ protected:
+ class DropAckChain {
+ public:
+ DropAckChain()
+ : records_(nullptr), ack_(nullptr), drop_(nullptr), chain_(nullptr) {}
+
+ void Init(const std::shared_ptr<TlsAgent>& agent) {
+ records_ = std::make_shared<TlsRecordRecorder>(agent);
+ ack_ = std::make_shared<TlsRecordRecorder>(agent, content_ack);
+ ack_->EnableDecryption();
+ drop_ = std::make_shared<SelectiveRecordDropFilter>(agent, 0, false);
+ chain_ = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, ack_, drop_}));
+ agent->SetFilter(chain_);
+ }
+
+ const TlsRecord& record(size_t i) const { return records_->record(i); }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsRecordRecorder> ack_;
+ std::shared_ptr<SelectiveRecordDropFilter> drop_;
+ std::shared_ptr<PacketFilter> chain_;
+ };
+
+ void CheckAcks(const DropAckChain& chain, size_t index,
+ std::vector<uint64_t> acks) {
+ const DataBuffer& buf = chain.ack_->record(index).buffer;
+ size_t offset = 0;
+
+ EXPECT_EQ(acks.size() * 8, buf.len());
+ if ((acks.size() * 8) != buf.len()) {
+ while (offset < buf.len()) {
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ std::cerr << "Ack=0x" << std::hex << ack << std::dec << std::endl;
+ }
+ return;
+ }
+
+ for (size_t i = 0; i < acks.size(); ++i) {
+ uint64_t a = acks[i];
+ uint64_t ack;
+ ASSERT_TRUE(buf.Read(offset, 8, &ack));
+ offset += 8;
+ if (a != ack) {
+ ADD_FAILURE() << "Wrong ack " << i << " expected=0x" << std::hex << a
+ << " got=0x" << ack << std::dec;
+ }
+ }
+ }
+
+ void CheckedHandshakeSendReceive() {
+ Handshake();
+ CheckPostHandshake();
+ }
+
+ void CheckPostHandshake() {
+ CheckConnected();
+ SendReceive();
+ EXPECT_EQ(expected_client_acks_, client_filters_.ack_->count());
+ EXPECT_EQ(expected_server_acks_, server_filters_.ack_->count());
+ }
+
+ protected:
+ DropAckChain client_filters_;
+ DropAckChain server_filters_;
+ size_t expected_client_acks_;
+ size_t expected_server_acks_;
+};
+
+// All of these tests produce a minimum one ACK, from the server
+// to the client upon receiving the client Finished.
+// Dropping complete first and second flights does not produce
+// ACKs
+TEST_F(TlsDropDatagram13, DropClientFirstFlightOnce) {
+ client_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+TEST_F(TlsDropDatagram13, DropServerFirstFlightOnce) {
+ server_filters_.drop_->Reset(0xff);
+ StartConnect();
+ client_->Handshake();
+ // Send the first flight, all dropped.
+ server_->Handshake();
+ server_filters_.drop_->Disable();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the server's first record also does not produce
+// an ACK because the next record is ignored.
+// TODO(ekr@rtfm.com): We should generate an empty ACK.
+TEST_F(TlsDropDatagram13, DropServerFirstRecordOnce) {
+ server_filters_.drop_->Reset({0});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ Handshake();
+ CheckedHandshakeSendReceive();
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Dropping the second packet of the server's flight should
+// produce an ACK.
+TEST_F(TlsDropDatagram13, DropServerSecondRecordOnce) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ HandshakeAndAck(client_);
+ expected_client_acks_ = 1;
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0, {0}); // ServerHello
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Drop the server ACK and verify that the client retransmits
+// the ClientHello.
+TEST_F(TlsDropDatagram13, DropServerAckOnce) {
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // At this point the server has sent it's first flight,
+ // so make it drop the ACK.
+ server_filters_.drop_->Reset({0});
+ client_->Handshake(); // Send the client Finished.
+ server_->Handshake(); // Receive the Finished and send the ACK.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Wait for the DTLS timeout to make sure we retransmit the
+ // Finished.
+ ShiftDtlsTimers();
+ client_->Handshake(); // Retransmit the Finished.
+ server_->Handshake(); // Read the Finished and send an ACK.
+ uint8_t buf[1];
+ PRInt32 rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ expected_server_acks_ = 2;
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ CheckPostHandshake();
+ // There should be two copies of the finished ACK
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
+}
+
+// Drop the client certificate verify.
+TEST_F(TlsDropDatagram13, DropClientCertVerify) {
+ StartConnect();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->Handshake();
+ server_->Handshake();
+ // Have the client drop Cert Verify
+ client_filters_.drop_->Reset({1});
+ expected_server_acks_ = 2;
+ CheckedHandshakeSendReceive();
+ // Ack of the Cert.
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ // Ack of the whole client handshake.
+ CheckAcks(
+ server_filters_, 1,
+ {0x0002000000000000ULL, // CH (we drop everything after this on client)
+ 0x0002000000000003ULL, // CT (2)
+ 0x0002000000000004ULL}); // FIN (2)
+}
+
+// Shrink the MTU down so that certs get split and drop the first piece.
+TEST_F(TlsDropDatagram13, DropFirstHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({2});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(2).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT2, CV, FIN
+ // Check that the first record is CT1 (which is identical to the same
+ // as the previous CT1).
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000002ULL}); // CT2
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and drop the second piece.
+TEST_F(TlsDropDatagram13, DropSecondHalfOfServerCertificate) {
+ server_filters_.drop_->Reset({3});
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // SH, EE, CT1, CT2, CV, FIN
+ size_t ct1_size = server_filters_.record(3).buffer.len();
+ server_filters_.records_->Clear();
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ server_->Handshake(); // Retransmit
+ EXPECT_EQ(3UL, server_filters_.records_->count()); // CT1, CV, FIN
+ // Check that the first record is CT1
+ EXPECT_EQ(ct1_size, server_filters_.record(0).buffer.len());
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000ULL, // EE
+ 0x0002000000000001ULL, // CT1
+ });
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// In this test, the Certificate message is sent four times, we drop all or part
+// of the first three attempts:
+// 1. Without fragmentation so that we can see how big it is - we drop that.
+// 2. In two pieces - we drop half AND the resulting ACK.
+// 3. In three pieces - we drop the middle piece.
+//
+// After that we let all the ACKs through and allow the handshake to complete
+// without further interference.
+//
+// This allows us to test that ranges of handshake messages are sent correctly
+// even when there are overlapping acknowledgments; that ACKs with duplicate or
+// overlapping message ranges are handled properly; and that extra
+// retransmissions are handled properly.
+class TlsFragmentationAndRecoveryTest : public TlsDropDatagram13 {
+ public:
+ TlsFragmentationAndRecoveryTest() : cert_len_(0) {}
+
+ protected:
+ void RunTest(size_t dropped_half) {
+ FirstFlightDropCertificate();
+
+ SecondAttemptDropHalf(dropped_half);
+ size_t dropped_half_size = server_record_len(dropped_half);
+ size_t second_flight_count = server_filters_.records_->count();
+
+ ThirdAttemptDropMiddle();
+ size_t repaired_third_size = server_record_len((dropped_half == 0) ? 0 : 2);
+ size_t third_flight_count = server_filters_.records_->count();
+
+ AckAndCompleteRetransmission();
+ size_t final_server_flight_count = server_filters_.records_->count();
+ EXPECT_LE(3U, final_server_flight_count); // CT(sixth), CV, Fin
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ SendDelayedAck();
+ // Same number of messages as the last flight.
+ EXPECT_EQ(final_server_flight_count, server_filters_.records_->count());
+ // Double check that the Certificate size is still correct.
+ CheckSizeOfSixth(dropped_half_size, repaired_third_size);
+
+ CompleteHandshake(final_server_flight_count);
+
+ // This is the ACK for the first attempt to send a whole certificate.
+ std::vector<uint64_t> client_acks = {
+ 0, // SH
+ 0x0002000000000000ULL // EE
+ };
+ CheckAcks(client_filters_, 0, client_acks);
+ // And from the second attempt for the half was kept (we delayed this ACK).
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ ~dropped_half % 2);
+ CheckAcks(client_filters_, 1, client_acks);
+ // And the third attempt where the first and last thirds got through.
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count - 1);
+ client_acks.push_back(0x0002000000000000ULL + second_flight_count +
+ third_flight_count + 1);
+ CheckAcks(client_filters_, 2, client_acks);
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ }
+
+ private:
+ void FirstFlightDropCertificate() {
+ StartConnect();
+ client_->Handshake();
+
+ // Note: 1 << N is the Nth packet, starting from zero.
+ server_filters_.drop_->Reset(1 << 2); // Drop Cert0.
+ server_->Handshake();
+ EXPECT_EQ(5U, server_filters_.records_->count()); // SH, EE, CT, CV, Fin
+ cert_len_ = server_filters_.records_->record(2).buffer.len();
+
+ HandshakeAndAck(client_);
+ EXPECT_EQ(2U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU so that the server has to split the certificate in two
+ // pieces. The server resends Certificate (in two), plus CV and Fin.
+ void SecondAttemptDropHalf(size_t dropped_half) {
+ ASSERT_LE(0U, dropped_half);
+ ASSERT_GT(2U, dropped_half);
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({dropped_half}); // Drop Cert1[half]
+ SplitServerMtu(2);
+ server_->Handshake();
+ EXPECT_LE(4U, server_filters_.records_->count()); // CT x2, CV, Fin
+
+ // Generate and capture the ACK from the client.
+ client_filters_.drop_->Reset({0});
+ HandshakeAndAck(client_);
+ EXPECT_EQ(3U, client_filters_.records_->count());
+ }
+
+ // Lower the MTU again so that the server sends Certificate cut into three
+ // pieces. Drop the middle piece.
+ void ThirdAttemptDropMiddle() {
+ server_filters_.records_->Clear();
+ server_filters_.drop_->Reset({1}); // Drop Cert2[1] (of 3)
+ SplitServerMtu(3);
+ // Because we dropped the client ACK, the server retransmits on a timer.
+ ShiftDtlsTimers();
+ server_->Handshake();
+ EXPECT_LE(5U, server_filters_.records_->count()); // CT x3, CV, Fin
+ }
+
+ void AckAndCompleteRetransmission() {
+ // Generate ACKs.
+ HandshakeAndAck(client_);
+ // The server should send the final sixth of the certificate: the client has
+ // acknowledged the first half and the last third. Also send CV and Fin.
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CheckSizeOfSixth(size_t size_of_half, size_t size_of_third) {
+ // Work out if the final sixth is the right size. We get the records with
+ // overheads added, which obscures the length of the payload. We want to
+ // ensure that the server only sent the missing sixth of the Certificate.
+ //
+ // We captured |size_of_half + overhead| and |size_of_third + overhead| and
+ // want to calculate |size_of_third - size_of_third + overhead|. We can't
+ // calculate |overhead|, but it is is (currently) always a handshake message
+ // header, a content type, and an authentication tag:
+ static const size_t record_overhead = 12 + 1 + 16;
+ EXPECT_EQ(size_of_half - size_of_third + record_overhead,
+ server_filters_.records_->record(0).buffer.len());
+ }
+
+ void SendDelayedAck() {
+ // Send the ACK we held back. The reordered ACK doesn't add new
+ // information,
+ // but triggers an extra retransmission of the missing records again (even
+ // though the client has all that it needs).
+ client_->SendRecordDirect(client_filters_.records_->record(2));
+ server_filters_.records_->Clear();
+ server_->Handshake();
+ }
+
+ void CompleteHandshake(size_t extra_retransmissions) {
+ // All this messing around shouldn't cause a failure...
+ Handshake();
+ // ...but it leaves a mess. Add an extra few calls to Handshake() for the
+ // client so that it absorbs the extra retransmissions.
+ for (size_t i = 0; i < extra_retransmissions; ++i) {
+ client_->Handshake();
+ }
+ CheckConnected();
+ }
+
+ // Split the server MTU so that the Certificate is split into |count| pieces.
+ // The calculation doesn't need to be perfect as long as the Certificate
+ // message is split into the right number of pieces.
+ void SplitServerMtu(size_t count) {
+ // Set the MTU based on the formula:
+ // bare_size = cert_len_ - actual_overhead
+ // MTU = ceil(bare_size / count) + pessimistic_overhead
+ //
+ // actual_overhead is the amount of actual overhead on the record we
+ // captured, which is (note that our length doesn't include the header):
+ static const size_t actual_overhead = 12 + // handshake message header
+ 1 + // content type
+ 16; // authentication tag
+ size_t bare_size = cert_len_ - actual_overhead;
+
+ // pessimistic_overhead is the amount of expansion that NSS assumes will be
+ // added to each handshake record. Right now, that is DTLS_MIN_FRAGMENT:
+ static const size_t pessimistic_overhead =
+ 12 + // handshake message header
+ 1 + // content type
+ 13 + // record header length
+ 64; // maximum record expansion: IV, MAC and block cipher expansion
+
+ size_t mtu = (bare_size + count - 1) / count + pessimistic_overhead;
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "server: set MTU to " << mtu << std::endl;
+ }
+ EXPECT_EQ(SECSuccess, SSLInt_SetMTU(server_->ssl_fd(), mtu));
+ }
+
+ size_t server_record_len(size_t index) const {
+ return server_filters_.records_->record(index).buffer.len();
+ }
+
+ size_t cert_len_;
+};
+
+TEST_F(TlsFragmentationAndRecoveryTest, DropFirstHalf) { RunTest(0); }
+
+TEST_F(TlsFragmentationAndRecoveryTest, DropSecondHalf) { RunTest(1); }
+
+TEST_F(TlsDropDatagram13, NoDropsDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000001ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
+}
+
+TEST_F(TlsDropDatagram13, DropEEDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ server_filters_.drop_->Reset({1});
+ ZeroRttSendReceive(true, true);
+ HandshakeAndAck(client_);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ CheckAcks(client_filters_, 0, {0});
+ CheckAcks(server_filters_, 0,
+ {0x0001000000000002ULL, // EOED
+ 0x0002000000000000ULL}); // Finished
+}
+
+class TlsReorderDatagram13 : public TlsDropDatagram13 {
+ public:
+ TlsReorderDatagram13() {}
+
+ // Send records from the records buffer in the given order.
+ void ReSend(TlsAgent::Role side, std::vector<size_t> indices) {
+ std::shared_ptr<TlsAgent> agent;
+ std::shared_ptr<TlsRecordRecorder> records;
+
+ if (side == TlsAgent::CLIENT) {
+ agent = client_;
+ records = client_filters_.records_;
+ } else {
+ agent = server_;
+ records = server_filters_.records_;
+ }
+
+ for (auto i : indices) {
+ agent->SendRecordDirect(records->record(i));
+ }
+ }
+};
+
+// Reorder the server records so that EE comes at the end
+// of the flight and will still produce an ACK.
+TEST_F(TlsDropDatagram13, ReorderServerEE) {
+ server_filters_.drop_->Reset({1});
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // We dropped EE, now reinject.
+ server_->SendRecordDirect(server_filters_.record(1));
+ expected_client_acks_ = 1;
+ HandshakeAndAck(client_);
+ CheckedHandshakeSendReceive();
+ CheckAcks(client_filters_, 0,
+ {
+ 0, // SH
+ 0x0002000000000000, // EE
+ });
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+// The client sends an out of order non-handshake message
+// but with the handshake key.
+class TlsSendCipherSpecCapturer {
+ public:
+ TlsSendCipherSpecCapturer(std::shared_ptr<TlsAgent>& agent)
+ : send_cipher_specs_() {
+ SSLInt_SetCipherSpecChangeFunc(agent->ssl_fd(), CipherSpecChanged,
+ (void*)this);
+ }
+
+ std::shared_ptr<TlsCipherSpec> spec(size_t i) {
+ if (i >= send_cipher_specs_.size()) {
+ return nullptr;
+ }
+ return send_cipher_specs_[i];
+ }
+
+ private:
+ static void CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec) {
+ if (!sending) {
+ return;
+ }
+
+ auto self = static_cast<TlsSendCipherSpecCapturer*>(arg);
+
+ auto spec = std::make_shared<TlsCipherSpec>();
+ bool ret = spec->Init(SSLInt_CipherSpecToEpoch(newSpec),
+ SSLInt_CipherSpecToAlgorithm(newSpec),
+ SSLInt_CipherSpecToKey(newSpec),
+ SSLInt_CipherSpecToIv(newSpec));
+ EXPECT_EQ(true, ret);
+ self->send_cipher_specs_.push_back(spec);
+ }
+
+ std::vector<std::shared_ptr<TlsCipherSpec>> send_cipher_specs_;
+};
+
+TEST_F(TlsDropDatagram13, SendOutOfOrderAppWithHandshakeKey) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // After the client sends Finished, inject an app data record
+ // with the handshake key. This should produce an alert.
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+ ASSERT_TRUE(client_->SendEncryptedRecord(
+ spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
+ kTlsApplicationDataType, DataBuffer(buf, sizeof(buf))));
+
+ // Now have the server consume the bogus message.
+ server_->ExpectSendAlert(illegal_parameter, kTlsAlertFatal);
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, server_->state());
+ EXPECT_EQ(SSL_ERROR_RX_UNKNOWN_RECORD_TYPE, PORT_GetError());
+}
+
+TEST_F(TlsDropDatagram13, SendOutOfOrderHsNonsenseWithHandshakeKey) {
+ StartConnect();
+ TlsSendCipherSpecCapturer capturer(client_);
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ // Inject a new bogus handshake record, which the server responds
+ // to by just ACKing the original one (we ignore the contents).
+ uint8_t buf[] = {'a', 'b', 'c'};
+ auto spec = capturer.spec(0);
+ ASSERT_NE(nullptr, spec.get());
+ ASSERT_EQ(2, spec->epoch());
+ ASSERT_TRUE(client_->SendEncryptedRecord(
+ spec, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 0x0002000000000002,
+ kTlsHandshakeType, DataBuffer(buf, sizeof(buf))));
+ server_->Handshake();
+ EXPECT_EQ(2UL, server_filters_.ack_->count());
+ // The server acknowledges client Finished twice.
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+ CheckAcks(server_filters_, 1, {0x0002000000000000ULL});
+}
+
+// Shrink the MTU down so that certs get split and then swap the first and
+// second pieces of the server certificate.
+TEST_F(TlsReorderDatagram13, ReorderServerCertificate) {
+ StartConnect();
+ ShrinkPostServerHelloMtu();
+ client_->Handshake();
+ // Drop the entire handshake flight so we can reorder.
+ server_filters_.drop_->Reset(0xff);
+ server_->Handshake();
+ // Check that things got split.
+ EXPECT_EQ(6UL,
+ server_filters_.records_->count()); // CH, EE, CT1, CT2, CV, FIN
+ // Now re-send things in a different order.
+ ReSend(TlsAgent::SERVER, std::vector<size_t>{0, 1, 3, 2, 4, 5});
+ // Clear.
+ server_filters_.drop_->Disable();
+ server_filters_.records_->Clear();
+ // Wait for client to send ACK.
+ ShiftDtlsTimers();
+ CheckedHandshakeSendReceive();
+ EXPECT_EQ(2UL, server_filters_.records_->count()); // ACK + Data
+ CheckAcks(server_filters_, 0, {0x0002000000000000ULL});
+}
+
+TEST_F(TlsReorderDatagram13, DataAfterEOEDDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // data, EOED, FIN
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, data, FIN
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 0, 2}));
+ server_->Handshake();
+ CheckConnected();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+TEST_F(TlsReorderDatagram13, DataAfterFinDuringZeroRtt) {
+ SetupForZeroRtt();
+ SetFilters();
+ std::cerr << "Starting second handshake" << std::endl;
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ // Send the client's first flight of zero RTT data.
+ ZeroRttSendReceive(true, true);
+ // Now send another client application data record but
+ // capture it.
+ client_filters_.records_->Clear();
+ client_filters_.drop_->Reset(0xff);
+ const char* k0RttData = "123456";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ EXPECT_EQ(k0RttDataLen, rv);
+ EXPECT_EQ(1UL, client_filters_.records_->count()); // data
+ server_->Handshake();
+ client_->Handshake();
+ ExpectEarlyDataAccepted(true);
+ // The server still hasn't received anything at this point.
+ EXPECT_EQ(3UL, client_filters_.records_->count()); // EOED, FIN, Data
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+ // Now re-send the client's messages: EOED, FIN, Data
+ ReSend(TlsAgent::CLIENT, std::vector<size_t>({1, 2, 0}));
+ server_->Handshake();
+ CheckConnected();
+ EXPECT_EQ(0U, client_filters_.ack_->count());
+ // Acknowledgements for EOED and Finished.
+ CheckAcks(server_filters_, 0, {0x0001000000000002ULL, 0x0002000000000000ULL});
+ uint8_t buf[8];
+ rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(-1, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
uint64_t* limit = nullptr) {
uint64_t l;
@@ -111,7 +846,6 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindow) {
GetCipherAndLimit(version_, &cipher);
server_->EnableSingleCipher(cipher);
Connect();
-
EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0));
SendReceive();
}
@@ -129,5 +863,7 @@ TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) {
INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus,
TlsConnectTestBase::kTlsV12Plus);
+INSTANTIATE_TEST_CASE_P(DatagramPre13, TlsConnectDatagramPre13,
+ TlsConnectTestBase::kTlsV11V12);
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
index 1e406b6c2..3c7cd2ecf 100644
--- a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -69,20 +69,19 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
server_->ConfigNamedGroups(groups);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
EnsureTlsSetup();
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3,
hrr_capture->buffer().len() != 0);
}
@@ -112,7 +111,7 @@ TEST_P(TlsKeyExchangeTest, P384Priority) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
CheckKEXDetails(groups, shares);
@@ -129,7 +128,7 @@ TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1,
@@ -147,7 +146,7 @@ TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
@@ -172,7 +171,7 @@ TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
@@ -188,12 +187,14 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
public:
- TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {}
+ TlsKeyExchangeGroupCapture(const std::shared_ptr<TlsAgent> &agent)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerKeyExchange}),
+ group_(ssl_grp_none) {}
SSLNamedGroup group() const { return group_; }
@@ -201,10 +202,6 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
uint32_t value = 0;
EXPECT_TRUE(input.Read(0, 1, &value));
EXPECT_EQ(3U, value) << "curve type has to be 3";
@@ -223,10 +220,8 @@ class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
// P-256 is supported by the client (<= 1.2 only).
TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
- auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>();
- server_->SetPacketFilter(group_capture);
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn);
+ auto group_capture = MakeTlsFilter<TlsKeyExchangeGroupCapture>(server_);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
@@ -238,8 +233,7 @@ TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
// Supported groups is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, DropSupportedGroupExtension) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_supported_groups_xtn);
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
@@ -278,7 +272,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
CheckConnected();
// The renegotiation has to use the same preferences as the original session.
@@ -286,7 +280,7 @@ TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) {
client_->StartRenegotiate();
Handshake();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
TEST_P(TlsKeyExchangeTest, Curve25519) {
@@ -320,7 +314,7 @@ TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
}
#ifndef NSS_DISABLE_TLS_1_3
@@ -339,7 +333,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1};
CheckKEXDetails(client_groups, shares);
}
@@ -359,7 +353,7 @@ TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
CheckKEXDetails(client_groups, shares);
}
@@ -381,7 +375,7 @@ TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
}
@@ -403,7 +397,7 @@ TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) {
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
}
@@ -425,7 +419,7 @@ TEST_P(TlsKeyExchangeTest13,
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
}
@@ -447,7 +441,7 @@ TEST_P(TlsKeyExchangeTest13,
Connect();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
}
@@ -509,7 +503,7 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
// The server would accept 25519 but its preferred group (P256) has to win.
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
- ssl_sig_rsa_pss_sha256);
+ ssl_sig_rsa_pss_rsae_sha256);
const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519,
ssl_grp_ec_secp256r1};
CheckKEXDetails(client_groups, shares);
@@ -518,16 +512,13 @@ TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
// Replace the point in the client key exchange message with an empty one
class ECCClientKEXFilter : public TlsHandshakeFilter {
public:
- ECCClientKEXFilter() {}
+ ECCClientKEXFilter(const std::shared_ptr<TlsAgent> &client)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
- return KEEP;
- }
-
// Replace the client key exchange message with an empty point
output->Allocate(1);
output->Write(0, 0U, 1); // set point length 0
@@ -538,20 +529,17 @@ class ECCClientKEXFilter : public TlsHandshakeFilter {
// Replace the point in the server key exchange message with an empty one
class ECCServerKEXFilter : public TlsHandshakeFilter {
public:
- ECCServerKEXFilter() {}
+ ECCServerKEXFilter(const std::shared_ptr<TlsAgent> &server)
+ : TlsHandshakeFilter(server, {kTlsHandshakeServerKeyExchange}) {}
protected:
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
const DataBuffer &input,
DataBuffer *output) {
- if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
- return KEEP;
- }
-
// Replace the server key exchange message with an empty point
output->Allocate(4);
output->Write(0, 3U, 1); // named curve
- uint32_t curve;
+ uint32_t curve = 0;
EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id
output->Write(1, curve, 2); // write curve id
output->Write(3, 0U, 1); // point length 0
@@ -560,15 +548,13 @@ class ECCServerKEXFilter : public TlsHandshakeFilter {
};
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
- // add packet filter
- server_->SetPacketFilter(std::make_shared<ECCServerKEXFilter>());
+ MakeTlsFilter<ECCServerKEXFilter>(server_);
ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
}
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
- // add packet filter
- client_->SetPacketFilter(std::make_shared<ECCClientKEXFilter>());
+ MakeTlsFilter<ECCClientKEXFilter>(client_);
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
index be407b42e..c42883eb7 100644
--- a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -118,7 +118,6 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
TEST_P(TlsConnectTls13, EarlyExporter) {
SetupForZeroRtt();
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
index d15139419..0453dabdb 100644
--- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -19,8 +19,9 @@ namespace nss_test {
class TlsExtensionTruncator : public TlsExtensionFilter {
public:
- TlsExtensionTruncator(uint16_t extension, size_t length)
- : extension_(extension), length_(length) {}
+ TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t length)
+ : TlsExtensionFilter(agent), extension_(extension), length_(length) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -42,8 +43,9 @@ class TlsExtensionTruncator : public TlsExtensionFilter {
class TlsExtensionDamager : public TlsExtensionFilter {
public:
- TlsExtensionDamager(uint16_t extension, size_t index)
- : extension_(extension), index_(index) {}
+ TlsExtensionDamager(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, size_t index)
+ : TlsExtensionFilter(agent), extension_(extension), index_(index) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) {
@@ -61,60 +63,17 @@ class TlsExtensionDamager : public TlsExtensionFilter {
size_t index_;
};
-class TlsExtensionInjector : public TlsHandshakeFilter {
- public:
- TlsExtensionInjector(uint16_t ext, DataBuffer& data)
- : extension_(ext), data_(data) {}
-
- virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) {
- TlsParser parser(input);
- if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
- return KEEP;
- }
- size_t offset = parser.consumed();
-
- *output = input;
-
- // Increase the size of the extensions.
- uint16_t ext_len;
- memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
- ext_len = htons(ntohs(ext_len) + data_.len() + 4);
- memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
-
- // Insert the extension type and length.
- DataBuffer type_length;
- type_length.Allocate(4);
- type_length.Write(0, extension_, 2);
- type_length.Write(2, data_.len(), 2);
- output->Splice(type_length, offset + 2);
-
- // Insert the payload.
- if (data_.len() > 0) {
- output->Splice(data_, offset + 6);
- }
-
- return CHANGE;
- }
-
- private:
- const uint16_t extension_;
- const DataBuffer data_;
-};
-
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
- TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
- : handshake_type_(handshake_type), extension_(ext), data_(data) {}
+ TlsExtensionAppender(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : TlsHandshakeFilter(agent, {handshake_type}),
+ extension_(ext),
+ data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() != handshake_type_) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
return KEEP;
@@ -159,7 +118,6 @@ class TlsExtensionAppender : public TlsHandshakeFilter {
return true;
}
- const uint8_t handshake_type_;
const uint16_t extension_;
const DataBuffer data_;
};
@@ -171,13 +129,13 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- client_->SetPacketFilter(filter);
+ client_->SetFilter(filter);
ConnectExpectAlert(server_, desc);
}
void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
uint8_t desc = kTlsAlertDecodeError) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, desc);
}
@@ -200,11 +158,10 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
client_->ConfigNamedGroups(client_groups);
server_->ConfigNamedGroups(server_groups);
EnsureTlsSetup();
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
+ MakeTlsFilter<TlsExtensionDropper>(client_, type);
Handshake();
client_->CheckErrorCode(client_error);
server_->CheckErrorCode(server_error);
@@ -245,8 +202,8 @@ class TlsExtensionTest13
void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
DataBuffer versions_buf(buf, len);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -257,8 +214,8 @@ class TlsExtensionTest13
size_t index = versions_buf.Write(0, 2, 1);
versions_buf.Write(index, version, 2);
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_supported_versions_xtn, versions_buf));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_supported_versions_xtn, versions_buf);
ConnectExpectFail();
}
};
@@ -289,26 +246,26 @@ class TlsExtensionTestPre13 : public TlsExtensionTestBase,
TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4));
+ std::make_shared<TlsExtensionDamager>(client_, ssl_server_name_xtn, 4));
}
TEST_P(TlsExtensionTestGeneric, TruncateSni) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_server_name_xtn, 7));
}
// A valid extension that appears twice will be reported as unsupported.
TEST_P(TlsExtensionTestGeneric, RepeatSni) {
DataBuffer extension;
InitSimpleSni(&extension);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionInjector>(
+ client_, ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An SNI entry with zero length is considered invalid (strangely, not if it is
@@ -320,23 +277,23 @@ TEST_P(TlsExtensionTestGeneric, BadSni) {
extension.Allocate(simple.len() + 3);
extension.Write(0, static_cast<uint32_t>(0), 3);
extension.Write(3, simple);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptySni) {
DataBuffer extension;
extension.Allocate(2);
extension.Write(0, static_cast<uint32_t>(0), 2);
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
EnableAlpn();
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
@@ -347,21 +304,21 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ client_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertNoApplicationProtocol);
}
TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
EnableAlpn();
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
EnableAlpn();
// This will leave the length of the second entry, but no value.
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_app_layer_protocol_xtn, 5));
}
TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
@@ -369,7 +326,7 @@ TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
const uint8_t val[] = {0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ client_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
@@ -388,7 +345,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
@@ -396,7 +353,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
@@ -404,7 +361,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
@@ -412,7 +369,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
@@ -420,7 +377,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
@@ -428,7 +385,7 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension));
+ server_, ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
@@ -436,55 +393,64 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
DataBuffer extension(val, sizeof(val));
ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_app_layer_protocol_xtn, extension),
+ server_, ssl_app_layer_protocol_xtn, extension),
kTlsAlertIllegalParameter);
}
TEST_P(TlsExtensionTestDtls, SrtpShort) {
EnableSrtp();
ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3));
+ std::make_shared<TlsExtensionTruncator>(client_, ssl_use_srtp_xtn, 3));
}
TEST_P(TlsExtensionTestDtls, SrtpOdd) {
EnableSrtp();
const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_use_srtp_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsNoOverlap) {
+ const uint8_t val[] = {0x00, 0x02, 0xff, 0xff};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ client_, ssl_signature_algorithms_xtn, extension),
+ kTlsAlertHandshakeFailure);
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
const uint8_t val[] = {0x00, 0x01, 0x04};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_signature_algorithms_xtn, extension));
+ client_, ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
ClientHelloErrorTest(
- std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn),
+ std::make_shared<TlsExtensionDropper>(client_, ssl_supported_groups_xtn),
version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
: kTlsAlertMissingExtension);
}
@@ -493,75 +459,74 @@ TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_elliptic_curves_xtn, extension));
+ client_, ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
const uint8_t val[] = {0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
const uint8_t val[] = {0x01, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_ec_point_formats_xtn, extension));
+ client_, ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
const uint8_t val[] = {0x99};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
const uint8_t val[] = {0x01, 0x00};
DataBuffer extension(val, sizeof(val));
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// The extension has to contain a length.
TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
DataBuffer extension;
ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
- ssl_renegotiation_info_xtn, extension));
+ client_, ssl_renegotiation_info_xtn, extension));
}
// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
// picks the wrong cipher suite.
TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
- const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512,
- ssl_sig_rsa_pss_sha384};
+ const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_rsae_sha512,
+ ssl_sig_rsa_pss_rsae_sha384};
auto capture =
- std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_signature_algorithms_xtn);
client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
- client_->SetPacketFilter(capture);
EnableOnlyStaticRsaCiphers();
Connect();
@@ -579,9 +544,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
// Temporary test to verify that we choke on an empty ClientKeyShare.
// This test will fail when we implement HelloRetryRequest.
TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
- ClientHelloErrorTest(
- std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2),
- kTlsAlertHandshakeFailure);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionTruncator>(
+ client_, ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
}
// These tests only work in stream mode because the client sends a
@@ -590,8 +555,7 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
// packet gets dropped.
TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn);
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -611,8 +575,7 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -633,8 +596,7 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
0x02};
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
- server_->SetPacketFilter(
- std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ MakeTlsFilter<TlsExtensionReplacer>(server_, ssl_tls13_key_share_xtn, buf);
client_->ExpectSendAlert(kTlsAlertMissingExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -645,8 +607,8 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
SetupForResume();
DataBuffer empty;
- server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>(
- ssl_signature_algorithms_xtn, empty));
+ MakeTlsFilter<TlsExtensionInjector>(server_, ssl_signature_algorithms_xtn,
+ empty);
client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -666,8 +628,12 @@ typedef std::function<void(TlsPreSharedKeyReplacer*)>
class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
public:
- TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
- : identities_(), binders_(), function_(function) {}
+ TlsPreSharedKeyReplacer(const std::shared_ptr<TlsAgent>& agent,
+ TlsPreSharedKeyReplacerFunc function)
+ : TlsExtensionFilter(agent),
+ identities_(),
+ binders_(),
+ function_(function) {}
static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
const std::unique_ptr<DataBuffer>& replace,
@@ -781,8 +747,10 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->identities_[0].identity.Truncate(0);
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -792,10 +760,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -805,10 +773,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -818,8 +786,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>(
- [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -830,11 +798,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
r->binders_.push_back(r->binders_[0]);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -845,10 +813,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
SetupForResume();
- client_->SetPacketFilter(
- std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
- }));
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -857,8 +825,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
SetupForResume();
- client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
- TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ MakeTlsFilter<TlsPreSharedKeyReplacer>(
+ client_, [](TlsPreSharedKeyReplacer* r) {
+ r->binders_.push_back(r->binders_[0]);
+ });
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -870,8 +840,8 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
const uint8_t empty_buf[] = {0};
DataBuffer empty(empty_buf, 0);
// Inject an unused extension after the PSK extension.
- client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
- kTlsHandshakeClientHello, 0xffff, empty));
+ MakeTlsFilter<TlsExtensionAppender>(client_, kTlsHandshakeClientHello, 0xffff,
+ empty);
ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
@@ -881,8 +851,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
SetupForResume();
DataBuffer empty;
- client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(
- ssl_tls13_psk_key_exchange_modes_xtn));
+ MakeTlsFilter<TlsExtensionDropper>(client_,
+ ssl_tls13_psk_key_exchange_modes_xtn);
ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
@@ -897,8 +867,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
kTls13PskKe};
DataBuffer modes(ke_modes, sizeof(ke_modes));
- client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
- ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ MakeTlsFilter<TlsExtensionReplacer>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn, modes);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
@@ -908,9 +878,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
- auto capture = std::make_shared<TlsExtensionCapture>(
- ssl_tls13_psk_key_exchange_modes_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ client_, ssl_tls13_psk_key_exchange_modes_xtn);
Connect();
EXPECT_FALSE(capture->captured());
}
@@ -1006,12 +975,9 @@ class TlsBogusExtensionTest : public TlsConnectTestBase,
static uint8_t empty_buf[1] = {0};
DataBuffer empty(empty_buf, 0);
auto filter =
- std::make_shared<TlsExtensionAppender>(message, extension, empty);
+ MakeTlsFilter<TlsExtensionAppender>(server_, message, extension, empty);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- server_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
- } else {
- server_->SetPacketFilter(filter);
}
}
@@ -1032,17 +998,20 @@ class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
protected:
void ConnectAndFail(uint8_t message) override {
- if (message == kTlsHandshakeHelloRetryRequest) {
+ if (message != kTlsHandshakeServerHello) {
ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
return;
}
- client_->StartConnect();
- server_->StartConnect();
+ FailWithAlert(kTlsAlertUnsupportedExtension);
+ }
+
+ void FailWithAlert(uint8_t alert) {
+ StartConnect();
client_->Handshake(); // ClientHello
server_->Handshake(); // ServerHello
- client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ client_->ExpectSendAlert(alert);
client_->Handshake();
if (variant_ == ssl_variant_stream) {
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
@@ -1067,9 +1036,12 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
Run(kTlsHandshakeCertificate);
}
+// It's perfectly valid to set unknown extensions in CertificateRequest.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
server_->RequestClientAuth(false);
- Run(kTlsHandshakeCertificateRequest);
+ AddFilter(kTlsHandshakeCertificateRequest, 0xff);
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
}
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
@@ -1079,10 +1051,6 @@ TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
Run(kTlsHandshakeHelloRetryRequest);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) {
- Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn);
-}
-
TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) {
Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
}
@@ -1096,13 +1064,6 @@ TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) {
Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
}
-TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) {
- static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
- server_->ConfigNamedGroups(groups);
-
- Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn);
-}
-
// NewSessionTicket allows unknown extensions AND it isn't protected by the
// Finished. So adding an unknown extension doesn't cause an error.
TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
@@ -1132,8 +1093,7 @@ TEST_P(TlsConnectStream, IncludePadding) {
SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
EXPECT_EQ(SECSuccess, rv);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn);
- client_->SetPacketFilter(capture);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, ssl_padding_xtn);
client_->StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
index 44cacce46..f4940bf28 100644
--- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -51,10 +51,16 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsHandshakeFilter::HandshakeHeader handshake_header;
DataBuffer handshake_body;
- if (!handshake_header.Parse(&parser, record_header, &handshake_body)) {
+ bool complete = false;
+ if (!handshake_header.Parse(&parser, record_header, DataBuffer(),
+ &handshake_body, &complete)) {
ADD_FAILURE() << "couldn't parse handshake header";
return false;
}
+ if (!complete) {
+ ADD_FAILURE() << "don't want to deal with fragmented messages";
+ return false;
+ }
DataBuffer record_fragment;
// We can't fragment handshake records that are too small.
@@ -82,7 +88,7 @@ class RecordFragmenter : public PacketFilter {
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(&parser, &record)) {
+ if (!header.Parse(0, &parser, &record)) {
ADD_FAILURE() << "bad record header";
return false;
}
@@ -143,13 +149,13 @@ class RecordFragmenter : public PacketFilter {
};
TEST_P(TlsConnectDatagram, FragmentClientPackets) {
- client_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ client_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagram, FragmentServerPackets) {
- server_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ server_->SetFilter(std::make_shared<RecordFragmenter>());
Connect();
SendReceive();
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
index 1587b66de..99448321c 100644
--- a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -27,7 +27,8 @@ class TlsFuzzTest : public ::testing::Test {};
// Record the application data stream.
class TlsApplicationDataRecorder : public TlsRecordFilter {
public:
- TlsApplicationDataRecorder() : buffer_() {}
+ TlsApplicationDataRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), buffer_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -47,9 +48,9 @@ class TlsApplicationDataRecorder : public TlsRecordFilter {
// Ensure that ssl_Time() returns a constant value.
FUZZ_F(TlsFuzzTest, SSL_Time_Constant) {
- PRUint32 now = ssl_Time();
+ PRUint32 now = ssl_TimeSec();
PR_Sleep(PR_SecondsToInterval(2));
- EXPECT_EQ(ssl_Time(), now);
+ EXPECT_EQ(ssl_TimeSec(), now);
}
// Check that due to the deterministic PRNG we derive
@@ -106,16 +107,16 @@ FUZZ_P(TlsConnectGeneric, DeterministicTranscript) {
DisableECDHEServerKeyReuse();
DataBuffer buffer;
- client_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
- server_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
+ MakeTlsFilter<TlsConversationRecorder>(client_, buffer);
+ MakeTlsFilter<TlsConversationRecorder>(server_, buffer);
// Reset the RNG state.
EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Ensure the filters go away before |buffer| does.
- client_->DeletePacketFilter();
- server_->DeletePacketFilter();
+ client_->ClearFilter();
+ server_->ClearFilter();
if (last.len() > 0) {
EXPECT_EQ(last, buffer);
@@ -133,10 +134,8 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
EnsureTlsSetup();
// Set up app data filters.
- auto client_recorder = std::make_shared<TlsApplicationDataRecorder>();
- client_->SetPacketFilter(client_recorder);
- auto server_recorder = std::make_shared<TlsApplicationDataRecorder>();
- server_->SetPacketFilter(server_recorder);
+ auto client_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(client_);
+ auto server_recorder = MakeTlsFilter<TlsApplicationDataRecorder>(server_);
Connect();
@@ -161,10 +160,9 @@ FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeFinished,
DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
- client_->SetPacketFilter(i1);
Connect();
SendReceive();
}
@@ -173,10 +171,9 @@ FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
FUZZ_P(TlsConnectGeneric, BogusServerFinished) {
EnsureTlsSetup();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeFinished,
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ server_, kTlsHandshakeFinished,
DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
- server_->SetPacketFilter(i1);
Connect();
SendReceive();
}
@@ -187,7 +184,7 @@ FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) {
uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
? kTlsHandshakeCertificateVerify
: kTlsHandshakeServerKeyExchange;
- server_->SetPacketFilter(std::make_shared<TlsLastByteDamager>(msg_type));
+ MakeTlsFilter<TlsLastByteDamager>(server_, msg_type);
Connect();
SendReceive();
}
@@ -197,8 +194,7 @@ FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
- client_->SetPacketFilter(
- std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify));
+ MakeTlsFilter<TlsLastByteDamager>(client_, kTlsHandshakeCertificateVerify);
Connect();
}
@@ -215,82 +211,32 @@ FUZZ_P(TlsConnectGeneric, SessionTicketResumption) {
SendReceive();
}
-class TlsSessionTicketMacDamager : public TlsExtensionFilter {
- public:
- TlsSessionTicketMacDamager() {}
- virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
- const DataBuffer& input,
- DataBuffer* output) {
- if (extension_type != ssl_session_ticket_xtn &&
- extension_type != ssl_tls13_pre_shared_key_xtn) {
- return KEEP;
- }
-
- *output = input;
-
- // Handle everything before TLS 1.3.
- if (extension_type == ssl_session_ticket_xtn) {
- // Modify the last byte of the MAC.
- output->data()[output->len() - 1] ^= 0xff;
- }
-
- // Handle TLS 1.3.
- if (extension_type == ssl_tls13_pre_shared_key_xtn) {
- TlsParser parser(input);
-
- uint32_t ids_len;
- EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0);
-
- uint32_t ticket_len;
- EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0);
-
- // Modify the last byte of the MAC.
- output->data()[2 + 2 + ticket_len - 1] ^= 0xff;
- }
-
- return CHANGE;
- }
-};
-
-// Check that session ticket resumption works with a bad MAC.
-FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) {
- ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- Connect();
- SendReceive();
-
- Reset();
- ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- ExpectResumption(RESUME_TICKET);
-
- client_->SetPacketFilter(std::make_shared<TlsSessionTicketMacDamager>());
- Connect();
- SendReceive();
-}
-
// Check that session tickets are not encrypted.
FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeNewSessionTicket);
- server_->SetPacketFilter(i1);
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeNewSessionTicket);
Connect();
+ std::cerr << "ticket" << filter->buffer() << std::endl;
size_t offset = 4; /* lifetime */
if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
- offset += 1 + 1 + /* ke_modes */
- 1 + 1; /* auth_modes */
+ offset += 4; /* ticket_age_add */
+ uint32_t nonce_len = 0;
+ EXPECT_TRUE(filter->buffer().Read(offset, 1, &nonce_len));
+ offset += 1 + nonce_len;
}
offset += 2 + /* ticket length */
2; /* TLS_EX_SESS_TICKET_VERSION */
// Check the protocol version number.
uint32_t tls_version = 0;
- EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version));
+ EXPECT_TRUE(filter->buffer().Read(offset, sizeof(version_), &tls_version));
EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version));
// Check the cipher suite.
uint32_t suite = 0;
- EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite));
+ EXPECT_TRUE(filter->buffer().Read(offset + sizeof(version_), 2, &suite));
client_->CheckCipherSuite(static_cast<uint16_t>(suite));
}
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
index cd10076b8..2fff9d7cb 100644
--- a/security/nss/gtests/ssl_gtest/ssl_gtest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc
@@ -6,6 +6,7 @@
#include <cstdlib>
#include "test_io.h"
+#include "databuffer.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
@@ -28,6 +29,7 @@ int main(int argc, char** argv) {
++i;
} else if (!strcmp(argv[i], "-v")) {
g_ssl_gtest_verbose = true;
+ nss_test::DataBuffer::SetLogLimit(16384);
}
}
diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
index 8cd7d1009..e2a8d830a 100644
--- a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
+++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp
@@ -11,6 +11,7 @@
'target_name': 'ssl_gtest',
'type': 'executable',
'sources': [
+ 'bloomfilter_unittest.cc',
'libssl_internals.c',
'selfencrypt_unittest.cc',
'ssl_0rtt_unittest.cc',
@@ -18,6 +19,7 @@
'ssl_auth_unittest.cc',
'ssl_cert_ext_unittest.cc',
'ssl_ciphersuite_unittest.cc',
+ 'ssl_custext_unittest.cc',
'ssl_damage_unittest.cc',
'ssl_dhe_unittest.cc',
'ssl_drop_unittest.cc',
@@ -30,11 +32,16 @@
'ssl_gather_unittest.cc',
'ssl_gtest.cc',
'ssl_hrr_unittest.cc',
+ 'ssl_keylog_unittest.cc',
+ 'ssl_keyupdate_unittest.cc',
'ssl_loopback_unittest.cc',
+ 'ssl_misc_unittest.cc',
'ssl_record_unittest.cc',
'ssl_resumption_unittest.cc',
+ 'ssl_renegotiation_unittest.cc',
'ssl_skip_unittest.cc',
'ssl_staticrsa_unittest.cc',
+ 'ssl_tls13compat_unittest.cc',
'ssl_v2_client_hello_unittest.cc',
'ssl_version_unittest.cc',
'ssl_versionpolicy_unittest.cc',
diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
index 39055f641..05ae87034 100644
--- a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -35,17 +35,15 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Send first ClientHello and send 0-RTT data
auto capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
client_->Handshake();
EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
k0RttDataLen)); // 0-RTT write.
EXPECT_TRUE(capture_early_data->captured());
// Send the HelloRetryRequest
- auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(hrr_capture);
+ auto hrr_capture = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
server_->Handshake();
EXPECT_LT(0U, hrr_capture->buffer().len());
@@ -56,8 +54,7 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// Make a new capture for the early data.
capture_early_data =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
- client_->SetPacketFilter(capture_early_data);
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_early_data_xtn);
// Complete the handshake successfully
Handshake();
@@ -71,6 +68,10 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
// packet. If the record is split into two packets, or there are multiple
// handshake packets, this will break.
class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
+ public:
+ CorrectMessageSeqAfterHrrFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
+
protected:
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record, size_t* offset,
@@ -131,8 +132,7 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
// Correct the DTLS message sequence number after an HRR.
if (variant_ == ssl_variant_datagram) {
- client_->SetPacketFilter(
- std::make_shared<CorrectMessageSeqAfterHrrFilter>());
+ MakeTlsFilter<CorrectMessageSeqAfterHrrFilter>(client_);
}
server_->SetPeer(client_);
@@ -151,7 +151,8 @@ TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
class KeyShareReplayer : public TlsExtensionFilter {
public:
- KeyShareReplayer() {}
+ KeyShareReplayer(const std::shared_ptr<TlsAgent>& agent)
+ : TlsExtensionFilter(agent) {}
virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
@@ -178,7 +179,22 @@ class KeyShareReplayer : public TlsExtensionFilter {
// server should reject this.
TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EnsureTlsSetup();
- client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
+ MakeTlsFilter<KeyShareReplayer>(client_);
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
+// Here we modify the second ClientHello so that the client retries with the
+// same shares, even though the server wanted something else.
+TEST_P(TlsConnectTls13, RetryWithTwoShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+ MakeTlsFilter<KeyShareReplayer>(client_);
+
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
@@ -187,13 +203,574 @@ TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
}
+TEST_P(TlsConnectTls13, RetryCallbackAccept) {
+ EnsureTlsSetup();
+
+ auto accept_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ accept_hello, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackAcceptGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto accept_hello_twice = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_accept;
+ };
+
+ auto capture =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
+ capture->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_run = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), accept_hello_twice, &cb_run));
+ Connect();
+ EXPECT_EQ(2U, cb_run);
+ EXPECT_TRUE(capture->captured()) << "expected a cookie in HelloRetryRequest";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackFail) {
+ EnsureTlsSetup();
+
+ auto fail_hello = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(0U, clientTokenLen);
+ return ssl_hello_retry_fail;
+ };
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ fail_hello, &cb_run));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
+ server_->CheckErrorCode(SSL_ERROR_APPLICATION_ABORT);
+ EXPECT_TRUE(cb_run);
+}
+
+// Asking for retry twice isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackRequestHrrTwice) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ return ssl_hello_retry_request;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// Accepting the CH and modifying the token isn't allowed.
+TEST_P(TlsConnectTls13, RetryCallbackAcceptAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// As above, but with reject.
+TEST_P(TlsConnectTls13, RetryCallbackRejectAndSetToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = 1;
+ return ssl_hello_retry_fail;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+// This is a (pretend) buffer overflow.
+TEST_P(TlsConnectTls13, RetryCallbackSetTooLargeToken) {
+ EnsureTlsSetup();
+
+ auto bad_callback = [](PRBool firstHello, const PRUint8* clientToken,
+ unsigned int clientTokenLen, PRUint8* appToken,
+ unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) -> SSLHelloRetryRequestAction {
+ *appTokenLen = appTokenMax + 1;
+ return ssl_hello_retry_accept;
+ };
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ bad_callback, NULL));
+ ConnectExpectAlert(server_, kTlsAlertInternalError);
+ server_->CheckErrorCode(SSL_ERROR_APP_CALLBACK_ERROR);
+}
+
+SSLHelloRetryRequestAction RetryHello(PRBool firstHello,
+ const PRUint8* clientToken,
+ unsigned int clientTokenLen,
+ PRUint8* appToken,
+ unsigned int* appTokenLen,
+ unsigned int appTokenMax, void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ EXPECT_EQ(0U, clientTokenLen);
+ return firstHello ? ssl_hello_retry_request : ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetry) {
+ EnsureTlsSetup();
+
+ auto capture_hrr = std::make_shared<TlsHandshakeRecorder>(
+ server_, ssl_hs_hello_retry_request);
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ std::vector<std::shared_ptr<PacketFilter>> chain = {capture_hrr,
+ capture_key_share};
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(chain));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_LT(0U, capture_hrr->buffer().len()) << "HelloRetryRequest expected";
+ EXPECT_FALSE(capture_key_share->captured())
+ << "no key_share extension expected";
+
+ auto capture_cookie =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_cookie_xtn);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "should have a cookie";
+}
+
+static size_t CountShares(const DataBuffer& key_share) {
+ size_t count = 0;
+ uint32_t len = 0;
+ size_t offset = 2;
+
+ EXPECT_TRUE(key_share.Read(0, 2, &len));
+ EXPECT_EQ(key_share.len() - 2, len);
+ while (offset < key_share.len()) {
+ offset += 2; // Skip KeyShareEntry.group
+ EXPECT_TRUE(key_share.Read(offset, 2, &len));
+ offset += 2 + len; // Skip KeyShareEntry.key_exchange
+ ++count;
+ }
+ return count;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithAdditionalShares) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_FALSE(capture_server->captured())
+ << "no key_share extension expected from server";
+
+ auto capture_client_2nd =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_client_2nd->captured()) << "client should send key_share";
+ EXPECT_EQ(2U, CountShares(capture_client_2nd->extension()))
+ << "client should still send two shares";
+}
+
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants a P-384 key share and the client didn't offer one.
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithGroupMismatch) {
+ EnsureTlsSetup();
+
+ auto capture_cookie =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_cookie_xtn);
+ capture_cookie->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ auto capture_key_share =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_cookie, capture_key_share}));
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_cookie->captured()) << "cookie expected";
+ EXPECT_TRUE(capture_key_share->captured()) << "key_share expected";
+}
+
+static const uint8_t kApplicationToken[] = {0x92, 0x44, 0x00};
+
+SSLHelloRetryRequestAction RetryHelloWithToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<size_t*>(arg);
+ ++*called;
+
+ if (firstHello) {
+ memcpy(appToken, kApplicationToken, sizeof(kApplicationToken));
+ *appTokenLen = sizeof(kApplicationToken);
+ return ssl_hello_retry_request;
+ }
+
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithToken) {
+ EnsureTlsSetup();
+
+ auto capture_key_share =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_FALSE(capture_key_share->captured()) << "no key share expected";
+}
+
+TEST_P(TlsConnectTls13, RetryCallbackRetryWithTokenAndGroupMismatch) {
+ EnsureTlsSetup();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ auto capture_key_share =
+ MakeTlsFilter<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_key_share->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess,
+ SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHelloWithToken, &cb_called));
+ Connect();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(capture_key_share->captured()) << "key share expected";
+}
+
+SSLHelloRetryRequestAction CheckTicketToken(
+ PRBool firstHello, const PRUint8* clientToken, unsigned int clientTokenLen,
+ PRUint8* appToken, unsigned int* appTokenLen, unsigned int appTokenMax,
+ void* arg) {
+ auto* called = reinterpret_cast<bool*>(arg);
+ *called = true;
+
+ EXPECT_TRUE(firstHello);
+ EXPECT_EQ(DataBuffer(kApplicationToken, sizeof(kApplicationToken)),
+ DataBuffer(clientToken, static_cast<size_t>(clientTokenLen)));
+ return ssl_hello_retry_accept;
+}
+
+// Stream because SSL_SendSessionTicket only supports that.
+TEST_F(TlsConnectStreamTls13, RetryCallbackWithSessionTicketToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), kApplicationToken,
+ sizeof(kApplicationToken)));
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+
+ bool cb_run = false;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(
+ server_->ssl_fd(), CheckTicketToken, &cb_run));
+ Connect();
+ EXPECT_TRUE(cb_run);
+}
+
+void TriggerHelloRetryRequest(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server) {
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Start the handshake.
+ client->StartConnect();
+ server->StartConnect();
+ client->Handshake();
+ server->Handshake();
+ EXPECT_EQ(1U, cb_called);
+}
+
+TEST_P(TlsConnectTls13, RetryStateless) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ Handshake();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, RetryStatefulDropCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeTlsFilter<TlsExtensionDropper>(client_, ssl_tls13_cookie_xtn);
+
+ ExpectAlert(server_, kTlsAlertMissingExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_COOKIE_EXTENSION);
+}
+
+// Stream only because DTLS drops bad packets.
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageFirstClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto damage_ch =
+ MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_F(TlsConnectStreamTls13, RetryStatelessDamageSecondClientHello) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ auto damage_ch =
+ MakeTlsFilter<TlsExtensionInjector>(client_, 0xfff3, DataBuffer());
+
+ // Key exchange fails when the handshake continues because client and server
+ // disagree about the transcript.
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+// Read the cipher suite from the HRR and disable it on the identified agent.
+static void DisableSuiteFromHrr(
+ std::shared_ptr<TlsAgent>& agent,
+ std::shared_ptr<TlsHandshakeRecorder>& capture_hrr) {
+ uint32_t tmp;
+ size_t offset = 2 + 32; // skip version + server_random
+ ASSERT_TRUE(
+ capture_hrr->buffer().Read(offset, 1, &tmp)); // session_id length
+ EXPECT_EQ(0U, tmp);
+ offset += 1 + tmp;
+ ASSERT_TRUE(capture_hrr->buffer().Read(offset, 2, &tmp)); // suite
+ EXPECT_EQ(
+ SECSuccess,
+ SSL_CipherPrefSet(agent->ssl_fd(), static_cast<uint16_t>(tmp), PR_FALSE));
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(client_, capture_hrr);
+
+ // The client thinks that the HelloRetryRequest is bad, even though its
+ // because it changed its mind about the cipher suite.
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableSuiteServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ auto capture_hrr =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_hello_retry_request);
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ DisableSuiteFromHrr(server_, capture_hrr);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupClient) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(groups);
+
+ // We're into undefined behavior on the client side, but - at the point this
+ // test was written - the client here doesn't amend its key shares because the
+ // server doesn't ask it to. The server notices that the key share (x25519)
+ // doesn't match the negotiated group (P-384) and objects.
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessDisableGroupServer) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+ MakeNewServer();
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectTls13, RetryStatelessBadCookie) {
+ ConfigureSelfEncrypt();
+ EnsureTlsSetup();
+
+ TriggerHelloRetryRequest(client_, server_);
+
+ // Now replace the self-encrypt MAC key with a garbage key.
+ static const uint8_t bad_hmac_key[32] = {0};
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(bad_hmac_key),
+ sizeof(bad_hmac_key)};
+ ScopedPK11SlotInfo slot(PK11_GetInternalSlot());
+ PK11SymKey* hmac_key =
+ PK11_ImportSymKey(slot.get(), CKM_SHA256_HMAC, PK11_OriginUnwrap,
+ CKA_SIGN, &key_item, nullptr);
+ ASSERT_NE(nullptr, hmac_key);
+ SSLInt_SetSelfEncryptMacKey(hmac_key); // Passes ownership.
+
+ MakeNewServer();
+
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+ Handshake();
+ server_->CheckErrorCode(SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Stream because the server doesn't consume the alert and terminate.
+TEST_F(TlsConnectStreamTls13, RetryWithDifferentCipherSuite) {
+ EnsureTlsSetup();
+ // Force a HelloRetryRequest.
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+ // Then switch out the default suite (TLS_AES_128_GCM_SHA256).
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(server_,
+ TLS_CHACHA20_POLY1305_SHA256);
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
// This tests that the second attempt at sending a ClientHello (after receiving
// a HelloRetryRequest) is correctly retransmitted.
TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
- server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
+ server_->SetFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
@@ -233,6 +810,54 @@ TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) {
CheckKEXDetails(client_groups, client_groups);
}
+// The callback should be run even if we have another reason to send
+// HelloRetryRequest. In this case, the server sends HRR because the server
+// wants an X25519 key share and the client didn't offer one.
+TEST_P(TlsKeyExchange13,
+ RetryCallbackRetryWithGroupMismatchAndAdditionalShares) {
+ EnsureKeyShareSetup();
+
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519};
+ server_->ConfigNamedGroups(server_groups);
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ auto capture_server =
+ std::make_shared<TlsExtensionCapture>(server_, ssl_tls13_key_share_xtn);
+ capture_server->SetHandshakeTypes({kTlsHandshakeHelloRetryRequest});
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{capture_hrr_, capture_server}));
+
+ size_t cb_called = 0;
+ EXPECT_EQ(SECSuccess, SSL_HelloRetryRequestCallback(server_->ssl_fd(),
+ RetryHello, &cb_called));
+
+ // Do the first message exchange.
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ EXPECT_EQ(1U, cb_called) << "callback should be called once here";
+ EXPECT_TRUE(capture_server->captured()) << "key_share extension expected";
+
+ uint32_t server_group = 0;
+ EXPECT_TRUE(capture_server->extension().Read(0, 2, &server_group));
+ EXPECT_EQ(ssl_grp_ec_curve25519, static_cast<SSLNamedGroup>(server_group));
+
+ Handshake();
+ CheckConnected();
+ EXPECT_EQ(2U, cb_called);
+ EXPECT_TRUE(shares_capture2_->captured()) << "client should send shares";
+
+ CheckKeys();
+ static const std::vector<SSLNamedGroup> client_shares(
+ client_groups.begin(), client_groups.begin() + 2);
+ CheckKEXDetails(client_groups, client_shares, server_groups[0]);
+}
+
TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
EnsureTlsSetup();
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -245,8 +870,7 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
static const std::vector<SSLNamedGroup> server_groups = {
ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(server_groups);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake();
@@ -276,15 +900,30 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient {
void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
uint32_t seq_num = 0) const {
DataBuffer hrr_data;
- hrr_data.Allocate(len + 4);
+ const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+
+ hrr_data.Allocate(len + 6);
size_t i = 0;
+ i = hrr_data.Write(i, 0x0303, 2);
+ i = hrr_data.Write(i, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random));
+ i = hrr_data.Write(i, static_cast<uint32_t>(0), 1); // session_id
+ i = hrr_data.Write(i, TLS_AES_128_GCM_SHA256, 2);
+ i = hrr_data.Write(i, ssl_compression_null, 1);
+ // Add extensions. First a length, which includes the supported version.
+ i = hrr_data.Write(i, static_cast<uint32_t>(len) + 6, 2);
+ // Now the supported version.
+ i = hrr_data.Write(i, ssl_tls13_supported_versions_xtn, 2);
+ i = hrr_data.Write(i, 2, 2);
i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2);
- i = hrr_data.Write(i, static_cast<uint32_t>(len), 2);
if (len) {
hrr_data.Write(i, body, len);
}
DataBuffer hrr;
- MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(),
+ MakeHandshakeMessage(kTlsHandshakeServerHello, hrr_data.data(),
hrr_data.len(), &hrr, seq_num);
MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(),
hrr.len(), hrr_record, seq_num);
@@ -334,28 +973,6 @@ TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) {
SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
}
-TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) {
- const uint8_t canned_cookie_hrr[] = {
- static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8),
- static_cast<uint8_t>(ssl_tls13_cookie_xtn),
- 0,
- 5, // length of cookie extension
- 0,
- 3, // cookie value length
- 0xc0,
- 0x0c,
- 0x13};
- DataBuffer hrr;
- MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr);
- auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
- agent_->SetPacketFilter(capture);
- ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
- const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len
- DataBuffer cookie(canned_cookie_hrr + cookie_pos,
- sizeof(canned_cookie_hrr) - cookie_pos);
- EXPECT_EQ(cookie, capture->extension());
-}
-
INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest,
::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV13));
diff --git a/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
new file mode 100644
index 000000000..322b64837
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keylog_unittest.cc
@@ -0,0 +1,118 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifdef NSS_ALLOW_SSLKEYLOGFILE
+
+#include <cstdlib>
+#include <fstream>
+#include <sstream>
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static const std::string keylog_file_path = "keylog.txt";
+static const std::string keylog_env = "SSLKEYLOGFILE=" + keylog_file_path;
+
+class KeyLogFileTest : public TlsConnectGeneric {
+ public:
+ void SetUp() override {
+ TlsConnectGeneric::SetUp();
+ // Remove previous results (if any).
+ (void)remove(keylog_file_path.c_str());
+ PR_SetEnv(keylog_env.c_str());
+ }
+
+ void CheckKeyLog() {
+ std::ifstream f(keylog_file_path);
+ std::map<std::string, size_t> labels;
+ std::set<std::string> client_randoms;
+ for (std::string line; std::getline(f, line);) {
+ if (line[0] == '#') {
+ continue;
+ }
+
+ std::istringstream iss(line);
+ std::string label, client_random, secret;
+ iss >> label >> client_random >> secret;
+
+ ASSERT_EQ(64U, client_random.size());
+ client_randoms.insert(client_random);
+ labels[label]++;
+ }
+
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(1U, client_randoms.size());
+ } else {
+ /* two handshakes for 0-RTT */
+ ASSERT_EQ(2U, client_randoms.size());
+ }
+
+ // Every entry occurs twice (one log from server, one from client).
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_EQ(2U, labels["CLIENT_RANDOM"]);
+ } else {
+ ASSERT_EQ(2U, labels["CLIENT_EARLY_TRAFFIC_SECRET"]);
+ ASSERT_EQ(2U, labels["EARLY_EXPORTER_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["SERVER_HANDSHAKE_TRAFFIC_SECRET"]);
+ ASSERT_EQ(4U, labels["CLIENT_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["SERVER_TRAFFIC_SECRET_0"]);
+ ASSERT_EQ(4U, labels["EXPORTER_SECRET"]);
+ }
+ }
+
+ void ConnectAndCheck() {
+ // This is a child process, ensure that error messages immediately
+ // propagate or else it will not be visible.
+ ::testing::GTEST_FLAG(throw_on_failure) = true;
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+ } else {
+ Connect();
+ }
+ CheckKeyLog();
+ _exit(0);
+ }
+};
+
+// Tests are run in a separate process to ensure that NSS is not initialized yet
+// and can process the SSLKEYLOGFILE environment variable.
+
+TEST_P(KeyLogFileTest, KeyLogFile) {
+ testing::GTEST_FLAG(death_test_style) = "threadsafe";
+
+ ASSERT_EXIT(ConnectAndCheck(), ::testing::ExitedWithCode(0), "");
+}
+
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileDTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileTLS12, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(
+ KeyLogFileTLS13, KeyLogFileTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
+
+#endif // NSS_ALLOW_SSLKEYLOGFILE
diff --git a/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
new file mode 100644
index 000000000..d03775c25
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_keyupdate_unittest.cc
@@ -0,0 +1,178 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// All stream only tests; DTLS isn't supported yet.
+
+TEST_F(TlsConnectTest, KeyUpdateClient) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 3);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateClientRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ // SendReceive() only gives each peer one chance to read. This isn't enough
+ // when the read on one side generates another handshake message. A second
+ // read gives each peer an extra chance to consume the KeyUpdate.
+ SendReceive(50);
+ SendReceive(60); // Cumulative count.
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServer) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(3, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateServerRequestUpdate) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(4, 4);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // The server should have updated twice, but the client should have declined
+ // to respond to the second request from the server, since it doesn't send
+ // anything in between those two requests.
+ CheckEpochs(4, 5);
+}
+
+// Check that a local update can be immediately followed by a remotely triggered
+// update even if there is no use of the keys.
+TEST_F(TlsConnectTest, KeyUpdateLocalUpdateThenConsecutiveRequests) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // This should trigger an update on the client.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ // The client should update for the first request.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ // ...but not the second.
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ // Both should have updated twice.
+ CheckEpochs(5, 5);
+}
+
+TEST_F(TlsConnectTest, KeyUpdateMultiple) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_FALSE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_FALSE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 6);
+}
+
+// Both ask the other for an update, and both should react.
+TEST_F(TlsConnectTest, KeyUpdateBothRequest) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(client_->ssl_fd(), PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_KeyUpdate(server_->ssl_fd(), PR_TRUE));
+ SendReceive(50);
+ SendReceive(60);
+ CheckEpochs(5, 5);
+}
+
+// If the sequence number exceeds the number of writes before an automatic
+// update (currently 3/4 of the max records for the cipher suite), then the
+// stack should send an update automatically (but not request one).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnWrite) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Set this to one below the write threshold.
+ uint64_t threshold = (0x5aULL << 28) * 3 / 4;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should be OK.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // This should cause the client to update.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ SendReceive(100);
+ CheckEpochs(4, 3);
+}
+
+// If the sequence number exceeds a certain number of reads (currently 7/8 of
+// the max records for the cipher suite), then the stack should send AND request
+// an update automatically. However, the sender (client) will be above its
+// automatic update threshold, so the KeyUpdate - that it sends with the old
+// cipher spec - will exceed the receiver (server) automatic update threshold.
+// The receiver gets a packet with a sequence number over its automatic read
+// update threshold. Even though the sender has updated, the code that checks
+// the sequence numbers at the receiver doesn't know this and it will request an
+// update. This causes two updates: one from the sender (without requesting a
+// response) and one from the receiver (which does request a response).
+TEST_F(TlsConnectTest, KeyUpdateAutomaticOnRead) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithCipherSuite(TLS_AES_128_GCM_SHA256);
+
+ // Move to right at the read threshold. Unlike the write test, we can't send
+ // packets because that would cause the client to update, which would spoil
+ // the test.
+ uint64_t threshold = ((0x5aULL << 28) * 7 / 8) + 1;
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), threshold));
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), threshold));
+
+ // This should cause the client to update, but not early enough to prevent the
+ // server from updating also.
+ client_->SendData(10);
+ server_->ReadBytes();
+
+ // Need two SendReceive() calls to ensure that the update that the server
+ // requested is properly generated and consumed.
+ SendReceive(70);
+ SendReceive(80);
+ CheckEpochs(5, 4);
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
index 77703dd8e..f1b78f52f 100644
--- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -6,6 +6,7 @@
#include <functional>
#include <memory>
+#include <vector>
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
@@ -55,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
class TlsAlertRecorder : public TlsRecordFilter {
public:
- TlsAlertRecorder() : level_(255), description_(255) {}
+ TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), level_(255), description_(255) {}
PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -84,13 +86,13 @@ class TlsAlertRecorder : public TlsRecordFilter {
};
class HelloTruncator : public TlsHandshakeFilter {
+ public:
+ HelloTruncator(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(
+ agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {}
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeClientHello &&
- header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
output->Assign(input.data(), input.len() - 1);
return CHANGE;
}
@@ -98,19 +100,17 @@ class HelloTruncator : public TlsHandshakeFilter {
// Verify that when NSS reports that an alert is sent, it is actually sent.
TEST_P(TlsConnectGeneric, CaptureAlertServer) {
- client_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- server_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(client_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_);
- ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ ConnectExpectAlert(server_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
}
TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
ConnectExpectAlert(client_, kTlsAlertDecodeError);
EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
@@ -119,12 +119,10 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
// In TLS 1.3, the server can't read the client alert.
TEST_P(TlsConnectTls13, CaptureAlertClient) {
- server_->SetPacketFilter(std::make_shared<HelloTruncator>());
- auto alert_recorder = std::make_shared<TlsAlertRecorder>();
- client_->SetPacketFilter(alert_recorder);
+ MakeTlsFilter<HelloTruncator>(server_);
+ auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->Handshake();
client_->ExpectSendAlert(kTlsAlertDecodeError);
@@ -166,26 +164,111 @@ TEST_P(TlsConnectDatagram, ConnectSrtp) {
SendReceive();
}
-// 1.3 is disabled in the next few tests because we don't
-// presently support resumption in 1.3.
-TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) {
+TEST_P(TlsConnectGeneric, ConnectSendReceive) {
Connect();
- server_->PrepareForRenegotiate();
- client_->StartRenegotiate();
- Handshake();
- CheckConnected();
+ SendReceive();
}
-TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) {
+class SaveTlsRecord : public TlsRecordFilter {
+ public:
+ SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0), contents_() {}
+
+ const DataBuffer& contents() const { return contents_; }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ contents_ = data;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+ DataBuffer contents_;
+};
+
+// Check that decrypting filters work and can read any record.
+// This test (currently) only works in TLS 1.3 where we can decrypt.
+TEST_F(TlsConnectStreamTls13, DecryptRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3);
+ saved->EnableDecryption();
Connect();
- client_->PrepareForRenegotiate();
- server_->StartRenegotiate();
- Handshake();
- CheckConnected();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xdc};
+ DataBuffer buf(data, sizeof(data));
+ client_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
}
-TEST_P(TlsConnectGeneric, ConnectSendReceive) {
+TEST_F(TlsConnectStreamTls13, DecryptRecordServer) {
+ EnsureTlsSetup();
+ // Disable tickets so that we are sure to not get NewSessionTicket.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+ // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer
+ auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3);
+ saved->EnableDecryption();
+ Connect();
+ SendReceive();
+
+ static const uint8_t data[] = {0xde, 0xad, 0xd5};
+ DataBuffer buf(data, sizeof(data));
+ server_->SendBuffer(buf);
+ EXPECT_EQ(buf, saved->contents());
+}
+
+class DropTlsRecord : public TlsRecordFilter {
+ public:
+ DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index)
+ : TlsRecordFilter(agent), index_(index), count_(0) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (count_++ == index_) {
+ return DROP;
+ }
+ return KEEP;
+ }
+
+ private:
+ const size_t index_;
+ size_t count_;
+};
+
+// Test that decrypting filters work correctly and are able to drop records.
+TEST_F(TlsConnectStreamTls13, DropRecordServer) {
+ EnsureTlsSetup();
+ // Disable session tickets so that the server doesn't send an extra record.
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ // 0 = ServerHello, 1 = other handshake, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2);
+ filter->EnableDecryption();
+ Connect();
+ server_->SendData(23, 23); // This should be dropped, so it won't be counted.
+ server_->ResetSentBytes();
+ SendReceive();
+}
+
+TEST_F(TlsConnectStreamTls13, DropRecordClient) {
+ EnsureTlsSetup();
+ // 0 = ClientHello, 1 = Finished, 2 = first write
+ auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2);
+ filter->EnableDecryption();
Connect();
+ client_->SendData(26, 26); // This should be dropped, so it won't be counted.
+ client_->ResetSentBytes();
SendReceive();
}
@@ -224,32 +307,74 @@ TEST_P(TlsConnectStream, ShortRead) {
ASSERT_EQ(50U, client_->received_bytes());
}
-TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) {
+// We enable compression via the API but it's disabled internally,
+// so we should never get it.
+TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) {
EnsureTlsSetup();
- client_->EnableCompression();
- server_->EnableCompression();
+ client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
+ server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE);
Connect();
- EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
- variant_ != ssl_variant_datagram,
- client_->is_compressed());
+ EXPECT_FALSE(client_->is_compressed());
SendReceive();
}
-TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
+class TlsHolddownTest : public TlsConnectDatagram {
+ protected:
+ // This causes all timers to run to completion. It advances the clock and
+ // handshakes on both peers until both peers have no more timers pending,
+ // which should happen at the end of a handshake. This is necessary to ensure
+ // that the relatively long holddown timer expires, but that any other timers
+ // also expire and run correctly.
+ void RunAllTimersDown() {
+ while (true) {
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv != SECSuccess) {
+ break; // Neither peer has an outstanding timer.
+ }
+ }
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Shifting timers" << std::endl;
+ }
+ ShiftDtlsTimers();
+ Handshake();
+ }
+ }
+};
+
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) {
Connect();
- std::cerr << "Expiring holddown timer\n";
- SSLInt_ForceTimerExpiry(client_->ssl_fd());
- SSLInt_ForceTimerExpiry(server_->ssl_fd());
+ std::cerr << "Expiring holddown timer" << std::endl;
+ RunAllTimersDown();
SendReceive();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
// One for send, one for receive.
- EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd()));
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
}
}
+TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ RunAllTimersDown();
+ SendReceive();
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd()));
+}
+
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
- TlsPreCCSHeaderInjector() {}
+ TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
size_t* offset, DataBuffer* output) override {
@@ -266,16 +391,15 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
};
TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
- client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(client_);
ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
}
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
- server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
- client_->StartConnect();
- server_->StartConnect();
+ MakeTlsFilter<TlsPreCCSHeaderInjector>(server_);
+ StartConnect();
ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
@@ -306,21 +430,64 @@ TEST_P(TlsConnectTls13, AlertWrongLevel) {
TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
EnsureTlsSetup();
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake();
server_->Handshake(); // Send first flight.
- client_->adapter()->CloseWrites();
+ client_->adapter()->SetWriteError(PR_IO_ERROR);
client_->Handshake(); // This will get an error, but shouldn't crash.
client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
}
-TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) {
- client_->SetShortHeadersEnabled();
- server_->SetShortHeadersEnabled();
- client_->ExpectShortHeaders();
- server_->ExpectShortHeaders();
+TEST_P(TlsConnectDatagram, BlockedWrite) {
Connect();
+
+ // Mark the socket as blocked.
+ client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR);
+ static const uint8_t data[] = {1, 2, 3};
+ int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Remove the write error and though the previous write failed, future reads
+ // and writes should just work as if it never happened.
+ client_->adapter()->SetWriteError(0);
+ SendReceive();
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_3_0);
+ EnableOnlyStaticRsaCiphers();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) {
+ // MAC-then-Encrypt expansion formula:
+ return ((in + hmac + (block - 1)) / block) * block;
+}
+
+TEST_F(TlsConnectTest, OneNRecordSplitting) {
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0);
+ EnsureTlsSetup();
+ ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA);
+ auto records = MakeTlsFilter<TlsRecordRecorder>(server_);
+ // This should be split into 1, 16384 and 20.
+ DataBuffer big_buffer;
+ big_buffer.Allocate(1 + 16384 + 20);
+ server_->SendBuffer(big_buffer);
+ ASSERT_EQ(3U, records->count());
+ EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len());
+ EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len());
}
INSTANTIATE_TEST_CASE_P(
@@ -336,6 +503,8 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream,
TlsConnectTestBase::kTlsVAll);
INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram,
TlsConnectTestBase::kTlsV11Plus);
+INSTANTIATE_TEST_CASE_P(DatagramHolddown, TlsHolddownTest,
+ TlsConnectTestBase::kTlsV11Plus);
INSTANTIATE_TEST_CASE_P(
Pre12Stream, TlsConnectPre12,
@@ -368,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus,
::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus));
-} // namespace nspr_test
+INSTANTIATE_TEST_CASE_P(
+ GenericStream, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll,
+ ::testing::Values(true, false)));
+INSTANTIATE_TEST_CASE_P(
+ GenericDatagram, TlsConnectGenericResumption,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus,
+ ::testing::Values(true, false)));
+
+INSTANTIATE_TEST_CASE_P(
+ GenericStream, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(
+ GenericDatagram, TlsConnectGenericResumptionToken,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken,
+ TlsConnectTestBase::kTlsVariantsAll);
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
new file mode 100644
index 000000000..2b1b92dcd
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_misc_unittest.cc
@@ -0,0 +1,20 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslexp.h"
+
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+class MiscTest : public ::testing::Test {};
+
+TEST_F(MiscTest, NonExistentExperimentalAPI) {
+ EXPECT_EQ(nullptr, SSL_GetExperimentalAPI("blah"));
+ EXPECT_EQ(SSL_ERROR_UNSUPPORTED_EXPERIMENTAL_API, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
index ef81b222c..3b8727850 100644
--- a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -10,6 +10,8 @@
#include "databuffer.h"
#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
namespace nss_test {
@@ -51,8 +53,8 @@ class TlsPaddingTest
<< " total length=" << plaintext_.len() << std::endl;
std::cerr << "Plaintext: " << plaintext_ << std::endl;
sslBuffer s;
- s.buf = const_cast<unsigned char *>(
- static_cast<const unsigned char *>(plaintext_.data()));
+ s.buf = const_cast<unsigned char*>(
+ static_cast<const unsigned char*>(plaintext_.data()));
s.len = plaintext_.len();
SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
if (expect_success) {
@@ -99,6 +101,73 @@ TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
}
}
+class RecordReplacer : public TlsRecordFilter {
+ public:
+ RecordReplacer(const std::shared_ptr<TlsAgent>& agent, size_t size)
+ : TlsRecordFilter(agent), enabled_(false), size_(size) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ if (!enabled_) {
+ return KEEP;
+ }
+
+ EXPECT_EQ(kTlsApplicationDataType, header.content_type());
+ changed->Allocate(size_);
+
+ for (size_t i = 0; i < size_; ++i) {
+ changed->data()[i] = i & 0xff;
+ }
+
+ enabled_ = false;
+ return CHANGE;
+ }
+
+ void Enable() { enabled_ = true; }
+
+ private:
+ bool enabled_;
+ size_t size_;
+};
+
+TEST_F(TlsConnectStreamTls13, LargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit);
+ replacer->EnableDecryption();
+ Connect();
+
+ replacer->Enable();
+ client_->SendData(10);
+ WAIT_(server_->received_bytes() == record_limit, 2000);
+ ASSERT_EQ(record_limit, server_->received_bytes());
+}
+
+TEST_F(TlsConnectStreamTls13, TooLargeRecord) {
+ EnsureTlsSetup();
+
+ const size_t record_limit = 16384;
+ auto replacer = MakeTlsFilter<RecordReplacer>(client_, record_limit + 1);
+ replacer->EnableDecryption();
+ Connect();
+
+ replacer->Enable();
+ ExpectAlert(server_, kTlsAlertRecordOverflow);
+ client_->SendData(10); // This is expanded.
+
+ uint8_t buf[record_limit + 2];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RX_RECORD_TOO_LONG, PORT_GetError());
+
+ // Read the server alert.
+ rv = PR_Read(client_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_GT(0, rv);
+ EXPECT_EQ(SSL_ERROR_RECORD_OVERFLOW_ALERT, PORT_GetError());
+}
+
const static size_t kContentSizesArr[] = {
1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
@@ -108,4 +177,4 @@ auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest,
::testing::Combine(kContentSizes, kTrueFalse));
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
new file mode 100644
index 000000000..a902a5f7f
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_renegotiation_unittest.cc
@@ -0,0 +1,212 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+// 1.3 is disabled in the next few tests because we don't
+// presently support resumption in 1.3.
+TEST_P(TlsConnectStreamPre13, RenegotiateClient) {
+ Connect();
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectStreamPre13, RenegotiateServer) {
+ Connect();
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+// The renegotiation options shouldn't cause an error if TLS 1.3 is chosen.
+TEST_F(TlsConnectTest, RenegotiationConfigTls13) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetOption(SSL_ENABLE_RENEGOTIATION, SSL_RENEGOTIATE_UNRESTRICTED);
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ Connect();
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ client_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ server_->StartRenegotiate();
+
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ server_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ client_->StartRenegotiate();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+ }
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_EQ(SECFailure, rv);
+ return;
+ }
+ ASSERT_EQ(SECSuccess, rv);
+
+ // Now, before handshaking, tweak the server configuration.
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+
+ // The server should catch the own error.
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
+
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+TEST_P(TlsConnectStream, ConnectAndServerWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the server set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(server_->ssl_fd(), PR_TRUE));
+}
+
+TEST_P(TlsConnectStream, ConnectAndClientWontRenegotiateLower) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ Connect();
+
+ // Now renegotiate with the client set to TLS 1.0.
+ client_->PrepareForRenegotiate();
+ server_->PrepareForRenegotiate();
+ server_->ResetPreliminaryInfo();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // The client will refuse to renegotiate down.
+ EXPECT_EQ(SECFailure, SSL_ReHandshake(client_->ssl_fd(), PR_TRUE));
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
index ce0e3ca8d..eb78c0585 100644
--- a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -9,6 +9,7 @@
#include "secerr.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
extern "C" {
@@ -59,7 +60,7 @@ TEST_P(TlsConnectGenericPre13, ConnectResumed) {
Connect();
}
-TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) {
+TEST_P(TlsConnectGenericResumption, ConnectClientCacheDisabled) {
ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID);
Connect();
SendReceive();
@@ -70,7 +71,7 @@ TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) {
+TEST_P(TlsConnectGenericResumption, ConnectServerCacheDisabled) {
ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE);
Connect();
SendReceive();
@@ -81,7 +82,7 @@ TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) {
+TEST_P(TlsConnectGenericResumption, ConnectSessionCacheDisabled) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
SendReceive();
@@ -92,7 +93,7 @@ TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeSupportBoth) {
// This prefers tickets.
ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
Connect();
@@ -105,7 +106,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientTicketServerBoth) {
// This causes no resumption because the client needs the
// session cache to resume even with tickets.
ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
@@ -119,7 +120,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothTicketServerTicket) {
// This causes a ticket resumption.
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
Connect();
@@ -132,7 +133,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientServerTicketOnly) {
// This causes no resumption because the client needs the
// session cache to resume even with tickets.
ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
@@ -146,7 +147,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientBothServerNone) {
ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
Connect();
SendReceive();
@@ -158,7 +159,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectResumeClientNoneServerBoth) {
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientNoneServerBoth) {
ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
Connect();
SendReceive();
@@ -201,7 +202,7 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) {
SendReceive();
}
-TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) {
+TEST_P(TlsConnectGenericResumption, ConnectWithExpiredTicketAtClient) {
SSLInt_SetTicketLifetime(1); // one second
// This causes a ticket resumption.
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
@@ -218,8 +219,7 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn);
Connect();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
@@ -244,10 +244,8 @@ TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
? ssl_tls13_pre_shared_key_xtn
: ssl_session_ticket_xtn;
- auto capture = std::make_shared<TlsExtensionCapture>(xtn);
- client_->SetPacketFilter(capture);
- client_->StartConnect();
- server_->StartConnect();
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(client_, xtn);
+ StartConnect();
client_->Handshake();
EXPECT_TRUE(capture->captured());
EXPECT_LT(0U, capture->extension().len());
@@ -327,25 +325,23 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are the same.
EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
@@ -355,32 +351,25 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
// This test parses the ServerKeyExchange, which isn't in 1.3
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
- server_->EnsureTlsSetup();
- SECStatus rv =
- SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
- auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i1);
+ server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ auto filter = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe1;
- EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+ EXPECT_TRUE(dhe1.Parse(filter->buffer()));
// Restart
Reset();
- server_->EnsureTlsSetup();
- rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
- auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerKeyExchange);
- server_->SetPacketFilter(i2);
+ server_->SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ auto filter2 = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeServerKeyExchange);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
CheckKeys();
TlsServerKeyExchangeEcdhe dhe2;
- EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+ EXPECT_TRUE(dhe2.Parse(filter2->buffer()));
// Make sure they are different.
EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
@@ -401,7 +390,8 @@ TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) {
client_->ConfigNamedGroups(kFFDHEGroups);
server_->ConfigNamedGroups(kFFDHEGroups);
Connect();
- CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
}
// We need to enable different cipher suites at different times in the following
@@ -421,7 +411,7 @@ static uint16_t ChooseAnotherCipher(uint16_t version) {
}
// Test that we don't resume when we can't negotiate the same cipher.
-TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) {
+TEST_P(TlsConnectGenericResumption, TestResumeClientDifferentCipher) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
client_->EnableSingleCipher(ChooseOneCipher(version_));
Connect();
@@ -438,15 +428,15 @@ TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) {
} else {
ticket_extension = ssl_session_ticket_xtn;
}
- auto ticket_capture = std::make_shared<TlsExtensionCapture>(ticket_extension);
- client_->SetPacketFilter(ticket_capture);
+ auto ticket_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ticket_extension);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
EXPECT_EQ(0U, ticket_capture->extension().len());
}
// Test that we don't resume when we can't negotiate the same cipher.
-TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) {
+TEST_P(TlsConnectGenericResumption, TestResumeServerDifferentCipher) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
server_->EnableSingleCipher(ChooseOneCipher(version_));
Connect();
@@ -461,36 +451,6 @@ TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) {
CheckKeys();
}
-class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
- public:
- SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}
-
- protected:
- PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
-
- *output = input;
- uint32_t temp = 0;
- EXPECT_TRUE(input.Read(0, 2, &temp));
- // Cipher suite is after version(2) and random(32).
- size_t pos = 34;
- if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
- // In old versions, we have to skip a session_id too.
- EXPECT_TRUE(input.Read(pos, 1, &temp));
- pos += 1 + temp;
- }
- output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
- return CHANGE;
- }
-
- private:
- uint16_t cipher_suite_;
-};
-
// Test that the client doesn't tolerate the server picking a different cipher
// suite for resumption.
TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
@@ -502,8 +462,8 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
- ChooseAnotherCipher(version_)));
+ MakeTlsFilter<SelectedCipherSuiteReplacer>(server_,
+ ChooseAnotherCipher(version_));
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->ExpectSendAlert(kTlsAlertIllegalParameter);
@@ -524,16 +484,15 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
class SelectedVersionReplacer : public TlsHandshakeFilter {
public:
- SelectedVersionReplacer(uint16_t version) : version_(version) {}
+ SelectedVersionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ version_(version) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) override {
- if (header.handshake_type() != kTlsHandshakeServerHello) {
- return KEEP;
- }
-
*output = input;
output->Write(0, static_cast<uint32_t>(version_), 2);
return CHANGE;
@@ -580,8 +539,7 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
// Enable the lower version on the client.
client_->SetVersionRange(version_ - 1, version_);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
- server_->SetPacketFilter(
- std::make_shared<SelectedVersionReplacer>(override_version));
+ MakeTlsFilter<SelectedVersionReplacer>(server_, override_version);
ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
@@ -604,12 +562,12 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- auto c1 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c1);
+ auto c1 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_none);
+ ssl_sig_rsa_pss_rsae_sha256);
// The filter will go away when we reset, so save the captured extension.
DataBuffer initialTicket(c1->extension());
ASSERT_LT(0U, initialTicket.len());
@@ -621,13 +579,13 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ClearStats();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- auto c2 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
- client_->SetPacketFilter(c2);
+ auto c2 =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
ExpectResumption(RESUME_TICKET);
Connect();
SendReceive();
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
- ssl_sig_none);
+ ssl_sig_rsa_pss_rsae_sha256);
ASSERT_LT(0U, c2->extension().len());
ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
@@ -652,7 +610,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
// Clear the session ticket keys to invalidate the old ticket.
SSLInt_ClearSelfEncryptKey();
- SSLInt_SendNewSessionTicket(server_->ssl_fd());
+ SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0);
SendReceive(); // Need to read so that we absorb the session tickets.
CheckKeys();
@@ -666,6 +624,144 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
SendReceive();
}
+// Check that the value captured in a NewSessionTicket message matches the value
+// captured from a pre_shared_key extension.
+void NstTicketMatchesPskIdentity(const DataBuffer& nst, const DataBuffer& psk) {
+ uint32_t len;
+
+ size_t offset = 4 + 4; // Skip ticket_lifetime and ticket_age_add.
+ ASSERT_TRUE(nst.Read(offset, 1, &len));
+ offset += 1 + len; // Skip ticket_nonce.
+
+ ASSERT_TRUE(nst.Read(offset, 2, &len));
+ offset += 2; // Skip the ticket length.
+ ASSERT_LE(offset + len, nst.len());
+ DataBuffer nst_ticket(nst.data() + offset, static_cast<size_t>(len));
+
+ offset = 2; // Skip the identities length.
+ ASSERT_TRUE(psk.Read(offset, 2, &len));
+ offset += 2; // Skip the identity length.
+ ASSERT_LE(offset + len, psk.len());
+ DataBuffer psk_ticket(psk.data() + offset, static_cast<size_t>(len));
+
+ EXPECT_EQ(nst_ticket, psk_ticket);
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNSTWithToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+
+ // Clear the session ticket keys to invalidate the old ticket.
+ SSLInt_ClearSelfEncryptKey();
+ nst_capture->Reset();
+ uint8_t token[] = {0x20, 0x20, 0xff, 0x00};
+ EXPECT_EQ(SECSuccess,
+ SSL_SendSessionTicket(server_->ssl_fd(), token, sizeof(token)));
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+ EXPECT_LT(0U, nst_capture->buffer().len());
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Disable SSL_ENABLE_SESSION_TICKETS but ensure that tickets can still be sent
+// by invoking SSL_SendSessionTicket directly (and that the ticket is usable).
+TEST_F(TlsConnectTest, SendSessionTicketWithTicketsDisabled) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_ENABLE_SESSION_TICKETS, PR_FALSE));
+
+ auto nst_capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, ssl_hs_new_session_ticket);
+ nst_capture->EnableDecryption();
+ Connect();
+
+ EXPECT_EQ(0U, nst_capture->buffer().len()) << "expect nothing captured yet";
+
+ EXPECT_EQ(SECSuccess, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0));
+ EXPECT_LT(0U, nst_capture->buffer().len()) << "should capture now";
+
+ SendReceive(); // Ensure that the client reads the ticket.
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+
+ auto psk_capture =
+ MakeTlsFilter<TlsExtensionCapture>(client_, ssl_tls13_pre_shared_key_xtn);
+ Connect();
+ SendReceive();
+
+ NstTicketMatchesPskIdentity(nst_capture->buffer(), psk_capture->extension());
+}
+
+// Test calling SSL_SendSessionTicket in inappropriate conditions.
+TEST_F(TlsConnectTest, SendSessionTicketInappropriate) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(client_->ssl_fd(), NULL, 0))
+ << "clients can't send tickets";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ StartConnect();
+
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no ticket before the handshake has started";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+ Handshake();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no special tickets in TLS 1.2";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, SendSessionTicketMassiveToken) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ // It should be safe to set length with a NULL token because the length should
+ // be checked before reading token.
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0x1ffff))
+ << "this is clearly too big";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+
+ static const uint8_t big_token[0xffff] = {1};
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), big_token,
+ sizeof(big_token)))
+ << "this is too big, but that's not immediately obvious";
+ EXPECT_EQ(SEC_ERROR_INVALID_ARGS, PORT_GetError());
+}
+
+TEST_F(TlsConnectDatagram13, SendSessionTicketDtls) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ EXPECT_EQ(SECFailure, SSL_SendSessionTicket(server_->ssl_fd(), NULL, 0))
+ << "no extra tickets in DTLS until we have Ack support";
+ EXPECT_EQ(SSL_ERROR_FEATURE_NOT_SUPPORTED_FOR_VERSION, PORT_GetError());
+}
+
TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
@@ -716,16 +812,220 @@ TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
// We will eventually fail the (sid.version == SH.version) check.
std::vector<std::shared_ptr<PacketFilter>> filters;
filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>(
- TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ server_, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ filters.push_back(std::make_shared<SelectedVersionReplacer>(
+ server_, SSL_LIBRARY_VERSION_TLS_1_2));
+
+ // Drop a bunch of extensions so that we get past the SH processing. The
+ // version extension says TLS 1.3, which is counter to our goal, the others
+ // are not permitted in TLS 1.2 handshakes.
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_supported_versions_xtn));
filters.push_back(
- std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2));
- server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters));
-
- client_->ExpectSendAlert(kTlsAlertDecodeError);
+ std::make_shared<TlsExtensionDropper>(server_, ssl_tls13_key_share_xtn));
+ filters.push_back(std::make_shared<TlsExtensionDropper>(
+ server_, ssl_tls13_pre_shared_key_xtn));
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(filters));
+
+ // The client here generates an unexpected_message alert when it receives an
+ // encrypted handshake message from the server (EncryptedExtension). The
+ // client expects to receive an unencrypted TLS 1.2 Certificate message.
+ // The server can't decrypt the alert.
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read
ConnectExpectFail();
- client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
}
+TEST_P(TlsConnectGenericResumption, ReConnectTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, ReConnectCache) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+TEST_P(TlsConnectGenericResumption, ReConnectAgainTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_rsae_sha256);
+ // Resume
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+ // Resume connection again
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET, 2);
+ Connect();
+ // Only the client knows this.
+ CheckKeysResumption(ssl_kea_ecdh, ssl_grp_none, ssl_grp_ec_curve25519,
+ ssl_auth_rsa_sign, ssl_sig_rsa_pss_rsae_sha256);
+}
+
+void CheckGetInfoResult(uint32_t alpnSize, uint32_t earlyDataSize,
+ ScopedCERTCertificate& cert,
+ ScopedSSLResumptionTokenInfo& token) {
+ ASSERT_TRUE(cert);
+ ASSERT_TRUE(token->peerCert);
+
+ // Check that the server cert is the correct one.
+ ASSERT_EQ(cert->derCert.len, token->peerCert->derCert.len);
+ EXPECT_EQ(0, memcmp(cert->derCert.data, token->peerCert->derCert.data,
+ cert->derCert.len));
+
+ ASSERT_EQ(alpnSize, token->alpnSelectionLen);
+ EXPECT_EQ(0, memcmp("a", token->alpnSelection, token->alpnSelectionLen));
+
+ ASSERT_EQ(earlyDataSize, token->maxEarlyDataSize);
+}
+
+TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfo) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+
+ CheckGetInfoResult(0, 0, cert, token);
+
+ Handshake();
+ CheckConnected();
+
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericResumptionToken, ConnectResumeGetInfoAlpn) {
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ CheckAlpn("a");
+ SendReceive();
+
+ Reset();
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+
+ CheckGetInfoResult(1, 0, cert, token);
+
+ Handshake();
+ CheckConnected();
+ CheckAlpn("a");
+
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13ResumptionToken, ConnectResumeGetInfoZeroRtt) {
+ EnableAlpn();
+ SSLInt_RolloverAntiReplay();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ server_->Set0RttEnabled(true);
+ Connect();
+ CheckAlpn("a");
+ SendReceive();
+
+ Reset();
+ EnableAlpn();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+
+ StartConnect();
+ server_->Set0RttEnabled(true);
+ client_->Set0RttEnabled(true);
+ ASSERT_TRUE(client_->MaybeSetResumptionToken());
+
+ // Get resumption token infos
+ SSLResumptionTokenInfo tokenInfo = {0};
+ ScopedSSLResumptionTokenInfo token(&tokenInfo);
+ client_->GetTokenInfo(token);
+ ScopedCERTCertificate cert(
+ PK11_FindCertFromNickname(server_->name().c_str(), nullptr));
+
+ CheckGetInfoResult(1, 1024, cert, token);
+
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ CheckAlpn("a");
+
+ SendReceive();
+}
+
+// Resumption on sessions with client authentication only works with internal
+// caching.
+TEST_P(TlsConnectGenericResumption, ConnectResumeClientAuth) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ SendReceive();
+ EXPECT_FALSE(client_->resumption_callback_called());
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ if (use_external_cache()) {
+ ExpectResumption(RESUME_NONE);
+ } else {
+ ExpectResumption(RESUME_TICKET);
+ }
+ Connect();
+ SendReceive();
+}
+
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
index a130ef77f..e4a9e5aed 100644
--- a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -22,8 +22,11 @@ namespace nss_test {
class TlsHandshakeSkipFilter : public TlsRecordFilter {
public:
// A TLS record filter that skips handshake messages of the identified type.
- TlsHandshakeSkipFilter(uint8_t handshake_type)
- : handshake_type_(handshake_type), skipped_(false) {}
+ TlsHandshakeSkipFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsRecordFilter(agent),
+ handshake_type_(handshake_type),
+ skipped_(false) {}
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
@@ -43,7 +46,14 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
size_t start = parser.consumed();
TlsHandshakeFilter::HandshakeHeader header;
DataBuffer ignored;
- if (!header.Parse(&parser, record_header, &ignored)) {
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, DataBuffer(), &ignored,
+ &complete)) {
+ ADD_FAILURE() << "Error parsing handshake header";
+ return KEEP;
+ }
+ if (!complete) {
+ ADD_FAILURE() << "Don't want to deal with fragmented input";
return KEEP;
}
@@ -85,9 +95,14 @@ class TlsSkipTest : public TlsConnectTestBase,
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
+ EnsureTlsSetup();
+ }
+
void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- server_->SetPacketFilter(filter);
+ server_->SetFilter(filter);
ConnectExpectAlert(client_, alert);
}
};
@@ -98,29 +113,23 @@ class Tls13SkipTest : public TlsConnectTestBase,
Tls13SkipTest()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
- void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ void SetUp() override {
+ TlsConnectTestBase::SetUp();
EnsureTlsSetup();
- server_->SetTlsRecordFilter(filter);
+ }
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
filter->EnableDecryption();
- client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
- if (variant_ == ssl_variant_stream) {
- server_->ExpectSendAlert(kTlsAlertBadRecordMac);
- ConnectExpectFail();
- } else {
- ConnectExpectFailOneSide(TlsAgent::CLIENT);
- }
+ server_->SetFilter(filter);
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
+ ConnectExpectFail();
client_->CheckErrorCode(error);
- if (variant_ == ssl_variant_stream) {
- server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- } else {
- ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
- }
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
}
void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
- EnsureTlsSetup();
- client_->SetTlsRecordFilter(filter);
filter->EnableDecryption();
+ client_->SetFilter(filter);
server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
ConnectExpectFailOneSide(TlsAgent::SERVER);
@@ -133,49 +142,49 @@ class Tls13SkipTest : public TlsConnectTestBase,
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit{std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange)});
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
@@ -183,48 +192,48 @@ TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
auto chain = std::make_shared<ChainedPacketFilter>();
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
- chain->Add(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate));
+ chain->Add(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
- kTlsHandshakeEncryptedExtensions),
+ server_, kTlsHandshakeEncryptedExtensions),
SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
}
TEST_P(Tls13SkipTest, SkipServerCertificate) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
- ServerSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ server_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
TEST_P(Tls13SkipTest, SkipClientCertificate) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
- SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
- ClientSkipTest(
- std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
- SSL_ERROR_RX_UNEXPECTED_FINISHED);
+ ClientSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ client_, kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
INSTANTIATE_TEST_CASE_P(
diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
index 8db1f30e1..e5fccc12b 100644
--- a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -48,10 +48,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
- auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(i1);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -59,8 +58,7 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -69,9 +67,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
// ConnectStaticRSABogusPMSVersionDetect.
TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- server_->DisableRollbackDetection();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
@@ -79,10 +76,9 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
- kTlsHandshakeClientKeyExchange,
+ MakeTlsFilter<TlsInspectorReplaceHandshakeMessage>(
+ client_, kTlsHandshakeClientKeyExchange,
DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
- client_->SetPacketFilter(inspect);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -91,8 +87,7 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
@@ -100,10 +95,9 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
- server_->DisableRollbackDetection();
+ MakeTlsFilter<TlsClientHelloVersionChanger>(client_, server_);
+ server_->SetOption(SSL_ROLLBACK_DETECTION, PR_FALSE);
Connect();
}
-} // namespace nspr_test
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
new file mode 100644
index 000000000..f5ccf096b
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_tls13compat_unittest.cc
@@ -0,0 +1,363 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+#include <vector>
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class Tls13CompatTest : public TlsConnectStreamTls13 {
+ protected:
+ void EnableCompatMode() {
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ }
+
+ void InstallFilters() {
+ EnsureTlsSetup();
+ client_recorders_.Install(client_);
+ server_recorders_.Install(server_);
+ }
+
+ void CheckRecordVersions() {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_0,
+ client_recorders_.records_->record(0).header.version());
+ CheckRecordsAreTls12("client", client_recorders_.records_, 1);
+ CheckRecordsAreTls12("server", server_recorders_.records_, 0);
+ }
+
+ void CheckHelloVersions() {
+ uint32_t ver;
+ ASSERT_TRUE(server_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ ASSERT_TRUE(client_recorders_.hello_->buffer().Read(0, 2, &ver));
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_2, static_cast<uint16_t>(ver));
+ }
+
+ void CheckForCCS(bool expected_client, bool expected_server) {
+ client_recorders_.CheckForCCS(expected_client);
+ server_recorders_.CheckForCCS(expected_server);
+ }
+
+ void CheckForRegularHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(0U, client_recorders_.session_id_length());
+ EXPECT_EQ(0U, server_recorders_.session_id_length());
+ CheckForCCS(false, false);
+ }
+
+ void CheckForCompatHandshake() {
+ CheckRecordVersions();
+ CheckHelloVersions();
+ EXPECT_EQ(32U, client_recorders_.session_id_length());
+ EXPECT_EQ(32U, server_recorders_.session_id_length());
+ CheckForCCS(true, true);
+ }
+
+ private:
+ struct Recorders {
+ Recorders() : records_(nullptr), hello_(nullptr) {}
+
+ uint8_t session_id_length() const {
+ // session_id is always after version (2) and random (32).
+ uint32_t len = 0;
+ EXPECT_TRUE(hello_->buffer().Read(2 + 32, 1, &len));
+ return static_cast<uint8_t>(len);
+ }
+
+ void CheckForCCS(bool expected) const {
+ EXPECT_LT(0U, records_->count());
+ for (size_t i = 0; i < records_->count(); ++i) {
+ // Only the second record can be a CCS.
+ bool expected_match = expected && (i == 1);
+ EXPECT_EQ(expected_match,
+ kTlsChangeCipherSpecType ==
+ records_->record(i).header.content_type());
+ }
+ }
+
+ void Install(std::shared_ptr<TlsAgent>& agent) {
+ if (records_ && records_->agent() == agent) {
+ // Avoid replacing the filters if they are already installed on this
+ // agent. This ensures that InstallFilters() can be used after
+ // MakeNewServer() without losing state on the client filters.
+ return;
+ }
+ records_.reset(new TlsRecordRecorder(agent));
+ hello_.reset(new TlsHandshakeRecorder(
+ agent, std::set<uint8_t>(
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello})));
+ agent->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({records_, hello_})));
+ }
+
+ std::shared_ptr<TlsRecordRecorder> records_;
+ std::shared_ptr<TlsHandshakeRecorder> hello_;
+ };
+
+ void CheckRecordsAreTls12(const std::string& agent,
+ const std::shared_ptr<TlsRecordRecorder>& records,
+ size_t start) {
+ EXPECT_LE(start, records->count());
+ for (size_t i = start; i < records->count(); ++i) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_2,
+ records->record(i).header.version())
+ << agent << ": record " << i << " has wrong version";
+ }
+ }
+
+ Recorders client_recorders_;
+ Recorders server_recorders_;
+};
+
+TEST_F(Tls13CompatTest, Disabled) {
+ InstallFilters();
+ Connect();
+ CheckForRegularHandshake();
+}
+
+TEST_F(Tls13CompatTest, Enabled) {
+ EnableCompatMode();
+ InstallFilters();
+ Connect();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ CheckForCCS(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest. The server sends CCS immediately.
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ CheckForCCS(false, true);
+
+ Handshake();
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+TEST_F(Tls13CompatTest, EnabledStatelessHrr) {
+ EnableCompatMode();
+ InstallFilters();
+
+ // Force a HelloRetryRequest
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+
+ // The server should send CCS before HRR.
+ CheckForCCS(false, true);
+
+ // A new server should complete the handshake, and not send CCS.
+ MakeNewServer();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ Handshake();
+ CheckConnected();
+ CheckRecordVersions();
+ CheckHelloVersions();
+ CheckForCCS(true, false);
+}
+
+TEST_F(Tls13CompatTest, EnabledHrrZeroRtt) {
+ SetupForZeroRtt();
+ EnableCompatMode();
+ InstallFilters();
+ server_->ConfigNamedGroups({ssl_grp_ec_secp384r1});
+
+ // With 0-RTT, the client sends CCS immediately. With HRR, the server sends
+ // CCS immediately too.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ CheckForCCS(true, true);
+
+ Handshake();
+ ExpectEarlyDataAccepted(false);
+ CheckConnected();
+ CheckForCompatHandshake();
+}
+
+static const uint8_t kCannedCcs[] = {
+ kTlsChangeCipherSpecType,
+ SSL_LIBRARY_VERSION_TLS_1_2 >> 8,
+ SSL_LIBRARY_VERSION_TLS_1_2 & 0xff,
+ 0,
+ 1, // length
+ 1 // change_cipher_spec_choice
+};
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello13) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// A ChangeCipherSpec is ignored by a server because we have to tolerate it for
+// compatibility mode. That doesn't mean that we have to tolerate it
+// unconditionally. If we negotiate 1.3, we expect to see a cookie extension.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHelloTwice) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+// If we negotiate 1.2, we abort.
+TEST_F(TlsConnectStreamTls13, ChangeCipherSpecBeforeClientHello12) {
+ EnsureTlsSetup();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ // Client sends CCS before starting the handshake.
+ client_->SendDirect(DataBuffer(kCannedCcs, sizeof(kCannedCcs)));
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+}
+
+TEST_F(TlsConnectDatagram13, CompatModeDtlsClient) {
+ EnsureTlsSetup();
+ client_->SetOption(SSL_ENABLE_TLS13_COMPAT_MODE, PR_TRUE);
+ auto client_records = MakeTlsFilter<TlsRecordRecorder>(client_);
+ auto server_records = MakeTlsFilter<TlsRecordRecorder>(server_);
+ Connect();
+
+ ASSERT_EQ(2U, client_records->count()); // CH, Fin
+ EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type());
+ EXPECT_EQ(kTlsApplicationDataType,
+ client_records->record(1).header.content_type());
+
+ ASSERT_EQ(6U, server_records->count()); // SH, EE, CT, CV, Fin, Ack
+ EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kTlsApplicationDataType,
+ server_records->record(i).header.content_type());
+ }
+}
+
+class AddSessionIdFilter : public TlsHandshakeFilter {
+ public:
+ AddSessionIdFilter(const std::shared_ptr<TlsAgent>& client)
+ : TlsHandshakeFilter(client, {ssl_hs_client_hello}) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(input.Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+ uint8_t session_id[33] = {32}; // 32 for length, the rest zero.
+ *output = input;
+ output->Splice(session_id, sizeof(session_id), 34, 1);
+ return CHANGE;
+ }
+};
+
+// Adding a session ID to a DTLS ClientHello should not trigger compatibility
+// mode. It should be ignored instead.
+TEST_F(TlsConnectDatagram13, CompatModeDtlsServer) {
+ EnsureTlsSetup();
+ auto client_records = std::make_shared<TlsRecordRecorder>(client_);
+ client_->SetFilter(
+ std::make_shared<ChainedPacketFilter>(ChainedPacketFilterInit(
+ {client_records, std::make_shared<AddSessionIdFilter>(client_)})));
+ auto server_hello =
+ std::make_shared<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
+ auto server_records = std::make_shared<TlsRecordRecorder>(server_);
+ server_->SetFilter(std::make_shared<ChainedPacketFilter>(
+ ChainedPacketFilterInit({server_records, server_hello})));
+ StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ // The client will consume the ServerHello, but discard everything else
+ // because it doesn't decrypt. And don't wait around for the client to ACK.
+ client_->Handshake();
+
+ ASSERT_EQ(1U, client_records->count());
+ EXPECT_EQ(kTlsHandshakeType, client_records->record(0).header.content_type());
+
+ ASSERT_EQ(5U, server_records->count()); // SH, EE, CT, CV, Fin
+ EXPECT_EQ(kTlsHandshakeType, server_records->record(0).header.content_type());
+ for (size_t i = 1; i < server_records->count(); ++i) {
+ EXPECT_EQ(kTlsApplicationDataType,
+ server_records->record(i).header.content_type());
+ }
+
+ uint32_t session_id_len = 0;
+ EXPECT_TRUE(server_hello->buffer().Read(2 + 32, 1, &session_id_len));
+ EXPECT_EQ(0U, session_id_len);
+}
+
+TEST_F(Tls13CompatTest, ConnectWith12ThenAttemptToResume13CompatMode) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ EnableCompatMode();
+ Connect();
+}
+
+} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
index 110e3e0b6..100595732 100644
--- a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -23,7 +23,8 @@ namespace nss_test {
// Replaces the client hello with an SSLv2 version once.
class SSLv2ClientHelloFilter : public PacketFilter {
public:
- SSLv2ClientHelloFilter(std::shared_ptr<TlsAgent>& client, uint16_t version)
+ SSLv2ClientHelloFilter(const std::shared_ptr<TlsAgent>& client,
+ uint16_t version)
: replaced_(false),
client_(client),
version_(version),
@@ -147,17 +148,9 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version)
: TlsConnectTestBase(variant, version), filter_(nullptr) {}
- void SetUp() {
+ void SetUp() override {
TlsConnectTestBase::SetUp();
- filter_ = std::make_shared<SSLv2ClientHelloFilter>(client_, version_);
- client_->SetPacketFilter(filter_);
- }
-
- void RequireSafeRenegotiation() {
- server_->EnsureTlsSetup();
- SECStatus rv =
- SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
- EXPECT_EQ(rv, SECSuccess);
+ filter_ = MakeTlsFilter<SSLv2ClientHelloFilter>(client_, version_);
}
void SetExpectedVersion(uint16_t version) {
@@ -319,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
// Connection must fail if we require safe renegotiation but the client doesn't
// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
- RequireSafeRenegotiation();
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code());
@@ -328,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
// Connection must succeed when requiring safe renegotiation and the client
// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) {
- RequireSafeRenegotiation();
+ server_->SetOption(SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
TLS_EMPTY_RENEGOTIATION_INFO_SCSV};
SetAvailableCipherSuites(cipher_suites);
diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
index 379a67e35..4e9099561 100644
--- a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -56,18 +56,15 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
// two validate that we can also detect fallback using the
// SSL_SetDowngradeCheckVersion() API.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_1));
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_,
+ SSL_LIBRARY_VERSION_TLS_1_1);
ConnectExpectFail();
ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
}
/* Attempt to negotiate the bogus DTLS 1.1 version. */
TEST_F(DtlsConnectTest, TestDtlsVersion11) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- ((~0x0101) & 0xffff)));
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_, ((~0x0101) & 0xffff));
ConnectExpectFail();
// It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is
// what is returned here, but this is deliberate in ssl3_HandleAlert().
@@ -78,9 +75,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) {
// Disabled as long as we have draft version.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
EnsureTlsSetup();
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_2));
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_,
+ SSL_LIBRARY_VERSION_TLS_1_2);
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -92,9 +88,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
// TLS 1.1 clients do not check the random values, so we should
// instead get a handshake failure alert from the server.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_0));
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_,
+ SSL_LIBRARY_VERSION_TLS_1_0);
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
SSL_LIBRARY_VERSION_TLS_1_1);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
@@ -128,12 +123,12 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) {
#endif
TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) {
- client_->SetFallbackSCSVEnabled(true);
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
Connect();
}
TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) {
- client_->SetFallbackSCSVEnabled(true);
+ client_->SetOption(SSL_ENABLE_FALLBACK_SCSV, PR_TRUE);
server_->SetVersionRange(version_, version_ + 1);
ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT);
@@ -155,107 +150,10 @@ TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) {
EXPECT_EQ(SECFailure, rv);
}
-TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
- if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
- return;
- }
- // Set the client so it will accept any version from 1.0
- // to |version_|.
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
- SSL_LIBRARY_VERSION_TLS_1_0);
- // Reset version so that the checks succeed.
- uint16_t test_version = version_;
- version_ = SSL_LIBRARY_VERSION_TLS_1_0;
- Connect();
-
- // Now renegotiate, with the server being set to do
- // |version_|.
- client_->PrepareForRenegotiate();
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
- // Reset version and cipher suite so that the preinfo callback
- // doesn't fail.
- server_->ResetPreliminaryInfo();
- server_->StartRenegotiate();
-
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- ExpectAlert(server_, kTlsAlertUnexpectedMessage);
- } else {
- ExpectAlert(client_, kTlsAlertIllegalParameter);
- }
-
- Handshake();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- // In TLS 1.3, the server detects this problem.
- client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
- server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
- } else {
- client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
- }
-}
-
-TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
- if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
- return;
- }
- // Set the client so it will accept any version from 1.0
- // to |version_|.
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
- SSL_LIBRARY_VERSION_TLS_1_0);
- // Reset version so that the checks succeed.
- uint16_t test_version = version_;
- version_ = SSL_LIBRARY_VERSION_TLS_1_0;
- Connect();
-
- // Now renegotiate, with the server being set to do
- // |version_|.
- server_->PrepareForRenegotiate();
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
- // Reset version and cipher suite so that the preinfo callback
- // doesn't fail.
- server_->ResetPreliminaryInfo();
- client_->StartRenegotiate();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- ExpectAlert(server_, kTlsAlertUnexpectedMessage);
- } else {
- ExpectAlert(client_, kTlsAlertIllegalParameter);
- }
- Handshake();
- if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
- // In TLS 1.3, the server detects this problem.
- client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
- server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
- } else {
- client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
- server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
- }
-}
-
-TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
- EnsureTlsSetup();
- ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- Connect();
- SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
- EXPECT_EQ(SECFailure, rv);
- EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
-}
-
-TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
- EnsureTlsSetup();
- ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- Connect();
- SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
- EXPECT_EQ(SECFailure, rv);
- EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
-}
-
TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
EnsureTlsSetup();
client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning);
- client_->StartConnect();
- server_->StartConnect();
+ StartConnect();
client_->Handshake(); // Send ClientHello.
static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
kTlsAlertUnrecognizedName};
@@ -274,12 +172,10 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- overwritten_client_version));
- auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- server_->SetPacketFilter(capture);
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_,
+ overwritten_client_version);
+ auto capture =
+ MakeTlsFilter<TlsHandshakeRecorder>(server_, kTlsHandshakeServerHello);
ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
@@ -311,23 +207,21 @@ TEST_F(Tls13NoSupportedVersions,
// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
// causes a bad MAC error when we read EncryptedExtensions.
TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
- client_->SetPacketFilter(
- std::make_shared<TlsInspectorClientHelloVersionSetter>(
- SSL_LIBRARY_VERSION_TLS_1_3 + 1));
- auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeServerHello);
- server_->SetPacketFilter(capture);
+ MakeTlsFilter<TlsClientHelloVersionSetter>(client_,
+ SSL_LIBRARY_VERSION_TLS_1_3 + 1);
+ auto capture = MakeTlsFilter<TlsExtensionCapture>(
+ server_, ssl_tls13_supported_versions_xtn);
client_->ExpectSendAlert(kTlsAlertBadRecordMac);
server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
- const DataBuffer& server_hello = capture->buffer();
- ASSERT_GT(server_hello.len(), 2U);
- uint32_t ver;
- ASSERT_TRUE(server_hello.Read(0, 2, &ver));
+
+ ASSERT_EQ(2U, capture->extension().len());
+ uint32_t version = 0;
+ ASSERT_TRUE(capture->extension().Read(0, 2, &version));
// This way we don't need to change with new draft version.
- ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver);
+ ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), version);
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
index eda96831c..7f3c4a896 100644
--- a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -189,12 +189,12 @@ class TestPolicyVersionRange
}
}
- void SetUp() {
- SetPolicy(policy_.range());
+ void SetUp() override {
TlsConnectTestBase::SetUp();
+ SetPolicy(policy_.range());
}
- void TearDown() {
+ void TearDown() override {
TlsConnectTestBase::TearDown();
saved_version_policy_.RestoreOriginalPolicy();
}
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc
index b9f0c672e..728217851 100644
--- a/security/nss/gtests/ssl_gtest/test_io.cc
+++ b/security/nss/gtests/ssl_gtest/test_io.cc
@@ -25,10 +25,6 @@ namespace nss_test {
if (g_ssl_gtest_verbose) LOG(a); \
} while (false)
-void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- filter_ = filter;
-}
-
ScopedPRFileDesc DummyPrSocket::CreateFD() {
static PRDescIdentity test_fd_identity =
PR_GetUniqueIdentity("testtransportadapter");
@@ -98,8 +94,13 @@ int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen,
}
int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
+ if (write_error_) {
+ PR_SetError(write_error_, 0);
+ return -1;
+ }
+
auto peer = peer_.lock();
- if (!peer || !writeable_) {
+ if (!peer) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
@@ -109,7 +110,7 @@ int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
DataBuffer filtered;
PacketFilter::Action action = PacketFilter::KEEP;
if (filter_) {
- action = filter_->Filter(packet, &filtered);
+ action = filter_->Process(packet, &filtered);
}
switch (action) {
case PacketFilter::CHANGE:
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h
index ac2497222..dbeb6b9d4 100644
--- a/security/nss/gtests/ssl_gtest/test_io.h
+++ b/security/nss/gtests/ssl_gtest/test_io.h
@@ -33,9 +33,18 @@ class PacketFilter {
CHANGE, // change the packet to a different value
DROP // drop the packet
};
-
+ PacketFilter(bool enabled = true) : enabled_(enabled) {}
virtual ~PacketFilter() {}
+ virtual Action Process(const DataBuffer& input, DataBuffer* output) {
+ if (!enabled_) {
+ return KEEP;
+ }
+ return Filter(input, output);
+ }
+ void Enable() { enabled_ = true; }
+ void Disable() { enabled_ = false; }
+
// The packet filter takes input and has the option of mutating it.
//
// A filter that modifies the data places the modified data in *output and
@@ -43,6 +52,9 @@ class PacketFilter {
// case the value in *output is ignored. A Filter can return DROP, in which
// case the packet is dropped (and *output is ignored).
virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
+
+ private:
+ bool enabled_;
};
class DummyPrSocket : public DummyIOLayerMethods {
@@ -53,7 +65,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
peer_(),
input_(),
filter_(nullptr),
- writeable_(true) {}
+ write_error_(0) {}
virtual ~DummyPrSocket() {}
// Create a file descriptor that will reference this object. The fd must not
@@ -62,7 +74,9 @@ class DummyPrSocket : public DummyIOLayerMethods {
std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
+ void SetPacketFilter(const std::shared_ptr<PacketFilter>& filter) {
+ filter_ = filter;
+ }
// Drops peer, packet filter and any outstanding packets.
void Reset();
@@ -71,7 +85,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
PRIntervalTime to) override;
int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
- void CloseWrites() { writeable_ = false; }
+ void SetWriteError(PRErrorCode code) { write_error_ = code; }
SSLProtocolVariant variant() const { return variant_; }
bool readable() const { return !input_.empty(); }
@@ -98,7 +112,7 @@ class DummyPrSocket : public DummyIOLayerMethods {
std::weak_ptr<DummyPrSocket> peer_;
std::queue<Packet> input_;
std::shared_ptr<PacketFilter> filter_;
- bool writeable_;
+ PRErrorCode write_error_;
};
// Marker interface.
@@ -164,6 +178,6 @@ class Poller {
timers_;
};
-} // end of namespace
+} // namespace nss_test
#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
index d6d91f7f7..2f71caedb 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -10,7 +10,9 @@
#include "pk11func.h"
#include "ssl.h"
#include "sslerr.h"
+#include "sslexp.h"
#include "sslproto.h"
+#include "tls_filter.h"
#include "tls_parser.h"
extern "C" {
@@ -35,7 +37,6 @@ const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
-const std::string TlsAgent::kServerRsaChain = "rsa_chain";
const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
@@ -66,6 +67,7 @@ TlsAgent::TlsAgent(const std::string& name, Role role,
expected_sent_alert_(kTlsAlertCloseNotify),
expected_sent_alert_level_(kTlsAlertWarning),
handshake_callback_called_(false),
+ resumption_callback_called_(false),
error_code_(0),
send_ctr_(0),
recv_ctr_(0),
@@ -73,8 +75,8 @@ TlsAgent::TlsAgent(const std::string& name, Role role,
handshake_callback_(),
auth_certificate_callback_(),
sni_callback_(),
- expect_short_headers_(false),
- skip_version_checks_(false) {
+ skip_version_checks_(false),
+ resumption_token_() {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
@@ -93,11 +95,11 @@ TlsAgent::~TlsAgent() {
// Add failures manually, if any, so we don't throw in a destructor.
if (expected_received_alert_ != kTlsAlertCloseNotify ||
expected_received_alert_level_ != kTlsAlertWarning) {
- ADD_FAILURE() << "Wrong expected_received_alert status";
+ ADD_FAILURE() << "Wrong expected_received_alert status: " << role_str();
}
if (expected_sent_alert_ != kTlsAlertCloseNotify ||
expected_sent_alert_level_ != kTlsAlertWarning) {
- ADD_FAILURE() << "Wrong expected_sent_alert status";
+ ADD_FAILURE() << "Wrong expected_sent_alert status: " << role_str();
}
}
@@ -183,6 +185,10 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
ScopedCERTCertList anchors(CERT_NewCertList());
rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
if (rv != SECSuccess) return false;
+
+ rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
} else {
rv = SSL_SetURL(ssl_fd(), "server");
EXPECT_EQ(SECSuccess, rv);
@@ -208,6 +214,29 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
return true;
}
+bool TlsAgent::MaybeSetResumptionToken() {
+ if (!resumption_token_.empty()) {
+ SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
+ resumption_token_.size());
+
+ // rv is SECFailure with error set to SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR
+ // if the resumption token was bad (expired/malformed/etc.).
+ if (expect_resumption_) {
+ // Only in case we expect resumption this has to be successful. We might
+ // not expect resumption due to some reason but the token is totally fine.
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ if (rv != SECSuccess) {
+ EXPECT_EQ(SSL_ERROR_BAD_RESUMPTION_TOKEN_ERROR, PORT_GetError());
+ resumption_token_.clear();
+ EXPECT_FALSE(expect_resumption_);
+ if (expect_resumption_) return false;
+ }
+ }
+
+ return true;
+}
+
void TlsAgent::SetupClientAuth() {
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(CLIENT, role_);
@@ -258,13 +287,10 @@ void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) {
}
void TlsAgent::RequestClientAuth(bool requireAuth) {
- EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(SERVER, role_);
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE,
- requireAuth ? PR_TRUE : PR_FALSE));
+ SetOption(SSL_REQUEST_CERTIFICATE, PR_TRUE);
+ SetOption(SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
ssl_fd(), &TlsAgent::ClientAuthenticated, this));
@@ -376,42 +402,8 @@ void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
EXPECT_EQ(SECSuccess, rv);
}
-void TlsAgent::SetSessionTicketsEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
- en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetSessionCacheEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
void TlsAgent::Set0RttEnabled(bool en) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv =
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetFallbackSCSVEnabled(bool en) {
- EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV,
- en ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::SetShortHeadersEnabled() {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd());
- EXPECT_EQ(SECSuccess, rv);
+ SetOption(SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
}
void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
@@ -424,6 +416,27 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
}
}
+SECStatus ResumptionTokenCallback(PRFileDesc* fd,
+ const PRUint8* resumptionToken,
+ unsigned int len, void* ctx) {
+ EXPECT_NE(nullptr, resumptionToken);
+ if (!resumptionToken) {
+ return SECFailure;
+ }
+
+ std::vector<uint8_t> new_token(resumptionToken, resumptionToken + len);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionToken(new_token);
+ reinterpret_cast<TlsAgent*>(ctx)->SetResumptionCallbackCalled();
+ return SECSuccess;
+}
+
+void TlsAgent::SetResumptionTokenCallback() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv =
+ SSL_SetResumptionTokenCallback(ssl_fd(), ResumptionTokenCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
*minver = vrange_.min;
*maxver = vrange_.max;
@@ -437,8 +450,6 @@ void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
-void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
-
void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
@@ -517,6 +528,12 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
}
}
+void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_group, info_.originalKeaGroup);
+ }
+}
+
void TlsAgent::CheckAuthType(SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
EXPECT_EQ(STATE_CONNECTED, state_);
@@ -569,8 +586,7 @@ void TlsAgent::EnableFalseStart() {
falsestart_enabled_ = true;
EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
ssl_fd(), CanFalseStartCallback, this));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE));
+ SetOption(SSL_ENABLE_FALSE_START, PR_TRUE);
}
void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
@@ -578,7 +594,7 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE));
+ SetOption(SSL_ENABLE_ALPN, PR_TRUE);
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
@@ -622,12 +638,8 @@ void TlsAgent::CheckErrorCode(int32_t expected) const {
}
static uint8_t GetExpectedAlertLevel(uint8_t alert) {
- switch (alert) {
- case kTlsAlertCloseNotify:
- case kTlsAlertEndOfEarlyData:
- return kTlsAlertWarning;
- default:
- break;
+ if (alert == kTlsAlertCloseNotify) {
+ return kTlsAlertWarning;
}
return kTlsAlertFatal;
}
@@ -730,6 +742,50 @@ void TlsAgent::ResetPreliminaryInfo() {
expected_cipher_suite_ = 0;
}
+void TlsAgent::ValidateCipherSpecs() {
+ PRInt32 cipherSpecs = SSLInt_CountCipherSpecs(ssl_fd());
+ // We use one ciphersuite in each direction.
+ PRInt32 expected = 2;
+ if (variant_ == ssl_variant_datagram) {
+ // For DTLS 1.3, the client retains the cipher spec for early data and the
+ // handshake so that it can retransmit EndOfEarlyData and its final flight.
+ // It also retains the handshake read cipher spec so that it can read ACKs
+ // from the server. The server retains the handshake read cipher spec so it
+ // can read the client's retransmitted Finished.
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ if (role_ == CLIENT) {
+ expected = info_.earlyDataAccepted ? 5 : 4;
+ } else {
+ expected = 3;
+ }
+ } else {
+ // For DTLS 1.1 and 1.2, the last endpoint to send maintains a cipher spec
+ // until the holddown timer runs down.
+ if (expect_resumption_) {
+ if (role_ == CLIENT) {
+ expected = 3;
+ }
+ } else {
+ if (role_ == SERVER) {
+ expected = 3;
+ }
+ }
+ }
+ }
+ // This function will be run before the handshake completes if false start is
+ // enabled. In that case, the client will still be reading cleartext, but
+ // will have a spec prepared for reading ciphertext. With DTLS, the client
+ // will also have a spec retained for retransmission of handshake messages.
+ if (role_ == CLIENT && falsestart_enabled_ && !handshake_callback_called_) {
+ EXPECT_GT(SSL_LIBRARY_VERSION_TLS_1_3, expected_version_);
+ expected = (variant_ == ssl_variant_datagram) ? 4 : 3;
+ }
+ EXPECT_EQ(expected, cipherSpecs);
+ if (expected != cipherSpecs) {
+ SSLInt_PrintCipherSpecs(role_str().c_str(), ssl_fd());
+ }
+}
+
void TlsAgent::Connected() {
if (state_ == STATE_CONNECTED) {
return;
@@ -743,6 +799,8 @@ void TlsAgent::Connected() {
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
+ EXPECT_EQ(expect_resumption_, info_.resumed == PR_TRUE);
+
// Preliminary values are exposed through callbacks during the handshake.
// If either expected values were set or the callbacks were called, check
// that the final values are correct.
@@ -753,32 +811,13 @@ void TlsAgent::Connected() {
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
- if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd());
- // We use one ciphersuite in each direction, plus one that's kept around
- // by DTLS for retransmission.
- PRInt32 expected =
- ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2;
- EXPECT_EQ(expected, cipherSuites);
- if (expected != cipherSuites) {
- SSLInt_PrintTls13CipherSpecs(ssl_fd());
- }
- }
+ ValidateCipherSpecs();
- PRBool short_headers;
- rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers);
- EXPECT_EQ(SECSuccess, rv);
- EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
SetState(STATE_CONNECTED);
}
void TlsAgent::EnableExtendedMasterSecret() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv =
- SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
-
- ASSERT_EQ(SECSuccess, rv);
+ SetOption(SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
}
void TlsAgent::CheckExtendedMasterSecret(bool expected) {
@@ -801,21 +840,6 @@ void TlsAgent::CheckSecretsDestroyed() {
ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
-void TlsAgent::DisableRollbackDetection() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE);
-
- ASSERT_EQ(SECSuccess, rv);
-}
-
-void TlsAgent::EnableCompression() {
- ASSERT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE);
- ASSERT_EQ(SECSuccess, rv);
-}
-
void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
ASSERT_TRUE(EnsureTlsSetup());
@@ -883,6 +907,14 @@ void TlsAgent::SendDirect(const DataBuffer& buf) {
}
}
+void TlsAgent::SendRecordDirect(const TlsRecord& record) {
+ DataBuffer buf;
+
+ auto rv = record.header.Write(&buf, 0, record.buffer);
+ EXPECT_EQ(record.header.header_length() + record.buffer.len(), rv);
+ SendDirect(buf);
+}
+
static bool ErrorIsNonFatal(PRErrorCode code) {
return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
}
@@ -918,6 +950,27 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
}
}
+bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint16_t wireVersion, uint64_t seq,
+ uint8_t ct, const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ // Ensure we are a TLS 1.3 cipher agent.
+ EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
+ TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq);
+ DataBuffer padded = buf;
+ padded.Write(padded.len(), ct, 1);
+ DataBuffer ciphertext;
+ if (!spec->Protect(header, padded, &ciphertext)) {
+ return false;
+ }
+
+ DataBuffer record;
+ auto rv = header.Write(&record, 0, ciphertext);
+ EXPECT_EQ(header.header_length() + ciphertext.len(), rv);
+ SendDirect(record);
+ return true;
+}
+
void TlsAgent::ReadBytes(size_t amount) {
uint8_t block[16384];
@@ -951,23 +1004,20 @@ void TlsAgent::ReadBytes(size_t amount) {
void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
-void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
- EXPECT_TRUE(EnsureTlsSetup());
-
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE,
- mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
- EXPECT_EQ(SECSuccess, rv);
+void TlsAgent::SetOption(int32_t option, int value) {
+ ASSERT_TRUE(EnsureTlsSetup());
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), option, value));
+}
- rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
- mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ SetOption(SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ SetOption(SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
}
void TlsAgent::DisableECDHEServerKeyReuse() {
- ASSERT_TRUE(EnsureTlsSetup());
ASSERT_EQ(TlsAgent::SERVER, role_);
- SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
- EXPECT_EQ(SECSuccess, rv);
+ SetOption(SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
}
static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h
index 4bccb9a84..6cd6d5073 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.h
+++ b/security/nss/gtests/ssl_gtest/tls_agent.h
@@ -14,7 +14,6 @@
#include <iostream>
#include "test_io.h"
-#include "tls_filter.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
@@ -37,7 +36,10 @@ enum SessionResumptionMode {
RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};
+class PacketFilter;
class TlsAgent;
+class TlsCipherSpec;
+struct TlsRecord;
const extern std::vector<SSLNamedGroup> kAllDHEGroups;
const extern std::vector<SSLNamedGroup> kECDHEGroups;
@@ -66,7 +68,6 @@ class TlsAgent : public PollTarget {
static const std::string kServerRsaSign;
static const std::string kServerRsaPss;
static const std::string kServerRsaDecrypt;
- static const std::string kServerRsaChain; // A cert that requires a chain.
static const std::string kServerEcdsa256;
static const std::string kServerEcdsa384;
static const std::string kServerEcdsa521;
@@ -81,20 +82,15 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}
- void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
- filter->SetAgent(this);
+ void SetFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
}
-
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- adapter_->SetPacketFilter(filter);
- }
-
- void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+ void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
+ void CheckOriginalKEA(SSLNamedGroup kea_group) const;
void CheckAuthType(SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const;
@@ -121,12 +117,10 @@ class TlsAgent : public PollTarget {
void SetupClientAuth();
void RequestClientAuth(bool requireAuth);
+ void SetOption(int32_t option, int value);
void ConfigureSessionCache(SessionResumptionMode mode);
- void SetSessionTicketsEnabled(bool en);
- void SetSessionCacheEnabled(bool en);
void Set0RttEnabled(bool en);
void SetFallbackSCSVEnabled(bool en);
- void SetShortHeadersEnabled();
void SetVersionRange(uint16_t minver, uint16_t maxver);
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
void CheckPreliminaryInfo();
@@ -136,7 +130,6 @@ class TlsAgent : public PollTarget {
void ExpectReadWriteError();
void EnableFalseStart();
void ExpectResumption();
- void ExpectShortHeaders();
void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
@@ -149,27 +142,49 @@ class TlsAgent : public PollTarget {
// Send data on the socket, encrypting it.
void SendData(size_t bytes, size_t blocksize = 1024);
void SendBuffer(const DataBuffer& buf);
+ bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint16_t wireVersion, uint64_t seq, uint8_t ct,
+ const DataBuffer& buf);
// Send data directly to the underlying socket, skipping the TLS layer.
void SendDirect(const DataBuffer& buf);
+ void SendRecordDirect(const TlsRecord& record);
void ReadBytes(size_t max = 16384U);
void ResetSentBytes(); // Hack to test drops.
void EnableExtendedMasterSecret();
void CheckExtendedMasterSecret(bool expected);
void CheckEarlyDataAccepted(bool expected);
- void DisableRollbackDetection();
- void EnableCompression();
void SetDowngradeCheckVersion(uint16_t version);
void CheckSecretsDestroyed();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
void DisableECDHEServerKeyReuse();
bool GetPeerChainLength(size_t* count);
void CheckCipherSuite(uint16_t cipher_suite);
+ void SetResumptionTokenCallback();
+ bool MaybeSetResumptionToken();
+ void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
+ resumption_token_ = resumption_token;
+ }
+ const std::vector<uint8_t>& GetResumptionToken() const {
+ return resumption_token_;
+ }
+ void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
+ SECStatus rv = SSL_GetResumptionTokenInfo(
+ resumption_token_.data(), resumption_token_.size(), token.get(),
+ sizeof(SSLResumptionTokenInfo));
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
+ bool resumption_callback_called() const {
+ return resumption_callback_called_;
+ }
const std::string& name() const { return name_; }
Role role() const { return role_; }
std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+ SSLProtocolVariant variant() const { return variant_; }
+
State state() const { return state_; }
const CERTCertificate* peer_cert() const {
@@ -253,6 +268,7 @@ class TlsAgent : public PollTarget {
const static char* states[];
void SetState(State state);
+ void ValidateCipherSpecs();
// Dummy auth certificate hook.
static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
@@ -378,6 +394,7 @@ class TlsAgent : public PollTarget {
uint8_t expected_sent_alert_;
uint8_t expected_sent_alert_level_;
bool handshake_callback_called_;
+ bool resumption_callback_called_;
SSLChannelInfo info_;
SSLCipherSuiteInfo csinfo_;
SSLVersionRange vrange_;
@@ -388,8 +405,8 @@ class TlsAgent : public PollTarget {
HandshakeCallbackFunction handshake_callback_;
AuthCertificateCallbackFunction auth_certificate_callback_;
SniCallbackFunction sni_callback_;
- bool expect_short_headers_;
bool skip_version_checks_;
+ std::vector<uint8_t> resumption_token_;
};
inline std::ostream& operator<<(std::ostream& stream,
@@ -440,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- std::unique_ptr<TlsAgent> agent_;
+ std::shared_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
SSLProtocolVariant variant_;
uint16_t version_;
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
index c8de5a1fe..8567b392f 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -5,6 +5,7 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "tls_connect.h"
+#include "sslexp.h"
extern "C" {
#include "libssl_internals.h"
}
@@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) {
switch (version) {
case 0:
return "(no version)";
+ case SSL_LIBRARY_VERSION_3_0:
+ return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_0:
return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_1:
@@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
+ expected_resumptions_(0),
session_ids_(),
expect_extended_master_secret_(false),
expect_early_data_accepted_(false),
@@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares(
EXPECT_EQ(shares.len(), i);
}
+void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
+ uint16_t server_epoch) const {
+ uint16_t read_epoch = 0;
+ uint16_t write_epoch = 0;
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(server_epoch, read_epoch) << "client read epoch";
+ EXPECT_EQ(client_epoch, write_epoch) << "client write epoch";
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(client_epoch, read_epoch) << "server read epoch";
+ EXPECT_EQ(server_epoch, write_epoch) << "server write epoch";
+}
+
void TlsConnectTestBase::ClearStats() {
// Clear statistics.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -177,7 +197,7 @@ void TlsConnectTestBase::SetUp() {
SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
SSLInt_ClearSelfEncryptKey();
SSLInt_SetTicketLifetime(30);
- SSLInt_SetMaxEarlyDataSize(1024);
+ SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3);
ClearStats();
Init();
}
@@ -209,7 +229,9 @@ void TlsConnectTestBase::Reset() {
void TlsConnectTestBase::Reset(const std::string& server_name,
const std::string& client_name) {
+ auto token = client_->GetResumptionToken();
client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
+ client_->SetResumptionToken(token);
server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
if (skip_version_checks_) {
client_->SkipVersionChecks();
@@ -219,12 +241,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name,
Init();
}
-void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
+void TlsConnectTestBase::MakeNewServer() {
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ if (version_) {
+ server_->SetVersionRange(version_, version_);
+ }
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->StartConnect();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumptions) {
expected_resumption_mode_ = expected;
if (expected != RESUME_NONE) {
client_->ExpectResumption();
server_->ExpectResumption();
+ expected_resumptions_ = num_resumptions;
}
+ EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
}
void TlsConnectTestBase::EnsureTlsSetup() {
@@ -254,10 +291,16 @@ void TlsConnectTestBase::EnableExtendedMasterSecret() {
void TlsConnectTestBase::Connect() {
server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+ client_->MaybeSetResumptionToken();
Handshake();
CheckConnected();
}
+void TlsConnectTestBase::StartConnect() {
+ server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
+ client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+}
+
void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
EnsureTlsSetup();
client_->EnableSingleCipher(cipher_suite);
@@ -274,6 +317,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
+ // Have the client read handshake twice to make sure we get the
+ // NST and the ACK.
+ if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ variant_ == ssl_variant_datagram) {
+ client_->Handshake();
+ client_->Handshake();
+ auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
+ // Verify that we dropped the client's retransmission cipher suites.
+ EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
+ if (suites != 2) {
+ SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
+ }
+ }
EXPECT_EQ(client_->version(), server_->version());
if (!skip_version_checks_) {
// Check the version is as expected
@@ -314,10 +370,12 @@ void TlsConnectTestBase::CheckConnected() {
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
- client_->CheckKEA(kea_type, kea_group);
- server_->CheckKEA(kea_type, kea_group);
- client_->CheckAuthType(auth_type, sig_scheme);
+ if (kea_group != ssl_grp_none) {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ }
server_->CheckAuthType(auth_type, sig_scheme);
+ client_->CheckAuthType(auth_type, sig_scheme);
}
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
@@ -346,13 +404,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
break;
case ssl_auth_rsa_sign:
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
- scheme = ssl_sig_rsa_pss_sha256;
+ scheme = ssl_sig_rsa_pss_rsae_sha256;
} else {
scheme = ssl_sig_rsa_pkcs1_sha256;
}
break;
case ssl_auth_rsa_pss:
- scheme = ssl_sig_rsa_pss_sha256;
+ scheme = ssl_sig_rsa_pss_rsae_sha256;
break;
case ssl_auth_ecdsa:
scheme = ssl_sig_ecdsa_secp256r1_sha256;
@@ -372,9 +430,19 @@ void TlsConnectTestBase::CheckKeys() const {
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}
+void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
+ SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) {
+ CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
+ EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
+ client_->CheckOriginalKEA(original_kea_group);
+ server_->CheckOriginalKEA(original_kea_group);
+}
+
void TlsConnectTestBase::ConnectExpectFail() {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
Handshake();
ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
@@ -395,8 +463,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
}
void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->SetServerKeyBits(server_->server_key_bits());
client_->Handshake();
server_->Handshake();
@@ -455,29 +522,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() {
}
}
+void TlsConnectTestBase::ConfigureSelfEncrypt() {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(
+ TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+}
+
void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server) {
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- ScopedCERTCertificate cert;
- ScopedSECKEYPrivateKey privKey;
- ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
- &privKey));
-
- ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
- ASSERT_TRUE(pubKey);
-
- EXPECT_EQ(SECSuccess,
- SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+ ConfigureSelfEncrypt();
}
}
void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
EXPECT_NE(RESUME_BOTH, expected);
- int resume_count = expected ? 1 : 0;
- int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
+ int resume_count = expected ? expected_resumptions_ : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
// Note: hch == server counter; hsh == client counter.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -490,7 +561,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
if (expected != RESUME_NONE) {
if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
// Check that the last two session ids match.
- ASSERT_EQ(2U, session_ids_.size());
+ ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
EXPECT_EQ(session_ids_[session_ids_.size() - 1],
session_ids_[session_ids_.size() - 2]);
} else {
@@ -540,31 +611,28 @@ void TlsConnectTestBase::CheckSrtp() const {
server_->CheckSrtp();
}
-void TlsConnectTestBase::SendReceive() {
- client_->SendData(50);
- server_->SendData(50);
- Receive(50);
+void TlsConnectTestBase::SendReceive(size_t total) {
+ ASSERT_GT(total, client_->received_bytes());
+ ASSERT_GT(total, server_->received_bytes());
+ client_->SendData(total - server_->received_bytes());
+ server_->SendData(total - client_->received_bytes());
+ Receive(total); // Receive() is cumulative
}
// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
+ // If we don't do this, then all 0-RTT attempts will be rejected.
+ SSLInt_RolloverAntiReplay();
+
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys();
Reset();
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
}
// Do a first connection so we can do resumption
@@ -584,10 +652,6 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
- if (expect_writable && expect_readable) {
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
- }
-
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -599,7 +663,7 @@ void TlsConnectTestBase::ZeroRttSendReceive(
} else {
EXPECT_EQ(SECFailure, rv);
}
- server_->Handshake(); // Consume ClientHello, EE, Finished.
+ server_->Handshake(); // Consume ClientHello
std::vector<uint8_t> buf(k0RttDataLen);
rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
@@ -608,7 +672,8 @@ void TlsConnectTestBase::ZeroRttSendReceive(
EXPECT_EQ(k0RttDataLen, rv);
} else {
EXPECT_EQ(SECFailure, rv);
- EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
+ << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
}
// Do a second read. this should fail.
@@ -653,6 +718,30 @@ void TlsConnectTestBase::SkipVersionChecks() {
server_->SkipVersionChecks();
}
+// Shift the DTLS timers, to the minimum time necessary to let the next timer
+// run on either client or server. This allows tests to skip waiting without
+// having timers run out of order.
+void TlsConnectTestBase::ShiftDtlsTimers() {
+ PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv == SECSuccess) {
+ time_shift = time;
+ }
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv == SECSuccess &&
+ (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
+ time_shift = time;
+ }
+
+ if (time_shift == PR_INTERVAL_NO_TIMEOUT) {
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
+}
+
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -668,20 +757,29 @@ TlsConnectTls12Plus::TlsConnectTls12Plus()
TlsConnectTls13::TlsConnectTls13()
: TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+TlsConnectGenericResumption::TlsConnectGenericResumption()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ external_cache_(std::get<2>(GetParam())) {}
+
+TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
void TlsKeyExchangeTest::EnsureKeyShareSetup() {
EnsureTlsSetup();
groups_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
shares_capture_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
- shares_capture2_ =
- std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
+ std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
+ shares_capture2_ = std::make_shared<TlsExtensionCapture>(
+ client_, ssl_tls13_key_share_xtn, true);
std::vector<std::shared_ptr<PacketFilter>> captures = {
groups_capture_, shares_capture_, shares_capture2_};
- client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
- capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
- kTlsHandshakeHelloRetryRequest);
- server_->SetPacketFilter(capture_hrr_);
+ client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>(
+ server_, kTlsHandshakeHelloRetryRequest);
}
void TlsKeyExchangeTest::ConfigNamedGroups(
@@ -691,11 +789,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
EXPECT_TRUE(ext.len() % 2 == 0);
+
std::vector<SSLNamedGroup> groups;
for (size_t i = 1; i < ext.len() / 2; i += 1) {
EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
@@ -705,10 +807,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+
std::vector<SSLNamedGroup> shares;
size_t i = 2;
while (i < ext.len()) {
@@ -724,17 +830,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
void TlsKeyExchangeTest::CheckKEXDetails(
const std::vector<SSLNamedGroup>& expected_groups,
const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
- std::vector<SSLNamedGroup> groups =
- GetGroupDetails(groups_capture_->extension());
+ std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
EXPECT_EQ(expected_groups, groups);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
ASSERT_LT(0U, expected_shares.size());
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture_->extension());
+ std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
EXPECT_EQ(expected_shares, shares);
} else {
- EXPECT_EQ(0U, shares_capture_->extension().len());
+ EXPECT_FALSE(shares_capture_->captured());
}
EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
@@ -756,8 +860,6 @@ void TlsKeyExchangeTest::CheckKEXDetails(
EXPECT_NE(expected_share2, it);
}
std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture2_->extension());
- EXPECT_EQ(expected_shares2, shares);
+ EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h
index 73e8dc81a..7dffe7f8a 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.h
+++ b/security/nss/gtests/ssl_gtest/tls_connect.h
@@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test {
TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
- void SetUp();
- void TearDown();
+ virtual void SetUp();
+ virtual void TearDown();
// Initialize client and server.
void Init();
@@ -55,13 +55,17 @@ class TlsConnectTestBase : public ::testing::Test {
// Clear the server session cache.
void ClearServerCache();
// Make sure TLS is configured for a connection.
- void EnsureTlsSetup();
+ virtual void EnsureTlsSetup();
// Reset and keep the same certificate names
void Reset();
// Reset, and update the certificate names on both peers
void Reset(const std::string& server_name,
const std::string& client_name = "client");
+ // Replace the server.
+ void MakeNewServer();
+ // Set up
+ void StartConnect();
// Run the handshake.
void Handshake();
// Connect and check that it works.
@@ -81,20 +85,28 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
// This version assumes defaults.
void CheckKeys() const;
+ // Check that keys on resumed sessions.
+ void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme);
void CheckGroups(const DataBuffer& groups,
std::function<void(SSLNamedGroup)> check_group);
void CheckShares(const DataBuffer& shares,
std::function<void(SSLNamedGroup)> check_group);
+ void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const;
void ConfigureVersion(uint16_t version);
void SetExpectedVersion(uint16_t version);
// Expect resumption of a particular type.
- void ExpectResumption(SessionResumptionMode expected);
+ void ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumed = 1);
void DisableAllCiphers();
void EnableOnlyStaticRsaCiphers();
void EnableOnlyDheCiphers();
void EnableSomeEcdhCiphers();
void EnableExtendedMasterSecret();
+ void ConfigureSelfEncrypt();
void ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server);
void EnableAlpn();
@@ -103,7 +115,7 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckAlpn(const std::string& val);
void EnableSrtp();
void CheckSrtp() const;
- void SendReceive();
+ void SendReceive(size_t total = 50);
void SetupForZeroRtt();
void SetupForResume();
void ZeroRttSendReceive(
@@ -115,6 +127,9 @@ class TlsConnectTestBase : public ::testing::Test {
void DisableECDHEServerKeyReuse();
void SkipVersionChecks();
+ // Move the DTLS timers for both endpoints to pop the next timer.
+ void ShiftDtlsTimers();
+
protected:
SSLProtocolVariant variant_;
std::shared_ptr<TlsAgent> client_;
@@ -123,6 +138,7 @@ class TlsConnectTestBase : public ::testing::Test {
std::unique_ptr<TlsAgent> server_model_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
+ uint8_t expected_resumptions_;
std::vector<std::vector<uint8_t>> session_ids_;
// A simple value of "a", "b". Note that the preferred value of "a" is placed
@@ -192,6 +208,52 @@ class TlsConnectGeneric : public TlsConnectTestBase,
TlsConnectGeneric();
};
+class TlsConnectGenericResumption
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t, bool>> {
+ private:
+ bool external_cache_;
+
+ public:
+ TlsConnectGenericResumption();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ // Enable external resumption token cache.
+ if (external_cache_) {
+ client_->SetResumptionTokenCallback();
+ }
+ }
+
+ bool use_external_cache() const { return external_cache_; }
+};
+
+class TlsConnectTls13ResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ public:
+ TlsConnectTls13ResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
+class TlsConnectGenericResumptionToken
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsConnectGenericResumptionToken();
+
+ virtual void EnsureTlsSetup() {
+ TlsConnectTestBase::EnsureTlsSetup();
+ client_->SetResumptionTokenCallback();
+ }
+};
+
// A Pre TLS 1.2 generic test.
class TlsConnectPre12 : public TlsConnectTestBase,
public ::testing::WithParamInterface<
@@ -244,6 +306,11 @@ class TlsConnectDatagram13 : public TlsConnectTestBase {
: TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
+class TlsConnectDatagramPre13 : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagramPre13() {}
+};
+
// A variant that is used only with Pre13.
class TlsConnectGenericPre13 : public TlsConnectGeneric {};
@@ -252,12 +319,14 @@ class TlsKeyExchangeTest : public TlsConnectGeneric {
std::shared_ptr<TlsExtensionCapture> groups_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture_;
std::shared_ptr<TlsExtensionCapture> shares_capture2_;
- std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
+ std::shared_ptr<TlsHandshakeRecorder> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
- std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext);
- std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext);
+ std::vector<SSLNamedGroup> GetGroupDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
+ std::vector<SSLNamedGroup> GetShareDetails(
+ const std::shared_ptr<TlsExtensionCapture>& capture);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
const std::vector<SSLNamedGroup>& expectedShares);
void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
index 76d9aaaff..d34b13bcb 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -12,6 +12,7 @@ extern "C" {
#include "libssl_internals.h"
}
+#include <cassert>
#include <iostream>
#include "gtest_utils.h"
#include "tls_agent.h"
@@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
if (g_ssl_gtest_verbose) {
- std::cerr << "Cipher spec changed. Role="
- << (isServer ? "server" : "client")
- << " direction=" << (sending ? "send" : "receive") << std::endl;
+ std::cerr << (isServer ? "server" : "client") << ": "
+ << (sending ? "send" : "receive")
+ << " cipher spec changed: " << newSpec->epoch << " ("
+ << newSpec->phase << ")" << std::endl;
+ }
+ if (!sending) {
+ return;
}
- if (!sending) return;
+ self->in_sequence_number_ = 0;
+ self->out_sequence_number_ = 0;
+ self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
- bool ret =
- self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
- SSLInt_CipherSpecToKey(isServer, newSpec),
- SSLInt_CipherSpecToIv(isServer, newSpec));
+ bool ret = self->cipher_spec_->Init(
+ SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec),
+ SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec));
EXPECT_EQ(true, ret);
}
@@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(&parser, &record)) {
+ if (!header.Parse(in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
+ // Track the sequence number, which is necessary for stream mode (the
+ // sequence number is in the header for datagram).
+ //
+ // This isn't perfectly robust. If there is a change from an active cipher
+ // spec to another active cipher spec (KeyUpdate for instance) AND writes
+ // are consolidated across that change AND packets were dropped from the
+ // older epoch, we will not correctly re-encrypt records in the old epoch to
+ // update their sequence numbers.
+ if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
+ ++in_sequence_number_;
+ }
+
if (FilterRecord(header, record, &offset, output) != KEEP) {
changed = true;
} else {
@@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
header.sequence_number()};
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
+ // In stream mode, even if something doesn't change we need to re-encrypt if
+ // previous packets were dropped.
if (action == KEEP) {
- return KEEP;
+ if (header.is_dtls() || !dropped_record_) {
+ return KEEP;
+ }
+ filtered = plaintext;
}
if (action == DROP) {
- std::cerr << "record drop: " << record << std::endl;
+ std::cerr << "record drop: " << header << ":" << record << std::endl;
+ dropped_record_ = true;
return DROP;
}
EXPECT_GT(0x10000U, filtered.len());
- std::cerr << "record old: " << plaintext << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
+ if (action != KEEP) {
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ }
+
+ uint64_t seq_num;
+ if (header.is_dtls() || !cipher_spec_ ||
+ header.content_type() != kTlsApplicationDataType) {
+ seq_num = header.sequence_number();
+ } else {
+ seq_num = out_sequence_number_++;
+ }
+ TlsRecordHeader out_header = {header.version(), header.content_type(),
+ seq_num};
DataBuffer ciphertext;
- bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
+ bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
EXPECT_TRUE(rv);
if (!rv) {
return KEEP;
}
- *offset = header.Write(output, *offset, ciphertext);
+ *offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
-bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
@@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
}
version_ = version;
- sequence_number_ = 0;
+ // If this is DTLS, overwrite the sequence number.
if (IsDtls(version)) {
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
@@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return false;
}
sequence_number_ |= static_cast<uint64_t>(tmp);
+ } else {
+ sequence_number_ = sequence_number;
}
return parser->ReadVariable(body, 2);
}
@@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
return true;
}
- if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;
+ if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) {
+ return false;
+ }
size_t len = plaintext->len();
while (len > 0 && !plaintext->data()[len - 1]) {
@@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
*inner_content_type = plaintext->data()[len - 1];
plaintext->Truncate(len - 1);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "unprotect: " << std::hex << header.sequence_number()
+ << std::dec << " type=" << static_cast<int>(*inner_content_type)
+ << " " << *plaintext << std::endl;
+ }
return true;
}
@@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
*ciphertext = plaintext;
return true;
}
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "protect: " << header.sequence_number() << std::endl;
+ }
DataBuffer padded = plaintext;
padded.Write(padded.len(), inner_content_type, 1);
return cipher_spec_->Protect(header, padded, ciphertext);
}
+bool IsHelloRetry(const DataBuffer& body) {
+ static const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+ return memcmp(body.data() + 2, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random)) == 0;
+}
+
+bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& body) {
+ if (handshake_types_.empty()) {
+ return true;
+ }
+
+ uint8_t type = header.handshake_type();
+ if (type == kTlsHandshakeServerHello) {
+ if (IsHelloRetry(body)) {
+ type = kTlsHandshakeHelloRetryRequest;
+ }
+ }
+ return handshake_types_.count(type) > 0U;
+}
+
PacketFilter::Action TlsHandshakeFilter::FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
// Check that the first byte is as requested.
- if (record_header.content_type() != kTlsHandshakeType) {
+ if ((record_header.content_type() != kTlsHandshakeType) &&
+ (record_header.content_type() != kTlsAltHandshakeType)) {
return KEEP;
}
@@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
while (parser.remaining()) {
HandshakeHeader header;
DataBuffer handshake;
- if (!header.Parse(&parser, record_header, &handshake)) {
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake,
+ &complete)) {
return KEEP;
}
+ if (!complete) {
+ EXPECT_TRUE(record_header.is_dtls());
+ // Save the fragment and drop it from this record. Fragments are
+ // coalesced with the last fragment of the handshake message.
+ changed = true;
+ preceding_fragment_.Assign(handshake);
+ continue;
+ }
+ preceding_fragment_.Truncate(0);
+
DataBuffer filtered;
- PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+ PacketFilter::Action action;
+ if (!IsFilteredType(header, handshake)) {
+ action = KEEP;
+ } else {
+ action = FilterHandshake(header, handshake, &filtered);
+ }
if (action == DROP) {
changed = true;
std::cerr << "handshake drop: " << handshake << std::endl;
@@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
std::cerr << "handshake old: " << handshake << std::endl;
std::cerr << "handshake new: " << filtered << std::endl;
source = &filtered;
+ } else if (preceding_fragment_.len()) {
+ changed = true;
}
offset = header.Write(output, offset, *source);
@@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
}
bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
- TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) {
- if (!parser->Read(length, 3)) {
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset,
+ uint32_t* length, bool* last_fragment) {
+ uint32_t message_length;
+ if (!parser->Read(&message_length, 3)) {
return false; // malformed
}
if (!header.is_dtls()) {
+ *last_fragment = true;
+ *length = message_length;
return true; // nothing left to do
}
@@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
}
message_seq_ = message_seq_tmp;
- uint32_t fragment_offset;
- if (!parser->Read(&fragment_offset, 3)) {
+ uint32_t offset = 0;
+ if (!parser->Read(&offset, 3)) {
+ return false;
+ }
+ // We only parse if the fragments are all complete and in order.
+ if (offset != expected_offset) {
+ EXPECT_NE(0U, header.epoch())
+ << "Received out of order handshake fragment for epoch 0";
return false;
}
- uint32_t fragment_length;
- if (!parser->Read(&fragment_length, 3)) {
+ // For DTLS, we return the length of just this fragment.
+ if (!parser->Read(length, 3)) {
return false;
}
- // All current tests where we are using this code don't fragment.
- return (fragment_offset == 0 && fragment_length == *length);
+ // It's a fragment if the entire message is longer than what we have.
+ *last_fragment = message_length == (*length + offset);
+ return true;
}
bool TlsHandshakeFilter::HandshakeHeader::Parse(
- TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) {
+ TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
+ *complete = false;
+
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
}
+
uint32_t length;
- if (!ReadLength(parser, record_header, &length)) {
+ if (!ReadLength(parser, record_header, preceding_fragment.len(), &length,
+ complete)) {
return false;
}
- return parser->Read(body, length);
+ if (!parser->Read(body, length)) {
+ return false;
+ }
+ if (preceding_fragment.len()) {
+ body->Splice(preceding_fragment, 0);
+ }
+ return true;
}
size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
@@ -337,7 +452,7 @@ size_t TlsHandshakeFilter::HandshakeHeader::Write(
return offset;
}
-PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+PacketFilter::Action TlsHandshakeRecorder::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
// Only do this once.
@@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
return KEEP;
}
- if (header.handshake_type() == handshake_type_) {
- buffer_ = input;
- }
+ buffer_ = input;
return KEEP;
}
PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == handshake_type_) {
- *output = buffer_;
- return CHANGE;
- }
+ *output = buffer_;
+ return CHANGE;
+}
+PacketFilter::Action TlsRecordRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (!filter_ || (header.content_type() == ct_)) {
+ records_.push_back({header, input});
+ }
return KEEP;
}
@@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord(
return KEEP;
}
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(header);
+ return KEEP;
+}
+
+const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) {
+ if (index > headers_.size() + 1) {
+ return nullptr;
+ }
+ return &headers_[index];
+}
+
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
bool changed = false;
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- PacketFilter::Action action = (*it)->Filter(in, output);
+ PacketFilter::Action action = (*it)->Process(in, output);
if (action == DROP) {
return DROP;
}
+
if (action == CHANGE) {
in = *output;
changed = true;
@@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
-static bool FindHelloRetryExtensions(TlsParser* parser,
- const TlsVersioned& header) {
- // TODO for -19 add cipher suite
- if (!parser->Skip(2)) { // version
- return false;
- }
- return true;
-}
-
bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
@@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser,
if (!parser->SkipVariable(1)) { // request context
return false;
}
- // TODO remove the next two for -19
- if (!parser->SkipVariable(2)) { // signature_algorithms
- return false;
- }
- if (!parser->SkipVariable(2)) { // certificate_authorities
- return false;
- }
return true;
}
@@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
if (!parser->Skip(8)) { // lifetime, age add
return false;
}
+ if (!parser->SkipVariable(1)) { // ticket_nonce
+ return false;
+ }
if (!parser->SkipVariable(2)) { // ticket
return false;
}
@@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
{kTlsHandshakeClientHello, FindClientHelloExtensions},
{kTlsHandshakeServerHello, FindServerHelloExtensions},
- {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
{kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
{kTlsHandshakeCertificateRequest, FindCertReqExtensions},
{kTlsHandshakeCertificate, FindCertificateExtensions},
@@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
PacketFilter::Action TlsExtensionFilter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (handshake_types_.count(header.handshake_type()) == 0) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!FindExtensions(&parser, header)) {
return KEEP;
@@ -610,13 +725,45 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension(
return KEEP;
}
+PacketFilter::Action TlsExtensionInjector::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ size_t offset = parser.consumed();
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+}
+
PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
if (counter_++ == record_) {
DataBuffer buf;
header.Write(&buf, 0, body);
- src_.lock()->SendDirect(buf);
+ agent()->SendDirect(buf);
dest_.lock()->Handshake();
func_();
return DROP;
@@ -625,13 +772,11 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
return KEEP;
}
-PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
+PacketFilter::Action TlsClientHelloVersionChanger::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
- EXPECT_EQ(SECSuccess,
- SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
- }
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
return KEEP;
}
@@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
return ((1 << counter_++) & pattern_) ? DROP : KEEP;
}
-PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
+PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& data,
+ DataBuffer* changed) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+/* static */ uint32_t SelectiveRecordDropFilter::ToPattern(
+ std::initializer_list<size_t> records) {
+ uint32_t pattern = 0;
+ for (auto it = records.begin(); it != records.end(); ++it) {
+ EXPECT_GT(32U, *it);
+ assert(*it < 32U);
+ pattern |= 1 << *it;
+ }
+ return pattern;
+}
+
+PacketFilter::Action TlsClientHelloVersionSetter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- *output = input;
- output->Write(0, version_, 2);
- return CHANGE;
- }
- return KEEP;
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+}
+
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ // Cipher suite is after version(2) and random(32).
+ size_t pos = 34;
+ if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In old versions, we have to skip a session_id too.
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+ }
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
}
} // namespace nss_test
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index e4030e23f..1bbe190ab 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -13,6 +13,7 @@
#include <vector>
#include "test_io.h"
+#include "tls_agent.h"
#include "tls_parser.h"
#include "tls_protect.h"
@@ -23,7 +24,6 @@ extern "C" {
namespace nss_test {
class TlsCipherSpec;
-class TlsAgent;
class TlsVersioned {
public:
@@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned {
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
- size_t header_length() const { return is_dtls() ? 11 : 3; }
+ uint16_t epoch() const {
+ return static_cast<uint16_t>(sequence_number_ >> 48);
+ }
+ size_t header_length() const { return is_dtls() ? 13 : 5; }
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(TlsParser* parser, DataBuffer* body);
+ bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
@@ -63,13 +66,32 @@ class TlsRecordHeader : public TlsVersioned {
uint64_t sequence_number_;
};
+struct TlsRecord {
+ const TlsRecordHeader header;
+ const DataBuffer buffer;
+};
+
+// Make a filter and install it on a TlsAgent.
+template <class T, typename... Args>
+inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
+ Args&&... args) {
+ auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
+ agent->SetFilter(filter);
+ return filter;
+}
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent),
+ count_(0),
+ cipher_spec_(),
+ dropped_record_(false),
+ in_sequence_number_(0),
+ out_sequence_number_(0) {}
- void SetAgent(const TlsAgent* agent) { agent_ = agent; }
- const TlsAgent* agent() const { return agent_; }
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -112,17 +134,24 @@ class TlsRecordFilter : public PacketFilter {
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
- const TlsAgent* agent_;
+ std::weak_ptr<TlsAgent> agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
+ // Whether we dropped a record since the cipher spec changed.
+ bool dropped_record_;
+ // The sequence number we use for reading records as they are written.
+ uint64_t in_sequence_number_;
+ // The sequence number we use for writing modified records.
+ uint64_t out_sequence_number_;
};
-inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
+inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
v.WriteStream(stream);
return stream;
}
-inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsRecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
@@ -133,13 +162,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
stream << "Alert";
break;
case kTlsHandshakeType:
+ case kTlsAltHandshakeType:
stream << "Handshake";
break;
case kTlsApplicationDataType:
stream << "Data";
break;
+ case kTlsAckType:
+ stream << "ACK";
+ break;
default:
- stream << '<' << hdr.content_type() << '>';
+ stream << '<' << static_cast<int>(hdr.content_type()) << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
@@ -150,7 +183,20 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(agent),
+ handshake_types_(types),
+ preceding_fragment_() {}
+
+ // This filter can be set to be selective based on handshake message type. If
+ // this function isn't used (or the set is empty), then all handshake messages
+ // will be filtered.
+ void SetHandshakeTypes(const std::set<uint8_t>& types) {
+ handshake_types_ = types;
+ }
class HandshakeHeader : public TlsVersioned {
public:
@@ -158,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
uint8_t handshake_type() const { return handshake_type_; }
bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
- DataBuffer* body);
+ const DataBuffer& preceding_fragment, DataBuffer* body,
+ bool* complete);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
@@ -169,7 +216,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
- uint32_t* length);
+ uint32_t expected_offset, uint32_t* length,
+ bool* last_fragment);
uint8_t handshake_type_;
uint16_t message_seq_;
@@ -185,60 +233,115 @@ class TlsHandshakeFilter : public TlsRecordFilter {
DataBuffer* output) = 0;
private:
+ bool IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& handshake);
+
+ std::set<uint8_t> handshake_types_;
+ DataBuffer preceding_fragment_;
};
// Make a copy of the first instance of a handshake message.
-class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
- TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : handshake_type_(handshake_type), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(agent, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
+ void Reset() { buffer_.Truncate(0); }
+
const DataBuffer& buffer() const { return buffer_; }
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
- TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type,
const DataBuffer& replacement)
- : handshake_type_(handshake_type), buffer_(replacement) {}
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
+// Make a copy of each record of a given type.
+class TlsRecordRecorder : public TlsRecordFilter {
+ public:
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct)
+ : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent),
+ filter_(false),
+ ct_(content_handshake), // dummy (<optional> is C++14)
+ records_() {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ size_t count() const { return records_.size(); }
+ void Clear() { records_.clear(); }
+
+ const TlsRecord& record(size_t i) const { return records_[i]; }
+
+ private:
+ bool filter_;
+ uint8_t ct_;
+ std::vector<TlsRecord> records_;
+};
+
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
- TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent,
+ DataBuffer& buffer)
+ : TlsRecordFilter(agent), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- DataBuffer& buffer_;
+ DataBuffer buffer_;
};
+// Make a copy of the records
+class TlsHeaderRecorder : public TlsRecordFilter {
+ public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ const TlsRecordHeader* header(size_t index);
+
+ private:
+ std::vector<TlsRecordHeader> headers_;
+};
+
+typedef std::initializer_list<std::shared_ptr<PacketFilter>>
+ ChainedPacketFilterInit;
+
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
+ ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
@@ -256,13 +359,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter() : handshake_types_() {
- handshake_types_.insert(kTlsHandshakeClientHello);
- handshake_types_.insert(kTlsHandshakeServerHello);
- }
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ kTlsHandshakeHelloRetryRequest,
+ kTlsHandshakeEncryptedExtensions}) {}
- TlsExtensionFilter(const std::set<uint8_t>& types)
- : handshake_types_(types) {}
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(agent, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -279,14 +384,17 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
-
- std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
public:
- TlsExtensionCapture(uint16_t ext, bool last = false)
- : extension_(ext), captured_(false), last_(last), data_() {}
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(agent),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
const DataBuffer& extension() const { return data_; }
bool captured() const { return captured_; }
@@ -305,8 +413,9 @@ class TlsExtensionCapture : public TlsExtensionFilter {
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
- TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
- : extension_(extension), data_(data) {}
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, const DataBuffer& data)
+ : TlsExtensionFilter(agent), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
@@ -318,7 +427,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
class TlsExtensionDropper : public TlsExtensionFilter {
public:
- TlsExtensionDropper(uint16_t extension) : extension_(extension) {}
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension)
+ : TlsExtensionFilter(agent), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
@@ -326,21 +437,41 @@ class TlsExtensionDropper : public TlsExtensionFilter {
uint16_t extension_;
};
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(agent), extension_(ext), data_(data) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
class TlsAgent;
typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
- unsigned int record, VoidFunction func)
- : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- std::weak_ptr<TlsAgent> src_;
std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
@@ -349,10 +480,12 @@ class AfterRecordN : public TlsRecordFilter {
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
-class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : server_(server) {}
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -377,10 +510,51 @@ class SelectiveDropFilter : public PacketFilter {
uint8_t counter_;
};
+// This class selectively drops complete records. The difference from
+// SelectiveDropFilter is that if multiple DTLS records are in the same
+// datagram, we just drop one.
+class SelectiveRecordDropFilter : public TlsRecordFilter {
+ public:
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint32_t pattern, bool enabled = true)
+ : TlsRecordFilter(agent), pattern_(pattern), counter_(0) {
+ if (!enabled) {
+ Disable();
+ }
+ }
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(agent, ToPattern(records), true) {}
+
+ void Reset(uint32_t pattern) {
+ counter_ = 0;
+ PacketFilter::Enable();
+ pattern_ = pattern;
+ }
+
+ void Reset(std::initializer_list<size_t> records) {
+ Reset(ToPattern(records));
+ }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override;
+
+ private:
+ static uint32_t ToPattern(std::initializer_list<size_t> records);
+
+ uint32_t pattern_;
+ uint8_t counter_;
+};
+
// Set the version number in the ClientHello.
-class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
+class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
+ TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}),
+ version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -393,7 +567,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
- TlsLastByteDamager(uint8_t type) : type_(type) {}
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type)
+ : TlsHandshakeFilter(agent), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
@@ -411,6 +586,22 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
uint8_t type_;
};
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t suite)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t cipher_suite_;
+};
+
} // namespace nss_test
#endif
diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
index 51ff938b1..45f6cf2bd 100644
--- a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
+++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
@@ -241,13 +241,13 @@ TEST_P(TlsHkdfTest, HkdfExpandLabel) {
{/* ssl_hash_md5 */},
{/* ssl_hash_sha1 */},
{/* ssl_hash_sha224 */},
- {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b,
- 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13,
- 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c},
- {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01,
- 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21,
- 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9,
- 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}};
+ {0xc6, 0xdd, 0x6e, 0xc4, 0x76, 0xb8, 0x55, 0xf2, 0xa4, 0xfc, 0x59,
+ 0x04, 0xa4, 0x90, 0xdc, 0xa7, 0xa7, 0x0d, 0x94, 0x8f, 0xc2, 0xdc,
+ 0x15, 0x6d, 0x48, 0x93, 0x9d, 0x05, 0xbb, 0x9a, 0xbc, 0xc1},
+ {0x41, 0xea, 0x77, 0x09, 0x8c, 0x90, 0x04, 0x10, 0xec, 0xbc, 0x37, 0xd8,
+ 0x5b, 0x54, 0xcd, 0x7b, 0x08, 0x15, 0x13, 0x20, 0xed, 0x1e, 0x3f, 0x54,
+ 0x74, 0xf7, 0x8b, 0x06, 0x38, 0x28, 0x06, 0x37, 0x75, 0x23, 0xa2, 0xb7,
+ 0x34, 0xb1, 0x72, 0x2e, 0x59, 0x6d, 0x5a, 0x31, 0xf5, 0x53, 0xab, 0x99}};
const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_],
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.cc b/security/nss/gtests/ssl_gtest/tls_protect.cc
index efcd89e14..6c945f66e 100644
--- a/security/nss/gtests/ssl_gtest/tls_protect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_protect.cc
@@ -32,7 +32,6 @@ void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) {
}
DataBuffer d(nonce, 12);
- std::cerr << "Nonce " << d << std::endl;
}
bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length,
@@ -92,8 +91,9 @@ bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq,
in, inlen, out, outlen, maxlen);
}
-bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key,
- const uint8_t *iv) {
+bool TlsCipherSpec::Init(uint16_t epoch, SSLCipherAlgorithm cipher,
+ PK11SymKey *key, const uint8_t *iv) {
+ epoch_ = epoch;
switch (cipher) {
case ssl_calg_aes_gcm:
aead_.reset(new AeadCipherAesGcm());
diff --git a/security/nss/gtests/ssl_gtest/tls_protect.h b/security/nss/gtests/ssl_gtest/tls_protect.h
index 4efbd6e6b..93ffd6322 100644
--- a/security/nss/gtests/ssl_gtest/tls_protect.h
+++ b/security/nss/gtests/ssl_gtest/tls_protect.h
@@ -20,7 +20,7 @@ class TlsRecordHeader;
class AeadCipher {
public:
AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {}
- ~AeadCipher();
+ virtual ~AeadCipher();
bool Init(PK11SymKey *key, const uint8_t *iv);
virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen,
@@ -58,16 +58,19 @@ class AeadCipherAesGcm : public AeadCipher {
// Our analog of ssl3CipherSpec
class TlsCipherSpec {
public:
- TlsCipherSpec() : aead_() {}
+ TlsCipherSpec() : epoch_(0), aead_() {}
- bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv);
+ bool Init(uint16_t epoch, SSLCipherAlgorithm cipher, PK11SymKey *key,
+ const uint8_t *iv);
bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext,
DataBuffer *ciphertext);
bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext,
DataBuffer *plaintext);
+ uint16_t epoch() const { return epoch_; }
private:
+ uint16_t epoch_;
std::unique_ptr<AeadCipher> aead_;
};