summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_connect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc708
1 files changed, 708 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
new file mode 100644
index 000000000..d02549954
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -0,0 +1,708 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_connect.h"
+extern "C" {
+#include "libssl_internals.h"
+}
+
+#include <iostream>
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "sslproto.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+static const std::string kTlsModesStreamArr[] = {"TLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesStream =
+ ::testing::ValuesIn(kTlsModesStreamArr);
+static const std::string kTlsModesDatagramArr[] = {"DTLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesDatagram =
+ ::testing::ValuesIn(kTlsModesDatagramArr);
+static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr);
+
+static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
+ ::testing::ValuesIn(kTlsV10Arr);
+static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
+ ::testing::ValuesIn(kTlsV11Arr);
+static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
+ ::testing::ValuesIn(kTlsV12Arr);
+static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
+ ::testing::ValuesIn(kTlsV10V11Arr);
+static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
+ ::testing::ValuesIn(kTlsV10ToV12Arr);
+static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
+ ::testing::ValuesIn(kTlsV11V12Arr);
+
+static const uint16_t kTlsV11PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
+ ::testing::ValuesIn(kTlsV11PlusArr);
+static const uint16_t kTlsV12PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
+ ::testing::ValuesIn(kTlsV12PlusArr);
+static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
+ ::testing::ValuesIn(kTlsV13Arr);
+static const uint16_t kTlsVAllArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
+ ::testing::ValuesIn(kTlsVAllArr);
+
+std::string VersionString(uint16_t version) {
+ switch (version) {
+ case 0:
+ return "(no version)";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "1.3";
+ default:
+ std::cerr << "Invalid version: " << version << std::endl;
+ EXPECT_TRUE(false);
+ return "";
+ }
+}
+
+TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version)
+ : mode_(mode),
+ client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)),
+ server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)),
+ client_model_(nullptr),
+ server_model_(nullptr),
+ version_(version),
+ expected_resumption_mode_(RESUME_NONE),
+ session_ids_(),
+ expect_extended_master_secret_(false),
+ expect_early_data_accepted_(false) {
+ std::string v;
+ if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
+ v = "1.0";
+ } else {
+ v = VersionString(version_);
+ }
+ std::cerr << "Version: " << mode_ << " " << v << std::endl;
+}
+
+TlsConnectTestBase::TlsConnectTestBase(const std::string& mode,
+ uint16_t version)
+ : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {}
+
+TlsConnectTestBase::~TlsConnectTestBase() {}
+
+// Check the group of each of the supported groups
+void TlsConnectTestBase::CheckGroups(
+ const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(groups.Read(0, 2, &tmp));
+ EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
+ for (size_t i = 2; i < groups.len(); i += 2) {
+ EXPECT_TRUE(groups.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ }
+}
+
+// Check the group of each of the shares
+void TlsConnectTestBase::CheckShares(
+ const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(shares.Read(0, 2, &tmp));
+ EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
+ size_t i;
+ for (i = 2; i < shares.len(); i += 4 + tmp) {
+ ASSERT_TRUE(shares.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
+ }
+ EXPECT_EQ(shares.len(), i);
+}
+
+void TlsConnectTestBase::ClearStats() {
+ // Clear statistics.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ memset(stats, 0, sizeof(*stats));
+}
+
+void TlsConnectTestBase::ClearServerCache() {
+ SSL_ShutdownServerSessionIDCache();
+ SSLInt_ClearSessionTicketKey();
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsConnectTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+ SSLInt_ClearSessionTicketKey();
+ ClearStats();
+ Init();
+}
+
+void TlsConnectTestBase::TearDown() {
+ delete client_;
+ delete server_;
+ if (client_model_) {
+ ASSERT_NE(server_model_, nullptr);
+ delete client_model_;
+ delete server_model_;
+ }
+
+ SSL_ClearSessionCache();
+ SSLInt_ClearSessionTicketKey();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsConnectTestBase::Init() {
+ EXPECT_TRUE(client_->Init());
+ EXPECT_TRUE(server_->Init());
+
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+
+ if (version_) {
+ ConfigureVersion(version_);
+ }
+}
+
+void TlsConnectTestBase::Reset() {
+ // Take a copy of the names because they are about to disappear.
+ std::string server_name = server_->name();
+ std::string client_name = client_->name();
+ Reset(server_name, client_name);
+}
+
+void TlsConnectTestBase::Reset(const std::string& server_name,
+ const std::string& client_name) {
+ delete client_;
+ delete server_;
+
+ client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_);
+ server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_);
+
+ Init();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
+ expected_resumption_mode_ = expected;
+ if (expected != RESUME_NONE) {
+ client_->ExpectResumption();
+ server_->ExpectResumption();
+ }
+}
+
+void TlsConnectTestBase::EnsureTlsSetup() {
+ EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd()
+ : nullptr));
+ EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd()
+ : nullptr));
+}
+
+void TlsConnectTestBase::Handshake() {
+ EnsureTlsSetup();
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
+ (server_->state() != TlsAgent::STATE_CONNECTING),
+ 5000);
+}
+
+void TlsConnectTestBase::EnableExtendedMasterSecret() {
+ client_->EnableExtendedMasterSecret();
+ server_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(true);
+}
+
+void TlsConnectTestBase::Connect() {
+ server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
+ client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+ Handshake();
+ CheckConnected();
+}
+
+void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
+ EnsureTlsSetup();
+ client_->EnableSingleCipher(cipher_suite);
+
+ Connect();
+ SendReceive();
+
+ // Check that we used the right cipher suite.
+ uint16_t actual;
+ EXPECT_TRUE(client_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+ EXPECT_TRUE(server_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+}
+
+void TlsConnectTestBase::CheckConnected() {
+ // Check the version is as expected
+ EXPECT_EQ(client_->version(), server_->version());
+ EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
+ client_->version());
+
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ uint16_t cipher_suite1, cipher_suite2;
+ bool ret = client_->cipher_suite(&cipher_suite1);
+ EXPECT_TRUE(ret);
+ ret = server_->cipher_suite(&cipher_suite2);
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(cipher_suite1, cipher_suite2);
+
+ std::cerr << "Connected with version " << client_->version()
+ << " cipher suite " << client_->cipher_suite_name() << std::endl;
+
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Check and store session ids.
+ std::vector<uint8_t> sid_c1 = client_->session_id();
+ EXPECT_EQ(32U, sid_c1.size());
+ std::vector<uint8_t> sid_s1 = server_->session_id();
+ EXPECT_EQ(32U, sid_s1.size());
+ EXPECT_EQ(sid_c1, sid_s1);
+ session_ids_.push_back(sid_c1);
+ }
+
+ CheckExtendedMasterSecret();
+ CheckEarlyDataAccepted();
+ CheckResumption(expected_resumption_mode_);
+ client_->CheckSecretsDestroyed();
+ server_->CheckSecretsDestroyed();
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ client_->CheckAuthType(auth_type, sig_scheme);
+ server_->CheckAuthType(auth_type, sig_scheme);
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
+ SSLAuthType auth_type) const {
+ SSLNamedGroup group;
+ switch (kea_type) {
+ case ssl_kea_ecdh:
+ group = ssl_grp_ec_curve25519;
+ break;
+ case ssl_kea_dh:
+ group = ssl_grp_ffdhe_2048;
+ break;
+ case ssl_kea_rsa:
+ group = ssl_grp_none;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected KEA";
+ group = ssl_grp_none;
+ break;
+ }
+
+ SSLSignatureScheme scheme;
+ switch (auth_type) {
+ case ssl_auth_rsa_decrypt:
+ scheme = ssl_sig_none;
+ break;
+ case ssl_auth_rsa_sign:
+ scheme = ssl_sig_rsa_pss_sha256;
+ break;
+ case ssl_auth_ecdsa:
+ scheme = ssl_sig_ecdsa_secp256r1_sha256;
+ break;
+ case ssl_auth_dsa:
+ scheme = ssl_sig_dsa_sha1;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected auth type";
+ scheme = static_cast<SSLSignatureScheme>(0x0100);
+ break;
+ }
+ CheckKeys(kea_type, group, auth_type, scheme);
+}
+
+void TlsConnectTestBase::CheckKeys() const {
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+void TlsConnectTestBase::ConnectExpectFail() {
+ server_->StartConnect();
+ client_->StartConnect();
+ Handshake();
+ ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
+}
+
+void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
+ client_->SetVersionRange(version, version);
+ server_->SetVersionRange(version, version);
+}
+
+void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
+ client_->SetExpectedVersion(version);
+ server_->SetExpectedVersion(version);
+}
+
+void TlsConnectTestBase::DisableAllCiphers() {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ server_->DisableAllCiphers();
+}
+
+void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
+ DisableAllCiphers();
+
+ client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+ server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+}
+
+void TlsConnectTestBase::EnableOnlyDheCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ server_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ } else {
+ client_->ConfigNamedGroups(kFFDHEGroups);
+ server_->ConfigNamedGroups(kFFDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::EnableSomeEcdhCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ } else {
+ client_->ConfigNamedGroups(kECDHEGroups);
+ server_->ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server) {
+ client_->ConfigureSessionCache(client);
+ server_->ConfigureSessionCache(server);
+ if ((server & RESUME_TICKET) != 0) {
+ // This is an abomination. NSS encrypts session tickets with the server's
+ // RSA public key. That means we need the server to have an RSA certificate
+ // even if it won't be used for the connection.
+ server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt);
+ }
+}
+
+void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
+ EXPECT_NE(RESUME_BOTH, expected);
+
+ int resume_count = expected ? 1 : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
+
+ // Note: hch == server counter; hsh == client counter.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
+ EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);
+
+ EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
+ EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);
+
+ if (expected != RESUME_NONE) {
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Check that the last two session ids match.
+ ASSERT_EQ(2U, session_ids_.size());
+ EXPECT_EQ(session_ids_[session_ids_.size() - 1],
+ session_ids_[session_ids_.size() - 2]);
+ } else {
+ // TLS 1.3 only uses tickets.
+ EXPECT_TRUE(expected & RESUME_TICKET);
+ }
+ }
+}
+
+void TlsConnectTestBase::EnableAlpn() {
+ client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+}
+
+void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) {
+ client_->EnableAlpn(val, len);
+ server_->EnableAlpn(val, len);
+}
+
+void TlsConnectTestBase::EnsureModelSockets() {
+ // Make sure models agents are available.
+ if (!client_model_) {
+ ASSERT_EQ(server_model_, nullptr);
+ client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_);
+ server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_);
+ }
+
+ // Initialise agents.
+ ASSERT_TRUE(client_model_->Init());
+ ASSERT_TRUE(server_model_->Init());
+}
+
+void TlsConnectTestBase::CheckAlpn(const std::string& val) {
+ client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
+ server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
+}
+
+void TlsConnectTestBase::EnableSrtp() {
+ client_->EnableSrtp();
+ server_->EnableSrtp();
+}
+
+void TlsConnectTestBase::CheckSrtp() const {
+ client_->CheckSrtp();
+ server_->CheckSrtp();
+}
+
+void TlsConnectTestBase::SendReceive() {
+ client_->SendData(50);
+ server_->SendData(50);
+ Receive(50);
+}
+
+// Do a first connection so we can do 0-RTT on the second one.
+void TlsConnectTestBase::SetupForZeroRtt() {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->StartConnect();
+ client_->StartConnect();
+}
+
+// Do a first connection so we can do resumption
+void TlsConnectTestBase::SetupForResume() {
+ EnsureTlsSetup();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+}
+
+void TlsConnectTestBase::ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check) {
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ client_->Handshake(); // Send ClientHello.
+ if (post_clienthello_check) {
+ if (!post_clienthello_check()) return;
+ }
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ if (expect_writable) {
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ }
+ server_->Handshake(); // Consume ClientHello, EE, Finished.
+
+ std::vector<uint8_t> buf(k0RttDataLen);
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
+ if (expect_readable) {
+ std::cerr << "0-RTT read " << rv << " bytes\n";
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ }
+
+ // Do a second read. this should fail.
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+void TlsConnectTestBase::Receive(size_t amount) {
+ WAIT_(client_->received_bytes() == amount &&
+ server_->received_bytes() == amount,
+ 2000);
+ ASSERT_EQ(amount, client_->received_bytes());
+ ASSERT_EQ(amount, server_->received_bytes());
+}
+
+void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
+ expect_extended_master_secret_ = expected;
+}
+
+void TlsConnectTestBase::CheckExtendedMasterSecret() {
+ client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+ server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+}
+
+void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
+ expect_early_data_accepted_ = expected;
+}
+
+void TlsConnectTestBase::CheckEarlyDataAccepted() {
+ client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+ server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+}
+
+void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
+ server_->DisableECDHEServerKeyReuse();
+}
+
+TlsConnectGeneric::TlsConnectGeneric()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectPre12::TlsConnectPre12()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls12::TlsConnectTls12()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}
+
+TlsConnectTls12Plus::TlsConnectTls12Plus()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls13::TlsConnectTls13()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+void TlsKeyExchangeTest::EnsureKeyShareSetup() {
+ EnsureTlsSetup();
+ groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn);
+ shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
+ shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true);
+ std::vector<PacketFilter*> captures;
+ captures.push_back(groups_capture_);
+ captures.push_back(shares_capture_);
+ captures.push_back(shares_capture2_);
+ client_->SetPacketFilter(new ChainedPacketFilter(captures));
+ capture_hrr_ =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ server_->SetPacketFilter(capture_hrr_);
+}
+
+void TlsKeyExchangeTest::ConfigNamedGroups(
+ const std::vector<SSLNamedGroup>& groups) {
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
+ const DataBuffer& ext) {
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+ EXPECT_TRUE(ext.len() % 2 == 0);
+ std::vector<SSLNamedGroup> groups;
+ for (size_t i = 1; i < ext.len() / 2; i += 1) {
+ EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
+ groups.push_back(static_cast<SSLNamedGroup>(tmp));
+ }
+ return groups;
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
+ const DataBuffer& ext) {
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+ std::vector<SSLNamedGroup> shares;
+ size_t i = 2;
+ while (i < ext.len()) {
+ EXPECT_TRUE(ext.Read(i, 2, &tmp));
+ shares.push_back(static_cast<SSLNamedGroup>(tmp));
+ EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
+ i += 4 + tmp;
+ }
+ EXPECT_EQ(ext.len(), i);
+ return shares;
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
+ std::vector<SSLNamedGroup> groups =
+ GetGroupDetails(groups_capture_->extension());
+ EXPECT_EQ(expected_groups, groups);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_LT(0U, expected_shares.size());
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture_->extension());
+ EXPECT_EQ(expected_shares, shares);
+ } else {
+ EXPECT_EQ(0U, shares_capture_->extension().len());
+ }
+
+ EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares) {
+ CheckKEXDetails(expected_groups, expected_shares, false);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares,
+ SSLNamedGroup expected_share2) {
+ CheckKEXDetails(expected_groups, expected_shares, true);
+
+ for (auto it : expected_shares) {
+ EXPECT_NE(expected_share2, it);
+ }
+ std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture2_->extension());
+ EXPECT_EQ(expected_shares2, shares);
+}
+} // namespace nss_test