diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest')
40 files changed, 12157 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/Makefile b/security/nss/gtests/ssl_gtest/Makefile new file mode 100644 index 000000000..dfb8df943 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/Makefile @@ -0,0 +1,59 @@ +#! gmake +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +####################################################################### +# (1) Include initial platform-independent assignments (MANDATORY). # +####################################################################### + +include manifest.mn + +####################################################################### +# (2) Include "global" configuration information. (OPTIONAL) # +####################################################################### + +include $(CORE_DEPTH)/coreconf/config.mk + +####################################################################### +# (3) Include "component" configuration information. (OPTIONAL) # +####################################################################### + + +####################################################################### +# (4) Include "local" platform-dependent assignments (OPTIONAL). # +####################################################################### + +include ../common/gtest.mk + +CFLAGS += -I$(CORE_DEPTH)/lib/ssl + +ifdef NSS_SSL_ENABLE_ZLIB +include $(CORE_DEPTH)/coreconf/zlib.mk +endif + +ifndef NSS_ENABLE_TLS_1_3 +NSS_DISABLE_TLS_1_3=1 +endif + +ifdef NSS_DISABLE_TLS_1_3 +# Run parameterized tests only, for which we can easily exclude TLS 1.3 +CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS)) +CFLAGS += -DNSS_DISABLE_TLS_1_3 +endif + +####################################################################### +# (5) Execute "global" rules. (OPTIONAL) # +####################################################################### + +include $(CORE_DEPTH)/coreconf/rules.mk + +####################################################################### +# (6) Execute "component" rules. (OPTIONAL) # +####################################################################### + + +####################################################################### +# (7) Execute "local" rules. (OPTIONAL). # +####################################################################### diff --git a/security/nss/gtests/ssl_gtest/databuffer.h b/security/nss/gtests/ssl_gtest/databuffer.h new file mode 100644 index 000000000..e7236d4e9 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/databuffer.h @@ -0,0 +1,191 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef databuffer_h__ +#define databuffer_h__ + +#include <algorithm> +#include <cassert> +#include <cstring> +#include <iomanip> +#include <iostream> +#if defined(WIN32) || defined(WIN64) +#include <winsock2.h> +#else +#include <arpa/inet.h> +#endif + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +class DataBuffer { + public: + DataBuffer() : data_(nullptr), len_(0) {} + DataBuffer(const uint8_t* data, size_t len) : data_(nullptr), len_(0) { + Assign(data, len); + } + DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) { + Assign(other); + } + ~DataBuffer() { delete[] data_; } + + DataBuffer& operator=(const DataBuffer& other) { + if (&other != this) { + Assign(other); + } + return *this; + } + + void Allocate(size_t len) { + delete[] data_; + data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0]. + len_ = len; + } + + void Truncate(size_t len) { len_ = std::min(len_, len); } + + void Assign(const DataBuffer& other) { Assign(other.data(), other.len()); } + + void Assign(const uint8_t* data, size_t len) { + if (data) { + Allocate(len); + memcpy(static_cast<void*>(data_), static_cast<const void*>(data), len); + } else { + assert(len == 0); + data_ = nullptr; + len_ = 0; + } + } + + // Write will do a new allocation and expand the size of the buffer if needed. + // Returns the offset of the end of the write. + size_t Write(size_t index, const uint8_t* val, size_t count) { + assert(val); + if (index + count > len_) { + size_t newlen = index + count; + uint8_t* tmp = new uint8_t[newlen]; // Always > 0. + if (data_) { + memcpy(static_cast<void*>(tmp), static_cast<const void*>(data_), len_); + } + if (index > len_) { + memset(static_cast<void*>(tmp + len_), 0, index - len_); + } + delete[] data_; + data_ = tmp; + len_ = newlen; + } + if (data_) { + memcpy(static_cast<void*>(data_ + index), static_cast<const void*>(val), + count); + } + return index + count; + } + + size_t Write(size_t index, const DataBuffer& buf) { + return Write(index, buf.data(), buf.len()); + } + + // Write an integer, also performing host-to-network order conversion. + // Returns the offset of the end of the write. + size_t Write(size_t index, uint32_t val, size_t count) { + assert(count <= sizeof(uint32_t)); + uint32_t nvalue = htonl(val); + auto* addr = reinterpret_cast<const uint8_t*>(&nvalue); + return Write(index, addr + sizeof(uint32_t) - count, count); + } + + // This can't use the same trick as Write(), since we might be reading from a + // smaller data source. + bool Read(size_t index, size_t count, uint32_t* val) const { + assert(count < sizeof(uint32_t)); + assert(val); + if ((index > len()) || (count > (len() - index))) { + return false; + } + *val = 0; + for (size_t i = 0; i < count; ++i) { + *val = (*val << 8) | data()[index + i]; + } + return true; + } + + // Starting at |index|, remove |remove| bytes and replace them with the + // contents of |buf|. + void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) { + Splice(buf.data(), buf.len(), index, remove); + } + + void Splice(const uint8_t* ins, size_t ins_len, size_t index, + size_t remove = 0) { + assert(ins); + uint8_t* old_value = data_; + size_t old_len = len_; + + // The amount of stuff remaining from the tail of the old. + size_t tail_len = old_len - std::min(old_len, index + remove); + // The new length: the head of the old, the new, and the tail of the old. + len_ = index + ins_len + tail_len; + data_ = new uint8_t[len_ ? len_ : 1]; + + // The head of the old. + if (old_value) { + Write(0, old_value, std::min(old_len, index)); + } + // Maybe a gap. + if (old_value && index > old_len) { + memset(old_value + index, 0, index - old_len); + } + // The new. + Write(index, ins, ins_len); + // The tail of the old. + if (tail_len > 0) { + Write(index + ins_len, old_value + index + remove, tail_len); + } + + delete[] old_value; + } + + void Append(const DataBuffer& buf) { Splice(buf, len_); } + + const uint8_t* data() const { return data_; } + uint8_t* data() { return data_; } + size_t len() const { return len_; } + bool empty() const { return len_ == 0; } + + private: + uint8_t* data_; + size_t len_; +}; + +static const size_t kMaxBufferPrint = 32; + +inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) { + stream << "[" << buf.len() << "] "; + for (size_t i = 0; i < buf.len(); ++i) { + if (!g_ssl_gtest_verbose && i >= kMaxBufferPrint) { + stream << "..."; + break; + } + stream << std::hex << std::setfill('0') << std::setw(2) + << static_cast<unsigned>(buf.data()[i]); + } + stream << std::dec; + return stream; +} + +inline bool operator==(const DataBuffer& a, const DataBuffer& b) { + return (a.empty() && b.empty()) || + (a.len() == b.len() && 0 == memcmp(a.data(), b.data(), a.len())); +} + +inline bool operator!=(const DataBuffer& a, const DataBuffer& b) { + return !(a == b); +} + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/gtest_utils.h b/security/nss/gtests/ssl_gtest/gtest_utils.h new file mode 100644 index 000000000..3ecd96cbe --- /dev/null +++ b/security/nss/gtests/ssl_gtest/gtest_utils.h @@ -0,0 +1,57 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef gtest_utils_h__ +#define gtest_utils_h__ + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "test_io.h" + +namespace nss_test { + +// Gtest utilities +class Timeout : public PollTarget { + public: + Timeout(int32_t timer_ms) : handle_(nullptr) { + Poller::Instance()->SetTimer(timer_ms, this, &Timeout::ExpiredCallback, + &handle_); + } + ~Timeout() { + if (handle_) { + handle_->Cancel(); + } + } + + static void ExpiredCallback(PollTarget* target, Event event) { + Timeout* timeout = static_cast<Timeout*>(target); + timeout->handle_ = nullptr; + } + + bool timed_out() const { return !handle_; } + + private: + Poller::Timer* handle_; +}; + +} // namespace nss_test + +#define WAIT_(expression, timeout) \ + do { \ + Timeout tm(timeout); \ + while (!(expression)) { \ + Poller::Instance()->Poll(); \ + if (tm.timed_out()) break; \ + } \ + } while (0) + +#define ASSERT_TRUE_WAIT(expression, timeout) \ + do { \ + WAIT_(expression, timeout); \ + ASSERT_TRUE(expression); \ + } while (0) + +#endif diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.c b/security/nss/gtests/ssl_gtest/libssl_internals.c new file mode 100644 index 000000000..5136ee8ec --- /dev/null +++ b/security/nss/gtests/ssl_gtest/libssl_internals.c @@ -0,0 +1,340 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +/* This file contains functions for frobbing the internals of libssl */ +#include "libssl_internals.h" + +#include "nss.h" +#include "pk11pub.h" +#include "seccomon.h" +#include "ssl.h" +#include "sslimpl.h" + +SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ++ss->clientHelloVersion; + + return SECSuccess; +} + +/* Use this function to update the ClientRandom of a client's handshake state + * after replacing its ClientHello message. We for example need to do this + * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */ +SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, + size_t rnd_len, uint8_t *msg, + size_t msg_len) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + SECStatus rv = ssl3_InitState(ss); + if (rv != SECSuccess) { + return rv; + } + + rv = ssl3_RestartHandshakeHashes(ss); + if (rv != SECSuccess) { + return rv; + } + + // Ensure we don't overrun hs.client_random. + rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len); + + // Zero the client_random struct. + PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH); + + // Copy over the challenge bytes. + size_t offset = SSL3_RANDOM_LENGTH - rnd_len; + PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len); + + // Rehash the SSLv2 client hello message. + return ssl3_UpdateHandshakeHashes(ss, msg, msg_len); +} + +PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) { + sslSocket *ss = ssl_FindSocket(fd); + return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext)); +} + +void SSLInt_ClearSessionTicketKey() { + ssl3_SessionTicketShutdown(NULL, NULL); + NSS_UnregisterShutdown(ssl3_SessionTicketShutdown, NULL); +} + +SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) { + sslSocket *ss = ssl_FindSocket(fd); + if (ss) { + ss->ssl3.mtu = mtu; + return SECSuccess; + } + return SECFailure; +} + +PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) { + PRCList *cur_p; + PRInt32 ct = 0; + + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return -1; + } + + for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); + cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { + ++ct; + } + return ct; +} + +void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) { + PRCList *cur_p; + + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return; + } + + fprintf(stderr, "Cipher specs\n"); + for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs); + cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) { + ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p; + fprintf(stderr, " %s\n", spec->phase); + } +} + +/* Force a timer expiry by backdating when the timer was started. + * We could set the remaining time to 0 but then backoff would not + * work properly if we decide to test it. */ +void SSLInt_ForceTimerExpiry(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return; + } + + if (!ss->ssl3.hs.rtTimerCb) return; + + ss->ssl3.hs.rtTimerStarted = + PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1); +} + +#define CHECK_SECRET(secret) \ + if (ss->ssl3.hs.secret) { \ + fprintf(stderr, "%s != NULL\n", #secret); \ + return PR_FALSE; \ + } + +PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + CHECK_SECRET(currentSecret); + CHECK_SECRET(resumptionMasterSecret); + CHECK_SECRET(dheSecret); + CHECK_SECRET(clientEarlyTrafficSecret); + CHECK_SECRET(clientHsTrafficSecret); + CHECK_SECRET(serverHsTrafficSecret); + + return PR_TRUE; +} + +PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) { + unsigned char data[32] = {0}; + PK11SymKey **keyPtr; + PK11SlotInfo *slot = PK11_GetInternalSlot(); + SECItem key_item = {siBuffer, data, sizeof(data)}; + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + if (!slot) { + return PR_FALSE; + } + keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset); + if (!*keyPtr) { + return PR_FALSE; + } + PK11_FreeSymKey(*keyPtr); + *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap, + CKA_DERIVE, &key_item, NULL); + PK11_FreeSlot(slot); + if (!*keyPtr) { + return PR_FALSE; + } + + return PR_TRUE; +} + +PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret)); +} + +PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret)); +} + +PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) { + return sslint_DamageTrafficSecret( + fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret)); +} + +SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE; + if (ss->xtnData.nextProto.data) { + SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE); + } + if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure; + PORT_Memcpy(ss->xtnData.nextProto.data, data, len); + + return SECSuccess; +} + +PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + return (PRBool)(!!ssl_FindServerCertByAuthType(ss, authType)); +} + +PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + SECStatus rv = SSL3_SendAlert(ss, level, type); + if (rv != SECSuccess) return PR_FALSE; + + return PR_TRUE; +} + +PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) { + sslSocket *ss = ssl_FindSocket(fd); + if (!ss) { + return PR_FALSE; + } + + ssl_GetSSL3HandshakeLock(ss); + ssl_GetXmitBufLock(ss); + + SECStatus rv = tls13_SendNewSessionTicket(ss); + if (rv == SECSuccess) { + rv = ssl3_FlushHandshake(ss, 0); + } + + ssl_ReleaseXmitBufLock(ss); + ssl_ReleaseSSL3HandshakeLock(ss); + + return rv == SECSuccess; +} + +SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) { + PRUint64 epoch; + sslSocket *ss; + ssl3CipherSpec *spec; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + if (to >= (1ULL << 48)) { + return SECFailure; + } + ssl_GetSpecWriteLock(ss); + spec = ss->ssl3.crSpec; + epoch = spec->read_seq_num >> 48; + spec->read_seq_num = (epoch << 48) | to; + + /* For DTLS, we need to fix the record sequence number. For this, we can just + * scrub the entire structure on the assumption that the new sequence number + * is far enough past the last received sequence number. */ + if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) { + return SECFailure; + } + dtls_RecordSetRecvd(&spec->recvdRecords, to); + + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; +} + +SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) { + PRUint64 epoch; + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + if (to >= (1ULL << 48)) { + return SECFailure; + } + ssl_GetSpecWriteLock(ss); + epoch = ss->ssl3.cwSpec->write_seq_num >> 48; + ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to; + ssl_ReleaseSpecWriteLock(ss); + return SECSuccess; +} + +SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) { + sslSocket *ss; + sslSequenceNumber to; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + ssl_GetSpecReadLock(ss); + to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra; + ssl_ReleaseSpecReadLock(ss); + return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX); +} + +SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) { + const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group); + if (!groupDef) return ssl_kea_null; + + return groupDef->keaType; +} + +SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + ss->opt.enableShortHeaders = PR_TRUE; + return SECSuccess; +} + +SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) { + sslSocket *ss; + + ss = ssl_FindSocket(fd); + if (!ss) { + return SECFailure; + } + + *result = ss->ssl3.hs.shortHeaders; + + return SECSuccess; +} diff --git a/security/nss/gtests/ssl_gtest/libssl_internals.h b/security/nss/gtests/ssl_gtest/libssl_internals.h new file mode 100644 index 000000000..6ea66db81 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/libssl_internals.h @@ -0,0 +1,43 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef libssl_internals_h_ +#define libssl_internals_h_ + +#include <stdint.h> + +#include "prio.h" +#include "seccomon.h" +#include "sslt.h" + +SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd); + +SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd, + size_t rnd_len, uint8_t *msg, + size_t msg_len); + +PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext); +void SSLInt_ClearSessionTicketKey(); +PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd); +void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd); +void SSLInt_ForceTimerExpiry(PRFileDesc *fd); +SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu); +PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd); +PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd); +PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd); +PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd); +SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len); +PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType); +PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type); +PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd); +SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to); +SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to); +SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra); +SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group); +SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd); +SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result); + +#endif // ndef libssl_internals_h_ diff --git a/security/nss/gtests/ssl_gtest/manifest.mn b/security/nss/gtests/ssl_gtest/manifest.mn new file mode 100644 index 000000000..391db813b --- /dev/null +++ b/security/nss/gtests/ssl_gtest/manifest.mn @@ -0,0 +1,54 @@ +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +CORE_DEPTH = ../.. +DEPTH = ../.. +MODULE = nss + +# These sources have access to libssl internals +CSRCS = \ + libssl_internals.c \ + $(NULL) + +CPPSRCS = \ + ssl_0rtt_unittest.cc \ + ssl_agent_unittest.cc \ + ssl_auth_unittest.cc \ + ssl_cert_ext_unittest.cc \ + ssl_ciphersuite_unittest.cc \ + ssl_damage_unittest.cc \ + ssl_dhe_unittest.cc \ + ssl_drop_unittest.cc \ + ssl_ecdh_unittest.cc \ + ssl_ems_unittest.cc \ + ssl_exporter_unittest.cc \ + ssl_extension_unittest.cc \ + ssl_fuzz_unittest.cc \ + ssl_gtest.cc \ + ssl_hrr_unittest.cc \ + ssl_loopback_unittest.cc \ + ssl_record_unittest.cc \ + ssl_resumption_unittest.cc \ + ssl_skip_unittest.cc \ + ssl_staticrsa_unittest.cc \ + ssl_v2_client_hello_unittest.cc \ + ssl_version_unittest.cc \ + test_io.cc \ + tls_agent.cc \ + tls_connect.cc \ + tls_hkdf_unittest.cc \ + tls_filter.cc \ + tls_parser.cc \ + $(NULL) + +INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \ + -I$(CORE_DEPTH)/gtests/common + +REQUIRES = nspr nss libdbm gtest + +PROGRAM = ssl_gtest +EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \ + $(DIST)/lib/$(LIB_PREFIX)softokn.$(LIB_SUFFIX) + +USE_STATIC_LIBS = 1 diff --git a/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc new file mode 100644 index 000000000..cf5a27fed --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc @@ -0,0 +1,203 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectTls13, ZeroRtt) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true); + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +// Test that we don't try to send 0-RTT data when the server sent +// us a ticket without the 0-RTT flags. +TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + Reset(); + server_->StartConnect(); + client_->StartConnect(); + // Now turn on 0-RTT but too late for the ticket. + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(false, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ClearServerCache(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + ZeroRttSendReceive(true, false); + Handshake(); + CheckConnected(); + SendReceive(); +} + +TEST_P(TlsConnectTls13, ZeroRttServerOnly) { + ExpectResumption(RESUME_NONE); + server_->Set0RttEnabled(true); + client_->StartConnect(); + server_->StartConnect(); + + // Client sends ordinary ClientHello. + client_->Handshake(); + + // Verify that the server doesn't get data. + uint8_t buf[100]; + PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf)); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Now make sure that things complete. + Handshake(); + CheckConnected(); + SendReceive(); + CheckKeys(); +} + +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) { + EnableAlpn(); + SetupForZeroRtt(); + EnableAlpn(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ExpectEarlyDataAccepted(true); + ZeroRttSendReceive(true, true, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + return true; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("a"); +} + +// Have the server negotiate a different ALPN value, and therefore +// reject 0-RTT. +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) { + EnableAlpn(); + SetupForZeroRtt(); + static const uint8_t client_alpn[] = {0x01, 0x61, 0x01, 0x62}; // "a", "b" + static const uint8_t server_alpn[] = {0x01, 0x62}; // "b" + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + server_->EnableAlpn(server_alpn, sizeof(server_alpn)); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + return true; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("b"); +} + +// Check that the client validates the ALPN selection of the server. +// Stomp the ALPN on the client after sending the ClientHello so +// that the server selection appears to be incorrect. The client +// should then fail the connection. +TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) { + EnableAlpn(); + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + EnableAlpn(); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true, [this]() { + PRUint8 b[] = {'b'}; + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a"); + EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b))); + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + return true; + }); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Set up with no ALPN and then set the client so it thinks it has ALPN. +// The server responds without the extension and the client returns an +// error. +TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, true, [this]() { + PRUint8 b[] = {'b'}; + EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1)); + client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b"); + return true; + }); + Handshake(); + client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +// Remove the old ALPN value and so the client will not offer early data. +TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) { + EnableAlpn(); + SetupForZeroRtt(); + static const uint8_t alpn[] = {0x01, 0x62}; // "b" + EnableAlpn(alpn, sizeof(alpn)); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + ZeroRttSendReceive(true, false, [this]() { + client_->CheckAlpn(SSL_NEXT_PROTO_NO_SUPPORT); + return false; + }); + Handshake(); + CheckConnected(); + SendReceive(); + CheckAlpn("b"); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc new file mode 100644 index 000000000..0e6ddaef8 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_agent_unittest.cc @@ -0,0 +1,210 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +// This is an internal header, used to get TLS_1_3_DRAFT_VERSION. +#include "ssl3prot.h" + +#include <memory> + +#include "databuffer.h" +#include "tls_agent.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +static const uint8_t kD13 = TLS_1_3_DRAFT_VERSION; +// This is a 1-RTT ClientHello with ECDHE. +const static uint8_t kCannedTls13ClientHello[] = { + 0x01, 0x00, 0x00, 0xcf, 0x03, 0x03, 0x6c, 0xb3, 0x46, 0x81, 0xc8, 0x1a, + 0xf9, 0xd2, 0x05, 0x97, 0x48, 0x7c, 0xa8, 0x31, 0x03, 0x1c, 0x06, 0xa8, + 0x62, 0xb1, 0x90, 0xd6, 0x21, 0x44, 0x7f, 0xc1, 0x9b, 0x87, 0x3e, 0xad, + 0x91, 0x85, 0x00, 0x00, 0x06, 0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0x01, + 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, + 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01, + 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00, + 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc, + 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65, + 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a, + 0xc2, 0xb3, 0xc6, 0x80, 0x72, 0x86, 0x08, 0x86, 0x8f, 0x52, 0xc5, 0xcb, + 0xbf, 0x2a, 0xb5, 0x59, 0x64, 0xcc, 0x0c, 0x49, 0x95, 0x36, 0xe4, 0xd9, + 0x2f, 0xd4, 0x24, 0x66, 0x71, 0x6f, 0x5d, 0x70, 0xe2, 0xa0, 0xea, 0x26, + 0x00, 0x2b, 0x00, 0x03, 0x02, 0x7f, kD13, 0x00, 0x0d, 0x00, 0x20, 0x00, + 0x1e, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08, + 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x04, + 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02}; + +const static uint8_t kCannedTls13ServerHello[] = { + 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0, + 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5, + 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01, + 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, + 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, + 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, + 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19}; +static const char *k0RttData = "ABCDEF"; + +TEST_P(TlsAgentTest, EarlyFinished) { + DataBuffer buffer; + MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_FINISHED); +} + +TEST_P(TlsAgentTest, EarlyCertificateVerify) { + DataBuffer buffer; + MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); +} + +TEST_P(TlsAgentTestClient, CannedHello) { + DataBuffer buffer; + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + DataBuffer server_hello; + MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello, + sizeof(kCannedTls13ServerHello), &server_hello); + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello.data(), server_hello.len(), &buffer); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); +} + +TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) { + DataBuffer server_hello; + MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello, + sizeof(kCannedTls13ServerHello), &server_hello); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello.data(), server_hello.len(), &buffer); + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) { + DataBuffer server_hello; + MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello, + sizeof(kCannedTls13ServerHello), &server_hello); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello.data(), 20, &buffer); + + DataBuffer buffer2; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello.data() + 20, server_hello.len() - 20, &buffer2); + + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ProcessMessage(buffer2, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) { + DataBuffer server_hello_frag1; + MakeHandshakeMessageFragment( + kTlsHandshakeServerHello, kCannedTls13ServerHello, + sizeof(kCannedTls13ServerHello), &server_hello_frag1, 0, 0, 20); + DataBuffer server_hello_frag2; + MakeHandshakeMessageFragment( + kTlsHandshakeServerHello, kCannedTls13ServerHello + 20, + sizeof(kCannedTls13ServerHello), &server_hello_frag2, 0, 20, + sizeof(kCannedTls13ServerHello) - 20); + DataBuffer encrypted_extensions; + MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0, + &encrypted_extensions, 1); + server_hello_frag2.Append(encrypted_extensions); + DataBuffer buffer; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello_frag1.data(), server_hello_frag1.len(), &buffer); + + DataBuffer buffer2; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + server_hello_frag2.data(), server_hello_frag2.len(), &buffer2, 1); + + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + ProcessMessage(buffer2, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HANDSHAKE); +} + +TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + auto filter = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeClientHello); + agent_->SetPacketFilter(filter); + PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData)); + EXPECT_EQ(-1, rv); + int32_t err = PORT_GetError(); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, err); + EXPECT_LT(0UL, filter->buffer().len()); +} + +TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + DataBuffer buffer; + MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3, + reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData), + &buffer); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA); +} + +// The server is allowing 0-RTT but the client doesn't offer it, +// so trial decryption isn't engaged and 0-RTT messages cause +// an error. +TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) { + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + agent_->Set0RttEnabled(true); + DataBuffer buffer; + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, + kCannedTls13ClientHello, sizeof(kCannedTls13ClientHello), &buffer); + ProcessMessage(buffer, TlsAgent::STATE_CONNECTING); + MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3, + reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData), + &buffer); + ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ); +} + +INSTANTIATE_TEST_CASE_P( + AgentTests, TlsAgentTest, + ::testing::Combine(TlsAgentTestBase::kTlsRolesAll, + TlsConnectTestBase::kTlsModesStream)); +INSTANTIATE_TEST_CASE_P(ClientTests, TlsAgentTestClient, + TlsConnectTestBase::kTlsModesAll); +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc new file mode 100644 index 000000000..d663d81e0 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_auth_unittest.cc @@ -0,0 +1,754 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, ServerAuthBigRsa) { + Reset(TlsAgent::kRsa2048); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ServerAuthRsaChain) { + Reset(TlsAgent::kServerRsaChain); + Connect(); + CheckKeys(); + size_t chain_length; + EXPECT_TRUE(client_->GetPeerChainLength(&chain_length)); + EXPECT_EQ(2UL, chain_length); +} + +TEST_P(TlsConnectGeneric, ClientAuth) { + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); +} + +// In TLS 1.3, the client sends its cert rejection on the +// second flight, and since it has already received the +// server's Finished, it transitions to complete and +// then gets an alert from the server. The test harness +// doesn't handle this right yet. +TEST_P(TlsConnectStream, DISABLED_ClientAuthRequiredRejected) { + server_->RequestClientAuth(true); + ConnectExpectFail(); +} + +TEST_P(TlsConnectGeneric, ClientAuthRequestedRejected) { + server_->RequestClientAuth(false); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ClientAuthEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectGeneric, ClientAuthBigRsa) { + Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); +} + +// Offset is the position in the captured buffer where the signature sits. +static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture, + size_t offset, TlsAgent* peer, + uint16_t expected_scheme, size_t expected_size) { + EXPECT_LT(offset + 2U, capture->buffer().len()); + + uint32_t scheme = 0; + capture->buffer().Read(offset, 2, &scheme); + EXPECT_EQ(expected_scheme, static_cast<uint16_t>(scheme)); + + ScopedCERTCertificate remote_cert(SSL_PeerCertificate(peer->ssl_fd())); + ScopedSECKEYPublicKey remote_key(CERT_ExtractPublicKey(remote_cert.get())); + EXPECT_EQ(expected_size, SECKEY_PublicKeyStrengthInBits(remote_key.get())); +} + +// The server should prefer SHA-256 by default, even for the small key size used +// in the default certificate. +TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) { + EnsureTlsSetup(); + auto capture_ske = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + server_->SetPacketFilter(capture_ske); + Connect(); + CheckKeys(); + + const DataBuffer& buffer = capture_ske->buffer(); + EXPECT_LT(3U, buffer.len()); + EXPECT_EQ(3U, buffer.data()[0]) << "curve_type == named_curve"; + uint32_t tmp; + EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve"; + EXPECT_EQ(ssl_grp_ec_curve25519, tmp); + EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint"; + CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_sha256, 1024); +} + +TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) { + EnsureTlsSetup(); + auto capture_cert_verify = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify); + client_->SetPacketFilter(capture_cert_verify); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024); +} + +TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) { + Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048); + auto capture_cert_verify = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify); + client_->SetPacketFilter(capture_cert_verify); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(); + CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048); +} + +static const SSLSignatureScheme SignatureSchemeEcdsaSha384[] = { + ssl_sig_ecdsa_secp384r1_sha384}; +static const SSLSignatureScheme SignatureSchemeEcdsaSha256[] = { + ssl_sig_ecdsa_secp256r1_sha256}; +static const SSLSignatureScheme SignatureSchemeRsaSha384[] = { + ssl_sig_rsa_pkcs1_sha384}; +static const SSLSignatureScheme SignatureSchemeRsaSha256[] = { + ssl_sig_rsa_pkcs1_sha256}; + +static SSLNamedGroup NamedGroupForEcdsa384(uint16_t version) { + // NSS tries to match the group size to the symmetric cipher. In TLS 1.1 and + // 1.0, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA is the highest priority suite, so + // we use P-384. With TLS 1.2 on we pick AES-128 GCM so use x25519. + if (version <= SSL_LIBRARY_VERSION_TLS_1_1) { + return ssl_grp_ec_secp384r1; + } + return ssl_grp_ec_curve25519; +} + +// When signature algorithms match up, this should connect successfully; even +// for TLS 1.1 and 1.0, where they should be ignored. +TEST_P(TlsConnectGeneric, SignatureAlgorithmServerAuth) { + Reset(TlsAgent::kServerEcdsa384); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// Here the client picks a single option, which should work in all versions. +// Defaults on the server include the first option. +TEST_P(TlsConnectGeneric, SignatureAlgorithmClientOnly) { + const SSLSignatureAndHashAlg clientAlgorithms[] = { + {ssl_hash_sha384, ssl_sign_ecdsa}, + {ssl_hash_sha384, ssl_sign_rsa}, // supported but unusable + {ssl_hash_md5, ssl_sign_ecdsa} // unsupported and ignored + }; + Reset(TlsAgent::kServerEcdsa384); + EnsureTlsSetup(); + // Use the old API for this function. + EXPECT_EQ(SECSuccess, + SSL_SignaturePrefSet(client_->ssl_fd(), clientAlgorithms, + PR_ARRAY_SIZE(clientAlgorithms))); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// Here the server picks a single option, which should work in all versions. +// Defaults on the client include the provided option. +TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) { + Reset(TlsAgent::kServerEcdsa384); + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + Connect(); + CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa, + ssl_sig_ecdsa_secp384r1_sha384); +} + +// In TLS 1.2, curve and hash aren't bound together. +TEST_P(TlsConnectTls12, SignatureSchemeCurveMismatch) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + Connect(); +} + +// In TLS 1.3, curve and hash are coupled. +TEST_P(TlsConnectTls13, SignatureSchemeCurveMismatch) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Configuring a P-256 cert with only SHA-384 signatures is OK in TLS 1.2. +TEST_P(TlsConnectTls12, SignatureSchemeBadConfig) { + Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + Connect(); +} + +// A P-256 certificate in TLS 1.3 needs a SHA-256 signature scheme. +TEST_P(TlsConnectTls13, SignatureSchemeBadConfig) { + Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used. + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Where there is no overlap on signature schemes, we still connect successfully +// if we aren't going to use a signature. +TEST_P(TlsConnectGenericPre13, SignatureAlgorithmNoOverlapStaticRsa) { + client_->SetSignatureSchemes(SignatureSchemeRsaSha384, + PR_ARRAY_SIZE(SignatureSchemeRsaSha384)); + server_->SetSignatureSchemes(SignatureSchemeRsaSha256, + PR_ARRAY_SIZE(SignatureSchemeRsaSha256)); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_auth_rsa_decrypt); +} + +TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM); +} + +// Pre 1.2, a mismatch on signature algorithms shouldn't affect anything. +TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384)); + server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256, + PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256)); + Connect(); +} + +// The signature_algorithms extension is mandatory in TLS 1.3. +TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) { + client_->SetPacketFilter( + new TlsExtensionDropper(ssl_signature_algorithms_xtn)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); +} + +// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and +// only fails when the Finished is checked. +TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) { + client_->SetPacketFilter( + new TlsExtensionDropper(ssl_signature_algorithms_xtn)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +TEST_P(TlsConnectTls12Plus, RequestClientAuthWithSha384) { + server_->SetSignatureSchemes(SignatureSchemeRsaSha384, + PR_ARRAY_SIZE(SignatureSchemeRsaSha384)); + server_->RequestClientAuth(false); + Connect(); +} + +class BeforeFinished : public TlsRecordFilter { + private: + enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE }; + + public: + BeforeFinished(TlsAgent* client, TlsAgent* server, VoidFunction before_ccs, + VoidFunction before_finished) + : client_(client), + server_(server), + before_ccs_(before_ccs), + before_finished_(before_finished), + state_(BEFORE_CCS) {} + + protected: + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& body, + DataBuffer* out) { + switch (state_) { + case BEFORE_CCS: + // Awaken when we see the CCS. + if (header.content_type() == kTlsChangeCipherSpecType) { + before_ccs_(); + + // Write the CCS out as a separate write, so that we can make + // progress. Ordinarily, libssl sends the CCS and Finished together, + // but that means that they both get processed together. + DataBuffer ccs; + header.Write(&ccs, 0, body); + server_->SendDirect(ccs); + client_->Handshake(); + state_ = AFTER_CCS; + // Request that the original record be dropped by the filter. + return DROP; + } + break; + + case AFTER_CCS: + EXPECT_EQ(kTlsHandshakeType, header.content_type()); + // This could check that data contains a Finished message, but it's + // encrypted, so that's too much extra work. + + before_finished_(); + state_ = DONE; + break; + + case DONE: + break; + } + return KEEP; + } + + private: + TlsAgent* client_; + TlsAgent* server_; + VoidFunction before_ccs_; + VoidFunction before_finished_; + HandshakeState state_; +}; + +// Running code after the client has started processing the encrypted part of +// the server's first flight, but before the Finished is processed is very hard +// in TLS 1.3. These encrypted messages are sent in a single encrypted blob. +// The following test uses DTLS to make it possible to force the client to +// process the handshake in pieces. +// +// The first encrypted message from the server is dropped, and the MTU is +// reduced to just below the original message size so that the server sends two +// messages. The Finished message is then processed separately. +class BeforeFinished13 : public PacketFilter { + private: + enum HandshakeState { + INIT, + BEFORE_FIRST_FRAGMENT, + BEFORE_SECOND_FRAGMENT, + DONE + }; + + public: + BeforeFinished13(TlsAgent* client, TlsAgent* server, + VoidFunction before_finished) + : client_(client), + server_(server), + before_finished_(before_finished), + records_(0) {} + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) { + switch (++records_) { + case 1: + // Packet 1 is the server's entire first flight. Drop it. + EXPECT_EQ(SECSuccess, + SSLInt_SetMTU(server_->ssl_fd(), input.len() - 1)); + return DROP; + + // Packet 2 is the first part of the server's retransmitted first + // flight. Keep that. + + case 3: + // Packet 3 is the second part of the server's retransmitted first + // flight. Before passing that on, make sure that the client processes + // packet 2, then call the before_finished_() callback. + client_->Handshake(); + before_finished_(); + break; + + default: + break; + } + return KEEP; + } + + private: + TlsAgent* client_; + TlsAgent* server_; + VoidFunction before_finished_; + size_t records_; +}; + +static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) { + return SECWouldBlock; +} + +// This test uses an AuthCertificateCallback that blocks. A filter is used to +// split the server's first flight into two pieces. Before the second piece is +// processed by the client, SSL_AuthCertificateComplete() is called. +TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + server_->SetPacketFilter(new BeforeFinished13(client_, server_, [this]() { + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + })); + Connect(); +} + +static void TriggerAuthComplete(PollTarget* target, Event event) { + std::cerr << "client: call SSL_AuthCertificateComplete" << std::endl; + EXPECT_EQ(TIMER_EVENT, event); + TlsAgent* client = static_cast<TlsAgent*>(target); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client->ssl_fd(), 0)); +} + +// This test uses a simple AuthCertificateCallback. Due to the way that the +// entire server flight is processed, the call to SSL_AuthCertificateComplete +// will trigger after the Finished message is processed. +TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) { + client_->SetAuthCertificateCallback( + [this](TlsAgent*, PRBool, PRBool) -> SECStatus { + Poller::Timer* timer_handle; + // This is really just to unroll the stack. + Poller::Instance()->SetTimer(1U, client_, TriggerAuthComplete, + &timer_handle); + return SECWouldBlock; + }); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) { + client_->EnableFalseStart(); + server_->SetPacketFilter(new BeforeFinished( + client_, server_, + [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); }, + [this]() { + // Write something, which used to fail: bug 1235366. + client_->SendData(10); + })); + + Connect(); + server_->SendData(10); + Receive(10); +} + +TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) { + client_->EnableFalseStart(); + client_->SetAuthCertificateCallback(AuthCompleteBlock); + server_->SetPacketFilter(new BeforeFinished( + client_, server_, + []() { + // Do nothing before CCS + }, + [this]() { + EXPECT_FALSE(client_->can_falsestart_hook_called()); + // AuthComplete before Finished still enables false start. + EXPECT_EQ(SECSuccess, + SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + EXPECT_TRUE(client_->can_falsestart_hook_called()); + client_->SendData(10); + })); + + Connect(); + server_->SendData(10); + Receive(10); +} + +class EnforceNoActivity : public PacketFilter { + protected: + PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override { + std::cerr << "Unexpected packet: " << input << std::endl; + EXPECT_TRUE(false) << "should not send anything"; + return KEEP; + } +}; + +// In this test, we want to make sure that the server completes its handshake, +// but the client does not. Because the AuthCertificate callback blocks and we +// never call SSL_AuthCertificateComplete(), the client should never report that +// it has completed the handshake. Manually call Handshake(), alternating sides +// between client and server, until the desired state is reached. +TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_->StartConnect(); + client_->StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + client_->Handshake(); // Send ClientKeyExchange and Finished + server_->Handshake(); // Send Finished + // The server should now report that it is connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // The client should send nothing from here on. + client_->SetPacketFilter(new EnforceNoActivity()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // This should allow the handshake to complete now. + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); // Transition to connected + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + // Remove this before closing or the close_notify alert will trigger it. + client_->SetPacketFilter(nullptr); +} + +// TLS 1.3 handles a delayed AuthComplete callback differently since the +// shape of the handshake is different. +TEST_P(TlsConnectTls13, AuthCompleteDelayed) { + client_->SetAuthCertificateCallback(AuthCompleteBlock); + + server_->StartConnect(); + client_->StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send ServerHello + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state()); + + // The client will send nothing until AuthCertificateComplete is called. + client_->SetPacketFilter(new EnforceNoActivity()); + client_->Handshake(); + EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state()); + + // This should allow the handshake to complete now. + client_->SetPacketFilter(nullptr); + EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0)); + client_->Handshake(); // Send Finished + server_->Handshake(); // Transition to connected and send NewSessionTicket + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); +} + +static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = { + ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr}; +static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = { + ssl_auth_rsa_sign, nullptr, nullptr, nullptr}; +static const SSLExtraServerCertData ServerCertDataRsaPss = { + ssl_auth_rsa_pss, nullptr, nullptr, nullptr}; + +// Test RSA cert with usage=[signature, encipherment]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1SignAndKEX) { + Reset(TlsAgent::kServerRsa); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign, rsa_pss, or rsa_decrypt should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPkcs1Decrypt)); + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false, + &ServerCertDataRsaPss)); +} + +// Test RSA cert with usage=[signature]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1Sign) { + Reset(TlsAgent::kServerRsaSign); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_decrypt should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPkcs1Decrypt)); + + // Configuring for only rsa_sign or rsa_pss should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false, + &ServerCertDataRsaPss)); +} + +// Test RSA cert with usage=[encipherment]. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1KEX) { + Reset(TlsAgent::kServerRsaDecrypt); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign or rsa_pss should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPss)); + + // Configuring for only rsa_decrypt should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false, + &ServerCertDataRsaPkcs1Decrypt)); +} + +// Test configuring an RSA-PSS cert. +TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) { + Reset(TlsAgent::kServerRsaPss); + + PRFileDesc* ssl_fd = agent_->ssl_fd(); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt)); + EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign)); + EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss)); + + // Configuring for only rsa_sign or rsa_decrypt should fail. + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPkcs1Sign)); + EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPkcs1Decrypt)); + + // Configuring for only rsa_pss should work. + EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false, + &ServerCertDataRsaPss)); +} + +// mode, version, certificate, auth type, signature scheme +typedef std::tuple<std::string, uint16_t, std::string, SSLAuthType, + SSLSignatureScheme> + SignatureSchemeProfile; + +class TlsSignatureSchemeConfiguration + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SignatureSchemeProfile> { + public: + TlsSignatureSchemeConfiguration() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())), + certificate_(std::get<2>(GetParam())), + auth_type_(std::get<3>(GetParam())), + signature_scheme_(std::get<4>(GetParam())) {} + + protected: + void TestSignatureSchemeConfig(TlsAgent* configPeer) { + EnsureTlsSetup(); + configPeer->SetSignatureSchemes(&signature_scheme_, 1); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, + signature_scheme_); + } + + std::string certificate_; + SSLAuthType auth_type_; + SSLSignatureScheme signature_scheme_; +}; + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) { + Reset(certificate_); + TestSignatureSchemeConfig(server_); +} + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) { + Reset(certificate_); + TlsExtensionCapture* capture = + new TlsExtensionCapture(ssl_signature_algorithms_xtn); + client_->SetPacketFilter(capture); + TestSignatureSchemeConfig(client_); + + const DataBuffer& ext = capture->extension(); + ASSERT_EQ(2U + 2U, ext.len()); + uint32_t v = 0; + ASSERT_TRUE(ext.Read(0, 2, &v)); + EXPECT_EQ(2U, v); + ASSERT_TRUE(ext.Read(2, 2, &v)); + EXPECT_EQ(signature_scheme_, static_cast<SSLSignatureScheme>(v)); +} + +TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) { + Reset(certificate_); + EnsureTlsSetup(); + client_->SetSignatureSchemes(&signature_scheme_, 1); + server_->SetSignatureSchemes(&signature_scheme_, 1); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, signature_scheme_); +} + +INSTANTIATE_TEST_CASE_P( + SignatureSchemeRsa, TlsSignatureSchemeConfiguration, + ::testing::Combine( + TlsConnectTestBase::kTlsModesAll, TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerRsaSign), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, + ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_sha256, + ssl_sig_rsa_pss_sha384))); +// PSS with SHA-512 needs a bigger key to work. +INSTANTIATE_TEST_CASE_P( + SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kRsa2048), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pss_sha512))); +INSTANTIATE_TEST_CASE_P( + SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12, + ::testing::Values(TlsAgent::kServerRsa), + ::testing::Values(ssl_auth_rsa_sign), + ::testing::Values(ssl_sig_rsa_pkcs1_sha1))); +INSTANTIATE_TEST_CASE_P( + SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa256), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256))); +INSTANTIATE_TEST_CASE_P( + SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa384), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384))); +INSTANTIATE_TEST_CASE_P( + SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus, + ::testing::Values(TlsAgent::kServerEcdsa521), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512))); +INSTANTIATE_TEST_CASE_P( + SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12, + ::testing::Values(TlsAgent::kServerEcdsa256, + TlsAgent::kServerEcdsa384), + ::testing::Values(ssl_auth_ecdsa), + ::testing::Values(ssl_sig_ecdsa_sha1))); +} diff --git a/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc new file mode 100644 index 000000000..876c36881 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc @@ -0,0 +1,214 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include <memory> + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +// Tests for Certificate Transparency (RFC 6962) +// These don't work with TLS 1.3: see bug 1252745. + +// Helper class - stores signed certificate timestamps as provided +// by the relevant callbacks on the client. +class SignedCertificateTimestampsExtractor { + public: + SignedCertificateTimestampsExtractor(TlsAgent* client) : client_(client) { + client_->SetAuthCertificateCallback( + [&](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus { + const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); + EXPECT_TRUE(scts); + if (!scts) { + return SECFailure; + } + auth_timestamps_.reset(new DataBuffer(scts->data, scts->len)); + return SECSuccess; + }); + client_->SetHandshakeCallback([&](TlsAgent* agent) { + const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd()); + ASSERT_TRUE(scts); + handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len)); + }); + } + + void assertTimestamps(const DataBuffer& timestamps) { + EXPECT_TRUE(auth_timestamps_); + EXPECT_EQ(timestamps, *auth_timestamps_); + + EXPECT_TRUE(handshake_timestamps_); + EXPECT_EQ(timestamps, *handshake_timestamps_); + + const SECItem* current = SSL_PeerSignedCertTimestamps(client_->ssl_fd()); + EXPECT_EQ(timestamps, DataBuffer(current->data, current->len)); + } + + private: + TlsAgent* client_; + std::unique_ptr<DataBuffer> auth_timestamps_; + std::unique_ptr<DataBuffer> handshake_timestamps_; +}; + +static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89}; +static const SECItem kSctItem = {siBuffer, const_cast<uint8_t*>(kSctValue), + sizeof(kSctValue)}; +static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue)); + +// Test timestamps extraction during a successful handshake. +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), + &kSctItem, ssl_kea_rsa)); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + PR_TRUE)); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + + timestamps_extractor.assertTimestamps(kSctBuffer); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) { + static const SSLExtraServerCertData kExtraData = {ssl_auth_rsa_sign, nullptr, + nullptr, &kSctItem}; + + EnsureTlsSetup(); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraData)); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + PR_TRUE)); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + + timestamps_extractor.assertTimestamps(kSctBuffer); +} + +// Test SSL_PeerSignedCertTimestamps returning zero-length SECItem +// when the client / the server / both have not enabled the feature. +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(), + &kSctItem, ssl_kea_rsa)); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS, + PR_TRUE)); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveBoth) { + EnsureTlsSetup(); + SignedCertificateTimestampsExtractor timestamps_extractor(client_); + + Connect(); + timestamps_extractor.assertTimestamps(DataBuffer()); +} + +// Check that the given agent doesn't have an OCSP response for its peer. +static SECStatus CheckNoOCSP(TlsAgent* agent, bool checksig, bool isServer) { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + EXPECT_TRUE(ocsp); + EXPECT_EQ(0U, ocsp->len); + return SECSuccess; +} + +static const uint8_t kOcspValue1[] = {1, 2, 3, 4, 5, 6}; +static const uint8_t kOcspValue2[] = {7, 8, 9}; +static const SECItem kOcspItems[] = { + {siBuffer, const_cast<uint8_t*>(kOcspValue1), sizeof(kOcspValue1)}, + {siBuffer, const_cast<uint8_t*>(kOcspValue2), sizeof(kOcspValue2)}}; +static const SECItemArray kOcspResponses = {const_cast<SECItem*>(kOcspItems), + PR_ARRAY_SIZE(kOcspItems)}; +const static SSLExtraServerCertData kOcspExtraData = { + ssl_auth_rsa_sign, nullptr, &kOcspResponses, nullptr}; + +TEST_P(TlsConnectGeneric, NoOcsp) { + EnsureTlsSetup(); + client_->SetAuthCertificateCallback(CheckNoOCSP); + Connect(); +} + +// The client doesn't get OCSP stapling unless it asks. +TEST_P(TlsConnectGeneric, OcspNotRequested) { + EnsureTlsSetup(); + client_->SetAuthCertificateCallback(CheckNoOCSP); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + Connect(); +} + +// Even if the client asks, the server has nothing unless it is configured. +TEST_P(TlsConnectGeneric, OcspNotProvided) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + client_->SetAuthCertificateCallback(CheckNoOCSP); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, OcspMangled) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + + static const uint8_t val[] = {1}; + auto replacer = new TlsExtensionReplacer(ssl_cert_status_xtn, + DataBuffer(val, sizeof(val))); + server_->SetPacketFilter(replacer); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +TEST_P(TlsConnectGeneric, OcspSuccess) { + EnsureTlsSetup(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_ENABLE_OCSP_STAPLING, PR_TRUE)); + auto capture_ocsp = new TlsExtensionCapture(ssl_cert_status_xtn); + server_->SetPacketFilter(capture_ocsp); + + // The value should be available during the AuthCertificateCallback + client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig, + bool isServer) -> SECStatus { + const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd()); + if (!ocsp) { + return SECFailure; + } + EXPECT_EQ(1U, ocsp->len) << "We only provide the first item"; + EXPECT_EQ(0, SECITEM_CompareItem(&kOcspItems[0], &ocsp->items[0])); + return SECSuccess; + }); + EXPECT_TRUE( + server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData)); + + Connect(); + // In TLS 1.3, the server doesn't provide a visible ServerHello extension. + // For earlier versions, the extension is just empty. + EXPECT_EQ(0U, capture_ocsp->extension().len()); +} + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc new file mode 100644 index 000000000..ab10a84a0 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc @@ -0,0 +1,455 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_parser.h" + +namespace nss_test { + +// mode, version, cipher suite +typedef std::tuple<std::string, uint16_t, uint16_t, SSLNamedGroup, + SSLSignatureScheme> + CipherSuiteProfile; + +class TlsCipherSuiteTestBase : public TlsConnectTestBase { + public: + TlsCipherSuiteTestBase(const std::string &mode, uint16_t version, + uint16_t cipher_suite, SSLNamedGroup group, + SSLSignatureScheme signature_scheme) + : TlsConnectTestBase(mode, version), + cipher_suite_(cipher_suite), + group_(group), + signature_scheme_(signature_scheme), + csinfo_({0}) { + SECStatus rv = + SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_)); + EXPECT_EQ(SECSuccess, rv); + if (rv == SECSuccess) { + std::cerr << "Cipher suite: " << csinfo_.cipherSuiteName << std::endl; + } + auth_type_ = csinfo_.authType; + kea_type_ = csinfo_.keaType; + } + + protected: + void EnableSingleCipher() { + EnsureTlsSetup(); + // It doesn't matter which does this, but the test is better if both do it. + client_->EnableSingleCipher(cipher_suite_); + server_->EnableSingleCipher(cipher_suite_); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + std::vector<SSLNamedGroup> groups = {group_}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + kea_type_ = SSLInt_GetKEAType(group_); + + client_->SetSignatureSchemes(&signature_scheme_, 1); + server_->SetSignatureSchemes(&signature_scheme_, 1); + } + } + + virtual void SetupCertificate() { + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + switch (signature_scheme_) { + case ssl_sig_rsa_pkcs1_sha256: + case ssl_sig_rsa_pkcs1_sha384: + case ssl_sig_rsa_pkcs1_sha512: + Reset(TlsAgent::kServerRsaSign); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_rsa_pss_sha256: + case ssl_sig_rsa_pss_sha384: + Reset(TlsAgent::kServerRsaSign); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_rsa_pss_sha512: + // You can't fit SHA-512 PSS in a 1024-bit key. + Reset(TlsAgent::kRsa2048); + auth_type_ = ssl_auth_rsa_sign; + break; + case ssl_sig_ecdsa_secp256r1_sha256: + Reset(TlsAgent::kServerEcdsa256); + auth_type_ = ssl_auth_ecdsa; + break; + case ssl_sig_ecdsa_secp384r1_sha384: + Reset(TlsAgent::kServerEcdsa384); + auth_type_ = ssl_auth_ecdsa; + break; + default: + ASSERT_TRUE(false) << "Unsupported signature scheme: " + << signature_scheme_; + break; + } + } else { + switch (csinfo_.authType) { + case ssl_auth_rsa_sign: + Reset(TlsAgent::kServerRsaSign); + break; + case ssl_auth_rsa_decrypt: + Reset(TlsAgent::kServerRsaDecrypt); + break; + case ssl_auth_ecdsa: + Reset(TlsAgent::kServerEcdsa256); + break; + case ssl_auth_ecdh_ecdsa: + Reset(TlsAgent::kServerEcdhEcdsa); + break; + case ssl_auth_ecdh_rsa: + Reset(TlsAgent::kServerEcdhRsa); + break; + case ssl_auth_dsa: + Reset(TlsAgent::kServerDsa); + break; + default: + ASSERT_TRUE(false) << "Unsupported cipher suite: " << cipher_suite_; + break; + } + } + } + + void ConnectAndCheckCipherSuite() { + Connect(); + SendReceive(); + + // Check that we used the right cipher suite. + uint16_t actual; + EXPECT_TRUE(client_->cipher_suite(&actual) && actual == cipher_suite_); + EXPECT_TRUE(server_->cipher_suite(&actual) && actual == cipher_suite_); + SSLAuthType auth; + EXPECT_TRUE(client_->auth_type(&auth) && auth == auth_type_); + EXPECT_TRUE(server_->auth_type(&auth) && auth == auth_type_); + SSLKEAType kea; + EXPECT_TRUE(client_->kea_type(&kea) && kea == kea_type_); + EXPECT_TRUE(server_->kea_type(&kea) && kea == kea_type_); + } + + // Get the expected limit on the number of records that can be sent for the + // cipher suite. + uint64_t record_limit() const { + switch (csinfo_.symCipher) { + case ssl_calg_rc4: + case ssl_calg_3des: + return 1ULL << 20; + case ssl_calg_aes: + case ssl_calg_aes_gcm: + return 0x5aULL << 28; + case ssl_calg_null: + case ssl_calg_chacha20: + return (1ULL << 48) - 1; + case ssl_calg_rc2: + case ssl_calg_des: + case ssl_calg_idea: + case ssl_calg_fortezza: + case ssl_calg_camellia: + case ssl_calg_seed: + break; + } + EXPECT_TRUE(false) << "No limit for " << csinfo_.cipherSuiteName; + return 1ULL < 48; + } + + uint64_t last_safe_write() const { + uint64_t limit = record_limit() - 1; + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1 && + (csinfo_.symCipher == ssl_calg_3des || + csinfo_.symCipher == ssl_calg_aes)) { + // 1/n-1 record splitting needs space for two records. + limit--; + } + return limit; + } + + protected: + uint16_t cipher_suite_; + SSLAuthType auth_type_; + SSLKEAType kea_type_; + SSLNamedGroup group_; + SSLSignatureScheme signature_scheme_; + SSLCipherSuiteInfo csinfo_; +}; + +class TlsCipherSuiteTest + : public TlsCipherSuiteTestBase, + public ::testing::WithParamInterface<CipherSuiteProfile> { + public: + TlsCipherSuiteTest() + : TlsCipherSuiteTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()), + std::get<2>(GetParam()), std::get<3>(GetParam()), + std::get<4>(GetParam())) {} + + protected: + bool SkipIfCipherSuiteIsDSA() { + bool isDSA = csinfo_.authType == ssl_auth_dsa; + if (isDSA) { + std::cerr << "Skipping DSA suite: " << csinfo_.cipherSuiteName + << std::endl; + } + return isDSA; + } +}; + +TEST_P(TlsCipherSuiteTest, SingleCipherSuite) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); +} + +TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) { + if (SkipIfCipherSuiteIsDSA()) { + return; // Tickets don't work with DSA (bug 1174677). + } + + SetupCertificate(); // This is only needed once. + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableSingleCipher(); + + ConnectAndCheckCipherSuite(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableSingleCipher(); + ExpectResumption(RESUME_TICKET); + ConnectAndCheckCipherSuite(); +} + +// This only works for stream ciphers because we modify the sequence number - +// which is included explicitly in the DTLS record header - and that trips a +// different error code. Note that the message that the client sends would not +// decrypt (the nonce/IV wouldn't match), but the record limit is hit before +// attempting to decrypt a record. +TEST_P(TlsCipherSuiteTest, ReadLimit) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write())); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write())); + + client_->SendData(10, 10); + server_->ReadBytes(); // This should be OK. + + // The payload needs to be big enough to pass for encrypted. In the extreme + // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for + // authentication tag. + static const uint8_t payload[18] = {6}; + DataBuffer record; + uint64_t epoch = 0; + if (mode_ == DGRAM) { + epoch++; + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + epoch++; + } + } + TlsAgentTestBase::MakeRecord(mode_, kTlsApplicationDataType, version_, + payload, sizeof(payload), &record, + (epoch << 48) | record_limit()); + server_->adapter()->PacketReceived(record); + server_->ExpectReadWriteError(); + server_->ReadBytes(); + EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code()); +} + +TEST_P(TlsCipherSuiteTest, WriteLimit) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write())); + client_->SendData(10, 10); + client_->ExpectReadWriteError(); + client_->SendData(10, 10); + EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, client_->error_code()); +} + +// This awful macro makes the test instantiations easier to read. +#define INSTANTIATE_CIPHER_TEST_P(name, modes, versions, groups, sigalgs, ...) \ + static const uint16_t k##name##CiphersArr[] = {__VA_ARGS__}; \ + static const ::testing::internal::ParamGenerator<uint16_t> \ + k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \ + INSTANTIATE_TEST_CASE_P( \ + CipherSuite##name, TlsCipherSuiteTest, \ + ::testing::Combine(TlsConnectTestBase::kTlsModes##modes, \ + TlsConnectTestBase::kTls##versions, k##name##Ciphers, \ + groups, sigalgs)); + +static const auto kDummyNamedGroupParams = ::testing::Values(ssl_grp_none); +static const auto kDummySignatureSchemesParams = + ::testing::Values(ssl_sig_none); + +#ifndef NSS_DISABLE_TLS_1_3 +static SSLSignatureScheme kSignatureSchemesParamsArr[] = { + ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384, + ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256, + ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_sha256, + ssl_sig_rsa_pss_sha384, ssl_sig_rsa_pss_sha512, +}; +#endif + +INSTANTIATE_CIPHER_TEST_P(RC4, Stream, V10ToV12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_RSA_WITH_RC4_128_SHA, + TLS_ECDH_ECDSA_WITH_RC4_128_SHA, + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + TLS_ECDH_RSA_WITH_RC4_128_SHA, + TLS_ECDHE_RSA_WITH_RC4_128_SHA); +INSTANTIATE_CIPHER_TEST_P(AEAD12, All, V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_RSA_WITH_AES_128_GCM_SHA256, + TLS_RSA_WITH_AES_256_GCM_SHA384, + TLS_DHE_DSS_WITH_AES_128_GCM_SHA256, + TLS_DHE_DSS_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384); +INSTANTIATE_CIPHER_TEST_P(AEAD, All, V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, + TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256); +INSTANTIATE_CIPHER_TEST_P( + CBC12, All, V12, kDummyNamedGroupParams, kDummySignatureSchemesParams, + TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, + TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256, + TLS_DHE_DSS_WITH_AES_256_CBC_SHA256); +INSTANTIATE_CIPHER_TEST_P( + CBCStream, Stream, V10ToV12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_NULL_SHA, + TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_ECDSA_WITH_NULL_SHA, + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_NULL_SHA, + TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_NULL_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); +INSTANTIATE_CIPHER_TEST_P( + CBCDatagram, Datagram, V11V12, kDummyNamedGroupParams, + kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, + TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_CIPHER_TEST_P(TLS13, All, V13, + ::testing::ValuesIn(kFasterDHEGroups), + ::testing::ValuesIn(kSignatureSchemesParamsArr), + TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256, + TLS_AES_256_GCM_SHA384); +INSTANTIATE_CIPHER_TEST_P(TLS13AllGroups, All, V13, + ::testing::ValuesIn(kAllDHEGroups), + ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384), + TLS_AES_256_GCM_SHA384); +#endif + +// Fields are: version, cipher suite, bulk cipher name, secretKeySize +struct SecStatusParams { + uint16_t version; + uint16_t cipher_suite; + std::string name; + int keySize; +}; + +inline std::ostream &operator<<(std::ostream &stream, + const SecStatusParams &vals) { + SSLCipherSuiteInfo csinfo; + SECStatus rv = + SSL_GetCipherSuiteInfo(vals.cipher_suite, &csinfo, sizeof(csinfo)); + if (rv != SECSuccess) { + return stream << "Error invoking SSL_GetCipherSuiteInfo()"; + } + + return stream << "TLS " << VersionString(vals.version) << ", " + << csinfo.cipherSuiteName << ", name = \"" << vals.name + << "\", key size = " << vals.keySize; +} + +class SecurityStatusTest + : public TlsCipherSuiteTestBase, + public ::testing::WithParamInterface<SecStatusParams> { + public: + SecurityStatusTest() + : TlsCipherSuiteTestBase("TLS", GetParam().version, + GetParam().cipher_suite, ssl_grp_none, + ssl_sig_none) {} +}; + +// SSL_SecurityStatus produces fairly useless output when compared to +// SSL_GetCipherSuiteInfo and SSL_GetChannelInfo, but we can't break it, so we +// need to check it. +TEST_P(SecurityStatusTest, CheckSecurityStatus) { + SetupCertificate(); + EnableSingleCipher(); + ConnectAndCheckCipherSuite(); + + int on; + char *cipher; + int keySize; + int secretKeySize; + char *issuer; + char *subject; + EXPECT_EQ(SECSuccess, + SSL_SecurityStatus(client_->ssl_fd(), &on, &cipher, &keySize, + &secretKeySize, &issuer, &subject)); + if (std::string(cipher) == "NULL") { + EXPECT_EQ(0, on); + } else { + EXPECT_NE(0, on); + } + EXPECT_EQ(GetParam().name, std::string(cipher)); + // All the ciphers we support have secret key size == key size. + EXPECT_EQ(GetParam().keySize, keySize); + EXPECT_EQ(GetParam().keySize, secretKeySize); + EXPECT_LT(0U, strlen(issuer)); + EXPECT_LT(0U, strlen(subject)); + + PORT_Free(cipher); + PORT_Free(issuer); + PORT_Free(subject); +} + +static const SecStatusParams kSecStatusTestValuesArr[] = { + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_NULL_SHA, "NULL", 0}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_RC4_128_SHA, "RC4", 128}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + "3DES-EDE-CBC", 168}, + {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_AES_128_CBC_SHA, "AES-128", 128}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_CBC_SHA256, "AES-256", + 256}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_128_GCM_SHA256, + "AES-128-GCM", 128}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_GCM_SHA384, + "AES-256-GCM", 256}, + {SSL_LIBRARY_VERSION_TLS_1_2, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "ChaCha20-Poly1305", 256}}; +INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest, + ::testing::ValuesIn(kSecStatusTestValuesArr)); + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc new file mode 100644 index 000000000..9dadcbdf6 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_damage_unittest.cc @@ -0,0 +1,61 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->StartConnect(); + client_->StartConnect(); + client_->Handshake(); + server_->Handshake(); + std::cerr << "Damaging HS secret\n"; + SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd()); + client_->Handshake(); + server_->Handshake(); + // The client thinks it has connected. + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetPacketFilter(new AfterRecordN( + server_, client_, + 0, // ServerHello. + [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc new file mode 100644 index 000000000..82d55586b --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc @@ -0,0 +1,609 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include <set> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, ConnectDhe) { + EnableOnlyDheCiphers(); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) { + EnsureTlsSetup(); + client_->ConfigNamedGroups(kAllDHEGroups); + + auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn); + auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn); + std::vector<PacketFilter*> captures; + captures.push_back(groups_capture); + captures.push_back(shares_capture); + client_->SetPacketFilter(new ChainedPacketFilter(captures)); + + Connect(); + + CheckKeys(); + + bool ec, dh; + auto track_group_type = [&ec, &dh](SSLNamedGroup group) { + if ((group & 0xff00U) == 0x100U) { + dh = true; + } else { + ec = true; + } + }; + CheckGroups(groups_capture->extension(), track_group_type); + CheckShares(shares_capture->extension(), track_group_type); + EXPECT_TRUE(ec) << "Should include an EC group and share"; + EXPECT_TRUE(dh) << "Should include an FFDHE group and share"; +} + +TEST_P(TlsConnectGeneric, ConnectFfdheClient) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn); + auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn); + std::vector<PacketFilter*> captures; + captures.push_back(groups_capture); + captures.push_back(shares_capture); + client_->SetPacketFilter(new ChainedPacketFilter(captures)); + + Connect(); + + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + auto is_ffdhe = [](SSLNamedGroup group) { + // The group has to be in this range. + EXPECT_LE(ssl_grp_ffdhe_2048, group); + EXPECT_GE(ssl_grp_ffdhe_8192, group); + }; + CheckGroups(groups_capture->extension(), is_ffdhe); + if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) { + CheckShares(shares_capture->extension(), is_ffdhe); + } else { + EXPECT_EQ(0U, shares_capture->extension().len()); + } +} + +// Requiring the FFDHE extension on the server alone means that clients won't be +// able to connect using a DHE suite. They should still connect in TLS 1.3, +// because the client automatically sends the supported groups extension. +TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + Connect(); + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + } else { + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + } +} + +class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter { + public: + TlsDheServerKeyExchangeDamager() {} + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + // Damage the first octet of dh_p. Anything other than the known prime will + // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled. + *output = input; + output->data()[3] ^= 73; + return CHANGE; + } +}; + +// Changing the prime in the server's key share results in an error. This will +// invalidate the signature over the ServerKeyShare. That's ok, NSS won't check +// the signature until everything else has been checked. +TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + server_->SetPacketFilter(new TlsDheServerKeyExchangeDamager()); + + ConnectExpectFail(); + + client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class TlsDheSkeChangeY : public TlsHandshakeFilter { + public: + enum ChangeYTo { + kYZero, + kYOne, + kYPMinusOne, + kYGreaterThanP, + kYTooLarge, + kYZeroPad + }; + + TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {} + + protected: + void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset, + const DataBuffer& prime) { + static const uint8_t kExtraZero = 0; + static const uint8_t kTooLargeExtra = 1; + + uint32_t dh_Ys_len; + EXPECT_TRUE(input.Read(offset, 2, &dh_Ys_len)); + EXPECT_LT(offset + dh_Ys_len, input.len()); + offset += 2; + + // This isn't generally true, but our code pads. + EXPECT_EQ(prime.len(), dh_Ys_len) + << "Length of dh_Ys must equal length of dh_p"; + + *output = input; + switch (change_Y_) { + case kYZero: + memset(output->data() + offset, 0, prime.len()); + break; + + case kYOne: + memset(output->data() + offset, 0, prime.len() - 1); + output->Write(offset + prime.len() - 1, 1U, 1); + break; + + case kYPMinusOne: + output->Write(offset, prime); + EXPECT_TRUE(output->data()[offset + prime.len() - 1] & 0x01) + << "P must at least be odd"; + --output->data()[offset + prime.len() - 1]; + break; + + case kYGreaterThanP: + // Set the first 32 octets of Y to 0xff, except the first which we set + // to p[0]. This will make Y > p. That is, unless p is Mersenne, or + // improbably large (but still the same bit length). We currently only + // use a fixed prime that isn't a problem for this code. + EXPECT_LT(0, prime.data()[0]) << "dh_p should not be zero-padded"; + offset = output->Write(offset, prime.data()[0], 1); + memset(output->data() + offset, 0xff, 31); + break; + + case kYTooLarge: + // Increase the dh_Ys length. + output->Write(offset - 2, prime.len() + sizeof(kTooLargeExtra), 2); + // Then insert the octet. + output->Splice(&kTooLargeExtra, sizeof(kTooLargeExtra), offset); + break; + + case kYZeroPad: + output->Write(offset - 2, prime.len() + sizeof(kExtraZero), 2); + output->Splice(&kExtraZero, sizeof(kExtraZero), offset); + break; + } + } + + private: + ChangeYTo change_Y_; +}; + +class TlsDheSkeChangeYServer : public TlsDheSkeChangeY { + public: + TlsDheSkeChangeYServer(ChangeYTo change, bool modify) + : TlsDheSkeChangeY(change), modify_(modify), p_() {} + + const DataBuffer& prime() const { return p_; } + + protected: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + size_t offset = 2; + // Read dh_p + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + EXPECT_GT(input.len(), offset + dh_len); + p_.Assign(input.data() + offset, dh_len); + offset += dh_len; + + // Skip dh_g to find dh_Ys + EXPECT_TRUE(input.Read(offset, 2, &dh_len)); + offset += 2 + dh_len; + + if (modify_) { + ChangeY(input, output, offset, p_); + return CHANGE; + } + return KEEP; + } + + private: + bool modify_; + DataBuffer p_; +}; + +class TlsDheSkeChangeYClient : public TlsDheSkeChangeY { + public: + TlsDheSkeChangeYClient(ChangeYTo change, + const TlsDheSkeChangeYServer* server_filter) + : TlsDheSkeChangeY(change), server_filter_(server_filter) {} + + protected: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { + return KEEP; + } + + ChangeY(input, output, 0, server_filter_->prime()); + return CHANGE; + } + + private: + const TlsDheSkeChangeYServer* server_filter_; +}; + +/* This matrix includes: mode (stream/datagram), TLS version, what change to + * make to dh_Ys, whether the client will be configured to require DH named + * groups. Test all combinations. */ +typedef std::tuple<std::string, uint16_t, TlsDheSkeChangeY::ChangeYTo, bool> + DamageDHYProfile; +class TlsDamageDHYTest + : public TlsConnectTestBase, + public ::testing::WithParamInterface<DamageDHYProfile> { + public: + TlsDamageDHYTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} +}; + +TEST_P(TlsDamageDHYTest, DamageServerY) { + EnableOnlyDheCiphers(); + if (std::get<3>(GetParam())) { + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + } + TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); + server_->SetPacketFilter(new TlsDheSkeChangeYServer(change, true)); + + ConnectExpectFail(); + if (change == TlsDheSkeChangeY::kYZeroPad) { + // Zero padding Y only manifests in a signature failure. + // In TLS 1.0 and 1.1, the client reports a device error. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR); + } else { + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + } + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + } else { + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } +} + +TEST_P(TlsDamageDHYTest, DamageClientY) { + EnableOnlyDheCiphers(); + if (std::get<3>(GetParam())) { + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + } + // The filter on the server is required to capture the prime. + TlsDheSkeChangeYServer* server_filter = + new TlsDheSkeChangeYServer(TlsDheSkeChangeY::kYZero, false); + server_->SetPacketFilter(server_filter); + + // The client filter does the damage. + TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam()); + client_->SetPacketFilter(new TlsDheSkeChangeYClient(change, server_filter)); + + ConnectExpectFail(); + if (change == TlsDheSkeChangeY::kYZeroPad) { + // Zero padding Y only manifests in a finished error. + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + } else { + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + } +} + +static const TlsDheSkeChangeY::ChangeYTo kAllYArr[] = { + TlsDheSkeChangeY::kYZero, TlsDheSkeChangeY::kYOne, + TlsDheSkeChangeY::kYPMinusOne, TlsDheSkeChangeY::kYGreaterThanP, + TlsDheSkeChangeY::kYTooLarge, TlsDheSkeChangeY::kYZeroPad}; +static ::testing::internal::ParamGenerator<TlsDheSkeChangeY::ChangeYTo> kAllY = + ::testing::ValuesIn(kAllYArr); +static const bool kTrueFalseArr[] = {true, false}; +static ::testing::internal::ParamGenerator<bool> kTrueFalse = + ::testing::ValuesIn(kTrueFalseArr); + +INSTANTIATE_TEST_CASE_P(DamageYStream, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10ToV12, + kAllY, kTrueFalse)); +INSTANTIATE_TEST_CASE_P( + DamageYDatagram, TlsDamageDHYTest, + ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse)); + +class TlsDheSkeMakePEven : public TlsHandshakeFilter { + public: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + // Find the end of dh_p + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + EXPECT_GT(input.len(), 2 + dh_len) << "enough space for dh_p"; + size_t offset = 2 + dh_len - 1; + EXPECT_TRUE((input.data()[offset] & 0x01) == 0x01) << "p should be odd"; + + *output = input; + output->data()[offset] &= 0xfe; + + return CHANGE; + } +}; + +// Even without requiring named groups, an even value for p is bad news. +TEST_P(TlsConnectGenericPre13, MakeDhePEven) { + EnableOnlyDheCiphers(); + server_->SetPacketFilter(new TlsDheSkeMakePEven()); + + ConnectExpectFail(); + + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + +class TlsDheSkeZeroPadP : public TlsHandshakeFilter { + public: + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + *output = input; + uint32_t dh_len = 0; + EXPECT_TRUE(input.Read(0, 2, &dh_len)); + static const uint8_t kZeroPad = 0; + output->Write(0, dh_len + sizeof(kZeroPad), 2); // increment the length + output->Splice(&kZeroPad, sizeof(kZeroPad), 2); // insert a zero + + return CHANGE; + } +}; + +// Zero padding only causes signature failure. +TEST_P(TlsConnectGenericPre13, PadDheP) { + EnableOnlyDheCiphers(); + server_->SetPacketFilter(new TlsDheSkeZeroPadP()); + + ConnectExpectFail(); + + // In TLS 1.0 and 1.1, the client reports a device error. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR); + } else { + client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE); + } + server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); +} + +// The server should not pick the weak DH group if the client includes FFDHE +// named groups in the supported_groups extension. The server then picks a +// commonly-supported named DH group and this connects. +// +// Note: This test case can take ages to generate the weak DH key. +TEST_P(TlsConnectGenericPre13, WeakDHGroup) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + EXPECT_EQ(SECSuccess, + SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE)); + + Connect(); +} + +TEST_P(TlsConnectGeneric, Ffdhe3072) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ffdhe_3072}; + client_->ConfigNamedGroups(groups); + + Connect(); +} + +// Even though the client doesn't have DHE groups enabled the server assumes it +// does. Because the client doesn't require named groups it accepts FF3072 as +// custom group. +TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +// Same test but for TLS 1.3. This has to fail. +TEST_P(TlsConnectTls13, NamedGroupMismatch13) { + EnableOnlyDheCiphers(); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// Even though the client doesn't have DHE groups enabled the server assumes it +// does. The client requires named groups and thus does not accept FF3072 as +// custom group in contrast to the previous test. +TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072}; + static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ffdhe_2048}; + server_->ConfigNamedGroups(server_groups); + client_->ConfigNamedGroups(client_groups); + + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectGenericPre13, PreferredFfdhe) { + EnableOnlyDheCiphers(); + static const SSLDHEGroupType groups[] = {ssl_ff_dhe_3072_group, + ssl_ff_dhe_2048_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), groups, + PR_ARRAY_SIZE(groups))); + + Connect(); + client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072); + client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); + server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsConnectGenericPre13, MismatchDHE) { + EnableOnlyDheCiphers(); + EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(), + SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE)); + static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups, + PR_ARRAY_SIZE(serverGroups))); + static const SSLDHEGroupType clientGroups[] = {ssl_ff_dhe_2048_group}; + EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups, + PR_ARRAY_SIZE(clientGroups))); + + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectTls13, ResumeFfdhe) { + EnableOnlyDheCiphers(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableOnlyDheCiphers(); + TlsExtensionCapture* clientCapture = + new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + client_->SetPacketFilter(clientCapture); + TlsExtensionCapture* serverCapture = + new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + server_->SetPacketFilter(serverCapture); + ExpectResumption(RESUME_TICKET); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); + ASSERT_LT(0UL, clientCapture->extension().len()); + ASSERT_LT(0UL, serverCapture->extension().len()); +} + +class TlsDheSkeChangeSignature : public TlsHandshakeFilter { + public: + TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len) + : version_(version), data_(data), len_(len) {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + TlsParser parser(input); + EXPECT_TRUE(parser.SkipVariable(2)); // dh_p + EXPECT_TRUE(parser.SkipVariable(2)); // dh_g + EXPECT_TRUE(parser.SkipVariable(2)); // dh_Ys + + // Copy DH params to output. + size_t offset = output->Write(0, input.data(), parser.consumed()); + + if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) { + // Write signature algorithm. + offset = output->Write(offset, ssl_sig_dsa_sha256, 2); + } + + // Write new signature. + offset = output->Write(offset, len_, 2); + offset = output->Write(offset, data_, len_); + + return CHANGE; + } + + private: + uint16_t version_; + const uint8_t* data_; + size_t len_; +}; + +TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) { + const uint8_t kBogusDheSignature[] = { + 0x30, 0x69, 0x3c, 0x02, 0x1c, 0x7d, 0x0b, 0x2f, 0x64, 0x00, 0x27, + 0xae, 0xcf, 0x1e, 0x28, 0x08, 0x6a, 0x7f, 0xb1, 0xbd, 0x78, 0xb5, + 0x3b, 0x8c, 0x8f, 0x59, 0xed, 0x8f, 0xee, 0x78, 0xeb, 0x2c, 0xe9, + 0x02, 0x1c, 0x6d, 0x7f, 0x3c, 0x0f, 0xf4, 0x44, 0x35, 0x0b, 0xb2, + 0x6d, 0xdc, 0xb8, 0x21, 0x87, 0xdd, 0x0d, 0xb9, 0x46, 0x09, 0x3e, + 0xef, 0x81, 0x5b, 0x37, 0x09, 0x39, 0xeb}; + + Reset(TlsAgent::kServerDsa); + + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(client_groups); + + server_->SetPacketFilter(new TlsDheSkeChangeSignature( + version_, kBogusDheSignature, sizeof(kBogusDheSignature))); + + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc new file mode 100644 index 000000000..89ca28e97 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_drop_unittest.cc @@ -0,0 +1,133 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) { + client_->SetPacketFilter(new SelectiveDropFilter(0x1)); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) { + server_->SetPacketFilter(new SelectiveDropFilter(0x1)); + Connect(); + SendReceive(); +} + +// This drops the first transmission from both the client and server of all +// flights that they send. Note: In DTLS 1.3, the shorter handshake means that +// this will also drop some application data, so we can't call SendReceive(). +TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) { + client_->SetPacketFilter(new SelectiveDropFilter(0x15)); + server_->SetPacketFilter(new SelectiveDropFilter(0x5)); + Connect(); +} + +// This drops the server's first flight three times. +TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) { + server_->SetPacketFilter(new SelectiveDropFilter(0x7)); + Connect(); +} + +// This drops the client's second flight once +TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) { + client_->SetPacketFilter(new SelectiveDropFilter(0x2)); + Connect(); +} + +// This drops the client's second flight three times. +TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) { + client_->SetPacketFilter(new SelectiveDropFilter(0xe)); + Connect(); +} + +// This drops the server's second flight three times. +TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) { + server_->SetPacketFilter(new SelectiveDropFilter(0xe)); + Connect(); +} + +static void GetCipherAndLimit(uint16_t version, uint16_t* cipher, + uint64_t* limit = nullptr) { + uint64_t l; + if (!limit) limit = &l; + + if (version < SSL_LIBRARY_VERSION_TLS_1_2) { + *cipher = TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA; + *limit = 0x5aULL << 28; + } else if (version == SSL_LIBRARY_VERSION_TLS_1_2) { + *cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256; + *limit = (1ULL << 48) - 1; + } else { + *cipher = TLS_CHACHA20_POLY1305_SHA256; + *limit = (1ULL << 48) - 1; + } +} + +// This simulates a huge number of drops on one side. +TEST_P(TlsConnectDatagram, MissLotsOfPackets) { + uint16_t cipher; + uint64_t limit; + + GetCipherAndLimit(version_, &cipher, &limit); + + EnsureTlsSetup(); + server_->EnableSingleCipher(cipher); + Connect(); + + // Note that the limit for ChaCha is 2^48-1. + EXPECT_EQ(SECSuccess, + SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), limit - 10)); + SendReceive(); +} + +class TlsConnectDatagram12Plus : public TlsConnectDatagram { + public: + TlsConnectDatagram12Plus() : TlsConnectDatagram() {} +}; + +// This simulates missing a window's worth of packets. +TEST_P(TlsConnectDatagram12Plus, MissAWindow) { + EnsureTlsSetup(); + uint16_t cipher; + GetCipherAndLimit(version_, &cipher); + server_->EnableSingleCipher(cipher); + Connect(); + + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0)); + SendReceive(); +} + +TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) { + EnsureTlsSetup(); + uint16_t cipher; + GetCipherAndLimit(version_, &cipher); + server_->EnableSingleCipher(cipher); + Connect(); + + EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 1)); + SendReceive(); +} + +INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus, + TlsConnectTestBase::kTlsV12Plus); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc new file mode 100644 index 000000000..79b7736e4 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc @@ -0,0 +1,584 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGenericPre13, ConnectEcdh) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdhEcdsa); + DisableAllCiphers(); + EnableSomeEcdhCiphers(); + + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa, + ssl_sig_none); +} + +TEST_P(TlsConnectGenericPre13, ConnectEcdhWithoutDisablingSuites) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdhEcdsa); + EnableSomeEcdhCiphers(); + + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa, + ssl_sig_none); +} + +TEST_P(TlsConnectGeneric, ConnectEcdhe) { + Connect(); + CheckKeys(); +} + +// If we pick a 256-bit cipher suite and use a P-384 certificate, the server +// should choose P-384 for key exchange too. Only valid for TLS == 1.2 because +// we don't have 256-bit ciphers before then and 1.3 doesn't try to couple +// DHE size to symmetric size. +TEST_P(TlsConnectTls12, ConnectEcdheP384) { + Reset(TlsAgent::kServerEcdsa384); + ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa, + ssl_sig_ecdsa_secp256r1_sha256); +} + +TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) { + EnsureTlsSetup(); + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care. +TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) { + EnsureTlsSetup(); + auto hrr_capture = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + server_->SetPacketFilter(hrr_capture); + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1}; + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3, + hrr_capture->buffer().len() != 0); +} + +// This enables only P-256 on the client and disables it on the server. +// This test will fail when we add other groups that identify as ECDHE. +TEST_P(TlsConnectGeneric, ConnectEcdheGroupMismatch) { + EnsureTlsSetup(); + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ffdhe_2048}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_2048}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign); +} + +TEST_P(TlsKeyExchangeTest, P384Priority) { + // P256, P384 and P521 are enabled. Both prefer P384. + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + CheckKEXDetails(groups, shares); +} + +TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) { + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp256r1, ssl_grp_ec_secp256r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp256r1}; + CheckKEXDetails(expectedGroups, shares); +} + +TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) { + // P256, P384, P521, and FFDHE2048 are enabled. Both prefer P384. + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1, + ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1}; + CheckKEXDetails(groups, shares); + } else { + std::vector<SSLNamedGroup> oldtlsgroups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + CheckKEXDetails(oldtlsgroups, std::vector<SSLNamedGroup>()); + } +} + +TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The server prefers P384. It has to win. + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) { + EnsureModelSockets(); + + /* Both prefer P384, set on the model socket. */ + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1, + ssl_grp_ffdhe_2048}; + client_model_->ConfigNamedGroups(groups); + server_model_->ConfigNamedGroups(groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter { + public: + TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {} + + SSLNamedGroup group() const { return group_; } + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + uint32_t value = 0; + EXPECT_TRUE(input.Read(0, 1, &value)); + EXPECT_EQ(3U, value) << "curve type has to be 3"; + + EXPECT_TRUE(input.Read(1, 2, &value)); + group_ = static_cast<SSLNamedGroup>(value); + + return KEEP; + } + + private: + SSLNamedGroup group_; +}; + +// If we strip the client's supported groups extension, the server should assume +// P-256 is supported by the client (<= 1.2 only). +TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) { + EnsureTlsSetup(); + client_->SetPacketFilter(new TlsExtensionDropper(ssl_supported_groups_xtn)); + auto group_capture = new TlsKeyExchangeGroupCapture(); + server_->SetPacketFilter(group_capture); + + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + + EXPECT_EQ(ssl_grp_ec_secp256r1, group_capture->group()); +} + +// Supported groups is mandatory in TLS 1.3. +TEST_P(TlsConnectTls13, DropSupportedGroupExtension) { + EnsureTlsSetup(); + client_->SetPacketFilter(new TlsExtensionDropper(ssl_supported_groups_xtn)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); +} + +// If we only have a lame group, we fall back to static RSA. +TEST_P(TlsConnectGenericPre13, UseLameGroup) { + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp192r1}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +// In TLS 1.3, we can't generate the ClientHello. +TEST_P(TlsConnectTls13, UseLameGroup) { + const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_sect283k1}; + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); + client_->StartConnect(); + client_->Handshake(); + client_->CheckErrorCode(SSL_ERROR_NO_CIPHERS_SUPPORTED); +} + +TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + CheckConnected(); + + // The renegotiation has to use the same preferences as the original session. + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +TEST_P(TlsKeyExchangeTest, Curve25519) { + Reset(TlsAgent::kServerEcdsa256); + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + EnsureKeyShareSetup(); + ConfigNamedGroups(groups); + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_ecdsa, + ssl_sig_ecdsa_secp256r1_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(groups, shares); +} + +TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The client prefers P256 while the server prefers 25519. + // The server's preference has to win. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); +} + +#ifndef NSS_DISABLE_TLS_1_3 +TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) { + EnsureKeyShareSetup(); + + // The client sends a P256 key share while the server prefers 25519. + // We have to accept P256 without retry. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1}; + CheckKEXDetails(client_groups, shares); +} + +TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // We have to accept 25519 without retry. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares); +} + +TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // The server prefers P-384 over x25519, so it must not consider P-256 and + // x25519 to be equivalent. It will therefore request a P-256 share + // with a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // The server prefers ffdhe_2048 over x25519, so it must not consider the + // P-256 and x25519 to be equivalent. It will therefore request a P-256 share + // with a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, + NotEqualPriorityWithUnsupportedFFIntermediateGroup13) { + EnsureKeyShareSetup(); + + // As in the previous test, the server prefers ffdhe_2048. Thus, even though + // the client doesn't support this group, the server must not regard x25519 as + // equivalent to P-256. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, + NotEqualPriorityWithUnsupportedECIntermediateGroup13) { + EnsureKeyShareSetup(); + + // As in the previous test, the server prefers P-384. Thus, even though + // the client doesn't support this group, the server must not regard x25519 as + // equivalent to P-256. The server sends a HelloRetryRequest. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1); +} + +TEST_P(TlsKeyExchangeTest13, EqualPriority13) { + EnsureKeyShareSetup(); + + // The client sends a 25519 key share while the server prefers P256. + // We have to accept 25519 without retry because it's considered equivalent to + // P256 by the server. + const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_curve25519, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + Connect(); + + CheckKeys(); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519}; + CheckKEXDetails(client_groups, shares); +} +#endif + +TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_ecdh); + + // The client sends a P256 key share while the server prefers 25519. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519}; + + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsKeyExchangeTest13, MultipleClientShares) { + EnsureKeyShareSetup(); + + // The client sends 25519 and P256 key shares. The server prefers P256, + // which must be chosen here. + const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1, + ssl_grp_ec_curve25519}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + + // Generate a key share on the client for both curves. + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + Connect(); + + // The server would accept 25519 but its preferred group (P256) has to win. + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign, + ssl_sig_rsa_pss_sha256); + const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519, + ssl_grp_ec_secp256r1}; + CheckKEXDetails(client_groups, shares); +} + +// Replace the point in the client key exchange message with an empty one +class ECCClientKEXFilter : public TlsHandshakeFilter { + public: + ECCClientKEXFilter() {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + if (header.handshake_type() != kTlsHandshakeClientKeyExchange) { + return KEEP; + } + + // Replace the client key exchange message with an empty point + output->Allocate(1); + output->Write(0, 0U, 1); // set point length 0 + return CHANGE; + } +}; + +// Replace the point in the server key exchange message with an empty one +class ECCServerKEXFilter : public TlsHandshakeFilter { + public: + ECCServerKEXFilter() {} + + protected: + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header, + const DataBuffer &input, + DataBuffer *output) { + if (header.handshake_type() != kTlsHandshakeServerKeyExchange) { + return KEEP; + } + + // Replace the server key exchange message with an empty point + output->Allocate(4); + output->Write(0, 3U, 1); // named curve + uint32_t curve; + EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id + output->Write(1, curve, 2); // write curve id + output->Write(3, 0U, 1); // point length 0 + return CHANGE; + } +}; + +TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) { + // add packet filter + server_->SetPacketFilter(new ECCServerKEXFilter()); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH); +} + +TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) { + // add packet filter + client_->SetPacketFilter(new ECCClientKEXFilter()); + ConnectExpectFail(); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH); +} + +INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11Plus)); + +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest13, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc new file mode 100644 index 000000000..b9c725b36 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_ems_unittest.cc @@ -0,0 +1,100 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecret) { + EnableExtendedMasterSecret(); + Connect(); + Reset(); + ExpectResumption(RESUME_SESSIONID); + EnableExtendedMasterSecret(); + Connect(); +} + +TEST_P(TlsConnectTls12, ConnectExtendedMasterSecretSha384) { + EnableExtendedMasterSecret(); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384); + ConnectWithCipherSuite(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretStaticRSA) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretECDHE) { + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_SESSIONID); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretTicket) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_TICKET); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretClientOnly) { + client_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(false); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretServerOnly) { + server_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(false); + Connect(); +} + +TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) { + EnableExtendedMasterSecret(); + Connect(); + + Reset(); + server_->EnableExtendedMasterSecret(); + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertHandshakeFailure, alert_recorder->description()); +} + +TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ExpectExtendedMasterSecret(false); + Connect(); + + Reset(); + EnableExtendedMasterSecret(); + ExpectResumption(RESUME_NONE); + Connect(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc new file mode 100644 index 000000000..0a0d9f25f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc @@ -0,0 +1,122 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" + +#include "gtest_utils.h" +#include "tls_connect.h" + +namespace nss_test { + +static const char* kExporterLabel = "EXPORTER-duck"; +static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56}; + +static void ExportAndCompare(TlsAgent* client, TlsAgent* server, bool context) { + static const size_t exporter_len = 10; + uint8_t client_value[exporter_len] = {0}; + EXPECT_EQ(SECSuccess, + SSL_ExportKeyingMaterial( + client->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + context ? PR_TRUE : PR_FALSE, kExporterContext, + sizeof(kExporterContext), client_value, sizeof(client_value))); + uint8_t server_value[exporter_len] = {0xff}; + EXPECT_EQ(SECSuccess, + SSL_ExportKeyingMaterial( + server->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + context ? PR_TRUE : PR_FALSE, kExporterContext, + sizeof(kExporterContext), server_value, sizeof(server_value))); + EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value))); +} + +TEST_P(TlsConnectGeneric, ExporterBasic) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +TEST_P(TlsConnectGeneric, ExporterContext) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, true); +} + +// Bug 1312976 - SHA-384 doesn't work in 1.2 right now. +TEST_P(TlsConnectTls13, ExporterSha384) { + EnsureTlsSetup(); + client_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + } else { + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + } + Connect(); + CheckKeys(); + ExportAndCompare(client_, server_, false); +} + +// This has a weird signature so that it can be passed to the SNI callback. +int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr, + PRUint32 srvNameArrSize) { + uint8_t val[10]; + EXPECT_EQ(SECFailure, SSL_ExportKeyingMaterial( + agent->ssl_fd(), kExporterLabel, + strlen(kExporterLabel), PR_TRUE, kExporterContext, + sizeof(kExporterContext), val, sizeof(val))) + << "regular exporter should fail"; + return 0; +} + +TEST_P(TlsConnectTls13, EarlyExporter) { + SetupForZeroRtt(); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + client_->Handshake(); // Send ClientHello. + uint8_t client_value[10] = {0}; + RegularExporterShouldFail(client_, nullptr, 0); + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), client_value, + sizeof(client_value))); + + server_->SetSniCallback(RegularExporterShouldFail); + server_->Handshake(); // Handle ClientHello. + uint8_t server_value[10] = {0}; + EXPECT_EQ(SECSuccess, + SSL_ExportEarlyKeyingMaterial( + server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel), + kExporterContext, sizeof(kExporterContext), server_value, + sizeof(server_value))); + EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value))); + + Handshake(); + ExpectEarlyDataAccepted(true); + CheckConnected(); + SendReceive(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc new file mode 100644 index 000000000..9200e724b --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -0,0 +1,985 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +#include <memory> + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class TlsExtensionTruncator : public TlsExtensionFilter { + public: + TlsExtensionTruncator(uint16_t extension, size_t length) + : extension_(extension), length_(length) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + if (input.len() <= length_) { + return KEEP; + } + + output->Assign(input.data(), length_); + return CHANGE; + } + + private: + uint16_t extension_; + size_t length_; +}; + +class TlsExtensionDamager : public TlsExtensionFilter { + public: + TlsExtensionDamager(uint16_t extension, size_t index) + : extension_(extension), index_(index) {} + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = input; + output->data()[index_] += 73; // Increment selected for maximum damage + return CHANGE; + } + + private: + uint16_t extension_; + size_t index_; +}; + +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(uint16_t ext, DataBuffer& data) + : extension_(ext), data_(data) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + size_t offset; + if (header.handshake_type() == kTlsHandshakeClientHello) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) { + return KEEP; + } + offset = parser.consumed(); + } else if (header.handshake_type() == kTlsHandshakeServerHello) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) { + return KEEP; + } + offset = parser.consumed(); + } else { + return KEEP; + } + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; + } + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionAppender : public TlsHandshakeFilter { + public: + TlsExtensionAppender(uint16_t ext, DataBuffer& data) + : extension_(ext), data_(data) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + size_t offset; + TlsParser parser(input); + if (header.handshake_type() == kTlsHandshakeClientHello) { + if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) { + return KEEP; + } + } else if (header.handshake_type() == kTlsHandshakeServerHello) { + if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) { + return KEEP; + } + } else { + return KEEP; + } + offset = parser.consumed(); + *output = input; + + uint32_t ext_len; + if (!parser.Read(&ext_len, 2)) { + ADD_FAILURE(); + return KEEP; + } + + ext_len += 4 + data_.len(); + output->Write(offset, ext_len, 2); + + offset = output->len(); + offset = output->Write(offset, extension_, 2); + WriteVariable(output, offset, data_, 2); + + return CHANGE; + } + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionTestBase : public TlsConnectTestBase { + protected: + TlsExtensionTestBase(Mode mode, uint16_t version) + : TlsConnectTestBase(mode, version) {} + TlsExtensionTestBase(const std::string& mode, uint16_t version) + : TlsConnectTestBase(mode, version) {} + + void ClientHelloErrorTest(PacketFilter* filter, + uint8_t alert = kTlsAlertDecodeError) { + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + if (filter) { + client_->SetPacketFilter(filter); + } + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(alert, alert_recorder->description()); + } + + void ServerHelloErrorTest(PacketFilter* filter, + uint8_t alert = kTlsAlertDecodeError) { + auto alert_recorder = new TlsAlertRecorder(); + client_->SetPacketFilter(alert_recorder); + if (filter) { + server_->SetPacketFilter(filter); + } + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(alert, alert_recorder->description()); + } + + static void InitSimpleSni(DataBuffer* extension) { + const char* name = "host.name"; + const size_t namelen = PL_strlen(name); + extension->Allocate(namelen + 5); + extension->Write(0, namelen + 3, 2); + extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname + extension->Write(3, namelen, 2); + extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen); + } + + void HrrThenRemoveExtensionsTest(SSLExtensionType type, PRInt32 client_error, + PRInt32 server_error) { + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + EnsureTlsSetup(); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); // Send ClientHello + server_->Handshake(); // Send HRR. + client_->SetPacketFilter(new TlsExtensionDropper(type)); + Handshake(); + client_->CheckErrorCode(client_error); + server_->CheckErrorCode(server_error); + } +}; + +class TlsExtensionTestDtls : public TlsExtensionTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {} +}; + +class TlsExtensionTest12Plus + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsExtensionTest12Plus() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTest12 + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsExtensionTest12() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTest13 : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsExtensionTest13() + : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + + void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) { + DataBuffer versions_buf(buf, len); + client_->SetPacketFilter(new TlsExtensionReplacer( + ssl_tls13_supported_versions_xtn, versions_buf)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + } + + void ConnectWithReplacementVersionList(uint16_t version) { + DataBuffer versions_buf; + + size_t index = versions_buf.Write(0, 2, 1); + versions_buf.Write(index, version, 2); + client_->SetPacketFilter(new TlsExtensionReplacer( + ssl_tls13_supported_versions_xtn, versions_buf)); + ConnectExpectFail(); + } +}; + +class TlsExtensionTest13Stream : public TlsExtensionTestBase { + public: + TlsExtensionTest13Stream() + : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsExtensionTestGeneric + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsExtensionTestGeneric() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +class TlsExtensionTestPre13 + : public TlsExtensionTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsExtensionTestPre13() + : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) { + } +}; + +TEST_P(TlsExtensionTestGeneric, DamageSniLength) { + ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { + ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4)); +} + +TEST_P(TlsExtensionTestGeneric, TruncateSni) { + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7)); +} + +// A valid extension that appears twice will be reported as unsupported. +TEST_P(TlsExtensionTestGeneric, RepeatSni) { + DataBuffer extension; + InitSimpleSni(&extension); + ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An SNI entry with zero length is considered invalid (strangely, not if it is +// the last entry, which is probably a bug). +TEST_P(TlsExtensionTestGeneric, BadSni) { + DataBuffer simple; + InitSimpleSni(&simple); + DataBuffer extension; + extension.Allocate(simple.len() + 3); + extension.Write(0, static_cast<uint32_t>(0), 3); + extension.Write(3, simple); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_server_name_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, EmptySni) { + DataBuffer extension; + extension.Allocate(2); + extension.Write(0, static_cast<uint32_t>(0), 2); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_server_name_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { + EnableAlpn(); + DataBuffer extension; + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), + kTlsAlertIllegalParameter); +} + +// An empty ALPN isn't considered bad, though it does lead to there being no +// protocol for the server to select. +TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), + kTlsAlertNoApplicationProtocol); +} + +TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { + EnableAlpn(); + ClientHelloErrorTest( + new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { + EnableAlpn(); + // This will leave the length of the second entry, but no value. + ClientHelloErrorTest( + new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { + EnableAlpn(); + const uint8_t val[] = {0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { + const uint8_t client_alpn[] = {0x01, 0x61}; + client_->EnableAlpn(client_alpn, sizeof(client_alpn)); + const uint8_t server_alpn[] = {0x02, 0x61, 0x62}; + server_->EnableAlpn(server_alpn, sizeof(server_alpn)); + + ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol); +} + +// Many of these tests fail in TLS 1.3 because the extension is encrypted, which +// prevents modification of the value from the ServerHello. +TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) { + EnableAlpn(); + const uint8_t val[] = {0x00, 0x02, 0x99, 0x61}; + DataBuffer extension(val, sizeof(val)); + ServerHelloErrorTest( + new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); +} + +TEST_P(TlsExtensionTestDtls, SrtpShort) { + EnableSrtp(); + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3)); +} + +TEST_P(TlsExtensionTestDtls, SrtpOdd) { + EnableSrtp(); + const uint8_t val[] = {0x00, 0x01, 0xff, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { + const uint8_t val[] = {0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { + const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { + const uint8_t val[] = {0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { + const uint8_t val[] = {0x00, 0x01, 0x04}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) { + ClientHelloErrorTest(new TlsExtensionDropper(ssl_supported_groups_xtn), + version_ < SSL_LIBRARY_VERSION_TLS_1_3 + ? kTlsAlertDecryptError + : kTlsAlertMissingExtension); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { + const uint8_t val[] = {0x00, 0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { + const uint8_t val[] = {0x09, 0x99, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { + const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) { + const uint8_t val[] = {0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) { + const uint8_t val[] = {0x99, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) { + const uint8_t val[] = {0x01, 0x00, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) { + const uint8_t val[] = {0x99}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); +} + +TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) { + const uint8_t val[] = {0x01, 0x00}; + DataBuffer extension(val, sizeof(val)); + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); +} + +// The extension has to contain a length. +TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) { + DataBuffer extension; + ClientHelloErrorTest( + new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension)); +} + +// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl +// picks the wrong cipher suite. +TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) { + const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512, + ssl_sig_rsa_pss_sha384}; + + TlsExtensionCapture* capture = + new TlsExtensionCapture(ssl_signature_algorithms_xtn); + client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes)); + client_->SetPacketFilter(capture); + EnableOnlyStaticRsaCiphers(); + Connect(); + + const DataBuffer& ext = capture->extension(); + EXPECT_EQ(2 + PR_ARRAY_SIZE(schemes) * 2, ext.len()); + for (size_t i = 0, cursor = 2; + i < PR_ARRAY_SIZE(schemes) && cursor < ext.len(); ++i) { + uint32_t v = 0; + EXPECT_TRUE(ext.Read(cursor, 2, &v)); + cursor += 2; + EXPECT_EQ(schemes[i], static_cast<SSLSignatureScheme>(v)); + } +} + +// Temporary test to verify that we choke on an empty ClientKeyShare. +// This test will fail when we implement HelloRetryRequest. +TEST_P(TlsExtensionTest13, EmptyClientKeyShare) { + ClientHelloErrorTest(new TlsExtensionTruncator(ssl_tls13_key_share_xtn, 2), + kTlsAlertHandshakeFailure); +} + +// These tests only work in stream mode because the client sends a +// cleartext alert which causes a MAC error on the server. With +// stream this causes handshake failure but with datagram, the +// packet gets dropped. +TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) { + EnsureTlsSetup(); + server_->SetPacketFilter(new TlsExtensionDropper(ssl_tls13_key_share_xtn)); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + +TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) { + const uint16_t wrong_group = ssl_grp_ec_secp384r1; + + static const uint8_t key_share[] = { + wrong_group >> 8, + wrong_group & 0xff, // Group we didn't offer. + 0x00, + 0x02, // length = 2 + 0x01, + 0x02}; + DataBuffer buf(key_share, sizeof(key_share)); + EnsureTlsSetup(); + server_->SetPacketFilter( + new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf)); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + +// TODO(ekr@rtfm.com): This is the wrong error code. See bug 1307269. +TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) { + const uint16_t wrong_group = 0xffff; + + static const uint8_t key_share[] = { + wrong_group >> 8, + wrong_group & 0xff, // Group we didn't offer. + 0x00, + 0x02, // length = 2 + 0x01, + 0x02}; + DataBuffer buf(key_share, sizeof(key_share)); + EnsureTlsSetup(); + server_->SetPacketFilter( + new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf)); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + +TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) { + SetupForResume(); + DataBuffer empty; + server_->SetPacketFilter( + new TlsExtensionInjector(ssl_signature_algorithms_xtn, empty)); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code()); + EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code()); +} + +struct PskIdentity { + DataBuffer identity; + uint32_t obfuscated_ticket_age; +}; + +class TlsPreSharedKeyReplacer; + +typedef std::function<void(TlsPreSharedKeyReplacer*)> + TlsPreSharedKeyReplacerFunc; + +class TlsPreSharedKeyReplacer : public TlsExtensionFilter { + public: + TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function) + : identities_(), binders_(), function_(function) {} + + static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size, + const std::unique_ptr<DataBuffer>& replace, + size_t index, DataBuffer* output) { + DataBuffer tmp; + bool ret = parser->ReadVariable(&tmp, size); + EXPECT_EQ(true, ret); + if (!ret) return 0; + if (replace) { + tmp = *replace; + } + + return WriteVariable(output, index, tmp, size); + } + + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_pre_shared_key_xtn) { + return KEEP; + } + + if (!Decode(input)) { + return KEEP; + } + + // Call the function. + function_(this); + + Encode(output); + + return CHANGE; + } + + std::vector<PskIdentity> identities_; + std::vector<DataBuffer> binders_; + + private: + bool Decode(const DataBuffer& input) { + std::unique_ptr<TlsParser> parser(new TlsParser(input)); + DataBuffer identities; + + if (!parser->ReadVariable(&identities, 2)) { + ADD_FAILURE(); + return false; + } + + DataBuffer binders; + if (!parser->ReadVariable(&binders, 2)) { + ADD_FAILURE(); + return false; + } + EXPECT_EQ(0UL, parser->remaining()); + + // Now parse the inner sections. + parser.reset(new TlsParser(identities)); + while (parser->remaining()) { + PskIdentity identity; + + if (!parser->ReadVariable(&identity.identity, 2)) { + ADD_FAILURE(); + return false; + } + + if (!parser->Read(&identity.obfuscated_ticket_age, 4)) { + ADD_FAILURE(); + return false; + } + + identities_.push_back(identity); + } + + parser.reset(new TlsParser(binders)); + while (parser->remaining()) { + DataBuffer binder; + + if (!parser->ReadVariable(&binder, 1)) { + ADD_FAILURE(); + return false; + } + + binders_.push_back(binder); + } + + return true; + } + + void Encode(DataBuffer* output) { + DataBuffer identities; + size_t index = 0; + for (auto id : identities_) { + index = WriteVariable(&identities, index, id.identity, 2); + index = identities.Write(index, id.obfuscated_ticket_age, 4); + } + + DataBuffer binders; + index = 0; + for (auto binder : binders_) { + index = WriteVariable(&binders, index, binder, 1); + } + + output->Truncate(0); + index = 0; + index = WriteVariable(output, index, identities, 2); + index = WriteVariable(output, index, binders, 2); + } + + TlsPreSharedKeyReplacerFunc function_; +}; + +TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) { + SetupForResume(); + + client_->SetPacketFilter(new TlsPreSharedKeyReplacer([]( + TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Flip the first byte of the binder. +TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { + SetupForResume(); + + client_->SetPacketFilter( + new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); + })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +// Extend the binder by one. +TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { + SetupForResume(); + + client_->SetPacketFilter( + new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(r->binders_[0].len(), 0xff, 1); + })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Binders must be at least 32 bytes. +TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) { + SetupForResume(); + + client_->SetPacketFilter(new TlsPreSharedKeyReplacer( + [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +// Duplicate the identity and binder. This will fail with an error +// processing the binder (because we extended the identity list.) +TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) { + SetupForResume(); + + client_->SetPacketFilter( + new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + r->identities_.push_back(r->identities_[0]); + r->binders_.push_back(r->binders_[0]); + })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + +// The next two tests have mismatches in the number of identities +// and binders. This generates an illegal parameter alert. +TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) { + SetupForResume(); + + client_->SetPacketFilter( + new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) { + r->identities_.push_back(r->identities_[0]); + })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) { + SetupForResume(); + + client_->SetPacketFilter(new TlsPreSharedKeyReplacer([]( + TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); })); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) { + SetupForResume(); + + const uint8_t empty_buf[] = {0}; + DataBuffer empty(empty_buf, 0); + client_->SetPacketFilter( + // Inject an unused extension. + new TlsExtensionAppender(0xffff, empty)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); +} + +TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) { + SetupForResume(); + + DataBuffer empty; + client_->SetPacketFilter( + new TlsExtensionDropper(ssl_tls13_psk_key_exchange_modes_xtn)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES); +} + +// The following test contains valid but unacceptable PreSharedKey +// modes and therefore produces non-resumption followed by MAC +// errors. +TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) { + SetupForResume(); + const static uint8_t ke_modes[] = {1, // Length + kTls13PskKe}; + + DataBuffer modes(ke_modes, sizeof(ke_modes)); + client_->SetPacketFilter( + new TlsExtensionReplacer(ssl_tls13_psk_key_exchange_modes_xtn, modes)); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); +} + +TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + auto capture = new TlsExtensionCapture(ssl_tls13_psk_key_exchange_modes_xtn); + client_->SetPacketFilter(capture); + Connect(); + EXPECT_FALSE(capture->captured()); +} + +// In these tests, we downgrade to TLS 1.2, causing the +// server to negotiate TLS 1.2. +// 1. Both sides only support TLS 1.3, so we get a cipher version +// error. +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) { + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); + client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT); + server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); +} + +// 2. Server supports 1.2 and 1.3, client supports 1.2, so we +// can't negotiate any ciphers. +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) { + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +// 3. Server supports 1.2 and 1.3, client supports 1.2 and 1.3 +// but advertises 1.2 (because we changed things). +TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2); +#ifndef TLS_1_3_DRAFT_VERSION + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +#else + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +#endif +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) { + HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn, + SSL_ERROR_MISSING_EXTENSION_ALERT, + SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION); +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) { + HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn, + SSL_ERROR_ILLEGAL_PARAMETER_ALERT, + SSL_ERROR_BAD_2ND_CLIENT_HELLO); +} + +TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) { + HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn, + SSL_ERROR_MISSING_EXTENSION_ALERT, + SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION); +} + +TEST_P(TlsExtensionTest13, EmptyVersionList) { + static const uint8_t ext[] = {0x00, 0x00}; + ConnectWithBogusVersionList(ext, sizeof(ext)); +} + +TEST_P(TlsExtensionTest13, OddVersionList) { + static const uint8_t ext[] = {0x00, 0x01, 0x00}; + ConnectWithBogusVersionList(ext, sizeof(ext)); +} + +INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11Plus)); +INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls, + TlsConnectTestBase::kTlsV11Plus); + +INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus)); + +INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11V12)); + +INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13, + TlsConnectTestBase::kTlsModesAll); + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc new file mode 100644 index 000000000..d144cd7d9 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc @@ -0,0 +1,223 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "blapi.h" +#include "ssl.h" +#include "sslimpl.h" +#include "tls_connect.h" + +#include "gtest/gtest.h" + +namespace nss_test { + +#ifdef UNSAFE_FUZZER_MODE + +const uint8_t kShortEmptyFinished[8] = {0}; +const uint8_t kLongEmptyFinished[128] = {0}; + +class TlsFuzzTest : public ::testing::Test {}; + +// Record the application data stream. +class TlsApplicationDataRecorder : public TlsRecordFilter { + public: + TlsApplicationDataRecorder() : buffer_() {} + + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& input, + DataBuffer* output) { + if (header.content_type() == kTlsApplicationDataType) { + buffer_.Append(input); + } + + return KEEP; + } + + const DataBuffer& buffer() const { return buffer_; } + + private: + DataBuffer buffer_; +}; + +// Damages an SKE or CV signature. +class TlsSignatureDamager : public TlsHandshakeFilter { + public: + TlsSignatureDamager(uint8_t type) : type_(type) {} + virtual PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) { + if (header.handshake_type() != type_) { + return KEEP; + } + + *output = input; + + // Modify the last byte of the signature. + output->data()[output->len() - 1]++; + return CHANGE; + } + + private: + uint8_t type_; +}; + +void ResetState() { + // Clear the list of RSA blinding params. + BL_Cleanup(); + + // Reinit the list of RSA blinding params. + EXPECT_EQ(SECSuccess, BL_Init()); + + // Reset the RNG state. + EXPECT_EQ(SECSuccess, RNG_ResetForFuzzing()); +} + +// Ensure that ssl_Time() returns a constant value. +TEST_F(TlsFuzzTest, Fuzz_SSL_Time_Constant) { + PRInt32 now = ssl_Time(); + PR_Sleep(PR_SecondsToInterval(2)); + EXPECT_EQ(ssl_Time(), now); +} + +// Check that due to the deterministic PRNG we derive +// the same master secret in two consecutive TLS sessions. +TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) { + const char kLabel[] = "label"; + std::vector<unsigned char> out1(32), out2(32); + + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + DisableECDHEServerKeyReuse(); + + ResetState(); + Connect(); + + // Export a key derived from the MS and nonces. + SECStatus rv = + SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), false, + NULL, 0, out1.data(), out1.size()); + EXPECT_EQ(SECSuccess, rv); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + DisableECDHEServerKeyReuse(); + + ResetState(); + Connect(); + + // Export another key derived from the MS and nonces. + rv = SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), + false, NULL, 0, out2.data(), out2.size()); + EXPECT_EQ(SECSuccess, rv); + + // The two exported keys should be the same. + EXPECT_EQ(out1, out2); +} + +// Check that due to the deterministic RNG two consecutive +// TLS sessions will have the exact same transcript. +TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) { + // Connect a few times and compare the transcripts byte-by-byte. + DataBuffer last; + for (size_t i = 0; i < 5; i++) { + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + DisableECDHEServerKeyReuse(); + + DataBuffer buffer; + client_->SetPacketFilter(new TlsConversationRecorder(buffer)); + server_->SetPacketFilter(new TlsConversationRecorder(buffer)); + + ResetState(); + Connect(); + + // Ensure the filters go away before |buffer| does. + client_->SetPacketFilter(nullptr); + server_->SetPacketFilter(nullptr); + + if (last.len() > 0) { + EXPECT_EQ(last, buffer); + } + + last = buffer; + } +} + +// Check that we can establish and use a connection +// with all supported TLS versions, STREAM and DGRAM. +// Check that records are NOT encrypted. +// Check that records don't have a MAC. +TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) { + EnsureTlsSetup(); + + // Set up app data filters. + auto client_recorder = new TlsApplicationDataRecorder(); + client_->SetPacketFilter(client_recorder); + auto server_recorder = new TlsApplicationDataRecorder(); + server_->SetPacketFilter(server_recorder); + + Connect(); + + // Construct the plaintext. + DataBuffer buf; + buf.Allocate(50); + for (size_t i = 0; i < buf.len(); ++i) { + buf.data()[i] = i & 0xff; + } + + // Send/Receive data. + client_->SendBuffer(buf); + server_->SendBuffer(buf); + Receive(buf.len()); + + // Check for plaintext on the wire. + EXPECT_EQ(buf, client_recorder->buffer()); + EXPECT_EQ(buf, server_recorder->buffer()); +} + +// Check that an invalid Finished message doesn't abort the connection. +TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) { + EnsureTlsSetup(); + + auto i1 = new TlsInspectorReplaceHandshakeMessage( + kTlsHandshakeFinished, + DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished))); + client_->SetPacketFilter(i1); + Connect(); + SendReceive(); +} + +// Check that an invalid Finished message doesn't abort the connection. +TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) { + EnsureTlsSetup(); + + auto i1 = new TlsInspectorReplaceHandshakeMessage( + kTlsHandshakeFinished, + DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished))); + server_->SetPacketFilter(i1); + Connect(); + SendReceive(); +} + +// Check that an invalid server auth signature doesn't abort the connection. +TEST_P(TlsConnectGeneric, Fuzz_BogusServerAuthSignature) { + EnsureTlsSetup(); + uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3 + ? kTlsHandshakeCertificateVerify + : kTlsHandshakeServerKeyExchange; + server_->SetPacketFilter(new TlsSignatureDamager(msg_type)); + Connect(); + SendReceive(); +} + +// Check that an invalid client auth signature doesn't abort the connection. +TEST_P(TlsConnectGeneric, Fuzz_BogusClientAuthSignature) { + EnsureTlsSetup(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + client_->SetPacketFilter( + new TlsSignatureDamager(kTlsHandshakeCertificateVerify)); + Connect(); +} + +#endif +} diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.cc b/security/nss/gtests/ssl_gtest/ssl_gtest.cc new file mode 100644 index 000000000..2d08dd865 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.cc @@ -0,0 +1,44 @@ +#include "nspr.h" +#include "nss.h" +#include "prenv.h" +#include "ssl.h" + +#include <cstdlib> + +#include "test_io.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +std::string g_working_dir_path; +bool g_ssl_gtest_verbose; + +int main(int argc, char** argv) { + // Start the tests + ::testing::InitGoogleTest(&argc, argv); + g_working_dir_path = "."; + g_ssl_gtest_verbose = false; + + char* workdir = PR_GetEnvSecure("NSS_GTEST_WORKDIR"); + if (workdir) g_working_dir_path = workdir; + + for (int i = 0; i < argc; i++) { + if (!strcmp(argv[i], "-d")) { + g_working_dir_path = argv[i + 1]; + ++i; + } else if (!strcmp(argv[i], "-v")) { + g_ssl_gtest_verbose = true; + } + } + + NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB, + NSS_INIT_READONLY); + NSS_SetDomesticPolicy(); + int rv = RUN_ALL_TESTS(); + + NSS_Shutdown(); + + nss_test::Poller::Shutdown(); + + return rv; +} diff --git a/security/nss/gtests/ssl_gtest/ssl_gtest.gyp b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp new file mode 100644 index 000000000..e232a8b7e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_gtest.gyp @@ -0,0 +1,101 @@ +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. +{ + 'includes': [ + '../../coreconf/config.gypi', + '../common/gtest.gypi', + ], + 'targets': [ + { + 'target_name': 'ssl_gtest', + 'type': 'executable', + 'sources': [ + 'libssl_internals.c', + 'ssl_0rtt_unittest.cc', + 'ssl_agent_unittest.cc', + 'ssl_auth_unittest.cc', + 'ssl_cert_ext_unittest.cc', + 'ssl_ciphersuite_unittest.cc', + 'ssl_damage_unittest.cc', + 'ssl_dhe_unittest.cc', + 'ssl_drop_unittest.cc', + 'ssl_ecdh_unittest.cc', + 'ssl_ems_unittest.cc', + 'ssl_exporter_unittest.cc', + 'ssl_extension_unittest.cc', + 'ssl_fuzz_unittest.cc', + 'ssl_gtest.cc', + 'ssl_hrr_unittest.cc', + 'ssl_loopback_unittest.cc', + 'ssl_record_unittest.cc', + 'ssl_resumption_unittest.cc', + 'ssl_skip_unittest.cc', + 'ssl_staticrsa_unittest.cc', + 'ssl_v2_client_hello_unittest.cc', + 'ssl_version_unittest.cc', + 'test_io.cc', + 'tls_agent.cc', + 'tls_connect.cc', + 'tls_filter.cc', + 'tls_hkdf_unittest.cc', + 'tls_parser.cc' + ], + 'dependencies': [ + '<(DEPTH)/exports.gyp:nss_exports', + '<(DEPTH)/lib/util/util.gyp:nssutil3', + '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3', + '<(DEPTH)/gtests/google_test/google_test.gyp:gtest', + '<(DEPTH)/lib/softoken/softoken.gyp:softokn', + '<(DEPTH)/lib/smime/smime.gyp:smime', + '<(DEPTH)/lib/ssl/ssl.gyp:ssl', + '<(DEPTH)/lib/nss/nss.gyp:nss_static', + '<(DEPTH)/cmd/lib/lib.gyp:sectool', + '<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12', + '<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7', + '<(DEPTH)/lib/certhigh/certhigh.gyp:certhi', + '<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi', + '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap', + '<(DEPTH)/lib/softoken/softoken.gyp:softokn', + '<(DEPTH)/lib/certdb/certdb.gyp:certdb', + '<(DEPTH)/lib/pki/pki.gyp:nsspki', + '<(DEPTH)/lib/dev/dev.gyp:nssdev', + '<(DEPTH)/lib/base/base.gyp:nssb', + '<(DEPTH)/lib/freebl/freebl.gyp:<(freebl_name)', + '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib' + ], + 'conditions': [ + [ 'disable_dbm==0', { + 'dependencies': [ + '<(DEPTH)/lib/dbm/src/src.gyp:dbm', + ], + }], + [ 'disable_libpkix==0', { + 'dependencies': [ + '<(DEPTH)/lib/libpkix/pkix/certsel/certsel.gyp:pkixcertsel', + '<(DEPTH)/lib/libpkix/pkix/checker/checker.gyp:pkixchecker', + '<(DEPTH)/lib/libpkix/pkix/crlsel/crlsel.gyp:pkixcrlsel', + '<(DEPTH)/lib/libpkix/pkix/params/params.gyp:pkixparams', + '<(DEPTH)/lib/libpkix/pkix/results/results.gyp:pkixresults', + '<(DEPTH)/lib/libpkix/pkix/store/store.gyp:pkixstore', + '<(DEPTH)/lib/libpkix/pkix/top/top.gyp:pkixtop', + '<(DEPTH)/lib/libpkix/pkix/util/util.gyp:pkixutil', + '<(DEPTH)/lib/libpkix/pkix_pl_nss/system/system.gyp:pkixsystem', + '<(DEPTH)/lib/libpkix/pkix_pl_nss/module/module.gyp:pkixmodule', + '<(DEPTH)/lib/libpkix/pkix_pl_nss/pki/pki.gyp:pkixpki', + ], + }], + ], + } + ], + 'target_defaults': { + 'include_dirs': [ + '../../gtests/google_test/gtest/include', + '../../gtests/common', + '../../lib/ssl' + ], + }, + 'variables': { + 'module': 'nss', + } +} diff --git a/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc new file mode 100644 index 000000000..5d670fa82 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc @@ -0,0 +1,285 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +// This is internal, just to get TLS_1_3_DRAFT_VERSION. +#include "ssl3prot.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) { + const char* k0RttData = "Such is life"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + SetupForZeroRtt(); // initial handshake as normal + + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + client_->Set0RttEnabled(true); + server_->Set0RttEnabled(true); + ExpectResumption(RESUME_TICKET); + + // Send first ClientHello and send 0-RTT data + auto capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn); + client_->SetPacketFilter(capture_early_data); + client_->Handshake(); + EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData, + k0RttDataLen)); // 0-RTT write. + EXPECT_TRUE(capture_early_data->captured()); + + // Send the HelloRetryRequest + auto hrr_capture = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + server_->SetPacketFilter(hrr_capture); + server_->Handshake(); + EXPECT_LT(0U, hrr_capture->buffer().len()); + + // The server can't read + std::vector<uint8_t> buf(k0RttDataLen); + EXPECT_EQ(SECFailure, PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen)); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Make a new capture for the early data. + capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn); + client_->SetPacketFilter(capture_early_data); + + // Complete the handshake successfully + Handshake(); + ExpectEarlyDataAccepted(false); // The server should reject 0-RTT + CheckConnected(); + SendReceive(); + EXPECT_FALSE(capture_early_data->captured()); +} + +class KeyShareReplayer : public TlsExtensionFilter { + public: + KeyShareReplayer() {} + + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) { + if (extension_type != ssl_tls13_key_share_xtn) { + return KEEP; + } + + if (!data_.len()) { + data_ = input; + return KEEP; + } + + *output = data_; + return CHANGE; + } + + private: + DataBuffer data_; +}; + +// This forces a HelloRetryRequest by disabling P-256 on the server. However, +// the second ClientHello is modified so that it omits the requested share. The +// server should reject this. +TEST_P(TlsConnectTls13, RetryWithSameKeyShare) { + EnsureTlsSetup(); + client_->SetPacketFilter(new KeyShareReplayer()); + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code()); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code()); +} + +// This tests that the second attempt at sending a ClientHello (after receiving +// a HelloRetryRequest) is correctly retransmitted. +TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) { + static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(groups); + server_->SetPacketFilter(new SelectiveDropFilter(0x2)); + Connect(); +} + +class TlsKeyExchange13 : public TlsKeyExchangeTest {}; + +// This should work, with an HRR, because the server prefers x25519 and the +// client generates a share for P-384 on the initial ClientHello. +TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrr) { + EnsureKeyShareSetup(); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + Connect(); + CheckKeys(); + static const std::vector<SSLNamedGroup> expectedShares = { + ssl_grp_ec_secp384r1}; + CheckKEXDetails(client_groups, expectedShares, ssl_grp_ec_curve25519); +} + +// This should work, but not use HRR because the key share for x25519 was +// pre-generated by the client. +TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) { + EnsureKeyShareSetup(); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519}; + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1}; + client_->ConfigNamedGroups(client_groups); + server_->ConfigNamedGroups(server_groups); + EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1)); + + Connect(); + CheckKeys(); + CheckKEXDetails(client_groups, client_groups); +} + +TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) { + EnsureTlsSetup(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + static const std::vector<SSLNamedGroup> client_groups = { + ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1}; + client_->ConfigNamedGroups(client_groups); + static const std::vector<SSLNamedGroup> server_groups = { + ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; + server_->ConfigNamedGroups(server_groups); + client_->StartConnect(); + server_->StartConnect(); + + client_->Handshake(); + server_->Handshake(); + + // Here we replace the TLS server with one that does TLS 1.2 only. + // This will happily send the client a TLS 1.2 ServerHello. + TlsAgent* replacement_server = + new TlsAgent(server_->name(), TlsAgent::SERVER, mode_); + delete server_; + server_ = replacement_server; + server_->Init(); + client_->SetPeer(server_); + server_->SetPeer(client_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->StartConnect(); + Handshake(); + EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code()); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +class HelloRetryRequestAgentTest : public TlsAgentTestClient { + protected: + void SetUp() override { + TlsAgentTestClient::SetUp(); + EnsureInit(); + agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3, + SSL_LIBRARY_VERSION_TLS_1_3); + agent_->StartConnect(); + } + + void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record, + uint32_t seq_num = 0) const { + DataBuffer hrr_data; + hrr_data.Allocate(len + 4); + size_t i = 0; + i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2); + i = hrr_data.Write(i, static_cast<uint32_t>(len), 2); + if (len) { + hrr_data.Write(i, body, len); + } + DataBuffer hrr; + MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(), + hrr_data.len(), &hrr, seq_num); + MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(), + hrr.len(), hrr_record, seq_num); + } + + void MakeGroupHrr(SSLNamedGroup group, DataBuffer* hrr_record, + uint32_t seq_num = 0) const { + const uint8_t group_hrr[] = { + static_cast<uint8_t>(ssl_tls13_key_share_xtn >> 8), + static_cast<uint8_t>(ssl_tls13_key_share_xtn), + 0, + 2, // length of key share extension + static_cast<uint8_t>(group >> 8), + static_cast<uint8_t>(group)}; + MakeCannedHrr(group_hrr, sizeof(group_hrr), hrr_record, seq_num); + } +}; + +// Send two HelloRetryRequest messages in response to the ClientHello. The are +// constructed to appear legitimate by asking for a new share in each, so that +// the client has to count to work out that the server is being unreasonable. +TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) { + DataBuffer hrr; + MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0); + ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); + MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST); +} + +// Here the client receives a HelloRetryRequest with a group that they already +// provided a share for. +TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) { + DataBuffer hrr; + MakeGroupHrr(ssl_grp_ec_curve25519, &hrr); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); +} + +TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) { + DataBuffer hrr; + MakeCannedHrr(nullptr, 0U, &hrr); + ProcessMessage(hrr, TlsAgent::STATE_ERROR, + SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST); +} + +TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) { + const uint8_t canned_cookie_hrr[] = { + static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8), + static_cast<uint8_t>(ssl_tls13_cookie_xtn), + 0, + 5, // length of cookie extension + 0, + 3, // cookie value length + 0xc0, + 0x0c, + 0x13}; + DataBuffer hrr; + MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr); + TlsExtensionCapture* capture = new TlsExtensionCapture(ssl_tls13_cookie_xtn); + agent_->SetPacketFilter(capture); + ProcessMessage(hrr, TlsAgent::STATE_CONNECTING); + const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len + DataBuffer cookie(canned_cookie_hrr + cookie_pos, + sizeof(canned_cookie_hrr) - cookie_pos); + EXPECT_EQ(cookie, capture->extension()); +} + +INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest, + TlsConnectTestBase::kTlsModesAll); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV13)); +#endif + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc new file mode 100644 index 000000000..65c0ca1f8 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -0,0 +1,274 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectGeneric, SetupOnly) {} + +TEST_P(TlsConnectGeneric, Connect) { + SetExpectedVersion(std::get<1>(GetParam())); + Connect(); + CheckKeys(); +} + +TEST_P(TlsConnectGeneric, ConnectEcdsa) { + SetExpectedVersion(std::get<1>(GetParam())); + Reset(TlsAgent::kServerEcdsa256); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); +} + +TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) { + EnsureTlsSetup(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256); + server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384); + } else { + client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA); + } + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); + server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP); +} + +TEST_P(TlsConnectGenericPre13, ConnectFalseStart) { + client_->EnableFalseStart(); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectAlpn) { + EnableAlpn(); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectGeneric, ConnectAlpnClone) { + EnsureModelSockets(); + client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + Connect(); + CheckAlpn("a"); +} + +TEST_P(TlsConnectDatagram, ConnectSrtp) { + EnableSrtp(); + Connect(); + CheckSrtp(); + SendReceive(); +} + +// 1.3 is disabled in the next few tests because we don't +// presently support resumption in 1.3. +TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) { + Connect(); + server_->PrepareForRenegotiate(); + client_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) { + Connect(); + client_->PrepareForRenegotiate(); + server_->StartRenegotiate(); + Handshake(); + CheckConnected(); +} + +TEST_P(TlsConnectGeneric, ConnectSendReceive) { + Connect(); + SendReceive(); +} + +// The next two tests takes advantage of the fact that we +// automatically read the first 1024 bytes, so if +// we provide 1200 bytes, they overrun the read buffer +// provided by the calling test. + +// DTLS should return an error. +TEST_P(TlsConnectDatagram, ShortRead) { + Connect(); + client_->ExpectReadWriteError(); + server_->SendData(1200, 1200); + client_->WaitForErrorCode(SSL_ERROR_RX_SHORT_DTLS_READ, 2000); + + // Now send and receive another packet. + server_->ResetSentBytes(); // Reset the counter. + SendReceive(); +} + +// TLS should get the write in two chunks. +TEST_P(TlsConnectStream, ShortRead) { + // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting, + // so skip in that case. + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) return; + + Connect(); + server_->SendData(1200, 1200); + // Read the first tranche. + WAIT_(client_->received_bytes() == 1024, 2000); + ASSERT_EQ(1024U, client_->received_bytes()); + // The second tranche should now immediately be available. + client_->ReadBytes(); + ASSERT_EQ(1200U, client_->received_bytes()); +} + +TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { + EnsureTlsSetup(); + client_->EnableCompression(); + server_->EnableCompression(); + Connect(); + EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && mode_ != DGRAM, + client_->is_compressed()); + SendReceive(); +} + +TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) { + Connect(); + std::cerr << "Expiring holddown timer\n"; + SSLInt_ForceTimerExpiry(client_->ssl_fd()); + SSLInt_ForceTimerExpiry(server_->ssl_fd()); + SendReceive(); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd())); + } +} + +class TlsPreCCSHeaderInjector : public TlsRecordFilter { + public: + TlsPreCCSHeaderInjector() {} + virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header, + const DataBuffer& input, + size_t* offset, + DataBuffer* output) override { + if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP; + + std::cerr << "Injecting Finished header before CCS\n"; + const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c}; + DataBuffer hhdr_buf(hhdr, sizeof(hhdr)); + RecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0); + *offset = nhdr.Write(output, *offset, hhdr_buf); + *offset = record_header.Write(output, *offset, input); + return CHANGE; + } +}; + +TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { + client_->SetPacketFilter(new TlsPreCCSHeaderInjector()); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); +} + +TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { + server_->SetPacketFilter(new TlsPreCCSHeaderInjector()); + client_->StartConnect(); + server_->StartConnect(); + Handshake(); + EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); +} + +TEST_P(TlsConnectTls13, UnknownAlert) { + Connect(); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + 0xff); // Unknown value. + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000); +} + +TEST_P(TlsConnectTls13, AlertWrongLevel) { + Connect(); + SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning, + kTlsAlertUnexpectedMessage); + client_->ExpectReadWriteError(); + client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000); +} + +TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { + client_->SetShortHeadersEnabled(); + server_->SetShortHeadersEnabled(); + client_->ExpectShortHeaders(); + server_->ExpectShortHeaders(); + Connect(); +} + +TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { + EnsureTlsSetup(); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); + server_->Handshake(); // Send first flight. + client_->adapter()->CloseWrites(); + client_->Handshake(); // This will get an error, but shouldn't crash. + client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); +} + +INSTANTIATE_TEST_CASE_P(GenericStream, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGeneric, + ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, + TlsConnectTestBase::kTlsVAll); +INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, + TlsConnectTestBase::kTlsV11Plus); + +INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10V11)); +INSTANTIATE_TEST_CASE_P( + Pre12Datagram, TlsConnectPre12, + ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + TlsConnectTestBase::kTlsV11)); + +INSTANTIATE_TEST_CASE_P(Version12Only, TlsConnectTls12, + TlsConnectTestBase::kTlsModesAll); +#ifndef NSS_DISABLE_TLS_1_3 +INSTANTIATE_TEST_CASE_P(Version13Only, TlsConnectTls13, + TlsConnectTestBase::kTlsModesAll); +#endif + +INSTANTIATE_TEST_CASE_P(Pre13Stream, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10ToV12)); +INSTANTIATE_TEST_CASE_P( + Pre13Datagram, TlsConnectGenericPre13, + ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram, + TlsConnectTestBase::kTlsV11V12)); +INSTANTIATE_TEST_CASE_P(Pre13StreamOnly, TlsConnectStreamPre13, + TlsConnectTestBase::kTlsV10ToV12); + +INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV12Plus)); + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc new file mode 100644 index 000000000..ef81b222c --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_record_unittest.cc @@ -0,0 +1,111 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "nss.h" +#include "ssl.h" +#include "sslimpl.h" + +#include "databuffer.h" +#include "gtest_utils.h" + +namespace nss_test { + +const static size_t kMacSize = 20; + +class TlsPaddingTest + : public ::testing::Test, + public ::testing::WithParamInterface<std::tuple<size_t, bool>> { + public: + TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) { + size_t extra = + (plaintext_len_ + 1) % 16; // Bytes past a block (1 == pad len) + // Minimal padding. + pad_len_ = extra ? 16 - extra : 0; + if (std::get<1>(GetParam())) { + // Maximal padding. + pad_len_ += 240; + } + MakePaddedPlaintext(); + } + + // Makes a plaintext record with correct padding. + void MakePaddedPlaintext() { + EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16); + size_t i = 0; + plaintext_.Allocate(plaintext_len_ + pad_len_ + 1); + for (; i < plaintext_len_; ++i) { + plaintext_.Write(i, 'A', 1); + } + + for (; i < plaintext_len_ + pad_len_ + 1; ++i) { + plaintext_.Write(i, pad_len_, 1); + } + } + + void Unpad(bool expect_success) { + std::cerr << "Content length=" << plaintext_len_ + << " padding length=" << pad_len_ + << " total length=" << plaintext_.len() << std::endl; + std::cerr << "Plaintext: " << plaintext_ << std::endl; + sslBuffer s; + s.buf = const_cast<unsigned char *>( + static_cast<const unsigned char *>(plaintext_.data())); + s.len = plaintext_.len(); + SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize); + if (expect_success) { + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len)); + } else { + EXPECT_EQ(SECFailure, rv); + } + } + + protected: + size_t plaintext_len_; + size_t pad_len_; + DataBuffer plaintext_; +}; + +TEST_P(TlsPaddingTest, Correct) { + if (plaintext_len_ >= kMacSize) { + Unpad(true); + } else { + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, PadTooLong) { + if (plaintext_.len() < 255) { + plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1); + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, FirstByteOfPadWrong) { + if (pad_len_) { + plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1); + Unpad(false); + } +} + +TEST_P(TlsPaddingTest, LastByteOfPadWrong) { + if (pad_len_) { + plaintext_.Write(plaintext_.len() - 2, + plaintext_.data()[plaintext_.len() - 1] + 1, 1); + Unpad(false); + } +} + +const static size_t kContentSizesArr[] = { + 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288}; + +auto kContentSizes = ::testing::ValuesIn(kContentSizesArr); +const static bool kTrueFalseArr[] = {true, false}; +auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr); + +INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest, + ::testing::Combine(kContentSizes, kTrueFalse)); +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc new file mode 100644 index 000000000..cfe42cb9f --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc @@ -0,0 +1,582 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +class TlsServerKeyExchangeEcdhe { + public: + bool Parse(const DataBuffer& buffer) { + TlsParser parser(buffer); + + uint8_t curve_type; + if (!parser.Read(&curve_type)) { + return false; + } + + if (curve_type != 3) { // named_curve + return false; + } + + uint32_t named_curve; + if (!parser.Read(&named_curve, 2)) { + return false; + } + + return parser.ReadVariable(&public_key_, 1); + } + + DataBuffer public_key_; +}; + +TEST_P(TlsConnectGenericPre13, ConnectResumed) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + Connect(); + + Reset(); + ExpectResumption(RESUME_SESSIONID); + Connect(); +} + +TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { + ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { + // This prefers tickets. + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { + // This causes no resumption because the client needs the + // session cache to resume even with tickets. + ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) { + // This causes no resumption because the client needs the + // session cache to resume even with tickets. + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) { + ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientNoneServerBoth) { + ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +TEST_P(TlsConnectGenericPre13, ConnectResumeWithHigherVersion) { + ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); + Connect(); + + Reset(); + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + ExpectResumption(RESUME_NONE); + Connect(); +} + +TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) { + // This causes a ticket resumption. + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ClearServerCache(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + Connect(); + SendReceive(); +} + +// This callback switches out the "server" cert used on the server with +// the "client" certificate, which should be the same type. +static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr, + uint32_t srvNameArrSize) { + bool ok = agent->ConfigServerCert("client"); + if (!ok) return SSL_SNI_SEND_ALERT; + + return 0; // first config +}; + +TEST_P(TlsConnectGeneric, ServerSNICertSwitch) { + Connect(); + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + server_->SetSniCallback(SwitchCertificates); + + Connect(); + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + CheckKeys(); + EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) { + Reset(TlsAgent::kServerEcdsa256); + Connect(); + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + + Reset(); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + + // Because we configure an RSA certificate here, it only adds a second, unused + // certificate, which has no effect on what the server uses. + server_->SetSniCallback(SwitchCertificates); + + Connect(); + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa); + EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert)); +} + +// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that +TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) { + TlsInspectorRecordHandshakeMessage* i1 = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + server_->SetPacketFilter(i1); + Connect(); + CheckKeys(); + TlsServerKeyExchangeEcdhe dhe1; + EXPECT_TRUE(dhe1.Parse(i1->buffer())); + + // Restart + Reset(); + TlsInspectorRecordHandshakeMessage* i2 = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + server_->SetPacketFilter(i2); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + CheckKeys(); + + TlsServerKeyExchangeEcdhe dhe2; + EXPECT_TRUE(dhe2.Parse(i2->buffer())); + + // Make sure they are the same. + EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); + EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len())); +} + +// This test parses the ServerKeyExchange, which isn't in 1.3 +TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) { + server_->EnsureTlsSetup(); + SECStatus rv = + SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + TlsInspectorRecordHandshakeMessage* i1 = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + server_->SetPacketFilter(i1); + Connect(); + CheckKeys(); + TlsServerKeyExchangeEcdhe dhe1; + EXPECT_TRUE(dhe1.Parse(i1->buffer())); + + // Restart + Reset(); + server_->EnsureTlsSetup(); + rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + TlsInspectorRecordHandshakeMessage* i2 = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); + server_->SetPacketFilter(i2); + ConfigureSessionCache(RESUME_NONE, RESUME_NONE); + Connect(); + CheckKeys(); + + TlsServerKeyExchangeEcdhe dhe2; + EXPECT_TRUE(dhe2.Parse(i2->buffer())); + + // Make sure they are different. + EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && + (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), + dhe1.public_key_.len()))); +} + +// Verify that TLS 1.3 reports an accurate group on resumption. +TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + client_->ConfigNamedGroups(kFFDHEGroups); + server_->ConfigNamedGroups(kFFDHEGroups); + Connect(); + CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none); +} + +// We need to enable different cipher suites at different times in the following +// tests. Those cipher suites need to be suited to the version. +static uint16_t ChooseOneCipher(uint16_t version) { + if (version >= SSL_LIBRARY_VERSION_TLS_1_3) { + return TLS_AES_128_GCM_SHA256; + } + return TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA; +} + +static uint16_t ChooseAnotherCipher(uint16_t version) { + if (version >= SSL_LIBRARY_VERSION_TLS_1_3) { + return TLS_AES_256_GCM_SHA384; + } + return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA; +} + +// Test that we don't resume when we can't negotiate the same cipher. +TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + client_->EnableSingleCipher(ChooseAnotherCipher(version_)); + uint16_t ticket_extension; + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ticket_extension = ssl_tls13_pre_shared_key_xtn; + } else { + ticket_extension = ssl_session_ticket_xtn; + } + auto ticket_capture = new TlsExtensionCapture(ticket_extension); + client_->SetPacketFilter(ticket_capture); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + EXPECT_EQ(0U, ticket_capture->extension().len()); +} + +// Test that we don't resume when we can't negotiate the same cipher. +TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_NONE); + server_->EnableSingleCipher(ChooseAnotherCipher(version_)); + Connect(); + CheckKeys(); +} + +class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { + public: + SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (header.handshake_type() != kTlsHandshakeServerHello) { + return KEEP; + } + + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + // Cipher suite is after version(2) and random(32). + size_t pos = 34; + if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { + // In old versions, we have to skip a session_id too. + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + } + output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); + return CHANGE; + } + + private: + uint16_t cipher_suite_; +}; + +// Test that the client doesn't tolerate the server picking a different cipher +// suite for resumption. +TEST_P(TlsConnectStream, TestResumptionOverrideCipher) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->EnableSingleCipher(ChooseOneCipher(version_)); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + server_->SetPacketFilter( + new SelectedCipherSuiteReplacer(ChooseAnotherCipher(version_))); + + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + // The reason this test is stream only: the server is unable to decrypt + // the alert that the client sends, see bug 1304603. + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + } else { + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); + } +} + +class SelectedVersionReplacer : public TlsHandshakeFilter { + public: + SelectedVersionReplacer(uint16_t version) : version_(version) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override { + if (header.handshake_type() != kTlsHandshakeServerHello) { + return KEEP; + } + + *output = input; + output->Write(0, static_cast<uint32_t>(version_), 2); + return CHANGE; + } + + private: + uint16_t version_; +}; + +// Test how the client handles the case where the server picks a +// lower version number on resumption. +TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) { + uint16_t override_version = 0; + if (mode_ == STREAM) { + switch (version_) { + case SSL_LIBRARY_VERSION_TLS_1_0: + return; // Skip the test. + case SSL_LIBRARY_VERSION_TLS_1_1: + override_version = SSL_LIBRARY_VERSION_TLS_1_0; + break; + case SSL_LIBRARY_VERSION_TLS_1_2: + override_version = SSL_LIBRARY_VERSION_TLS_1_1; + break; + default: + ASSERT_TRUE(false) << "unknown version"; + } + } else { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) { + override_version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; + } else { + ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, version_); + return; // Skip the test. + } + } + + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Need to use a cipher that is plausible for the lower version. + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + // Enable the lower version on the client. + client_->SetVersionRange(version_ - 1, version_); + server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + server_->SetPacketFilter(new SelectedVersionReplacer(override_version)); + + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO); + server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT); +} + +// Test that two TLS resumptions work and produce the same ticket. +// This will change after bug 1257047 is fixed. +TEST_F(TlsConnectTest, TestTls13ResumptionTwice) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + uint16_t original_suite; + EXPECT_TRUE(client_->cipher_suite(&original_suite)); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + TlsExtensionCapture* c1 = + new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + client_->SetPacketFilter(c1); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_none); + // The filter will go away when we reset, so save the captured extension. + DataBuffer initialTicket(c1->extension()); + ASSERT_LT(0U, initialTicket.len()); + + ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_TRUE(!!cert1.get()); + + Reset(); + ClearStats(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + TlsExtensionCapture* c2 = + new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn); + client_->SetPacketFilter(c2); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); + CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign, + ssl_sig_none); + ASSERT_LT(0U, c2->extension().len()); + + ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd())); + ASSERT_TRUE(!!cert2.get()); + + // Check that the cipher suite is reported the same on both sides, though in + // TLS 1.3 resumption actually negotiates a different cipher suite. + uint16_t resumed_suite; + EXPECT_TRUE(server_->cipher_suite(&resumed_suite)); + EXPECT_EQ(original_suite, resumed_suite); + EXPECT_TRUE(client_->cipher_suite(&resumed_suite)); + EXPECT_EQ(original_suite, resumed_suite); + + ASSERT_NE(initialTicket, c2->extension()); +} + +// Check that resumption works after receiving two NST messages. +TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + + // Clear the session ticket keys to invalidate the old ticket. + SSLInt_ClearSessionTicketKey(); + SSLInt_SendNewSessionTicket(server_->ssl_fd()); + + SendReceive(); // Need to read so that we absorb the session tickets. + CheckKeys(); + + // Resume the connection. + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ExpectResumption(RESUME_TICKET); + Connect(); + SendReceive(); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc new file mode 100644 index 000000000..523a37499 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_skip_unittest.cc @@ -0,0 +1,158 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "sslerr.h" + +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +/* + * The tests in this file test that the TLS state machine is robust against + * attacks that alter the order of handshake messages. + * + * See <https://www.smacktls.com/smack.pdf> for a description of the problems + * that this sort of attack can enable. + */ +namespace nss_test { + +class TlsHandshakeSkipFilter : public TlsRecordFilter { + public: + // A TLS record filter that skips handshake messages of the identified type. + TlsHandshakeSkipFilter(uint8_t handshake_type) + : handshake_type_(handshake_type), skipped_(false) {} + + protected: + // Takes a record; if it is a handshake record, it removes the first handshake + // message that is of handshake_type_ type. + virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header, + const DataBuffer& input, + DataBuffer* output) { + if (record_header.content_type() != kTlsHandshakeType) { + return KEEP; + } + + size_t output_offset = 0U; + output->Allocate(input.len()); + + TlsParser parser(input); + while (parser.remaining()) { + size_t start = parser.consumed(); + TlsHandshakeFilter::HandshakeHeader header; + DataBuffer ignored; + if (!header.Parse(&parser, record_header, &ignored)) { + return KEEP; + } + + if (skipped_ || header.handshake_type() != handshake_type_) { + size_t entire_length = parser.consumed() - start; + output->Write(output_offset, input.data() + start, entire_length); + // DTLS sequence numbers need to be rewritten + if (skipped_ && header.is_dtls()) { + output->data()[start + 5] -= 1; + } + output_offset += entire_length; + } else { + std::cerr << "Dropping handshake: " + << static_cast<unsigned>(handshake_type_) << std::endl; + // We only need to report that the output contains changed data if we + // drop a handshake message. But once we've skipped one message, we + // have to modify all subsequent handshake messages so that they include + // the correct DTLS sequence numbers. + skipped_ = true; + } + } + output->Truncate(output_offset); + return skipped_ ? CHANGE : KEEP; + } + + private: + // The type of handshake message to drop. + uint8_t handshake_type_; + // Whether this filter has ever skipped a handshake message. Track this so + // that sequence numbers on DTLS handshake messages can be rewritten in + // subsequent calls. + bool skipped_; +}; + +class TlsSkipTest + : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + protected: + TlsSkipTest() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + + void ServerSkipTest(PacketFilter* filter, + uint8_t alert = kTlsAlertUnexpectedMessage) { + auto alert_recorder = new TlsAlertRecorder(); + client_->SetPacketFilter(alert_recorder); + if (filter) { + server_->SetPacketFilter(filter); + } + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(alert, alert_recorder->description()); + } +}; + +TEST_P(TlsSkipTest, SkipCertificateRsa) { + EnableOnlyStaticRsaCiphers(); + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertificateDhe) { + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipCertificateEcdhe) { + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipCertificateEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); +} + +TEST_P(TlsSkipTest, SkipServerKeyExchange) { + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertAndKeyExch) { + auto chain = new ChainedPacketFilter(); + chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(chain); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { + Reset(TlsAgent::kServerEcdsa256); + auto chain = new ChainedPacketFilter(); + chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); + chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); + ServerSkipTest(chain); + client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); +} + +INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsModesStream, + TlsConnectTestBase::kTlsV10)); +INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest, + ::testing::Combine(TlsConnectTestBase::kTlsModesAll, + TlsConnectTestBase::kTlsV11V12)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc new file mode 100644 index 000000000..baf24ed9c --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc @@ -0,0 +1,123 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <functional> +#include <memory> +#include "secerr.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +const uint8_t kBogusClientKeyExchange[] = { + 0x01, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, +}; + +TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) { + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +// Test that a totally bogus EPMS is handled correctly. +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) { + EnableOnlyStaticRsaCiphers(); + TlsInspectorReplaceHandshakeMessage* i1 = + new TlsInspectorReplaceHandshakeMessage( + kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + client_->SetPacketFilter(i1); + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); +} + +// Test that a PMS with a bogus version number is handled correctly. +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) { + EnableOnlyStaticRsaCiphers(); + client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); +} + +// Test that a PMS with a bogus version number is ignored when +// rollback detection is disabled. This is a positive control for +// ConnectStaticRSABogusPMSVersionDetect. +TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) { + EnableOnlyStaticRsaCiphers(); + client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + server_->DisableRollbackDetection(); + Connect(); +} + +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + TlsInspectorReplaceHandshakeMessage* inspect = + new TlsInspectorReplaceHandshakeMessage( + kTlsHandshakeClientKeyExchange, + DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange))); + client_->SetPacketFilter(inspect); + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); +} + +// This test is stream so we can catch the bad_record_mac alert. +TEST_P(TlsConnectStreamPre13, + ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + auto alert_recorder = new TlsAlertRecorder(); + server_->SetPacketFilter(alert_recorder); + ConnectExpectFail(); + EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); + EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); +} + +TEST_P(TlsConnectStreamPre13, + ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { + EnableOnlyStaticRsaCiphers(); + EnableExtendedMasterSecret(); + client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_)); + server_->DisableRollbackDetection(); + Connect(); +} + +} // namespace nspr_test diff --git a/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc new file mode 100644 index 000000000..8b586beae --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc @@ -0,0 +1,351 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "pk11pub.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include "gtest_utils.h" +#include "tls_connect.h" +#include "tls_filter.h" + +namespace nss_test { + +// Replaces the client hello with an SSLv2 version once. +class SSLv2ClientHelloFilter : public PacketFilter { + public: + SSLv2ClientHelloFilter(TlsAgent* client, uint16_t version) + : replaced_(false), + client_(client), + version_(version), + pad_len_(0), + reported_pad_len_(0), + client_random_len_(16), + ciphers_(0), + send_escape_(false) {} + + void SetVersion(uint16_t version) { version_ = version; } + + void SetCipherSuites(const std::vector<uint16_t>& ciphers) { + ciphers_ = ciphers; + } + + // Set a padding length and announce it correctly. + void SetPadding(uint8_t pad_len) { SetPadding(pad_len, pad_len); } + + // Set a padding length and allow to lie about its length. + void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) { + pad_len_ = pad_len; + reported_pad_len_ = reported_pad_len; + } + + void SetClientRandomLength(uint16_t client_random_len) { + client_random_len_ = client_random_len; + } + + void SetSendEscape(bool send_escape) { send_escape_ = send_escape; } + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) { + if (replaced_) { + return KEEP; + } + + // Replace only the very first packet. + replaced_ = true; + + // The SSLv2 client hello size. + size_t packet_len = SSL_HL_CLIENT_HELLO_HBYTES + (ciphers_.size() * 3) + + client_random_len_ + pad_len_; + + size_t idx = 0; + *output = input; + output->Allocate(packet_len); + output->Truncate(packet_len); + + // Write record length. + if (pad_len_ > 0) { + size_t masked_len = 0x3fff & packet_len; + if (send_escape_) { + masked_len |= 0x4000; + } + + idx = output->Write(idx, masked_len, 2); + idx = output->Write(idx, reported_pad_len_, 1); + } else { + PR_ASSERT(!send_escape_); + idx = output->Write(idx, 0x8000 | packet_len, 2); + } + + // Remember header length. + size_t hdr_len = idx; + + // Write client hello. + idx = output->Write(idx, SSL_MT_CLIENT_HELLO, 1); + idx = output->Write(idx, version_, 2); + + // Cipher list length. + idx = output->Write(idx, (ciphers_.size() * 3), 2); + + // Session ID length. + idx = output->Write(idx, static_cast<uint32_t>(0), 2); + + // ClientRandom length. + idx = output->Write(idx, client_random_len_, 2); + + // Cipher suites. + for (auto cipher : ciphers_) { + idx = output->Write(idx, static_cast<uint32_t>(cipher), 3); + } + + // Challenge. + std::vector<uint8_t> challenge(client_random_len_); + PK11_GenerateRandom(challenge.data(), challenge.size()); + idx = output->Write(idx, challenge.data(), challenge.size()); + + // Add padding if any. + if (pad_len_ > 0) { + std::vector<uint8_t> pad(pad_len_); + idx = output->Write(idx, pad.data(), pad.size()); + } + + // Update the client random so that the handshake succeeds. + SECStatus rv = SSLInt_UpdateSSLv2ClientRandom( + client_->ssl_fd(), challenge.data(), challenge.size(), + output->data() + hdr_len, output->len() - hdr_len); + EXPECT_EQ(SECSuccess, rv); + + return CHANGE; + } + + private: + bool replaced_; + TlsAgent* client_; + uint16_t version_; + uint8_t pad_len_; + uint8_t reported_pad_len_; + uint16_t client_random_len_; + std::vector<uint16_t> ciphers_; + bool send_escape_; +}; + +class SSLv2ClientHelloTestF : public TlsConnectTestBase { + public: + SSLv2ClientHelloTestF() : TlsConnectTestBase(STREAM, 0), filter_(nullptr) {} + + SSLv2ClientHelloTestF(Mode mode, uint16_t version) + : TlsConnectTestBase(mode, version), filter_(nullptr) {} + + void SetUp() { + TlsConnectTestBase::SetUp(); + filter_ = new SSLv2ClientHelloFilter(client_, version_); + client_->SetPacketFilter(filter_); + } + + void RequireSafeRenegotiation() { + server_->EnsureTlsSetup(); + SECStatus rv = + SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE); + EXPECT_EQ(rv, SECSuccess); + } + + void SetExpectedVersion(uint16_t version) { + TlsConnectTestBase::SetExpectedVersion(version); + filter_->SetVersion(version); + } + + void SetAvailableCipherSuite(uint16_t cipher) { + filter_->SetCipherSuites(std::vector<uint16_t>(1, cipher)); + } + + void SetAvailableCipherSuites(const std::vector<uint16_t>& ciphers) { + filter_->SetCipherSuites(ciphers); + } + + void SetPadding(uint8_t pad_len) { filter_->SetPadding(pad_len); } + + void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) { + filter_->SetPadding(pad_len, reported_pad_len); + } + + void SetClientRandomLength(uint16_t client_random_len) { + filter_->SetClientRandomLength(client_random_len); + } + + void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); } + + private: + SSLv2ClientHelloFilter* filter_; +}; + +// Parameterized version of SSLv2ClientHelloTestF we can +// use with TEST_P to test multiple TLS versions easily. +class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF, + public ::testing::WithParamInterface<uint16_t> { + public: + SSLv2ClientHelloTest() : SSLv2ClientHelloTestF(STREAM, GetParam()) {} +}; + +// Test negotiating TLS 1.0 - 1.2. +TEST_P(SSLv2ClientHelloTest, Connect) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); +} + +// Test negotiating TLS 1.3. +TEST_F(SSLv2ClientHelloTestF, Connect13) { + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + + std::vector<uint16_t> cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}; + SetAvailableCipherSuites(cipher_suites); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Test negotiating an EC suite. +TEST_P(SSLv2ClientHelloTest, NegotiateECSuite) { + SetAvailableCipherSuite(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA); + Connect(); +} + +// Test negotiating TLS 1.0 - 1.2 with a padded client hello. +TEST_P(SSLv2ClientHelloTest, AddPadding) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + SetPadding(255); + Connect(); +} + +// Test that sending a security escape fails the handshake. +TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a security escape. + SetSendEscape(true); + + // Set a big padding so that the server fails instead of timing out. + SetPadding(255); + + ConnectExpectFail(); +} + +// Invalid SSLv2 client hello padding must fail the handshake. +TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Append 5 bytes of padding but say it's only 4. + SetPadding(5, 4); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Invalid SSLv2 client hello padding must fail the handshake. +TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Append 5 bytes of padding but say it's 6. + SetPadding(5, 6); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Wrong amount of bytes for the ClientRandom must fail the handshake. +TEST_P(SSLv2ClientHelloTest, SmallClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a ClientRandom that's too small. + SetClientRandomLength(15); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Test sending the maximum accepted number of ClientRandom bytes. +TEST_P(SSLv2ClientHelloTest, MaxClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + SetClientRandomLength(32); + Connect(); +} + +// Wrong amount of bytes for the ClientRandom must fail the handshake. +TEST_P(SSLv2ClientHelloTest, BigClientRandom) { + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + + // Send a ClientRandom that's too big. + SetClientRandomLength(33); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code()); +} + +// Connection must fail if we require safe renegotiation but the client doesn't +// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. +TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) { + RequireSafeRenegotiation(); + SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA); + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code()); +} + +// Connection must succeed when requiring safe renegotiation and the client +// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites. +TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) { + RequireSafeRenegotiation(); + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_EMPTY_RENEGOTIATION_INFO_SCSV}; + SetAvailableCipherSuites(cipher_suites); + Connect(); +} + +// Connect to the server with TLS 1.1, signalling that this is a fallback from +// a higher version. As the server doesn't support anything higher than TLS 1.1 +// it must accept the connection. +TEST_F(SSLv2ClientHelloTestF, FallbackSCSV) { + EnsureTlsSetup(); + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1); + + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_FALLBACK_SCSV}; + SetAvailableCipherSuites(cipher_suites); + Connect(); +} + +// Connect to the server with TLS 1.1, signalling that this is a fallback from +// a higher version. As the server supports TLS 1.2 though it must reject the +// connection due to a possible downgrade attack. +TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) { + SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + + std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, + TLS_FALLBACK_SCSV}; + SetAvailableCipherSuites(cipher_suites); + + ConnectExpectFail(); + EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code()); +} + +INSTANTIATE_TEST_CASE_P(VersionsStream10Pre13, SSLv2ClientHelloTest, + TlsConnectTestBase::kTlsV10); +INSTANTIATE_TEST_CASE_P(VersionsStreamPre13, SSLv2ClientHelloTest, + TlsConnectTestBase::kTlsV11V12); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc new file mode 100644 index 000000000..b3538497e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/ssl_version_unittest.cc @@ -0,0 +1,300 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "secerr.h" +#include "ssl.h" +#include "ssl3prot.h" +#include "sslerr.h" +#include "sslproto.h" + +#include "gtest_utils.h" +#include "scoped_ptrs.h" +#include "tls_connect.h" +#include "tls_filter.h" +#include "tls_parser.h" + +namespace nss_test { + +TEST_P(TlsConnectStream, ServerNegotiateTls10) { + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + Connect(); +} + +TEST_P(TlsConnectGeneric, ServerNegotiateTls11) { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) return; + + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + Connect(); +} + +TEST_P(TlsConnectGeneric, ServerNegotiateTls12) { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) return; + + uint16_t minver, maxver; + client_->GetVersionRange(&minver, &maxver); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, maxver); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + Connect(); +} +#ifndef TLS_1_3_DRAFT_VERSION + +// Test the ServerRandom version hack from +// [draft-ietf-tls-tls13-11 Section 6.3.1.1]. +// The first three tests test for active tampering. The next +// two validate that we can also detect fallback using the +// SSL_SetDowngradeCheckVersion() API. +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) { + client_->SetPacketFilter( + new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_1)); + ConnectExpectFail(); + ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +/* Attempt to negotiate the bogus DTLS 1.1 version. */ +TEST_F(DtlsConnectTest, TestDtlsVersion11) { + client_->SetPacketFilter( + new TlsInspectorClientHelloVersionSetter(((~0x0101) & 0xffff))); + ConnectExpectFail(); + // It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is + // what is returned here, but this is deliberate in ssl3_HandleAlert(). + EXPECT_EQ(SSL_ERROR_NO_CYPHER_OVERLAP, client_->error_code()); + EXPECT_EQ(SSL_ERROR_UNSUPPORTED_VERSION, server_->error_code()); +} + +// Disabled as long as we have draft version. +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) { + EnsureTlsSetup(); + client_->SetPacketFilter( + new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_2)); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectExpectFail(); + ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +// TLS 1.1 clients do not check the random values, so we should +// instead get a handshake failure alert from the server. +TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) { + client_->SetPacketFilter( + new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_0)); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_2); + ConnectExpectFail(); + ASSERT_EQ(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE, server_->error_code()); + ASSERT_EQ(SSL_ERROR_DECRYPT_ERROR_ALERT, client_->error_code()); +} + +TEST_F(TlsConnectTest, TestFallbackFromTls12) { + EnsureTlsSetup(); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_1); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2); + ConnectExpectFail(); + ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} + +TEST_F(TlsConnectTest, TestFallbackFromTls13) { + EnsureTlsSetup(); + client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + ConnectExpectFail(); + ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code()); +} +#endif + +// The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or +// accept any records with a version less than { 3, 0 }'. Thus we will not +// allow version ranges including both SSL v3 and TLS v1.3. +TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) { + SECStatus rv; + SSLVersionRange vrange = {SSL_LIBRARY_VERSION_3_0, + SSL_LIBRARY_VERSION_TLS_1_3}; + + EnsureTlsSetup(); + rv = SSL_VersionRangeSet(client_->ssl_fd(), &vrange); + EXPECT_EQ(SECFailure, rv); + + rv = SSL_VersionRangeSet(server_->ssl_fd(), &vrange); + EXPECT_EQ(SECFailure, rv); +} + +TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + client_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + server_->StartRenegotiate(); + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } +} + +TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) { + if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) { + return; + } + // Set the client so it will accept any version from 1.0 + // to |version_|. + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_0); + // Reset version so that the checks succeed. + uint16_t test_version = version_; + version_ = SSL_LIBRARY_VERSION_TLS_1_0; + Connect(); + + // Now renegotiate, with the server being set to do + // |version_|. + server_->PrepareForRenegotiate(); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version); + // Reset version and cipher suite so that the preinfo callback + // doesn't fail. + server_->ResetPreliminaryInfo(); + client_->StartRenegotiate(); + Handshake(); + if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) { + // In TLS 1.3, the server detects this problem. + client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); + server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED); + } else { + client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION); + server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); + } +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) { + EnsureTlsSetup(); + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3); + Connect(); + SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError()); +} + +TEST_P(TlsConnectGeneric, AlertBeforeServerHello) { + EnsureTlsSetup(); + client_->StartConnect(); + server_->StartConnect(); + client_->Handshake(); // Send ClientHello. + static const uint8_t kWarningAlert[] = {kTlsAlertWarning, + kTlsAlertUnrecognizedName}; + DataBuffer alert; + TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType, + SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert, + PR_ARRAY_SIZE(kWarningAlert), &alert); + client_->adapter()->PacketReceived(alert); + Handshake(); + CheckConnected(); +} + +class Tls13NoSupportedVersions : public TlsConnectStreamTls12 { + protected: + void Run(uint16_t overwritten_client_version, uint16_t max_server_version) { + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, + SSL_LIBRARY_VERSION_TLS_1_2); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version); + client_->SetPacketFilter( + new TlsInspectorClientHelloVersionSetter(overwritten_client_version)); + auto capture = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello); + server_->SetPacketFilter(capture); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); + const DataBuffer& server_hello = capture->buffer(); + ASSERT_GT(server_hello.len(), 2U); + uint32_t ver; + ASSERT_TRUE(server_hello.Read(0, 2, &ver)); + ASSERT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver); + } +}; + +// If we offer a 1.3 ClientHello w/o supported_versions, the server should +// negotiate 1.2. +TEST_F(Tls13NoSupportedVersions, + Tls13ClientHelloWithoutSupportedVersionsServer12) { + Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_2); +} + +TEST_F(Tls13NoSupportedVersions, + Tls13ClientHelloWithoutSupportedVersionsServer13) { + Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3); +} + +TEST_F(Tls13NoSupportedVersions, + Tls14ClientHelloWithoutSupportedVersionsServer13) { + Run(SSL_LIBRARY_VERSION_TLS_1_3 + 1, SSL_LIBRARY_VERSION_TLS_1_3); +} + +// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This +// causes a bad MAC error when we read EncryptedExtensions. +TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) { + client_->SetPacketFilter(new TlsInspectorClientHelloVersionSetter( + SSL_LIBRARY_VERSION_TLS_1_3 + 1)); + auto capture = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello); + server_->SetPacketFilter(capture); + ConnectExpectFail(); + client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ); + const DataBuffer& server_hello = capture->buffer(); + ASSERT_GT(server_hello.len(), 2U); + uint32_t ver; + ASSERT_TRUE(server_hello.Read(0, 2, &ver)); + // This way we don't need to change with new draft version. + ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc new file mode 100644 index 000000000..f3fd0b24c --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -0,0 +1,536 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "test_io.h" + +#include <algorithm> +#include <cassert> +#include <iostream> +#include <memory> + +#include "prerror.h" +#include "prlog.h" +#include "prthread.h" + +#include "databuffer.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER; + +#define UNIMPLEMENTED() \ + std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \ + PR_ASSERT(PR_FALSE); \ + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0) + +#define LOG(a) std::cerr << name_ << ": " << a << std::endl +#define LOGV(a) \ + do { \ + if (g_ssl_gtest_verbose) LOG(a); \ + } while (false) + +class Packet : public DataBuffer { + public: + Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {} + + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); + } + + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; +}; + +// Implementation of NSPR methods +static PRStatus DummyClose(PRFileDesc *f) { + DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); + f->secret = nullptr; + f->dtor(f); + delete io; + return PR_SUCCESS; +} + +static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) { + DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); + return io->Read(buf, length); +} + +static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) { + DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); + return io->Write(buf, length); +} + +static int32_t DummyAvailable(PRFileDesc *f) { + UNIMPLEMENTED(); + return -1; +} + +int64_t DummyAvailable64(PRFileDesc *f) { + UNIMPLEMENTED(); + return -1; +} + +static PRStatus DummySync(PRFileDesc *f) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) { + UNIMPLEMENTED(); + return -1; +} + +static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) { + UNIMPLEMENTED(); + return -1; +} + +static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size, + PRIntervalTime to) { + UNIMPLEMENTED(); + return -1; +} + +static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr, + PRIntervalTime to) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr, + PRIntervalTime to) { + UNIMPLEMENTED(); + return nullptr; +} + +static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static PRStatus DummyListen(PRFileDesc *f, int32_t depth) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) { + DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); + io->Reset(); + return PR_SUCCESS; +} + +// This function does not support peek. +static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen, + int32_t flags, PRIntervalTime to) { + PR_ASSERT(flags == 0); + if (flags != 0) { + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); + return -1; + } + + DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); + + if (io->mode() == DGRAM) { + return io->Recv(buf, buflen); + } else { + return io->Read(buf, buflen); + } +} + +// Note: this is always nonblocking and assumes a zero timeout. +static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount, + int32_t flags, PRIntervalTime to) { + int32_t written = DummyWrite(f, buf, amount); + return written; +} + +static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount, + int32_t flags, PRNetAddr *addr, + PRIntervalTime to) { + UNIMPLEMENTED(); + return -1; +} + +static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount, + int32_t flags, const PRNetAddr *addr, + PRIntervalTime to) { + UNIMPLEMENTED(); + return -1; +} + +static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) { + UNIMPLEMENTED(); + return -1; +} + +static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd, + PRNetAddr **raddr, void *buf, int32_t amount, + PRIntervalTime t) { + UNIMPLEMENTED(); + return -1; +} + +static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f, + const void *headers, int32_t hlen, + PRTransmitFileFlags flags, PRIntervalTime t) { + UNIMPLEMENTED(); + return -1; +} + +static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) { + // TODO: Modify to return unique names for each channel + // somehow, as opposed to always the same static address. The current + // implementation messes up the session cache, which is why it's off + // elsewhere + addr->inet.family = PR_AF_INET; + addr->inet.port = 0; + addr->inet.ip = 0; + + return PR_SUCCESS; +} + +static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) { + switch (opt->option) { + case PR_SockOpt_Nonblocking: + opt->value.non_blocking = PR_TRUE; + return PR_SUCCESS; + default: + UNIMPLEMENTED(); + break; + } + + return PR_FAILURE; +} + +// Imitate setting socket options. These are mostly noops. +static PRStatus DummySetsockoption(PRFileDesc *f, + const PRSocketOptionData *opt) { + switch (opt->option) { + case PR_SockOpt_Nonblocking: + return PR_SUCCESS; + case PR_SockOpt_NoDelay: + return PR_SUCCESS; + default: + UNIMPLEMENTED(); + break; + } + + return PR_FAILURE; +} + +static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in, + PRTransmitFileFlags flags, PRIntervalTime to) { + UNIMPLEMENTED(); + return -1; +} + +static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) { + UNIMPLEMENTED(); + return PR_FAILURE; +} + +static int32_t DummyReserved(PRFileDesc *f) { + UNIMPLEMENTED(); + return -1; +} + +DummyPrSocket::~DummyPrSocket() { Reset(); } + +void DummyPrSocket::SetPacketFilter(PacketFilter *filter) { + if (filter_) { + delete filter_; + } + filter_ = filter; +} + +void DummyPrSocket::Reset() { + delete filter_; + if (peer_) { + peer_->SetPeer(nullptr); + peer_ = nullptr; + } + while (!input_.empty()) { + Packet *front = input_.front(); + input_.pop(); + delete front; + } +} + +static const struct PRIOMethods DummyMethods = { + PR_DESC_LAYERED, DummyClose, + DummyRead, DummyWrite, + DummyAvailable, DummyAvailable64, + DummySync, DummySeek, + DummySeek64, DummyFileInfo, + DummyFileInfo64, DummyWritev, + DummyConnect, DummyAccept, + DummyBind, DummyListen, + DummyShutdown, DummyRecv, + DummySend, DummyRecvfrom, + DummySendto, DummyPoll, + DummyAcceptRead, DummyTransmitFile, + DummyGetsockname, DummyGetpeername, + DummyReserved, DummyReserved, + DummyGetsockoption, DummySetsockoption, + DummySendfile, DummyConnectContinue, + DummyReserved, DummyReserved, + DummyReserved, DummyReserved}; + +PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) { + if (test_fd_identity == PR_INVALID_IO_LAYER) { + test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); + } + + PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods)); + fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode)); + + return fd; +} + +DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { + return reinterpret_cast<DummyPrSocket *>(fd->secret); +} + +void DummyPrSocket::PacketReceived(const DataBuffer &packet) { + input_.push(new Packet(packet)); +} + +int32_t DummyPrSocket::Read(void *data, int32_t len) { + PR_ASSERT(mode_ == STREAM); + + if (mode_ != STREAM) { + PR_SetError(PR_INVALID_METHOD_ERROR, 0); + return -1; + } + + if (input_.empty()) { + LOGV("Read --> wouldblock " << len); + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + Packet *front = input_.front(); + size_t to_read = + std::min(static_cast<size_t>(len), front->len() - front->offset()); + memcpy(data, static_cast<const void *>(front->data() + front->offset()), + to_read); + front->Advance(to_read); + + if (!front->remaining()) { + input_.pop(); + delete front; + } + + return static_cast<int32_t>(to_read); +} + +int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { + if (input_.empty()) { + PR_SetError(PR_WOULD_BLOCK_ERROR, 0); + return -1; + } + + Packet *front = input_.front(); + if (static_cast<size_t>(buflen) < front->len()) { + PR_ASSERT(false); + PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); + return -1; + } + + size_t count = front->len(); + memcpy(buf, front->data(), count); + + input_.pop(); + delete front; + + return static_cast<int32_t>(count); +} + +int32_t DummyPrSocket::Write(const void *buf, int32_t length) { + if (!peer_ || !writeable_) { + PR_SetError(PR_IO_ERROR, 0); + return -1; + } + + DataBuffer packet(static_cast<const uint8_t *>(buf), + static_cast<size_t>(length)); + DataBuffer filtered; + PacketFilter::Action action = PacketFilter::KEEP; + if (filter_) { + action = filter_->Filter(packet, &filtered); + } + switch (action) { + case PacketFilter::CHANGE: + LOG("Original packet: " << packet); + LOG("Filtered packet: " << filtered); + peer_->PacketReceived(filtered); + break; + case PacketFilter::DROP: + LOG("Droppped packet: " << packet); + break; + case PacketFilter::KEEP: + LOGV("Packet: " << packet); + peer_->PacketReceived(packet); + break; + } + // libssl can't handle it if this reports something other than the length + // of what was passed in (or less, but we're not doing partial writes). + return static_cast<int32_t>(packet.len()); +} + +Poller *Poller::instance; + +Poller *Poller::Instance() { + if (!instance) instance = new Poller(); + + return instance; +} + +void Poller::Shutdown() { + delete instance; + instance = nullptr; +} + +Poller::~Poller() { + while (!timers_.empty()) { + Timer *timer = timers_.top(); + timers_.pop(); + delete timer; + } +} + +void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target, + PollCallback cb) { + auto it = waiters_.find(adapter); + Waiter *waiter; + + if (it == waiters_.end()) { + waiter = new Waiter(adapter); + } else { + waiter = it->second; + } + + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; + + waiter->targets_[event] = target; + waiter->callbacks_[event] = cb; + waiters_[adapter] = waiter; +} + +void Poller::Cancel(Event event, DummyPrSocket *adapter) { + auto it = waiters_.find(adapter); + Waiter *waiter; + + if (it == waiters_.end()) { + return; + } + + waiter = it->second; + + waiter->targets_[event] = nullptr; + waiter->callbacks_[event] = nullptr; + + // Clean up if there are no callbacks. + for (size_t i = 0; i < TIMER_EVENT; ++i) { + if (waiter->callbacks_[i]) return; + } + + delete waiter; + waiters_.erase(adapter); +} + +void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, + Timer **timer) { + Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb); + timers_.push(t); + if (timer) *timer = t; +} + +bool Poller::Poll() { + if (g_ssl_gtest_verbose) { + std::cerr << "Poll() waiters = " << waiters_.size() + << " timers = " << timers_.size() << std::endl; + } + PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT; + PRTime now = PR_Now(); + bool fired = false; + + // Figure out the timer for the select. + if (!timers_.empty()) { + Timer *first_timer = timers_.top(); + if (now >= first_timer->deadline_) { + // Timer expired. + timeout = PR_INTERVAL_NO_WAIT; + } else { + timeout = + PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000); + } + } + + for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { + Waiter *waiter = it->second; + + if (waiter->callbacks_[READABLE_EVENT]) { + if (waiter->io_->readable()) { + PollCallback callback = waiter->callbacks_[READABLE_EVENT]; + PollTarget *target = waiter->targets_[READABLE_EVENT]; + waiter->callbacks_[READABLE_EVENT] = nullptr; + waiter->targets_[READABLE_EVENT] = nullptr; + callback(target, READABLE_EVENT); + fired = true; + } + } + } + + if (fired) timeout = PR_INTERVAL_NO_WAIT; + + // Can't wait forever and also have nothing readable now. + if (timeout == PR_INTERVAL_NO_TIMEOUT) return false; + + // Sleep. + if (timeout != PR_INTERVAL_NO_WAIT) { + PR_Sleep(timeout); + } + + // Now process anything that timed out. + now = PR_Now(); + while (!timers_.empty()) { + if (now < timers_.top()->deadline_) break; + + Timer *timer = timers_.top(); + timers_.pop(); + if (timer->callback_) { + timer->callback_(timer->target_, TIMER_EVENT); + } + delete timer; + } + + return true; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h new file mode 100644 index 000000000..b78db0dc6 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -0,0 +1,152 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef test_io_h_ +#define test_io_h_ + +#include <string.h> +#include <map> +#include <memory> +#include <ostream> +#include <queue> +#include <string> + +#include "prio.h" + +namespace nss_test { + +class DataBuffer; +class Packet; +class DummyPrSocket; // Fwd decl. + +// Allow us to inspect a packet before it is written. +class PacketFilter { + public: + enum Action { + KEEP, // keep the original packet unmodified + CHANGE, // change the packet to a different value + DROP // drop the packet + }; + + virtual ~PacketFilter() {} + + // The packet filter takes input and has the option of mutating it. + // + // A filter that modifies the data places the modified data in *output and + // returns CHANGE. A filter that does not modify data returns LEAVE, in which + // case the value in *output is ignored. A Filter can return DROP, in which + // case the packet is dropped (and *output is ignored). + virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; +}; + +enum Mode { STREAM, DGRAM }; + +inline std::ostream& operator<<(std::ostream& os, Mode m) { + return os << ((m == STREAM) ? "TLS" : "DTLS"); +} + +class DummyPrSocket { + public: + ~DummyPrSocket(); + + static PRFileDesc* CreateFD(const std::string& name, + Mode mode); // Returns an FD. + static DummyPrSocket* GetAdapter(PRFileDesc* fd); + + DummyPrSocket* peer() const { return peer_; } + void SetPeer(DummyPrSocket* peer) { peer_ = peer; } + void SetPacketFilter(PacketFilter* filter); + // Drops peer, packet filter and any outstanding packets. + void Reset(); + + void PacketReceived(const DataBuffer& data); + int32_t Read(void* data, int32_t len); + int32_t Recv(void* buf, int32_t buflen); + int32_t Write(const void* buf, int32_t length); + void CloseWrites() { writeable_ = false; } + + Mode mode() const { return mode_; } + bool readable() const { return !input_.empty(); } + + private: + DummyPrSocket(const std::string& name, Mode mode) + : name_(name), + mode_(mode), + peer_(nullptr), + input_(), + filter_(nullptr), + writeable_(true) {} + + const std::string name_; + Mode mode_; + DummyPrSocket* peer_; + std::queue<Packet*> input_; + PacketFilter* filter_; + bool writeable_; +}; + +// Marker interface. +class PollTarget {}; + +enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; + +typedef void (*PollCallback)(PollTarget*, Event); + +class Poller { + public: + static Poller* Instance(); // Get a singleton. + static void Shutdown(); // Shut it down. + + class Timer { + public: + Timer(PRTime deadline, PollTarget* target, PollCallback callback) + : deadline_(deadline), target_(target), callback_(callback) {} + void Cancel() { callback_ = nullptr; } + + PRTime deadline_; + PollTarget* target_; + PollCallback callback_; + }; + + void Wait(Event event, DummyPrSocket* adapter, PollTarget* target, + PollCallback cb); + void Cancel(Event event, DummyPrSocket* adapter); + void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, + Timer** handle); + bool Poll(); + + private: + Poller() : waiters_(), timers_() {} + ~Poller(); + + class Waiter { + public: + Waiter(DummyPrSocket* io) : io_(io) { + memset(&callbacks_[0], 0, sizeof(callbacks_)); + } + + void WaitFor(Event event, PollCallback callback); + + DummyPrSocket* io_; + PollTarget* targets_[TIMER_EVENT]; + PollCallback callbacks_[TIMER_EVENT]; + }; + + class TimerComparator { + public: + bool operator()(const Timer* lhs, const Timer* rhs) { + return lhs->deadline_ > rhs->deadline_; + } + }; + + static Poller* instance; + std::map<DummyPrSocket*, Waiter*> waiters_; + std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_; +}; + +} // end of namespace + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc new file mode 100644 index 000000000..b75bba567 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -0,0 +1,992 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_agent.h" +#include "databuffer.h" +#include "keyhi.h" +#include "pk11func.h" +#include "ssl.h" +#include "sslerr.h" +#include "sslproto.h" +#include "tls_parser.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" +#include "gtest_utils.h" +#include "scoped_ptrs.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; + +const std::string TlsAgent::kClient = "client"; // both sign and encrypt +const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger +const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt +const std::string TlsAgent::kServerRsaSign = "rsa_sign"; +const std::string TlsAgent::kServerRsaPss = "rsa_pss"; +const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; +const std::string TlsAgent::kServerRsaChain = "rsa_chain"; +const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; +const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; +const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; +const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; +const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; +const std::string TlsAgent::kServerDsa = "dsa"; + +TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) + : name_(name), + mode_(mode), + server_key_bits_(0), + pr_fd_(nullptr), + adapter_(nullptr), + ssl_fd_(nullptr), + role_(role), + state_(STATE_INIT), + timer_handle_(nullptr), + falsestart_enabled_(false), + expected_version_(0), + expected_cipher_suite_(0), + expect_resumption_(false), + expect_client_auth_(false), + can_falsestart_hook_called_(false), + sni_hook_called_(false), + auth_certificate_hook_called_(false), + handshake_callback_called_(false), + error_code_(0), + send_ctr_(0), + recv_ctr_(0), + expect_readwrite_error_(false), + handshake_callback_(), + auth_certificate_callback_(), + sni_callback_(), + expect_short_headers_(false) { + memset(&info_, 0, sizeof(info_)); + memset(&csinfo_, 0, sizeof(csinfo_)); + SECStatus rv = SSL_VersionRangeGetDefault( + mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_); + EXPECT_EQ(SECSuccess, rv); +} + +TlsAgent::~TlsAgent() { + if (adapter_) { + Poller::Instance()->Cancel(READABLE_EVENT, adapter_); + // The adapter is closed when the FD closes. + } + if (timer_handle_) { + timer_handle_->Cancel(); + } + + if (pr_fd_) { + PR_Close(pr_fd_); + } + + if (ssl_fd_) { + PR_Close(ssl_fd_); + } +} + +void TlsAgent::SetState(State state) { + if (state_ == state) return; + + LOG("Changing state from " << state_ << " to " << state); + state_ = state; +} + +bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits, + const SSLExtraServerCertData* serverCertData) { + ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr)); + EXPECT_NE(nullptr, cert.get()); + if (!cert.get()) return false; + + ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); + EXPECT_NE(nullptr, pub.get()); + if (!pub.get()) return false; + if (updateKeyBits) { + server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); + } + + ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr)); + EXPECT_NE(nullptr, priv.get()); + if (!priv.get()) return false; + + SECStatus rv = + SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null); + EXPECT_EQ(SECFailure, rv); + rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData, + serverCertData ? sizeof(*serverCertData) : 0); + return rv == SECSuccess; +} + +bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { + // Don't set up twice + if (ssl_fd_) return true; + + if (adapter_->mode() == STREAM) { + ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_); + } else { + ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_); + } + + EXPECT_NE(nullptr, ssl_fd_); + if (!ssl_fd_) return false; + pr_fd_ = nullptr; + + SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + if (role_ == SERVER) { + EXPECT_TRUE(ConfigServerCert(name_, true)); + + rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + ScopedCERTCertList anchors(CERT_NewCertList()); + rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get()); + if (rv != SECSuccess) return false; + } else { + rv = SSL_SetURL(ssl_fd_, "server"); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } + + rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + return true; +} + +void TlsAgent::SetupClientAuth() { + EXPECT_TRUE(EnsureTlsSetup()); + ASSERT_EQ(CLIENT, role_); + + EXPECT_EQ(SECSuccess, + SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook, + reinterpret_cast<void*>(this))); +} + +bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert, + SECKEYPrivateKey** priv) const { + *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); + EXPECT_NE(nullptr, *cert); + if (!*cert) return false; + + *priv = PK11_FindKeyByAnyCert(*cert, nullptr); + EXPECT_NE(nullptr, *priv); + if (!*priv) return false; // Leak cert. + + return true; +} + +SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** cert, + SECKEYPrivateKey** privKey) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); + ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); + EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; + if (agent->GetClientAuthCredentials(cert, privKey)) { + return SECSuccess; + } + return SECFailure; +} + +bool TlsAgent::GetPeerChainLength(size_t* count) { + CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_); + if (!chain) return false; + *count = 0; + + for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list; + cursor = PR_NEXT_LINK(cursor)) { + CERTCertListNode* node = (CERTCertListNode*)cursor; + std::cerr << node->cert->subjectName << std::endl; + ++(*count); + } + + CERT_DestroyCertList(chain); + + return true; +} + +void TlsAgent::RequestClientAuth(bool requireAuth) { + EXPECT_TRUE(EnsureTlsSetup()); + ASSERT_EQ(SERVER, role_); + + EXPECT_EQ(SECSuccess, + SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, + requireAuth ? PR_TRUE : PR_FALSE)); + + EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( + ssl_fd_, &TlsAgent::ClientAuthenticated, this)); + expect_client_auth_ = true; +} + +void TlsAgent::StartConnect(PRFileDesc* model) { + EXPECT_TRUE(EnsureTlsSetup(model)); + + SECStatus rv; + rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + SetState(STATE_CONNECTING); +} + +void TlsAgent::DisableAllCiphers() { + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SECStatus rv = + SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE); + EXPECT_EQ(SECSuccess, rv); + } +} + +// Not actually all groups, just the onece that we are actually willing +// to use. +const std::vector<SSLNamedGroup> kAllDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, + ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; + +const std::vector<SSLNamedGroup> kECDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ec_secp521r1}; + +const std::vector<SSLNamedGroup> kFFDHEGroups = { + ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, + ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; + +// Defined because the big DHE groups are ridiculously slow. +const std::vector<SSLNamedGroup> kFasterDHEGroups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072}; + +void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { + EXPECT_TRUE(EnsureTlsSetup()); + + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + ASSERT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(csinfo), csinfo.length); + + if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { + rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + } + } +} + +void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { + switch (kea) { + case ssl_kea_dh: + ConfigNamedGroups(kFFDHEGroups); + break; + case ssl_kea_ecdh: + ConfigNamedGroups(kECDHEGroups); + break; + default: + break; + } +} + +void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { + if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || + authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { + ConfigNamedGroups(kECDHEGroups); + } +} + +void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { + EXPECT_TRUE(EnsureTlsSetup()); + + for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { + SSLCipherSuiteInfo csinfo; + + SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, + sizeof(csinfo)); + ASSERT_EQ(SECSuccess, rv); + + if ((csinfo.authType == authType) || + (csinfo.keaType == ssl_kea_tls13_any)) { + rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + } + } +} + +void TlsAgent::EnableSingleCipher(uint16_t cipher) { + DisableAllCiphers(); + SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { + EXPECT_TRUE(EnsureTlsSetup()); + SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size()); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetSessionTicketsEnabled(bool en) { + EXPECT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + en ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetSessionCacheEnabled(bool en) { + EXPECT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::Set0RttEnabled(bool en) { + EXPECT_TRUE(EnsureTlsSetup()); + + SECStatus rv = + SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetShortHeadersEnabled() { + EXPECT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { + vrange_.min = minver; + vrange_.max = maxver; + + if (ssl_fd_) { + SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); + EXPECT_EQ(SECSuccess, rv); + } +} + +void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { + *minver = vrange_.min; + *maxver = vrange_.max; +} + +void TlsAgent::SetExpectedVersion(uint16_t version) { + expected_version_ = version; +} + +void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } + +void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } + +void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } + +void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, + size_t count) { + EXPECT_TRUE(EnsureTlsSetup()); + EXPECT_LE(count, SSL_SignatureMaxCount()); + EXPECT_EQ(SECSuccess, + SSL_SignatureSchemePrefSet(ssl_fd_, schemes, + static_cast<unsigned int>(count))); + EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0)) + << "setting no schemes should fail and do nothing"; + + std::vector<SSLSignatureScheme> configuredSchemes(count); + unsigned int configuredCount; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1)) + << "get schemes, schemes is nullptr"; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], + &configuredCount, 0)) + << "get schemes, too little space"; + EXPECT_EQ(SECFailure, + SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr, + configuredSchemes.size())) + << "get schemes, countOut is nullptr"; + + EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( + ssl_fd_, &configuredSchemes[0], &configuredCount, + configuredSchemes.size())); + // SignatureSchemePrefSet drops unsupported algorithms silently, so the + // number that are configured might be fewer. + EXPECT_LE(configuredCount, count); + unsigned int i = 0; + for (unsigned int j = 0; j < count && i < configuredCount; ++j) { + if (i < configuredCount && schemes[j] == configuredSchemes[i]) { + ++i; + } + } + EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; +} + +void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, + size_t kea_size) const { + EXPECT_EQ(STATE_CONNECTED, state_); + EXPECT_EQ(kea_type, info_.keaType); + if (kea_size == 0) { + switch (kea_group) { + case ssl_grp_ec_curve25519: + kea_size = 255; + break; + case ssl_grp_ec_secp256r1: + kea_size = 256; + break; + case ssl_grp_ec_secp384r1: + kea_size = 384; + break; + case ssl_grp_ffdhe_2048: + kea_size = 2048; + break; + case ssl_grp_ffdhe_3072: + kea_size = 3072; + break; + case ssl_grp_ffdhe_custom: + break; + default: + if (kea_type == ssl_kea_rsa) { + kea_size = server_key_bits_; + } else { + EXPECT_TRUE(false) << "need to update group sizes"; + } + } + } + if (kea_group != ssl_grp_ffdhe_custom) { + EXPECT_EQ(kea_size, info_.keaKeyBits); + EXPECT_EQ(kea_group, info_.keaGroup); + } +} + +void TlsAgent::CheckAuthType(SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const { + EXPECT_EQ(STATE_CONNECTED, state_); + EXPECT_EQ(auth_type, info_.authType); + EXPECT_EQ(server_key_bits_, info_.authKeyBits); + if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { + switch (auth_type) { + case ssl_auth_rsa_sign: + sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; + break; + case ssl_auth_ecdsa: + sig_scheme = ssl_sig_ecdsa_sha1; + break; + default: + break; + } + } + EXPECT_EQ(sig_scheme, info_.signatureScheme); + + if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { + return; + } + + // Check authAlgorithm, which is the old value for authType. This is a second + // switch + // statement because default label is different. + switch (auth_type) { + case ssl_auth_rsa_sign: + EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) + << "authAlgorithm for RSA is always decrypt"; + break; + case ssl_auth_ecdh_rsa: + EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) + << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; + break; + case ssl_auth_ecdh_ecdsa: + EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) + << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; + break; + default: + EXPECT_EQ(auth_type, csinfo_.authAlgorithm) + << "authAlgorithm is (usually) the same as authType"; + break; + } +} + +void TlsAgent::EnableFalseStart() { + EXPECT_TRUE(EnsureTlsSetup()); + + falsestart_enabled_ = true; + EXPECT_EQ(SECSuccess, + SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this)); + EXPECT_EQ(SECSuccess, + SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE)); +} + +void TlsAgent::ExpectResumption() { expect_resumption_ = true; } + +void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { + EXPECT_TRUE(EnsureTlsSetup()); + + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); +} + +void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected) const { + SSLNextProtoState state; + char chosen[10]; + unsigned int chosen_len; + SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, + reinterpret_cast<unsigned char*>(chosen), + &chosen_len, sizeof(chosen)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(expected_state, state); + if (state == SSL_NEXT_PROTO_NO_SUPPORT) { + EXPECT_EQ("", expected); + } else { + EXPECT_NE("", expected); + EXPECT_EQ(expected, std::string(chosen, chosen_len)); + } +} + +void TlsAgent::EnableSrtp() { + EXPECT_TRUE(EnsureTlsSetup()); + const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, + SRTP_AES128_CM_HMAC_SHA1_32}; + EXPECT_EQ(SECSuccess, + SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers))); +} + +void TlsAgent::CheckSrtp() const { + uint16_t actual; + EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); + EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); +} + +void TlsAgent::CheckErrorCode(int32_t expected) const { + EXPECT_EQ(STATE_ERROR, state_); + EXPECT_EQ(expected, error_code_) + << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " + << PORT_ErrorToName(expected) << std::endl; +} + +void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { + ASSERT_EQ(0, error_code_); + WAIT_(error_code_ != 0, delay); + EXPECT_EQ(expected, error_code_) + << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " + << PORT_ErrorToName(expected) << std::endl; +} + +void TlsAgent::CheckPreliminaryInfo() { + SSLPreliminaryChannelInfo info; + EXPECT_EQ(SECSuccess, + SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info))); + EXPECT_EQ(sizeof(info), info.length); + EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); + EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); + + // A version of 0 is invalid and indicates no expectation. This value is + // initialized to 0 so that tests that don't explicitly set an expected + // version can negotiate a version. + if (!expected_version_) { + expected_version_ = info.protocolVersion; + } + EXPECT_EQ(expected_version_, info.protocolVersion); + + // As with the version; 0 is the null cipher suite (and also invalid). + if (!expected_cipher_suite_) { + expected_cipher_suite_ = info.cipherSuite; + } + EXPECT_EQ(expected_cipher_suite_, info.cipherSuite); +} + +// Check that all the expected callbacks have been called. +void TlsAgent::CheckCallbacks() const { + // If false start happens, the handshake is reported as being complete at the + // point that false start happens. + if (expect_resumption_ || !falsestart_enabled_) { + EXPECT_TRUE(handshake_callback_called_); + } + + // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. + if (role_ == SERVER) { + PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn); + EXPECT_EQ(((!expect_resumption_ && have_sni) || + expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), + sni_hook_called_); + } else { + EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_); + // Note that this isn't unconditionally called, even with false start on. + // But the callback is only skipped if a cipher that is ridiculously weak + // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. + EXPECT_EQ(falsestart_enabled_ && !expect_resumption_, + can_falsestart_hook_called_); + } +} + +void TlsAgent::ResetPreliminaryInfo() { + expected_version_ = 0; + expected_cipher_suite_ = 0; +} + +void TlsAgent::Connected() { + LOG("Handshake success"); + CheckPreliminaryInfo(); + CheckCallbacks(); + + SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(info_), info_.length); + + // Preliminary values are exposed through callbacks during the handshake. + // If either expected values were set or the callbacks were called, check + // that the final values are correct. + EXPECT_EQ(expected_version_, info_.protocolVersion); + EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); + + rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ(sizeof(csinfo_), csinfo_.length); + + if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_); + // We use one ciphersuite in each direction, plus one that's kept around + // by DTLS for retransmission. + PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2; + EXPECT_EQ(expected, cipherSuites); + if (expected != cipherSuites) { + SSLInt_PrintTls13CipherSpecs(ssl_fd_); + } + } + + PRBool short_headers; + rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers); + EXPECT_EQ(SECSuccess, rv); + EXPECT_EQ((PRBool)expect_short_headers_, short_headers); + SetState(STATE_CONNECTED); +} + +void TlsAgent::EnableExtendedMasterSecret() { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = + SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); + + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::CheckExtendedMasterSecret(bool expected) { + if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { + expected = PR_TRUE; + } + ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) + << "unexpected extended master secret state for " << name_; +} + +void TlsAgent::CheckEarlyDataAccepted(bool expected) { + if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { + expected = false; + } + ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) + << "unexpected early data state for " << name_; +} + +void TlsAgent::CheckSecretsDestroyed() { + ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_)); +} + +void TlsAgent::DisableRollbackDetection() { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE); + + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::EnableCompression() { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE); + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { + ASSERT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version); + ASSERT_EQ(SECSuccess, rv); +} + +void TlsAgent::Handshake() { + LOGV("Handshake"); + SECStatus rv = SSL_ForceHandshake(ssl_fd_); + if (rv == SECSuccess) { + Connected(); + + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + return; + } + + int32_t err = PR_GetError(); + if (err == PR_WOULD_BLOCK_ERROR) { + LOGV("Would have blocked"); + if (mode_ == DGRAM) { + if (timer_handle_) { + timer_handle_->Cancel(); + timer_handle_ = nullptr; + } + + PRIntervalTime timeout; + rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout); + if (rv == SECSuccess) { + Poller::Instance()->SetTimer( + timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); + } + } + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + return; + } + + LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " + << PORT_ErrorToString(err)); + error_code_ = err; + SetState(STATE_ERROR); +} + +void TlsAgent::PrepareForRenegotiate() { + EXPECT_EQ(STATE_CONNECTED, state_); + + SetState(STATE_CONNECTING); +} + +void TlsAgent::StartRenegotiate() { + PrepareForRenegotiate(); + + SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SendDirect(const DataBuffer& buf) { + LOG("Send Direct " << buf); + adapter_->peer()->PacketReceived(buf); +} + +static bool ErrorIsNonFatal(PRErrorCode code) { + return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ; +} + +void TlsAgent::SendData(size_t bytes, size_t blocksize) { + uint8_t block[4096]; + + ASSERT_LT(blocksize, sizeof(block)); + + while (bytes) { + size_t tosend = std::min(blocksize, bytes); + + for (size_t i = 0; i < tosend; ++i) { + block[i] = 0xff & send_ctr_; + ++send_ctr_; + } + + SendBuffer(DataBuffer(block, tosend)); + bytes -= tosend; + } +} + +void TlsAgent::SendBuffer(const DataBuffer& buf) { + LOGV("Writing " << buf.len() << " bytes"); + int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len()); + if (expect_readwrite_error_) { + EXPECT_GT(0, rv); + EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); + error_code_ = PR_GetError(); + expect_readwrite_error_ = false; + } else { + ASSERT_EQ(buf.len(), static_cast<size_t>(rv)); + } +} + +void TlsAgent::ReadBytes() { + uint8_t block[1024]; + + int32_t rv = PR_Read(ssl_fd_, block, sizeof(block)); + LOGV("ReadBytes " << rv); + int32_t err; + + if (rv >= 0) { + size_t count = static_cast<size_t>(rv); + for (size_t i = 0; i < count; ++i) { + ASSERT_EQ(recv_ctr_ & 0xff, block[i]); + recv_ctr_++; + } + } else { + err = PR_GetError(); + LOG("Read error " << PORT_ErrorToName(err) << ": " + << PORT_ErrorToString(err)); + if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { + error_code_ = err; + expect_readwrite_error_ = false; + } + } + + // If closed, then don't bother waiting around. + if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) { + LOGV("Re-arming"); + Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, + &TlsAgent::ReadableCallback); + } +} + +void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } + +void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { + EXPECT_TRUE(EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, + mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); + EXPECT_EQ(SECSuccess, rv); + + rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::DisableECDHEServerKeyReuse() { + ASSERT_EQ(TlsAgent::SERVER, role_); + SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; +::testing::internal::ParamGenerator<std::string> + TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); + +void TlsAgentTestBase::SetUp() { + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); +} + +void TlsAgentTestBase::TearDown() { + delete agent_; + SSL_ClearSessionCache(); + SSL_ShutdownServerSessionIDCache(); +} + +void TlsAgentTestBase::Reset(const std::string& server_name) { + delete agent_; + Init(server_name); +} + +void TlsAgentTestBase::Init(const std::string& server_name) { + agent_ = + new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, + role_, mode_); + agent_->Init(); + fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_); + agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_)); + agent_->StartConnect(); +} + +void TlsAgentTestBase::EnsureInit() { + if (!agent_) { + Init(); + } + const std::vector<SSLNamedGroup> groups = { + ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, + ssl_grp_ffdhe_2048}; + agent_->ConfigNamedGroups(groups); +} + +void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, + TlsAgent::State expected_state, + int32_t error_code) { + std::cerr << "Process message: " << buffer << std::endl; + EnsureInit(); + agent_->adapter()->PacketReceived(buffer); + agent_->Handshake(); + + ASSERT_EQ(expected_state, agent_->state()); + + if (expected_state == TlsAgent::STATE_ERROR) { + ASSERT_EQ(error_code, agent_->error_code()); + } +} + +void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, + const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num) { + size_t index = 0; + index = out->Write(index, type, 1); + index = out->Write( + index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2); + if (mode == DGRAM) { + index = out->Write(index, seq_num >> 32, 4); + index = out->Write(index, seq_num & PR_UINT32_MAX, 4); + } + index = out->Write(index, len, 2); + out->Write(index, buf, len); +} + +void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, + const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num) const { + MakeRecord(mode_, type, version, buf, len, out, seq_num); +} + +void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, + const uint8_t* data, size_t hs_len, + DataBuffer* out, + uint64_t seq_num) const { + return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, + 0); +} + +void TlsAgentTestBase::MakeHandshakeMessageFragment( + uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, + uint64_t seq_num, uint32_t fragment_offset, + uint32_t fragment_length) const { + size_t index = 0; + if (!fragment_length) fragment_length = hs_len; + index = out->Write(index, hs_type, 1); // Handshake record type. + index = out->Write(index, hs_len, 3); // Handshake length + if (mode_ == DGRAM) { + index = out->Write(index, seq_num, 2); + index = out->Write(index, fragment_offset, 3); + index = out->Write(index, fragment_length, 3); + } + if (data) { + index = out->Write(index, data, fragment_length); + } else { + for (size_t i = 0; i < fragment_length; ++i) { + index = out->Write(index, 1, 1); + } + } +} + +void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, + size_t hs_len, + DataBuffer* out) { + size_t index = 0; + index = out->Write(index, kTlsHandshakeType, 1); // Content Type + index = out->Write(index, 3, 1); // Version high + index = out->Write(index, 1, 1); // Version low + index = out->Write(index, 4 + hs_len, 2); // Length + + index = out->Write(index, hs_type, 1); // Handshake record type. + index = out->Write(index, hs_len, 3); // Handshake length + for (size_t i = 0; i < hs_len; ++i) { + index = out->Write(index, 1, 1); + } +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h new file mode 100644 index 000000000..78923c930 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -0,0 +1,457 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_agent_h_ +#define tls_agent_h_ + +#include "prio.h" +#include "ssl.h" + +#include <functional> +#include <iostream> + +#include "test_io.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +extern bool g_ssl_gtest_verbose; + +namespace nss_test { + +#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl +#define LOGV(msg) \ + do { \ + if (g_ssl_gtest_verbose) LOG(msg); \ + } while (false) + +enum SessionResumptionMode { + RESUME_NONE = 0, + RESUME_SESSIONID = 1, + RESUME_TICKET = 2, + RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET +}; + +class TlsAgent; + +const extern std::vector<SSLNamedGroup> kAllDHEGroups; +const extern std::vector<SSLNamedGroup> kECDHEGroups; +const extern std::vector<SSLNamedGroup> kFFDHEGroups; +const extern std::vector<SSLNamedGroup> kFasterDHEGroups; + +typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)> + AuthCertificateCallbackFunction; + +typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction; + +typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr, + PRUint32 srvNameArrSize)> + SniCallbackFunction; + +class TlsAgent : public PollTarget { + public: + enum Role { CLIENT, SERVER }; + enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR }; + + static const std::string kClient; // the client key is sign only + static const std::string kRsa2048; // bigger sign and encrypt for either + static const std::string kServerRsa; // both sign and encrypt + static const std::string kServerRsaSign; + static const std::string kServerRsaPss; + static const std::string kServerRsaDecrypt; + static const std::string kServerRsaChain; // A cert that requires a chain. + static const std::string kServerEcdsa256; + static const std::string kServerEcdsa384; + static const std::string kServerEcdsa521; + static const std::string kServerEcdhEcdsa; + static const std::string kServerEcdhRsa; + static const std::string kServerDsa; + + TlsAgent(const std::string& name, Role role, Mode mode); + virtual ~TlsAgent(); + + bool Init() { + pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_); + if (!pr_fd_) return false; + + adapter_ = DummyPrSocket::GetAdapter(pr_fd_); + if (!adapter_) return false; + + return true; + } + + void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } + + void SetPacketFilter(PacketFilter* filter) { + adapter_->SetPacketFilter(filter); + } + + void StartConnect(PRFileDesc* model = nullptr); + void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, + size_t kea_size = 0) const; + void CheckAuthType(SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const; + + void DisableAllCiphers(); + void EnableCiphersByAuthType(SSLAuthType authType); + void EnableCiphersByKeyExchange(SSLKEAType kea); + void EnableGroupsByKeyExchange(SSLKEAType kea); + void EnableGroupsByAuthType(SSLAuthType authType); + void EnableSingleCipher(uint16_t cipher); + + void Handshake(); + // Marks the internal state as CONNECTING in anticipation of renegotiation. + void PrepareForRenegotiate(); + // Prepares for renegotiation, then actually triggers it. + void StartRenegotiate(); + bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, + const SSLExtraServerCertData* serverCertData = nullptr); + bool ConfigServerCertWithChain(const std::string& name); + bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr); + + void SetupClientAuth(); + void RequestClientAuth(bool requireAuth); + bool GetClientAuthCredentials(CERTCertificate** cert, + SECKEYPrivateKey** priv) const; + + void ConfigureSessionCache(SessionResumptionMode mode); + void SetSessionTicketsEnabled(bool en); + void SetSessionCacheEnabled(bool en); + void Set0RttEnabled(bool en); + void SetShortHeadersEnabled(); + void SetVersionRange(uint16_t minver, uint16_t maxver); + void GetVersionRange(uint16_t* minver, uint16_t* maxver); + void CheckPreliminaryInfo(); + void ResetPreliminaryInfo(); + void SetExpectedVersion(uint16_t version); + void SetServerKeyBits(uint16_t bits); + void ExpectReadWriteError(); + void EnableFalseStart(); + void ExpectResumption(); + void ExpectShortHeaders(); + void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); + void EnableAlpn(const uint8_t* val, size_t len); + void CheckAlpn(SSLNextProtoState expected_state, + const std::string& expected = "") const; + void EnableSrtp(); + void CheckSrtp() const; + void CheckErrorCode(int32_t expected) const; + void WaitForErrorCode(int32_t expected, uint32_t delay) const; + // Send data on the socket, encrypting it. + void SendData(size_t bytes, size_t blocksize = 1024); + void SendBuffer(const DataBuffer& buf); + // Send data directly to the underlying socket, skipping the TLS layer. + void SendDirect(const DataBuffer& buf); + void ReadBytes(); + void ResetSentBytes(); // Hack to test drops. + void EnableExtendedMasterSecret(); + void CheckExtendedMasterSecret(bool expected); + void CheckEarlyDataAccepted(bool expected); + void DisableRollbackDetection(); + void EnableCompression(); + void SetDowngradeCheckVersion(uint16_t version); + void CheckSecretsDestroyed(); + void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); + void DisableECDHEServerKeyReuse(); + bool GetPeerChainLength(size_t* count); + + const std::string& name() const { return name_; } + + Role role() const { return role_; } + std::string role_str() const { return role_ == SERVER ? "server" : "client"; } + + State state() const { return state_; } + + const CERTCertificate* peer_cert() const { + return SSL_PeerCertificate(ssl_fd_); + } + + const char* state_str() const { return state_str(state()); } + + static const char* state_str(State state) { return states[state]; } + + PRFileDesc* ssl_fd() { return ssl_fd_; } + DummyPrSocket* adapter() { return adapter_; } + + bool is_compressed() const { + return info_.compressionMethod != ssl_compression_null; + } + uint16_t server_key_bits() const { return server_key_bits_; } + uint16_t min_version() const { return vrange_.min; } + uint16_t max_version() const { return vrange_.max; } + uint16_t version() const { + EXPECT_EQ(STATE_CONNECTED, state_); + return info_.protocolVersion; + } + + bool cipher_suite(uint16_t* cipher_suite) const { + if (state_ != STATE_CONNECTED) return false; + + *cipher_suite = info_.cipherSuite; + return true; + } + + std::string cipher_suite_name() const { + if (state_ != STATE_CONNECTED) return "UNKNOWN"; + + return csinfo_.cipherSuiteName; + } + + std::vector<uint8_t> session_id() const { + return std::vector<uint8_t>(info_.sessionID, + info_.sessionID + info_.sessionIDLength); + } + + bool auth_type(SSLAuthType* auth_type) const { + if (state_ != STATE_CONNECTED) return false; + + *auth_type = info_.authType; + return true; + } + + bool kea_type(SSLKEAType* kea_type) const { + if (state_ != STATE_CONNECTED) return false; + + *kea_type = info_.keaType; + return true; + } + + size_t received_bytes() const { return recv_ctr_; } + PRErrorCode error_code() const { return error_code_; } + + bool can_falsestart_hook_called() const { + return can_falsestart_hook_called_; + } + + void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) { + handshake_callback_ = handshake_callback; + } + + void SetAuthCertificateCallback( + AuthCertificateCallbackFunction auth_certificate_callback) { + auth_certificate_callback_ = auth_certificate_callback; + } + + void SetSniCallback(SniCallbackFunction sni_callback) { + sni_callback_ = sni_callback; + } + + private: + const static char* states[]; + + void SetState(State state); + + // Dummy auth certificate hook. + static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + agent->auth_certificate_hook_called_ = true; + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + // Client auth certificate hook. + static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd, + PRBool checksig, PRBool isServer) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + EXPECT_TRUE(agent->expect_client_auth_); + EXPECT_EQ(PR_TRUE, isServer); + if (agent->auth_certificate_callback_) { + return agent->auth_certificate_callback_(agent, checksig ? true : false, + isServer ? true : false); + } + return SECSuccess; + } + + static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, + CERTDistNames* caNames, + CERTCertificate** cert, + SECKEYPrivateKey** privKey); + + static void ReadableCallback(PollTarget* self, Event event) { + TlsAgent* agent = static_cast<TlsAgent*>(self); + if (event == TIMER_EVENT) { + agent->timer_handle_ = nullptr; + } + agent->ReadableCallback_int(); + } + + void ReadableCallback_int() { + LOGV("Readable"); + switch (state_) { + case STATE_CONNECTING: + Handshake(); + break; + case STATE_CONNECTED: + ReadBytes(); + break; + default: + break; + } + } + + static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr, + PRUint32 srvNameArrSize, void* arg) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + agent->sni_hook_called_ = true; + EXPECT_EQ(1UL, srvNameArrSize); + if (agent->sni_callback_) { + return agent->sni_callback_(agent, srvNameArr, srvNameArrSize); + } + return 0; // First configuration. + } + + static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg, + PRBool* canFalseStart) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->CheckPreliminaryInfo(); + EXPECT_TRUE(agent->falsestart_enabled_); + EXPECT_FALSE(agent->can_falsestart_hook_called_); + agent->can_falsestart_hook_called_ = true; + *canFalseStart = true; + return SECSuccess; + } + + static void HandshakeCallback(PRFileDesc* fd, void* arg) { + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + agent->handshake_callback_called_ = true; + agent->Connected(); + if (agent->handshake_callback_) { + agent->handshake_callback_(agent); + } + } + + void DisableLameGroups(); + void ConfigStrongECGroups(bool en); + void ConfigAllDHGroups(bool en); + void CheckCallbacks() const; + void Connected(); + + const std::string name_; + Mode mode_; + uint16_t server_key_bits_; + PRFileDesc* pr_fd_; + DummyPrSocket* adapter_; + PRFileDesc* ssl_fd_; + Role role_; + State state_; + Poller::Timer* timer_handle_; + bool falsestart_enabled_; + uint16_t expected_version_; + uint16_t expected_cipher_suite_; + bool expect_resumption_; + bool expect_client_auth_; + bool can_falsestart_hook_called_; + bool sni_hook_called_; + bool auth_certificate_hook_called_; + bool handshake_callback_called_; + SSLChannelInfo info_; + SSLCipherSuiteInfo csinfo_; + SSLVersionRange vrange_; + PRErrorCode error_code_; + size_t send_ctr_; + size_t recv_ctr_; + bool expect_readwrite_error_; + HandshakeCallbackFunction handshake_callback_; + AuthCertificateCallbackFunction auth_certificate_callback_; + SniCallbackFunction sni_callback_; + bool expect_short_headers_; +}; + +inline std::ostream& operator<<(std::ostream& stream, + const TlsAgent::State& state) { + return stream << TlsAgent::state_str(state); +} + +class TlsAgentTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll; + + TlsAgentTestBase(TlsAgent::Role role, Mode mode) + : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {} + ~TlsAgentTestBase() { + if (fd_) { + PR_Close(fd_); + } + } + + void SetUp(); + void TearDown(); + + static void MakeRecord(Mode mode, uint8_t type, uint16_t version, + const uint8_t* buf, size_t len, DataBuffer* out, + uint64_t seq_num = 0); + void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, + DataBuffer* out, uint64_t seq_num = 0) const; + void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data, + size_t hs_len, DataBuffer* out, + uint64_t seq_num, uint32_t fragment_offset, + uint32_t fragment_length) const; + static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, + DataBuffer* out); + static inline TlsAgent::Role ToRole(const std::string& str) { + return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; + } + + static inline Mode ToMode(const std::string& str) { + return str == "TLS" ? STREAM : DGRAM; + } + + void Init(const std::string& server_name = TlsAgent::kServerRsa); + void Reset(const std::string& server_name = TlsAgent::kServerRsa); + + protected: + void EnsureInit(); + void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, + int32_t error_code = 0); + + TlsAgent* agent_; + PRFileDesc* fd_; + TlsAgent::Role role_; + Mode mode_; +}; + +class TlsAgentTest : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple<std::string, std::string>> { + public: + TlsAgentTest() + : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), + ToMode(std::get<1>(GetParam()))) {} +}; + +class TlsAgentTestClient : public TlsAgentTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsAgentTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {} +}; + +class TlsAgentStreamTestClient : public TlsAgentTestBase { + public: + TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {} +}; + +class TlsAgentStreamTestServer : public TlsAgentTestBase { + public: + TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {} +}; + +class TlsAgentDgramTestClient : public TlsAgentTestBase { + public: + TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {} +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc new file mode 100644 index 000000000..d02549954 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -0,0 +1,708 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_connect.h" +extern "C" { +#include "libssl_internals.h" +} + +#include <iostream> + +#include "databuffer.h" +#include "gtest_utils.h" +#include "sslproto.h" + +extern std::string g_working_dir_path; + +namespace nss_test { + +static const std::string kTlsModesStreamArr[] = {"TLS"}; +::testing::internal::ParamGenerator<std::string> + TlsConnectTestBase::kTlsModesStream = + ::testing::ValuesIn(kTlsModesStreamArr); +static const std::string kTlsModesDatagramArr[] = {"DTLS"}; +::testing::internal::ParamGenerator<std::string> + TlsConnectTestBase::kTlsModesDatagram = + ::testing::ValuesIn(kTlsModesDatagramArr); +static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; +::testing::internal::ParamGenerator<std::string> + TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); + +static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 = + ::testing::ValuesIn(kTlsV10Arr); +static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 = + ::testing::ValuesIn(kTlsV11Arr); +static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 = + ::testing::ValuesIn(kTlsV12Arr); +static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 = + ::testing::ValuesIn(kTlsV10V11Arr); +static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0, + SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 = + ::testing::ValuesIn(kTlsV10ToV12Arr); +static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 = + ::testing::ValuesIn(kTlsV11V12Arr); + +static const uint16_t kTlsV11PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus = + ::testing::ValuesIn(kTlsV11PlusArr); +static const uint16_t kTlsV12PlusArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus = + ::testing::ValuesIn(kTlsV12PlusArr); +static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 = + ::testing::ValuesIn(kTlsV13Arr); +static const uint16_t kTlsVAllArr[] = { +#ifndef NSS_DISABLE_TLS_1_3 + SSL_LIBRARY_VERSION_TLS_1_3, +#endif + SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_0}; +::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll = + ::testing::ValuesIn(kTlsVAllArr); + +std::string VersionString(uint16_t version) { + switch (version) { + case 0: + return "(no version)"; + case SSL_LIBRARY_VERSION_TLS_1_0: + return "1.0"; + case SSL_LIBRARY_VERSION_TLS_1_1: + return "1.1"; + case SSL_LIBRARY_VERSION_TLS_1_2: + return "1.2"; + case SSL_LIBRARY_VERSION_TLS_1_3: + return "1.3"; + default: + std::cerr << "Invalid version: " << version << std::endl; + EXPECT_TRUE(false); + return ""; + } +} + +TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) + : mode_(mode), + client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)), + server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)), + client_model_(nullptr), + server_model_(nullptr), + version_(version), + expected_resumption_mode_(RESUME_NONE), + session_ids_(), + expect_extended_master_secret_(false), + expect_early_data_accepted_(false) { + std::string v; + if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) { + v = "1.0"; + } else { + v = VersionString(version_); + } + std::cerr << "Version: " << mode_ << " " << v << std::endl; +} + +TlsConnectTestBase::TlsConnectTestBase(const std::string& mode, + uint16_t version) + : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {} + +TlsConnectTestBase::~TlsConnectTestBase() {} + +// Check the group of each of the supported groups +void TlsConnectTestBase::CheckGroups( + const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(groups.Read(0, 2, &tmp)); + EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp)); + for (size_t i = 2; i < groups.len(); i += 2) { + EXPECT_TRUE(groups.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + } +} + +// Check the group of each of the shares +void TlsConnectTestBase::CheckShares( + const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) { + DuplicateGroupChecker group_set; + uint32_t tmp = 0; + EXPECT_TRUE(shares.Read(0, 2, &tmp)); + EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp)); + size_t i; + for (i = 2; i < shares.len(); i += 4 + tmp) { + ASSERT_TRUE(shares.Read(i, 2, &tmp)); + SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp); + group_set.AddAndCheckGroup(group); + check_group(group); + ASSERT_TRUE(shares.Read(i + 2, 2, &tmp)); + } + EXPECT_EQ(shares.len(), i); +} + +void TlsConnectTestBase::ClearStats() { + // Clear statistics. + SSL3Statistics* stats = SSL_GetStatistics(); + memset(stats, 0, sizeof(*stats)); +} + +void TlsConnectTestBase::ClearServerCache() { + SSL_ShutdownServerSessionIDCache(); + SSLInt_ClearSessionTicketKey(); + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); +} + +void TlsConnectTestBase::SetUp() { + SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); + SSLInt_ClearSessionTicketKey(); + ClearStats(); + Init(); +} + +void TlsConnectTestBase::TearDown() { + delete client_; + delete server_; + if (client_model_) { + ASSERT_NE(server_model_, nullptr); + delete client_model_; + delete server_model_; + } + + SSL_ClearSessionCache(); + SSLInt_ClearSessionTicketKey(); + SSL_ShutdownServerSessionIDCache(); +} + +void TlsConnectTestBase::Init() { + EXPECT_TRUE(client_->Init()); + EXPECT_TRUE(server_->Init()); + + client_->SetPeer(server_); + server_->SetPeer(client_); + + if (version_) { + ConfigureVersion(version_); + } +} + +void TlsConnectTestBase::Reset() { + // Take a copy of the names because they are about to disappear. + std::string server_name = server_->name(); + std::string client_name = client_->name(); + Reset(server_name, client_name); +} + +void TlsConnectTestBase::Reset(const std::string& server_name, + const std::string& client_name) { + delete client_; + delete server_; + + client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_); + server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_); + + Init(); +} + +void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { + expected_resumption_mode_ = expected; + if (expected != RESUME_NONE) { + client_->ExpectResumption(); + server_->ExpectResumption(); + } +} + +void TlsConnectTestBase::EnsureTlsSetup() { + EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd() + : nullptr)); + EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd() + : nullptr)); +} + +void TlsConnectTestBase::Handshake() { + EnsureTlsSetup(); + client_->SetServerKeyBits(server_->server_key_bits()); + client_->Handshake(); + server_->Handshake(); + + ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && + (server_->state() != TlsAgent::STATE_CONNECTING), + 5000); +} + +void TlsConnectTestBase::EnableExtendedMasterSecret() { + client_->EnableExtendedMasterSecret(); + server_->EnableExtendedMasterSecret(); + ExpectExtendedMasterSecret(true); +} + +void TlsConnectTestBase::Connect() { + server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr); + client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr); + Handshake(); + CheckConnected(); +} + +void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) { + EnsureTlsSetup(); + client_->EnableSingleCipher(cipher_suite); + + Connect(); + SendReceive(); + + // Check that we used the right cipher suite. + uint16_t actual; + EXPECT_TRUE(client_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); + EXPECT_TRUE(server_->cipher_suite(&actual)); + EXPECT_EQ(cipher_suite, actual); +} + +void TlsConnectTestBase::CheckConnected() { + // Check the version is as expected + EXPECT_EQ(client_->version(), server_->version()); + EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), + client_->version()); + + EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); + EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); + + uint16_t cipher_suite1, cipher_suite2; + bool ret = client_->cipher_suite(&cipher_suite1); + EXPECT_TRUE(ret); + ret = server_->cipher_suite(&cipher_suite2); + EXPECT_TRUE(ret); + EXPECT_EQ(cipher_suite1, cipher_suite2); + + std::cerr << "Connected with version " << client_->version() + << " cipher suite " << client_->cipher_suite_name() << std::endl; + + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { + // Check and store session ids. + std::vector<uint8_t> sid_c1 = client_->session_id(); + EXPECT_EQ(32U, sid_c1.size()); + std::vector<uint8_t> sid_s1 = server_->session_id(); + EXPECT_EQ(32U, sid_s1.size()); + EXPECT_EQ(sid_c1, sid_s1); + session_ids_.push_back(sid_c1); + } + + CheckExtendedMasterSecret(); + CheckEarlyDataAccepted(); + CheckResumption(expected_resumption_mode_); + client_->CheckSecretsDestroyed(); + server_->CheckSecretsDestroyed(); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme) const { + client_->CheckKEA(kea_type, kea_group); + server_->CheckKEA(kea_type, kea_group); + client_->CheckAuthType(auth_type, sig_scheme); + server_->CheckAuthType(auth_type, sig_scheme); +} + +void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, + SSLAuthType auth_type) const { + SSLNamedGroup group; + switch (kea_type) { + case ssl_kea_ecdh: + group = ssl_grp_ec_curve25519; + break; + case ssl_kea_dh: + group = ssl_grp_ffdhe_2048; + break; + case ssl_kea_rsa: + group = ssl_grp_none; + break; + default: + EXPECT_TRUE(false) << "unexpected KEA"; + group = ssl_grp_none; + break; + } + + SSLSignatureScheme scheme; + switch (auth_type) { + case ssl_auth_rsa_decrypt: + scheme = ssl_sig_none; + break; + case ssl_auth_rsa_sign: + scheme = ssl_sig_rsa_pss_sha256; + break; + case ssl_auth_ecdsa: + scheme = ssl_sig_ecdsa_secp256r1_sha256; + break; + case ssl_auth_dsa: + scheme = ssl_sig_dsa_sha1; + break; + default: + EXPECT_TRUE(false) << "unexpected auth type"; + scheme = static_cast<SSLSignatureScheme>(0x0100); + break; + } + CheckKeys(kea_type, group, auth_type, scheme); +} + +void TlsConnectTestBase::CheckKeys() const { + CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign); +} + +void TlsConnectTestBase::ConnectExpectFail() { + server_->StartConnect(); + client_->StartConnect(); + Handshake(); + ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); + ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); +} + +void TlsConnectTestBase::ConfigureVersion(uint16_t version) { + client_->SetVersionRange(version, version); + server_->SetVersionRange(version, version); +} + +void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { + client_->SetExpectedVersion(version); + server_->SetExpectedVersion(version); +} + +void TlsConnectTestBase::DisableAllCiphers() { + EnsureTlsSetup(); + client_->DisableAllCiphers(); + server_->DisableAllCiphers(); +} + +void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() { + DisableAllCiphers(); + + client_->EnableCiphersByKeyExchange(ssl_kea_rsa); + server_->EnableCiphersByKeyExchange(ssl_kea_rsa); +} + +void TlsConnectTestBase::EnableOnlyDheCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + DisableAllCiphers(); + client_->EnableCiphersByKeyExchange(ssl_kea_dh); + server_->EnableCiphersByKeyExchange(ssl_kea_dh); + } else { + client_->ConfigNamedGroups(kFFDHEGroups); + server_->ConfigNamedGroups(kFFDHEGroups); + } +} + +void TlsConnectTestBase::EnableSomeEcdhCiphers() { + if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) { + client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa); + server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa); + } else { + client_->ConfigNamedGroups(kECDHEGroups); + server_->ConfigNamedGroups(kECDHEGroups); + } +} + +void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server) { + client_->ConfigureSessionCache(client); + server_->ConfigureSessionCache(server); + if ((server & RESUME_TICKET) != 0) { + // This is an abomination. NSS encrypts session tickets with the server's + // RSA public key. That means we need the server to have an RSA certificate + // even if it won't be used for the connection. + server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt); + } +} + +void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { + EXPECT_NE(RESUME_BOTH, expected); + + int resume_count = expected ? 1 : 0; + int stateless_count = (expected & RESUME_TICKET) ? 1 : 0; + + // Note: hch == server counter; hsh == client counter. + SSL3Statistics* stats = SSL_GetStatistics(); + EXPECT_EQ(resume_count, stats->hch_sid_cache_hits); + EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits); + + EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes); + EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes); + + if (expected != RESUME_NONE) { + if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) { + // Check that the last two session ids match. + ASSERT_EQ(2U, session_ids_.size()); + EXPECT_EQ(session_ids_[session_ids_.size() - 1], + session_ids_[session_ids_.size() - 2]); + } else { + // TLS 1.3 only uses tickets. + EXPECT_TRUE(expected & RESUME_TICKET); + } + } +} + +void TlsConnectTestBase::EnableAlpn() { + client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); + server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); +} + +void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) { + client_->EnableAlpn(val, len); + server_->EnableAlpn(val, len); +} + +void TlsConnectTestBase::EnsureModelSockets() { + // Make sure models agents are available. + if (!client_model_) { + ASSERT_EQ(server_model_, nullptr); + client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_); + server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_); + } + + // Initialise agents. + ASSERT_TRUE(client_model_->Init()); + ASSERT_TRUE(server_model_->Init()); +} + +void TlsConnectTestBase::CheckAlpn(const std::string& val) { + client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val); + server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val); +} + +void TlsConnectTestBase::EnableSrtp() { + client_->EnableSrtp(); + server_->EnableSrtp(); +} + +void TlsConnectTestBase::CheckSrtp() const { + client_->CheckSrtp(); + server_->CheckSrtp(); +} + +void TlsConnectTestBase::SendReceive() { + client_->SendData(50); + server_->SendData(50); + Receive(50); +} + +// Do a first connection so we can do 0-RTT on the second one. +void TlsConnectTestBase::SetupForZeroRtt() { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT. + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); + client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, + SSL_LIBRARY_VERSION_TLS_1_3); + server_->StartConnect(); + client_->StartConnect(); +} + +// Do a first connection so we can do resumption +void TlsConnectTestBase::SetupForResume() { + EnsureTlsSetup(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); // Need to read so that we absorb the session ticket. + CheckKeys(); + + Reset(); +} + +void TlsConnectTestBase::ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check) { + const char* k0RttData = "ABCDEF"; + const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData)); + + client_->Handshake(); // Send ClientHello. + if (post_clienthello_check) { + if (!post_clienthello_check()) return; + } + PRInt32 rv = + PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write. + if (expect_writable) { + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + } + server_->Handshake(); // Consume ClientHello, EE, Finished. + + std::vector<uint8_t> buf(k0RttDataLen); + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read + if (expect_readable) { + std::cerr << "0-RTT read " << rv << " bytes\n"; + EXPECT_EQ(k0RttDataLen, rv); + } else { + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + } + + // Do a second read. this should fail. + rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); + EXPECT_EQ(SECFailure, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); +} + +void TlsConnectTestBase::Receive(size_t amount) { + WAIT_(client_->received_bytes() == amount && + server_->received_bytes() == amount, + 2000); + ASSERT_EQ(amount, client_->received_bytes()); + ASSERT_EQ(amount, server_->received_bytes()); +} + +void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { + expect_extended_master_secret_ = expected; +} + +void TlsConnectTestBase::CheckExtendedMasterSecret() { + client_->CheckExtendedMasterSecret(expect_extended_master_secret_); + server_->CheckExtendedMasterSecret(expect_extended_master_secret_); +} + +void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) { + expect_early_data_accepted_ = expected; +} + +void TlsConnectTestBase::CheckEarlyDataAccepted() { + client_->CheckEarlyDataAccepted(expect_early_data_accepted_); + server_->CheckEarlyDataAccepted(expect_early_data_accepted_); +} + +void TlsConnectTestBase::DisableECDHEServerKeyReuse() { + server_->DisableECDHEServerKeyReuse(); +} + +TlsConnectGeneric::TlsConnectGeneric() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectPre12::TlsConnectPre12() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls12::TlsConnectTls12() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {} + +TlsConnectTls12Plus::TlsConnectTls12Plus() + : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {} + +TlsConnectTls13::TlsConnectTls13() + : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {} + +void TlsKeyExchangeTest::EnsureKeyShareSetup() { + EnsureTlsSetup(); + groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn); + shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn); + shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true); + std::vector<PacketFilter*> captures; + captures.push_back(groups_capture_); + captures.push_back(shares_capture_); + captures.push_back(shares_capture2_); + client_->SetPacketFilter(new ChainedPacketFilter(captures)); + capture_hrr_ = + new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest); + server_->SetPacketFilter(capture_hrr_); +} + +void TlsKeyExchangeTest::ConfigNamedGroups( + const std::vector<SSLNamedGroup>& groups) { + client_->ConfigNamedGroups(groups); + server_->ConfigNamedGroups(groups); +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails( + const DataBuffer& ext) { + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + EXPECT_TRUE(ext.len() % 2 == 0); + std::vector<SSLNamedGroup> groups; + for (size_t i = 1; i < ext.len() / 2; i += 1) { + EXPECT_TRUE(ext.Read(2 * i, 2, &tmp)); + groups.push_back(static_cast<SSLNamedGroup>(tmp)); + } + return groups; +} + +std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails( + const DataBuffer& ext) { + uint32_t tmp = 0; + EXPECT_TRUE(ext.Read(0, 2, &tmp)); + EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp)); + std::vector<SSLNamedGroup> shares; + size_t i = 2; + while (i < ext.len()) { + EXPECT_TRUE(ext.Read(i, 2, &tmp)); + shares.push_back(static_cast<SSLNamedGroup>(tmp)); + EXPECT_TRUE(ext.Read(i + 2, 2, &tmp)); + i += 4 + tmp; + } + EXPECT_EQ(ext.len(), i); + return shares; +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) { + std::vector<SSLNamedGroup> groups = + GetGroupDetails(groups_capture_->extension()); + EXPECT_EQ(expected_groups, groups); + + if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { + ASSERT_LT(0U, expected_shares.size()); + std::vector<SSLNamedGroup> shares = + GetShareDetails(shares_capture_->extension()); + EXPECT_EQ(expected_shares, shares); + } else { + EXPECT_EQ(0U, shares_capture_->extension().len()); + } + + EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares) { + CheckKEXDetails(expected_groups, expected_shares, false); +} + +void TlsKeyExchangeTest::CheckKEXDetails( + const std::vector<SSLNamedGroup>& expected_groups, + const std::vector<SSLNamedGroup>& expected_shares, + SSLNamedGroup expected_share2) { + CheckKEXDetails(expected_groups, expected_shares, true); + + for (auto it : expected_shares) { + EXPECT_NE(expected_share2, it); + } + std::vector<SSLNamedGroup> expected_shares2 = {expected_share2}; + std::vector<SSLNamedGroup> shares = + GetShareDetails(shares_capture2_->extension()); + EXPECT_EQ(expected_shares2, shares); +} +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h new file mode 100644 index 000000000..aa4a32d96 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -0,0 +1,274 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_connect_h_ +#define tls_connect_h_ + +#include <tuple> + +#include "sslproto.h" +#include "sslt.h" + +#include "tls_agent.h" +#include "tls_filter.h" + +#define GTEST_HAS_RTTI 0 +#include "gtest/gtest.h" + +namespace nss_test { + +extern std::string VersionString(uint16_t version); + +// A generic TLS connection test base. +class TlsConnectTestBase : public ::testing::Test { + public: + static ::testing::internal::ParamGenerator<std::string> kTlsModesStream; + static ::testing::internal::ParamGenerator<std::string> kTlsModesDatagram; + static ::testing::internal::ParamGenerator<std::string> kTlsModesAll; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV13; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; + static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll; + + TlsConnectTestBase(Mode mode, uint16_t version); + TlsConnectTestBase(const std::string& mode, uint16_t version); + virtual ~TlsConnectTestBase(); + + void SetUp(); + void TearDown(); + + // Initialize client and server. + void Init(); + // Clear the statistics. + void ClearStats(); + // Clear the server session cache. + void ClearServerCache(); + // Make sure TLS is configured for a connection. + void EnsureTlsSetup(); + // Reset and keep the same certificate names + void Reset(); + // Reset, and update the certificate names on both peers + void Reset(const std::string& server_name, + const std::string& client_name = "client"); + + // Run the handshake. + void Handshake(); + // Connect and check that it works. + void Connect(); + // Check that the connection was successfully established. + void CheckConnected(); + // Connect and expect it to fail. + void ConnectExpectFail(); + void ConnectWithCipherSuite(uint16_t cipher_suite); + // Check that the keys used in the handshake match expectations. + void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; + // This version guesses some of the values. + void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; + // This version assumes defaults. + void CheckKeys() const; + void CheckGroups(const DataBuffer& groups, + std::function<void(SSLNamedGroup)> check_group); + void CheckShares(const DataBuffer& shares, + std::function<void(SSLNamedGroup)> check_group); + + void ConfigureVersion(uint16_t version); + void SetExpectedVersion(uint16_t version); + // Expect resumption of a particular type. + void ExpectResumption(SessionResumptionMode expected); + void DisableAllCiphers(); + void EnableOnlyStaticRsaCiphers(); + void EnableOnlyDheCiphers(); + void EnableSomeEcdhCiphers(); + void EnableExtendedMasterSecret(); + void ConfigureSessionCache(SessionResumptionMode client, + SessionResumptionMode server); + void EnableAlpn(); + void EnableAlpn(const uint8_t* val, size_t len); + void EnsureModelSockets(); + void CheckAlpn(const std::string& val); + void EnableSrtp(); + void CheckSrtp() const; + void SendReceive(); + void SetupForZeroRtt(); + void SetupForResume(); + void ZeroRttSendReceive( + bool expect_writable, bool expect_readable, + std::function<bool()> post_clienthello_check = nullptr); + void Receive(size_t amount); + void ExpectExtendedMasterSecret(bool expected); + void ExpectEarlyDataAccepted(bool expected); + void DisableECDHEServerKeyReuse(); + + protected: + Mode mode_; + TlsAgent* client_; + TlsAgent* server_; + TlsAgent* client_model_; + TlsAgent* server_model_; + uint16_t version_; + SessionResumptionMode expected_resumption_mode_; + std::vector<std::vector<uint8_t>> session_ids_; + + // A simple value of "a", "b". Note that the preferred value of "a" is placed + // at the end, because the NSS API follows the now defunct NPN specification, + // which places the preferred (and default) entry at the end of the list. + // NSS will move this final entry to the front when used with ALPN. + const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; + + private: + static inline Mode ToMode(const std::string& str) { + return str == "TLS" ? STREAM : DGRAM; + } + + void CheckResumption(SessionResumptionMode expected); + void CheckExtendedMasterSecret(); + void CheckEarlyDataAccepted(); + + bool expect_extended_master_secret_; + bool expect_early_data_accepted_; + + // Track groups and make sure that there are no duplicates. + class DuplicateGroupChecker { + public: + void AddAndCheckGroup(SSLNamedGroup group) { + EXPECT_EQ(groups_.end(), groups_.find(group)) + << "Group " << group << " should not be duplicated"; + groups_.insert(group); + } + + private: + std::set<SSLNamedGroup> groups_; + }; +}; + +// A non-parametrized TLS test base. +class TlsConnectTest : public TlsConnectTestBase { + public: + TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {} +}; + +// A non-parametrized DTLS-only test base. +class DtlsConnectTest : public TlsConnectTestBase { + public: + DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {} +}; + +// A TLS-only test base. +class TlsConnectStream : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} +}; + +// A TLS-only test base for tests before 1.3 +class TlsConnectStreamPre13 : public TlsConnectStream {}; + +// A DTLS-only test base. +class TlsConnectDatagram : public TlsConnectTestBase, + public ::testing::WithParamInterface<uint16_t> { + public: + TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} +}; + +// A generic test class that can be either STREAM or DGRAM and a single version +// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this +// should use TEST_P(). +class TlsConnectGeneric + : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsConnectGeneric(); +}; + +// A Pre TLS 1.2 generic test. +class TlsConnectPre12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsConnectPre12(); +}; + +// A TLS 1.2 only generic test. +class TlsConnectTls12 : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsConnectTls12(); +}; + +// A TLS 1.2 only stream test. +class TlsConnectStreamTls12 : public TlsConnectTestBase { + public: + TlsConnectStreamTls12() + : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {} +}; + +// A TLS 1.2+ generic test. +class TlsConnectTls12Plus + : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { + public: + TlsConnectTls12Plus(); +}; + +// A TLS 1.3 only generic test. +class TlsConnectTls13 : public TlsConnectTestBase, + public ::testing::WithParamInterface<std::string> { + public: + TlsConnectTls13(); +}; + +// A TLS 1.3 only stream test. +class TlsConnectStreamTls13 : public TlsConnectTestBase { + public: + TlsConnectStreamTls13() + : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +class TlsConnectDatagram13 : public TlsConnectTestBase { + public: + TlsConnectDatagram13() + : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {} +}; + +// A variant that is used only with Pre13. +class TlsConnectGenericPre13 : public TlsConnectGeneric {}; + +class TlsKeyExchangeTest : public TlsConnectGeneric { + protected: + TlsExtensionCapture* groups_capture_; + TlsExtensionCapture* shares_capture_; + TlsExtensionCapture* shares_capture2_; + TlsInspectorRecordHandshakeMessage* capture_hrr_; + + void EnsureKeyShareSetup(); + void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); + std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext); + std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares); + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + SSLNamedGroup expectedShare2); + + private: + void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, + const std::vector<SSLNamedGroup>& expectedShares, + bool expect_hrr); +}; + +class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {}; +class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc new file mode 100644 index 000000000..4f7d195d0 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -0,0 +1,503 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_filter.h" +#include "sslproto.h" + +extern "C" { +// This is not something that should make you happy. +#include "libssl_internals.h" +} + +#include <iostream> +#include "gtest_utils.h" +#include "tls_agent.h" + +namespace nss_test { + +PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + bool changed = false; + size_t offset = 0U; + output->Allocate(input.len()); + + TlsParser parser(input); + while (parser.remaining()) { + RecordHeader header; + DataBuffer record; + if (!header.Parse(&parser, &record)) { + return KEEP; + } + + if (FilterRecord(header, record, &offset, output) != KEEP) { + changed = true; + } else { + offset = header.Write(output, offset, record); + } + } + output->Truncate(offset); + + // Record how many packets we actually touched. + if (changed) { + ++count_; + return (offset == 0) ? DROP : CHANGE; + } + + return KEEP; +} + +PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header, + const DataBuffer& record, + size_t* offset, + DataBuffer* output) { + DataBuffer filtered; + PacketFilter::Action action = FilterRecord(header, record, &filtered); + if (action == KEEP) { + return KEEP; + } + + if (action == DROP) { + std::cerr << "record drop: " << record << std::endl; + return DROP; + } + + const DataBuffer* source = &record; + if (action == CHANGE) { + EXPECT_GT(0x10000U, filtered.len()); + std::cerr << "record old: " << record << std::endl; + std::cerr << "record new: " << filtered << std::endl; + source = &filtered; + } + + *offset = header.Write(output, *offset, *source); + return CHANGE; +} + +bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) { + if (!parser->Read(&content_type_)) { + return false; + } + + uint32_t version; + if (!parser->Read(&version, 2)) { + return false; + } + version_ = version; + + sequence_number_ = 0; + if (IsDtls(version)) { + uint32_t tmp; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ = static_cast<uint64_t>(tmp) << 32; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ |= static_cast<uint64_t>(tmp); + } + return parser->ReadVariable(body, 2); +} + +size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { + offset = buffer->Write(offset, content_type_, 1); + offset = buffer->Write(offset, version_, 2); + if (is_dtls()) { + // write epoch (2 octet), and seqnum (6 octet) + offset = buffer->Write(offset, sequence_number_ >> 32, 4); + offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + } + offset = buffer->Write(offset, body.len(), 2); + offset = buffer->Write(offset, body); + return offset; +} + +PacketFilter::Action TlsHandshakeFilter::FilterRecord( + const RecordHeader& record_header, const DataBuffer& input, + DataBuffer* output) { + // Check that the first byte is as requested. + if (record_header.content_type() != kTlsHandshakeType) { + return KEEP; + } + + bool changed = false; + size_t offset = 0U; + output->Allocate(input.len()); // Preallocate a little. + + TlsParser parser(input); + while (parser.remaining()) { + HandshakeHeader header; + DataBuffer handshake; + if (!header.Parse(&parser, record_header, &handshake)) { + return KEEP; + } + + DataBuffer filtered; + PacketFilter::Action action = FilterHandshake(header, handshake, &filtered); + if (action == DROP) { + changed = true; + std::cerr << "handshake drop: " << handshake << std::endl; + continue; + } + + const DataBuffer* source = &handshake; + if (action == CHANGE) { + EXPECT_GT(0x1000000U, filtered.len()); + changed = true; + std::cerr << "handshake old: " << handshake << std::endl; + std::cerr << "handshake new: " << filtered << std::endl; + source = &filtered; + } + + offset = header.Write(output, offset, *source); + } + output->Truncate(offset); + return changed ? (offset ? CHANGE : DROP) : KEEP; +} + +bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser, + const RecordHeader& header, + uint32_t* length) { + if (!parser->Read(length, 3)) { + return false; // malformed + } + + if (!header.is_dtls()) { + return true; // nothing left to do + } + + // Read and check DTLS parameters + uint32_t message_seq_tmp; + if (!parser->Read(&message_seq_tmp, 2)) { // sequence number + return false; + } + message_seq_ = message_seq_tmp; + + uint32_t fragment_offset; + if (!parser->Read(&fragment_offset, 3)) { + return false; + } + + uint32_t fragment_length; + if (!parser->Read(&fragment_length, 3)) { + return false; + } + + // All current tests where we are using this code don't fragment. + return (fragment_offset == 0 && fragment_length == *length); +} + +bool TlsHandshakeFilter::HandshakeHeader::Parse( + TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) { + version_ = record_header.version(); + if (!parser->Read(&handshake_type_)) { + return false; // malformed + } + uint32_t length; + if (!ReadLength(parser, record_header, &length)) { + return false; + } + + return parser->Read(body, length); +} + +size_t TlsHandshakeFilter::HandshakeHeader::Write( + DataBuffer* buffer, size_t offset, const DataBuffer& body) const { + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); + if (is_dtls()) { + offset = buffer->Write(offset, message_seq_, 2); + offset = buffer->Write(offset, 0U, 3); // fragment_offset + offset = buffer->Write(offset, body.len(), 3); + } + offset = buffer->Write(offset, body); + return offset; +} + +PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + // Only do this once. + if (buffer_.len()) { + return KEEP; + } + + if (header.handshake_type() == handshake_type_) { + buffer_ = input; + } + return KEEP; +} + +PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (header.handshake_type() == handshake_type_) { + *output = buffer_; + return CHANGE; + } + + return KEEP; +} + +PacketFilter::Action TlsConversationRecorder::FilterRecord( + const RecordHeader& header, const DataBuffer& input, DataBuffer* output) { + buffer_.Append(input); + return KEEP; +} + +PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header, + const DataBuffer& input, + DataBuffer* output) { + if (level_ == kTlsAlertFatal) { // already fatal + return KEEP; + } + if (header.content_type() != kTlsAlertType) { + return KEEP; + } + + std::cerr << "Alert: " << input << std::endl; + + TlsParser parser(input); + uint8_t lvl; + if (!parser.Read(&lvl)) { + return KEEP; + } + if (lvl == kTlsAlertWarning) { // not strong enough + return KEEP; + } + level_ = lvl; + (void)parser.Read(&description_); + return KEEP; +} + +ChainedPacketFilter::~ChainedPacketFilter() { + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + delete *it; + } +} + +PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + DataBuffer in(input); + bool changed = false; + for (auto it = filters_.begin(); it != filters_.end(); ++it) { + PacketFilter::Action action = (*it)->Filter(in, output); + if (action == DROP) { + return DROP; + } + if (action == CHANGE) { + in = *output; + changed = true; + } + } + return changed ? CHANGE : KEEP; +} + +PacketFilter::Action TlsExtensionFilter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (header.handshake_type() == kTlsHandshakeClientHello) { + TlsParser parser(input); + if (!FindClientHelloExtensions(&parser, header)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); + } + if (header.handshake_type() == kTlsHandshakeServerHello) { + TlsParser parser(input); + if (!FindServerHelloExtensions(&parser)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); + } + return KEEP; +} + +bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser, + const Versioned& header) { + if (!parser->Skip(2 + 32)) { // version + random + return false; + } + if (!parser->SkipVariable(1)) { // session ID + return false; + } + if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie + return false; + } + if (!parser->SkipVariable(2)) { // cipher suites + return false; + } + if (!parser->SkipVariable(1)) { // compression methods + return false; + } + return true; +} + +bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) { + uint32_t vtmp; + if (!parser->Read(&vtmp, 2)) { + return false; + } + uint16_t version = static_cast<uint16_t>(vtmp); + if (!parser->Skip(32)) { // random + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->SkipVariable(1)) { // session ID + return false; + } + } + if (!parser->Skip(2)) { // cipher suite + return false; + } + if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { + if (!parser->Skip(1)) { // compression method + return false; + } + } + return true; +} + +PacketFilter::Action TlsExtensionFilter::FilterExtensions( + TlsParser* parser, const DataBuffer& input, DataBuffer* output) { + size_t length_offset = parser->consumed(); + uint32_t all_extensions; + if (!parser->Read(&all_extensions, 2)) { + return KEEP; // no extensions, odd but OK + } + if (all_extensions != parser->remaining()) { + return KEEP; // malformed + } + + bool changed = false; + + // Write out the start of the message. + output->Allocate(input.len()); + size_t offset = output->Write(0, input.data(), parser->consumed()); + + while (parser->remaining()) { + uint32_t extension_type; + if (!parser->Read(&extension_type, 2)) { + return KEEP; // malformed + } + + DataBuffer extension; + if (!parser->ReadVariable(&extension, 2)) { + return KEEP; // malformed + } + + DataBuffer filtered; + PacketFilter::Action action = + FilterExtension(extension_type, extension, &filtered); + if (action == DROP) { + changed = true; + std::cerr << "extension drop: " << extension << std::endl; + continue; + } + + const DataBuffer* source = &extension; + if (action == CHANGE) { + EXPECT_GT(0x10000U, filtered.len()); + changed = true; + std::cerr << "extension old: " << extension << std::endl; + std::cerr << "extension new: " << filtered << std::endl; + source = &filtered; + } + + // Write out extension. + offset = output->Write(offset, extension_type, 2); + offset = output->Write(offset, source->len(), 2); + if (source->len() > 0) { + offset = output->Write(offset, *source); + } + } + output->Truncate(offset); + + if (changed) { + size_t newlen = output->len() - length_offset - 2; + EXPECT_GT(0x10000U, newlen); + if (newlen >= 0x10000) { + return KEEP; // bad: size increased too much + } + output->Write(length_offset, newlen, 2); + return CHANGE; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionCapture::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_ && (last_ || !captured_)) { + data_.Assign(input); + captured_ = true; + } + return KEEP; +} + +PacketFilter::Action TlsExtensionReplacer::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type != extension_) { + return KEEP; + } + + *output = data_; + return CHANGE; +} + +PacketFilter::Action TlsExtensionDropper::FilterExtension( + uint16_t extension_type, const DataBuffer& input, DataBuffer* output) { + if (extension_type == extension_) { + return DROP; + } + return KEEP; +} + +PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header, + const DataBuffer& body, + DataBuffer* out) { + if (counter_++ == record_) { + DataBuffer buf; + header.Write(&buf, 0, body); + src_->SendDirect(buf); + dest_->Handshake(); + func_(); + return DROP; + } + + return KEEP; +} + +PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd())); + } + return KEEP; +} + +PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, + DataBuffer* output) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (header.handshake_type() == kTlsHandshakeClientHello) { + *output = input; + output->Write(0, version_, 2); + return CHANGE; + } + return KEEP; +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h new file mode 100644 index 000000000..fa2e38785 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -0,0 +1,343 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_filter_h_ +#define tls_filter_h_ + +#include <functional> +#include <memory> +#include <vector> + +#include "test_io.h" +#include "tls_parser.h" + +namespace nss_test { + +// Abstract filter that operates on entire (D)TLS records. +class TlsRecordFilter : public PacketFilter { + public: + TlsRecordFilter() : count_(0) {} + + // External interface. Overrides PacketFilter. + PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); + + // Report how many packets were altered by the filter. + size_t filtered_packets() const { return count_; } + + class Versioned { + public: + Versioned() : version_(0) {} + explicit Versioned(uint16_t version) : version_(version) {} + + bool is_dtls() const { return IsDtls(version_); } + uint16_t version() const { return version_; } + + protected: + uint16_t version_; + }; + + class RecordHeader : public Versioned { + public: + RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {} + RecordHeader(uint16_t version, uint8_t content_type, + uint64_t sequence_number) + : Versioned(version), + content_type_(content_type), + sequence_number_(sequence_number) {} + + uint8_t content_type() const { return content_type_; } + uint64_t sequence_number() const { return sequence_number_; } + size_t header_length() const { return is_dtls() ? 11 : 3; } + + // Parse the header; return true if successful; body in an outparam if OK. + bool Parse(TlsParser* parser, DataBuffer* body); + // Write the header and body to a buffer at the given offset. + // Return the offset of the end of the write. + size_t Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const; + + private: + uint8_t content_type_; + uint64_t sequence_number_; + }; + + protected: + // There are two filter functions which can be overriden. Both are + // called with the header and the record but the outer one is called + // with a raw pointer to let you write into the buffer and lets you + // do anything with this section of the stream. The inner one + // just lets you change the record contents. By default, the + // outer one calls the inner one, so if you override the outer + // one, the inner one is never called unless you call it yourself. + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& record, + size_t* offset, DataBuffer* output); + + // The record filter receives the record contentType, version and DTLS + // sequence number (which is zero for TLS), plus the existing record payload. + // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` + // outparam with the new record contents if it chooses to CHANGE the record. + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) { + return KEEP; + } + + private: + size_t count_; +}; + +// Abstract filter that operates on handshake messages rather than records. +// This assumes that the handshake messages are written in a block as entire +// records and that they don't span records or anything crazy like that. +class TlsHandshakeFilter : public TlsRecordFilter { + public: + TlsHandshakeFilter() {} + + class HandshakeHeader : public Versioned { + public: + HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {} + + uint8_t handshake_type() const { return handshake_type_; } + bool Parse(TlsParser* parser, const RecordHeader& record_header, + DataBuffer* body); + size_t Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const; + + private: + // Reads the length from the record header. + // This also reads the DTLS fragment information and checks it. + bool ReadLength(TlsParser* parser, const RecordHeader& header, + uint32_t* length); + + uint8_t handshake_type_; + uint16_t message_seq_; + // fragment_offset is always zero in these tests. + }; + + protected: + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) = 0; + + private: +}; + +// Make a copy of the first instance of a handshake message. +class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { + public: + TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) + : handshake_type_(handshake_type), buffer_() {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + const DataBuffer& buffer() const { return buffer_; } + + private: + uint8_t handshake_type_; + DataBuffer buffer_; +}; + +// Replace all instances of a handshake message. +class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { + public: + TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, + const DataBuffer& replacement) + : handshake_type_(handshake_type), buffer_(replacement) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + uint8_t handshake_type_; + DataBuffer buffer_; +}; + +// Make a copy of the complete conversation. +class TlsConversationRecorder : public TlsRecordFilter { + public: + TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} + + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + DataBuffer& buffer_; +}; + +// Records an alert. If an alert has already been recorded, it won't save the +// new alert unless the old alert is a warning and the new one is fatal. +class TlsAlertRecorder : public TlsRecordFilter { + public: + TlsAlertRecorder() : level_(255), description_(255) {} + + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + uint8_t level() const { return level_; } + uint8_t description() const { return description_; } + + private: + uint8_t level_; + uint8_t description_; +}; + +// Runs multiple packet filters in series. +class ChainedPacketFilter : public PacketFilter { + public: + ChainedPacketFilter() {} + ChainedPacketFilter(const std::vector<PacketFilter*> filters) + : filters_(filters.begin(), filters.end()) {} + virtual ~ChainedPacketFilter(); + + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output); + + // Takes ownership of the filter. + void Add(PacketFilter* filter) { filters_.push_back(filter); } + + private: + std::vector<PacketFilter*> filters_; +}; + +class TlsExtensionFilter : public TlsHandshakeFilter { + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + virtual PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) = 0; + + public: + static bool FindClientHelloExtensions(TlsParser* parser, + const Versioned& header); + static bool FindServerHelloExtensions(TlsParser* parser); + + private: + PacketFilter::Action FilterExtensions(TlsParser* parser, + const DataBuffer& input, + DataBuffer* output); +}; + +class TlsExtensionCapture : public TlsExtensionFilter { + public: + TlsExtensionCapture(uint16_t ext, bool last = false) + : extension_(ext), captured_(false), last_(last), data_() {} + + const DataBuffer& extension() const { return data_; } + bool captured() const { return captured_; } + + protected: + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + bool captured_; + bool last_; + DataBuffer data_; +}; + +class TlsExtensionReplacer : public TlsExtensionFilter { + public: + TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) + : extension_(extension), data_(data) {} + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + +class TlsExtensionDropper : public TlsExtensionFilter { + public: + TlsExtensionDropper(uint16_t extension) : extension_(extension) {} + PacketFilter::Action FilterExtension(uint16_t extension_type, + const DataBuffer&, DataBuffer*) override; + + private: + uint16_t extension_; +}; + +class TlsAgent; +typedef std::function<void(void)> VoidFunction; + +class AfterRecordN : public TlsRecordFilter { + public: + AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record, + VoidFunction func) + : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} + + virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + const DataBuffer& body, + DataBuffer* out) override; + + private: + TlsAgent* src_; + TlsAgent* dest_; + unsigned int record_; + VoidFunction func_; + unsigned int counter_; +}; + +// When we see the ClientKeyExchange from |client|, increment the +// ClientHelloVersion on |server|. +class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { + public: + TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + TlsAgent* server_; +}; + +// This class selectively drops complete writes. This relies on the fact that +// writes in libssl are on record boundaries. +class SelectiveDropFilter : public PacketFilter { + public: + SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {} + + protected: + virtual PacketFilter::Action Filter(const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint32_t pattern_; + uint8_t counter_; +}; + +// Set the version number in the ClientHello. +class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { + public: + TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output); + + private: + uint16_t version_; +}; + +} // namespace nss_test + +#endif diff --git a/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc new file mode 100644 index 000000000..51ff938b1 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc @@ -0,0 +1,262 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <memory> +#include "nss.h" +#include "pk11pub.h" +#include "tls13hkdf.h" + +#include "databuffer.h" +#include "gtest_utils.h" +#include "scoped_ptrs.h" + +namespace nss_test { + +const uint8_t kKey1Data[] = { + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f}; +const DataBuffer kKey1(kKey1Data, sizeof(kKey1Data)); + +// The same as key1 but with the first byte +// 0x01. +const uint8_t kKey2Data[] = { + 0x01, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, + 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, + 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f}; +const DataBuffer kKey2(kKey2Data, sizeof(kKey2Data)); + +const char kLabelMasterSecret[] = "master secret"; + +const uint8_t kSessionHash[] = { + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, + 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, + 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xd0, 0xd1, 0xd2, 0xd3, + 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, + 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, + 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, + 0xfc, 0xfd, 0xfe, 0xff, +}; + +const size_t kHashLength[] = { + 0, /* ssl_hash_none */ + 16, /* ssl_hash_md5 */ + 20, /* ssl_hash_sha1 */ + 28, /* ssl_hash_sha224 */ + 32, /* ssl_hash_sha256 */ + 48, /* ssl_hash_sha384 */ + 64, /* ssl_hash_sha512 */ +}; + +const std::string kHashName[] = {"None", "MD5", "SHA-1", "SHA-224", + "SHA-256", "SHA-384", "SHA-512"}; + +static void ImportKey(ScopedPK11SymKey* to, const DataBuffer& key, + PK11SlotInfo* slot) { + SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()), + static_cast<unsigned int>(key.len())}; + + PK11SymKey* inner = + PK11_ImportSymKey(slot, CKM_SSL3_MASTER_KEY_DERIVE, PK11_OriginUnwrap, + CKA_DERIVE, &key_item, NULL); + ASSERT_NE(nullptr, inner); + to->reset(inner); +} + +static void DumpData(const std::string& label, const uint8_t* buf, size_t len) { + DataBuffer d(buf, len); + + std::cerr << label << ": " << d << std::endl; +} + +void DumpKey(const std::string& label, ScopedPK11SymKey& key) { + SECStatus rv = PK11_ExtractKeyValue(key.get()); + ASSERT_EQ(SECSuccess, rv); + + SECItem* key_data = PK11_GetKeyData(key.get()); + ASSERT_NE(nullptr, key_data); + + DumpData(label, key_data->data, key_data->len); +} + +extern "C" { +extern char ssl_trace; +extern FILE* ssl_trace_iob; +} + +class TlsHkdfTest : public ::testing::Test, + public ::testing::WithParamInterface<SSLHashType> { + public: + TlsHkdfTest() + : k1_(), k2_(), hash_type_(GetParam()), slot_(PK11_GetInternalSlot()) { + EXPECT_NE(nullptr, slot_); + char* ev = getenv("SSLTRACE"); + if (ev && ev[0]) { + ssl_trace = atoi(ev); + ssl_trace_iob = stderr; + } + } + + void SetUp() { + ImportKey(&k1_, kKey1, slot_.get()); + ImportKey(&k2_, kKey2, slot_.get()); + } + + void VerifyKey(const ScopedPK11SymKey& key, const DataBuffer& expected) { + SECStatus rv = PK11_ExtractKeyValue(key.get()); + ASSERT_EQ(SECSuccess, rv); + + SECItem* key_data = PK11_GetKeyData(key.get()); + ASSERT_NE(nullptr, key_data); + + EXPECT_EQ(expected.len(), key_data->len); + EXPECT_EQ(0, memcmp(expected.data(), key_data->data, expected.len())); + } + + void HkdfExtract(const ScopedPK11SymKey& ikmk1, const ScopedPK11SymKey& ikmk2, + SSLHashType base_hash, const DataBuffer& expected) { + std::cerr << "Hash = " << kHashName[base_hash] << std::endl; + + PK11SymKey* prk = nullptr; + SECStatus rv = tls13_HkdfExtract(ikmk1.get(), ikmk2.get(), base_hash, &prk); + ASSERT_EQ(SECSuccess, rv); + ScopedPK11SymKey prkk(prk); + + DumpKey("Output", prkk); + VerifyKey(prkk, expected); + } + + void HkdfExpandLabel(ScopedPK11SymKey* prk, SSLHashType base_hash, + const uint8_t* session_hash, size_t session_hash_len, + const char* label, size_t label_len, + const DataBuffer& expected) { + std::cerr << "Hash = " << kHashName[base_hash] << std::endl; + + std::vector<uint8_t> output(expected.len()); + + SECStatus rv = tls13_HkdfExpandLabelRaw(prk->get(), base_hash, session_hash, + session_hash_len, label, label_len, + &output[0], output.size()); + ASSERT_EQ(SECSuccess, rv); + DumpData("Output", &output[0], output.size()); + EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len())); + } + + protected: + ScopedPK11SymKey k1_; + ScopedPK11SymKey k2_; + SSLHashType hash_type_; + + private: + ScopedPK11SlotInfo slot_; +}; + +TEST_P(TlsHkdfTest, HkdfNullNull) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x33, 0xad, 0x0a, 0x1c, 0x60, 0x7e, 0xc0, 0x3b, 0x09, 0xe6, 0xcd, + 0x98, 0x93, 0x68, 0x0c, 0xe2, 0x10, 0xad, 0xf3, 0x00, 0xaa, 0x1f, + 0x26, 0x60, 0xe1, 0xb2, 0x2e, 0x10, 0xf1, 0x70, 0xf9, 0x2a}, + {0x7e, 0xe8, 0x20, 0x6f, 0x55, 0x70, 0x02, 0x3e, 0x6d, 0xc7, 0x51, 0x9e, + 0xb1, 0x07, 0x3b, 0xc4, 0xe7, 0x91, 0xad, 0x37, 0xb5, 0xc3, 0x82, 0xaa, + 0x10, 0xba, 0x18, 0xe2, 0x35, 0x7e, 0x71, 0x69, 0x71, 0xf9, 0x36, 0x2f, + 0x2c, 0x2f, 0xe2, 0xa7, 0x6b, 0xfd, 0x78, 0xdf, 0xec, 0x4e, 0xa9, 0xb5}}; + + const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + HkdfExtract(nullptr, nullptr, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey1Only) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x11, 0x87, 0x38, 0x28, 0xa9, 0x19, 0x78, 0x11, 0x33, 0x91, 0x24, + 0xb5, 0x8a, 0x1b, 0xb0, 0x9f, 0x7f, 0x0d, 0x8d, 0xbb, 0x10, 0xf4, + 0x9c, 0x54, 0xbd, 0x1f, 0xd8, 0x85, 0xcd, 0x15, 0x30, 0x33}, + {0x51, 0xb1, 0xd5, 0xb4, 0x59, 0x79, 0x79, 0x08, 0x4a, 0x15, 0xb2, 0xdb, + 0x84, 0xd3, 0xd6, 0xbc, 0xfc, 0x93, 0x45, 0xd9, 0xdc, 0x74, 0xda, 0x1a, + 0x57, 0xc2, 0x76, 0x9f, 0x3f, 0x83, 0x45, 0x2f, 0xf6, 0xf3, 0x56, 0x1f, + 0x58, 0x63, 0xdb, 0x88, 0xda, 0x40, 0xce, 0x63, 0x7d, 0x24, 0x37, 0xf3}}; + + const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + HkdfExtract(k1_, nullptr, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey2Only) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + { + 0x2f, 0x5f, 0x78, 0xd0, 0xa4, 0xc4, 0x36, 0xee, 0x6c, 0x8a, 0x4e, + 0xf9, 0xd0, 0x43, 0x81, 0x02, 0x13, 0xfd, 0x47, 0x83, 0x63, 0x3a, + 0xd2, 0xe1, 0x40, 0x6d, 0x2d, 0x98, 0x00, 0xfd, 0xc1, 0x87, + }, + {0x7b, 0x40, 0xf9, 0xef, 0x91, 0xff, 0xc9, 0xd1, 0x29, 0x24, 0x5c, 0xbf, + 0xf8, 0x82, 0x76, 0x68, 0xae, 0x4b, 0x63, 0xe8, 0x03, 0xdd, 0x39, 0xa8, + 0xd4, 0x6a, 0xf6, 0xe5, 0xec, 0xea, 0xf8, 0x7d, 0x91, 0x71, 0x81, 0xf1, + 0xdb, 0x3b, 0xaf, 0xbf, 0xde, 0x71, 0x61, 0x15, 0xeb, 0xb5, 0x5f, 0x68}}; + + const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + HkdfExtract(nullptr, k2_, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfKey1Key2) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + { + 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8, + 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34, + 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48, + }, + {0x01, 0x93, 0xc0, 0x07, 0x3f, 0x6a, 0x83, 0x0e, 0x2e, 0x4f, 0xb2, 0x58, + 0xe4, 0x00, 0x08, 0x5c, 0x68, 0x9c, 0x37, 0x32, 0x00, 0x37, 0xff, 0xc3, + 0x1c, 0x5b, 0x98, 0x0b, 0x02, 0x92, 0x3f, 0xfd, 0x73, 0x5a, 0x6f, 0x2a, + 0x95, 0xa3, 0xee, 0xf6, 0xd6, 0x8e, 0x6f, 0x86, 0xea, 0x63, 0xf8, 0x33}}; + + const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + HkdfExtract(k1_, k2_, hash_type_, expected_data); +} + +TEST_P(TlsHkdfTest, HkdfExpandLabel) { + const uint8_t tv[][48] = { + {/* ssl_hash_none */}, + {/* ssl_hash_md5 */}, + {/* ssl_hash_sha1 */}, + {/* ssl_hash_sha224 */}, + {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b, + 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13, + 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c}, + {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01, + 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21, + 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9, + 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}}; + + const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]); + HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_], + kLabelMasterSecret, strlen(kLabelMasterSecret), + expected_data); +} + +static const SSLHashType kHashTypes[] = {ssl_hash_sha256, ssl_hash_sha384}; +INSTANTIATE_TEST_CASE_P(AllHashFuncs, TlsHkdfTest, + ::testing::ValuesIn(kHashTypes)); + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_parser.cc b/security/nss/gtests/ssl_gtest/tls_parser.cc new file mode 100644 index 000000000..e4c06aa91 --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_parser.cc @@ -0,0 +1,73 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "tls_parser.h" + +namespace nss_test { + +bool TlsParser::Read(uint8_t* val) { + if (remaining() < 1) { + return false; + } + *val = *ptr(); + consume(1); + return true; +} + +bool TlsParser::Read(uint32_t* val, size_t size) { + if (size > sizeof(uint32_t)) { + return false; + } + + uint32_t v = 0; + for (size_t i = 0; i < size; ++i) { + uint8_t tmp; + if (!Read(&tmp)) { + return false; + } + + v = (v << 8) | tmp; + } + + *val = v; + return true; +} + +bool TlsParser::Read(DataBuffer* val, size_t len) { + if (remaining() < len) { + return false; + } + + val->Assign(ptr(), len); + consume(len); + return true; +} + +bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) { + uint32_t len; + if (!Read(&len, len_size)) { + return false; + } + return Read(val, len); +} + +bool TlsParser::Skip(size_t len) { + if (len > remaining()) { + return false; + } + consume(len); + return true; +} + +bool TlsParser::SkipVariable(size_t len_size) { + uint32_t len; + if (!Read(&len, len_size)) { + return false; + } + return Skip(len); +} + +} // namespace nss_test diff --git a/security/nss/gtests/ssl_gtest/tls_parser.h b/security/nss/gtests/ssl_gtest/tls_parser.h new file mode 100644 index 000000000..c79d45a7e --- /dev/null +++ b/security/nss/gtests/ssl_gtest/tls_parser.h @@ -0,0 +1,131 @@ +/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this file, + * You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef tls_parser_h_ +#define tls_parser_h_ + +#include <cstdint> +#include <cstring> +#include <memory> +#if defined(WIN32) || defined(WIN64) +#include <winsock2.h> +#else +#include <arpa/inet.h> +#endif +#include "databuffer.h" + +namespace nss_test { + +const uint8_t kTlsChangeCipherSpecType = 20; +const uint8_t kTlsAlertType = 21; +const uint8_t kTlsHandshakeType = 22; +const uint8_t kTlsApplicationDataType = 23; + +const uint8_t kTlsHandshakeClientHello = 1; +const uint8_t kTlsHandshakeServerHello = 2; +const uint8_t kTlsHandshakeHelloRetryRequest = 6; +const uint8_t kTlsHandshakeEncryptedExtensions = 8; +const uint8_t kTlsHandshakeCertificate = 11; +const uint8_t kTlsHandshakeServerKeyExchange = 12; +const uint8_t kTlsHandshakeCertificateVerify = 15; +const uint8_t kTlsHandshakeClientKeyExchange = 16; +const uint8_t kTlsHandshakeFinished = 20; + +const uint8_t kTlsAlertWarning = 1; +const uint8_t kTlsAlertFatal = 2; + +const uint8_t kTlsAlertUnexpectedMessage = 10; +const uint8_t kTlsAlertBadRecordMac = 20; +const uint8_t kTlsAlertHandshakeFailure = 40; +const uint8_t kTlsAlertIllegalParameter = 47; +const uint8_t kTlsAlertDecodeError = 50; +const uint8_t kTlsAlertDecryptError = 51; +const uint8_t kTlsAlertMissingExtension = 109; +const uint8_t kTlsAlertUnsupportedExtension = 110; +const uint8_t kTlsAlertUnrecognizedName = 112; +const uint8_t kTlsAlertNoApplicationProtocol = 120; + +const uint8_t kTlsFakeChangeCipherSpec[] = { + kTlsChangeCipherSpecType, // Type + 0xfe, + 0xff, // Version + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x10, // Fictitious sequence # + 0x00, + 0x01, // Length + 0x01 // Value +}; + +static const uint8_t kTls13PskKe = 0; +static const uint8_t kTls13PskDhKe = 1; +static const uint8_t kTls13PskAuth = 0; +static const uint8_t kTls13PskSignAuth = 1; + +inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; } + +inline uint16_t NormalizeTlsVersion(uint16_t version) { + if (version == 0xfeff) { + return 0x0302; // special: DTLS 1.0 == TLS 1.1 + } + if (IsDtls(version)) { + return (version ^ 0xffff) + 0x0201; + } + return version; +} + +inline uint16_t TlsVersionToDtlsVersion(uint16_t version) { + if (version == 0x0302) { + return 0xfeff; + } + if (version == 0x0304) { + return version; + } + return 0xffff - version + 0x0201; +} + +inline size_t WriteVariable(DataBuffer* target, size_t index, + const DataBuffer& buf, size_t len_size) { + index = target->Write(index, static_cast<uint32_t>(buf.len()), len_size); + return target->Write(index, buf.data(), buf.len()); +} + +class TlsParser { + public: + TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {} + explicit TlsParser(const DataBuffer& buf) : buffer_(buf), offset_(0) {} + + bool Read(uint8_t* val); + // Read an integral type of specified width. + bool Read(uint32_t* val, size_t size); + // Reads len bytes into dest buffer, overwriting it. + bool Read(DataBuffer* dest, size_t len); + // Reads bytes into dest buffer, overwriting it. The number of bytes is + // determined by reading from len_size bytes from the stream first. + bool ReadVariable(DataBuffer* dest, size_t len_size); + + bool Skip(size_t len); + bool SkipVariable(size_t len_size); + + size_t consumed() const { return offset_; } + size_t remaining() const { return buffer_.len() - offset_; } + + private: + void consume(size_t len) { offset_ += len; } + const uint8_t* ptr() const { return buffer_.data() + offset_; } + + DataBuffer buffer_; + size_t offset_; +}; + +} // namespace nss_test + +#endif |