/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "tls_connect.h"
extern "C" {
#include "libssl_internals.h"
}

#include <iostream>

#include "databuffer.h"
#include "gtest_utils.h"
#include "scoped_ptrs.h"
#include "sslproto.h"

extern std::string g_working_dir_path;

namespace nss_test {

static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsStream =
        ::testing::ValuesIn(kTlsVariantsStreamArr);
static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
    ssl_variant_datagram};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsDatagram =
        ::testing::ValuesIn(kTlsVariantsDatagramArr);
static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
                                                        ssl_variant_datagram};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsAll =
        ::testing::ValuesIn(kTlsVariantsAllArr);

static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
    ::testing::ValuesIn(kTlsV10Arr);
static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
    ::testing::ValuesIn(kTlsV11Arr);
static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
    ::testing::ValuesIn(kTlsV12Arr);
static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
                                         SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
    ::testing::ValuesIn(kTlsV10V11Arr);
static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
                                           SSL_LIBRARY_VERSION_TLS_1_1,
                                           SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
    ::testing::ValuesIn(kTlsV10ToV12Arr);
static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
                                         SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
    ::testing::ValuesIn(kTlsV11V12Arr);

static const uint16_t kTlsV11PlusArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
    ::testing::ValuesIn(kTlsV11PlusArr);
static const uint16_t kTlsV12PlusArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
    ::testing::ValuesIn(kTlsV12PlusArr);
static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
    ::testing::ValuesIn(kTlsV13Arr);
static const uint16_t kTlsVAllArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
    SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
    ::testing::ValuesIn(kTlsVAllArr);

std::string VersionString(uint16_t version) {
  switch (version) {
    case 0:
      return "(no version)";
    case SSL_LIBRARY_VERSION_TLS_1_0:
      return "1.0";
    case SSL_LIBRARY_VERSION_TLS_1_1:
      return "1.1";
    case SSL_LIBRARY_VERSION_TLS_1_2:
      return "1.2";
    case SSL_LIBRARY_VERSION_TLS_1_3:
      return "1.3";
    default:
      std::cerr << "Invalid version: " << version << std::endl;
      EXPECT_TRUE(false);
      return "";
  }
}

TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
                                       uint16_t version)
    : variant_(variant),
      client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
      server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
      client_model_(nullptr),
      server_model_(nullptr),
      version_(version),
      expected_resumption_mode_(RESUME_NONE),
      session_ids_(),
      expect_extended_master_secret_(false),
      expect_early_data_accepted_(false),
      skip_version_checks_(false) {
  std::string v;
  if (variant_ == ssl_variant_datagram &&
      version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
    v = "1.0";
  } else {
    v = VersionString(version_);
  }
  std::cerr << "Version: " << variant_ << " " << v << std::endl;
}

TlsConnectTestBase::~TlsConnectTestBase() {}

// Check the group of each of the supported groups
void TlsConnectTestBase::CheckGroups(
    const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
  DuplicateGroupChecker group_set;
  uint32_t tmp = 0;
  EXPECT_TRUE(groups.Read(0, 2, &tmp));
  EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
  for (size_t i = 2; i < groups.len(); i += 2) {
    EXPECT_TRUE(groups.Read(i, 2, &tmp));
    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    group_set.AddAndCheckGroup(group);
    check_group(group);
  }
}

// Check the group of each of the shares
void TlsConnectTestBase::CheckShares(
    const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
  DuplicateGroupChecker group_set;
  uint32_t tmp = 0;
  EXPECT_TRUE(shares.Read(0, 2, &tmp));
  EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
  size_t i;
  for (i = 2; i < shares.len(); i += 4 + tmp) {
    ASSERT_TRUE(shares.Read(i, 2, &tmp));
    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    group_set.AddAndCheckGroup(group);
    check_group(group);
    ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
  }
  EXPECT_EQ(shares.len(), i);
}

void TlsConnectTestBase::ClearStats() {
  // Clear statistics.
  SSL3Statistics* stats = SSL_GetStatistics();
  memset(stats, 0, sizeof(*stats));
}

void TlsConnectTestBase::ClearServerCache() {
  SSL_ShutdownServerSessionIDCache();
  SSLInt_ClearSelfEncryptKey();
  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
}

void TlsConnectTestBase::SetUp() {
  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
  SSLInt_ClearSelfEncryptKey();
  SSLInt_SetTicketLifetime(30);
  SSLInt_SetMaxEarlyDataSize(1024);
  ClearStats();
  Init();
}

void TlsConnectTestBase::TearDown() {
  client_ = nullptr;
  server_ = nullptr;

  SSL_ClearSessionCache();
  SSLInt_ClearSelfEncryptKey();
  SSL_ShutdownServerSessionIDCache();
}

void TlsConnectTestBase::Init() {
  client_->SetPeer(server_);
  server_->SetPeer(client_);

  if (version_) {
    ConfigureVersion(version_);
  }
}

void TlsConnectTestBase::Reset() {
  // Take a copy of the names because they are about to disappear.
  std::string server_name = server_->name();
  std::string client_name = client_->name();
  Reset(server_name, client_name);
}

void TlsConnectTestBase::Reset(const std::string& server_name,
                               const std::string& client_name) {
  client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
  server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
  if (skip_version_checks_) {
    client_->SkipVersionChecks();
    server_->SkipVersionChecks();
  }

  Init();
}

void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
  expected_resumption_mode_ = expected;
  if (expected != RESUME_NONE) {
    client_->ExpectResumption();
    server_->ExpectResumption();
  }
}

void TlsConnectTestBase::EnsureTlsSetup() {
  EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd()
                                                    : nullptr));
  EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd()
                                                    : nullptr));
}

void TlsConnectTestBase::Handshake() {
  EnsureTlsSetup();
  client_->SetServerKeyBits(server_->server_key_bits());
  client_->Handshake();
  server_->Handshake();

  ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
                       (server_->state() != TlsAgent::STATE_CONNECTING),
                   5000);
}

void TlsConnectTestBase::EnableExtendedMasterSecret() {
  client_->EnableExtendedMasterSecret();
  server_->EnableExtendedMasterSecret();
  ExpectExtendedMasterSecret(true);
}

void TlsConnectTestBase::Connect() {
  server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
  client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
  Handshake();
  CheckConnected();
}

void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
  EnsureTlsSetup();
  client_->EnableSingleCipher(cipher_suite);

  Connect();
  SendReceive();

  // Check that we used the right cipher suite.
  uint16_t actual;
  EXPECT_TRUE(client_->cipher_suite(&actual));
  EXPECT_EQ(cipher_suite, actual);
  EXPECT_TRUE(server_->cipher_suite(&actual));
  EXPECT_EQ(cipher_suite, actual);
}

void TlsConnectTestBase::CheckConnected() {
  EXPECT_EQ(client_->version(), server_->version());
  if (!skip_version_checks_) {
    // Check the version is as expected
    EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
              client_->version());
  }

  EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
  EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());

  uint16_t cipher_suite1, cipher_suite2;
  bool ret = client_->cipher_suite(&cipher_suite1);
  EXPECT_TRUE(ret);
  ret = server_->cipher_suite(&cipher_suite2);
  EXPECT_TRUE(ret);
  EXPECT_EQ(cipher_suite1, cipher_suite2);

  std::cerr << "Connected with version " << client_->version()
            << " cipher suite " << client_->cipher_suite_name() << std::endl;

  if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
    // Check and store session ids.
    std::vector<uint8_t> sid_c1 = client_->session_id();
    EXPECT_EQ(32U, sid_c1.size());
    std::vector<uint8_t> sid_s1 = server_->session_id();
    EXPECT_EQ(32U, sid_s1.size());
    EXPECT_EQ(sid_c1, sid_s1);
    session_ids_.push_back(sid_c1);
  }

  CheckExtendedMasterSecret();
  CheckEarlyDataAccepted();
  CheckResumption(expected_resumption_mode_);
  client_->CheckSecretsDestroyed();
  server_->CheckSecretsDestroyed();
}

void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
                                   SSLAuthType auth_type,
                                   SSLSignatureScheme sig_scheme) const {
  client_->CheckKEA(kea_type, kea_group);
  server_->CheckKEA(kea_type, kea_group);
  client_->CheckAuthType(auth_type, sig_scheme);
  server_->CheckAuthType(auth_type, sig_scheme);
}

void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
                                   SSLAuthType auth_type) const {
  SSLNamedGroup group;
  switch (kea_type) {
    case ssl_kea_ecdh:
      group = ssl_grp_ec_curve25519;
      break;
    case ssl_kea_dh:
      group = ssl_grp_ffdhe_2048;
      break;
    case ssl_kea_rsa:
      group = ssl_grp_none;
      break;
    default:
      EXPECT_TRUE(false) << "unexpected KEA";
      group = ssl_grp_none;
      break;
  }

  SSLSignatureScheme scheme;
  switch (auth_type) {
    case ssl_auth_rsa_decrypt:
      scheme = ssl_sig_none;
      break;
    case ssl_auth_rsa_sign:
      if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
        scheme = ssl_sig_rsa_pss_sha256;
      } else {
        scheme = ssl_sig_rsa_pkcs1_sha256;
      }
      break;
    case ssl_auth_rsa_pss:
      scheme = ssl_sig_rsa_pss_sha256;
      break;
    case ssl_auth_ecdsa:
      scheme = ssl_sig_ecdsa_secp256r1_sha256;
      break;
    case ssl_auth_dsa:
      scheme = ssl_sig_dsa_sha1;
      break;
    default:
      EXPECT_TRUE(false) << "unexpected auth type";
      scheme = static_cast<SSLSignatureScheme>(0x0100);
      break;
  }
  CheckKeys(kea_type, group, auth_type, scheme);
}

void TlsConnectTestBase::CheckKeys() const {
  CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}

void TlsConnectTestBase::ConnectExpectFail() {
  server_->StartConnect();
  client_->StartConnect();
  Handshake();
  ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
  ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
}

void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
                                     uint8_t alert) {
  EnsureTlsSetup();
  auto receiver = (sender == client_) ? server_ : client_;
  sender->ExpectSendAlert(alert);
  receiver->ExpectReceiveAlert(alert);
}

void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
                                            uint8_t alert) {
  ExpectAlert(sender, alert);
  ConnectExpectFail();
}

void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
  server_->StartConnect();
  client_->StartConnect();
  client_->SetServerKeyBits(server_->server_key_bits());
  client_->Handshake();
  server_->Handshake();

  auto failing_agent = server_;
  if (failing_side == TlsAgent::CLIENT) {
    failing_agent = client_;
  }
  ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
}

void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
  version_ = version;
  client_->SetVersionRange(version, version);
  server_->SetVersionRange(version, version);
}

void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
  client_->SetExpectedVersion(version);
  server_->SetExpectedVersion(version);
}

void TlsConnectTestBase::DisableAllCiphers() {
  EnsureTlsSetup();
  client_->DisableAllCiphers();
  server_->DisableAllCiphers();
}

void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
  DisableAllCiphers();

  client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
  server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
}

void TlsConnectTestBase::EnableOnlyDheCiphers() {
  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    DisableAllCiphers();
    client_->EnableCiphersByKeyExchange(ssl_kea_dh);
    server_->EnableCiphersByKeyExchange(ssl_kea_dh);
  } else {
    client_->ConfigNamedGroups(kFFDHEGroups);
    server_->ConfigNamedGroups(kFFDHEGroups);
  }
}

void TlsConnectTestBase::EnableSomeEcdhCiphers() {
  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
    server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
  } else {
    client_->ConfigNamedGroups(kECDHEGroups);
    server_->ConfigNamedGroups(kECDHEGroups);
  }
}

void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
                                               SessionResumptionMode server) {
  client_->ConfigureSessionCache(client);
  server_->ConfigureSessionCache(server);
  if ((server & RESUME_TICKET) != 0) {
    ScopedCERTCertificate cert;
    ScopedSECKEYPrivateKey privKey;
    ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
                                          &privKey));

    ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
    ASSERT_TRUE(pubKey);

    EXPECT_EQ(SECSuccess,
              SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
  }
}

void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
  EXPECT_NE(RESUME_BOTH, expected);

  int resume_count = expected ? 1 : 0;
  int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;

  // Note: hch == server counter; hsh == client counter.
  SSL3Statistics* stats = SSL_GetStatistics();
  EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
  EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);

  EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
  EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);

  if (expected != RESUME_NONE) {
    if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
      // Check that the last two session ids match.
      ASSERT_EQ(2U, session_ids_.size());
      EXPECT_EQ(session_ids_[session_ids_.size() - 1],
                session_ids_[session_ids_.size() - 2]);
    } else {
      // TLS 1.3 only uses tickets.
      EXPECT_TRUE(expected & RESUME_TICKET);
    }
  }
}

void TlsConnectTestBase::EnableAlpn() {
  client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
  server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
}

void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) {
  client_->EnableAlpn(val, len);
  server_->EnableAlpn(val, len);
}

void TlsConnectTestBase::EnsureModelSockets() {
  // Make sure models agents are available.
  if (!client_model_) {
    ASSERT_EQ(server_model_, nullptr);
    client_model_.reset(
        new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
    server_model_.reset(
        new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
    if (skip_version_checks_) {
      client_model_->SkipVersionChecks();
      server_model_->SkipVersionChecks();
    }
  }
}

void TlsConnectTestBase::CheckAlpn(const std::string& val) {
  client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
  server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
}

void TlsConnectTestBase::EnableSrtp() {
  client_->EnableSrtp();
  server_->EnableSrtp();
}

void TlsConnectTestBase::CheckSrtp() const {
  client_->CheckSrtp();
  server_->CheckSrtp();
}

void TlsConnectTestBase::SendReceive() {
  client_->SendData(50);
  server_->SendData(50);
  Receive(50);
}

// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
  client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                           SSL_LIBRARY_VERSION_TLS_1_3);
  server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                           SSL_LIBRARY_VERSION_TLS_1_3);
  server_->Set0RttEnabled(true);  // So we signal that we allow 0-RTT.
  Connect();
  SendReceive();  // Need to read so that we absorb the session ticket.
  CheckKeys();

  Reset();
  client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                           SSL_LIBRARY_VERSION_TLS_1_3);
  server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
                           SSL_LIBRARY_VERSION_TLS_1_3);
  server_->StartConnect();
  client_->StartConnect();
}

// Do a first connection so we can do resumption
void TlsConnectTestBase::SetupForResume() {
  EnsureTlsSetup();
  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
  Connect();
  SendReceive();  // Need to read so that we absorb the session ticket.
  CheckKeys();

  Reset();
}

void TlsConnectTestBase::ZeroRttSendReceive(
    bool expect_writable, bool expect_readable,
    std::function<bool()> post_clienthello_check) {
  const char* k0RttData = "ABCDEF";
  const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));

  if (expect_writable && expect_readable) {
    ExpectAlert(client_, kTlsAlertEndOfEarlyData);
  }

  client_->Handshake();  // Send ClientHello.
  if (post_clienthello_check) {
    if (!post_clienthello_check()) return;
  }
  PRInt32 rv =
      PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen);  // 0-RTT write.
  if (expect_writable) {
    EXPECT_EQ(k0RttDataLen, rv);
  } else {
    EXPECT_EQ(SECFailure, rv);
  }
  server_->Handshake();  // Consume ClientHello, EE, Finished.

  std::vector<uint8_t> buf(k0RttDataLen);
  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);  // 0-RTT read
  if (expect_readable) {
    std::cerr << "0-RTT read " << rv << " bytes\n";
    EXPECT_EQ(k0RttDataLen, rv);
  } else {
    EXPECT_EQ(SECFailure, rv);
    EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
  }

  // Do a second read. this should fail.
  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
  EXPECT_EQ(SECFailure, rv);
  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}

void TlsConnectTestBase::Receive(size_t amount) {
  WAIT_(client_->received_bytes() == amount &&
            server_->received_bytes() == amount,
        2000);
  ASSERT_EQ(amount, client_->received_bytes());
  ASSERT_EQ(amount, server_->received_bytes());
}

void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
  expect_extended_master_secret_ = expected;
}

void TlsConnectTestBase::CheckExtendedMasterSecret() {
  client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
  server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
}

void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
  expect_early_data_accepted_ = expected;
}

void TlsConnectTestBase::CheckEarlyDataAccepted() {
  client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
  server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
}

void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
  server_->DisableECDHEServerKeyReuse();
}

void TlsConnectTestBase::SkipVersionChecks() {
  skip_version_checks_ = true;
  client_->SkipVersionChecks();
  server_->SkipVersionChecks();
}

TlsConnectGeneric::TlsConnectGeneric()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectPre12::TlsConnectPre12()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectTls12::TlsConnectTls12()
    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}

TlsConnectTls12Plus::TlsConnectTls12Plus()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectTls13::TlsConnectTls13()
    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}

void TlsKeyExchangeTest::EnsureKeyShareSetup() {
  EnsureTlsSetup();
  groups_capture_ =
      std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
  shares_capture_ =
      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
  shares_capture2_ =
      std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
  std::vector<std::shared_ptr<PacketFilter>> captures = {
      groups_capture_, shares_capture_, shares_capture2_};
  client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
  capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
      kTlsHandshakeHelloRetryRequest);
  server_->SetPacketFilter(capture_hrr_);
}

void TlsKeyExchangeTest::ConfigNamedGroups(
    const std::vector<SSLNamedGroup>& groups) {
  client_->ConfigNamedGroups(groups);
  server_->ConfigNamedGroups(groups);
}

std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
    const DataBuffer& ext) {
  uint32_t tmp = 0;
  EXPECT_TRUE(ext.Read(0, 2, &tmp));
  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
  EXPECT_TRUE(ext.len() % 2 == 0);
  std::vector<SSLNamedGroup> groups;
  for (size_t i = 1; i < ext.len() / 2; i += 1) {
    EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
    groups.push_back(static_cast<SSLNamedGroup>(tmp));
  }
  return groups;
}

std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
    const DataBuffer& ext) {
  uint32_t tmp = 0;
  EXPECT_TRUE(ext.Read(0, 2, &tmp));
  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
  std::vector<SSLNamedGroup> shares;
  size_t i = 2;
  while (i < ext.len()) {
    EXPECT_TRUE(ext.Read(i, 2, &tmp));
    shares.push_back(static_cast<SSLNamedGroup>(tmp));
    EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
    i += 4 + tmp;
  }
  EXPECT_EQ(ext.len(), i);
  return shares;
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
  std::vector<SSLNamedGroup> groups =
      GetGroupDetails(groups_capture_->extension());
  EXPECT_EQ(expected_groups, groups);

  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    ASSERT_LT(0U, expected_shares.size());
    std::vector<SSLNamedGroup> shares =
        GetShareDetails(shares_capture_->extension());
    EXPECT_EQ(expected_shares, shares);
  } else {
    EXPECT_EQ(0U, shares_capture_->extension().len());
  }

  EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares) {
  CheckKEXDetails(expected_groups, expected_shares, false);
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares,
    SSLNamedGroup expected_share2) {
  CheckKEXDetails(expected_groups, expected_shares, true);

  for (auto it : expected_shares) {
    EXPECT_NE(expected_share2, it);
  }
  std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
  std::vector<SSLNamedGroup> shares =
      GetShareDetails(shares_capture2_->extension());
  EXPECT_EQ(expected_shares2, shares);
}
}  // namespace nss_test