diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc | 89 |
1 files changed, 88 insertions, 1 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc index 87b8e4ace..fb995953f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_extension_unittest.cc @@ -1,4 +1,5 @@ /* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ @@ -19,6 +20,45 @@ namespace nss_test { +class Dtls13LegacyCookieInjector : public TlsHandshakeFilter { + public: + Dtls13LegacyCookieInjector(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}) {} + + virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) { + const uint8_t cookie_bytes[] = {0x03, 0x0A, 0x0B, 0x0C}; + uint32_t offset = 2 /* version */ + 32 /* random */; + + if (agent()->variant() != ssl_variant_datagram) { + ADD_FAILURE(); + return KEEP; + } + + if (header.handshake_type() != ssl_hs_client_hello) { + return KEEP; + } + + DataBuffer cookie(cookie_bytes, sizeof(cookie_bytes)); + *output = input; + + // Add the SID length (if any) to locate the cookie. + uint32_t sid_len = 0; + if (!output->Read(offset, 1, &sid_len)) { + ADD_FAILURE(); + return KEEP; + } + offset += 1 + sid_len; + output->Splice(cookie, offset, 1); + + return CHANGE; + } + + private: + DataBuffer cookie_; +}; + class TlsExtensionTruncator : public TlsExtensionFilter { public: TlsExtensionTruncator(const std::shared_ptr<TlsAgent>& a, uint16_t extension, @@ -188,8 +228,27 @@ class TlsExtensionTest13 } void ConnectWithReplacementVersionList(uint16_t version) { - DataBuffer versions_buf; + // Convert the version encoding for DTLS, if needed. + if (variant_ == ssl_variant_datagram) { + switch (version) { +#ifdef DTLS_1_3_DRAFT_VERSION + case SSL_LIBRARY_VERSION_TLS_1_3: + version = 0x7f00 | DTLS_1_3_DRAFT_VERSION; + break; +#endif + case SSL_LIBRARY_VERSION_TLS_1_2: + version = SSL_LIBRARY_VERSION_DTLS_1_2_WIRE; + break; + case SSL_LIBRARY_VERSION_TLS_1_1: + /* TLS_1_1 maps to DTLS_1_0, see sslproto.h. */ + version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE; + break; + default: + PORT_Assert(0); + } + } + DataBuffer versions_buf; size_t index = versions_buf.Write(0, 2, 1); versions_buf.Write(index, version, 2); MakeTlsFilter<TlsExtensionReplacer>( @@ -887,6 +946,26 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) { server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); } +// Do the same with an External PSK. +TEST_P(TlsConnectTls13, TestTls13PskInvalidBinderValue) { + ScopedPK11SlotInfo slot(PK11_GetInternalSlot()); + ASSERT_TRUE(!!slot); + ScopedPK11SymKey key( + PK11_KeyGen(slot.get(), CKM_HKDF_KEY_GEN, nullptr, 16, nullptr)); + ASSERT_TRUE(!!key); + AddPsk(key, std::string("foo"), ssl_hash_sha256); + StartConnect(); + ASSERT_TRUE(client_->MaybeSetResumptionToken()); + + MakeTlsFilter<TlsPreSharedKeyReplacer>( + client_, [](TlsPreSharedKeyReplacer* r) { + r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1); + }); + ConnectExpectAlert(server_, kTlsAlertDecryptError); + client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT); + server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE); +} + // Extend the binder by one. TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) { SetupForResume(); @@ -1226,6 +1305,14 @@ TEST_P(TlsConnectStream, IncludePadding) { EXPECT_TRUE(capture->captured()); } +TEST_F(TlsConnectDatagram13, Dtls13RejectLegacyCookie) { + EnsureTlsSetup(); + MakeTlsFilter<Dtls13LegacyCookieInjector>(client_); + ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO); + client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT); +} + INSTANTIATE_TEST_CASE_P( ExtensionStream, TlsExtensionTestGeneric, ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, |