/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#include "tls_connect.h"
#include "sslexp.h"
extern "C" {
#include "libssl_internals.h"
}

#include <iostream>

#include "databuffer.h"
#include "gtest_utils.h"
#include "nss_scoped_ptrs.h"
#include "sslproto.h"

extern std::string g_working_dir_path;

namespace nss_test {

static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsStream =
        ::testing::ValuesIn(kTlsVariantsStreamArr);
static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
    ssl_variant_datagram};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsDatagram =
        ::testing::ValuesIn(kTlsVariantsDatagramArr);
static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
                                                        ssl_variant_datagram};
::testing::internal::ParamGenerator<SSLProtocolVariant>
    TlsConnectTestBase::kTlsVariantsAll =
        ::testing::ValuesIn(kTlsVariantsAllArr);

static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
    ::testing::ValuesIn(kTlsV10Arr);
static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
    ::testing::ValuesIn(kTlsV11Arr);
static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
    ::testing::ValuesIn(kTlsV12Arr);
static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
                                         SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
    ::testing::ValuesIn(kTlsV10V11Arr);
static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
                                           SSL_LIBRARY_VERSION_TLS_1_1,
                                           SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
    ::testing::ValuesIn(kTlsV10ToV12Arr);
static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
                                         SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
    ::testing::ValuesIn(kTlsV11V12Arr);

static const uint16_t kTlsV11PlusArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
    ::testing::ValuesIn(kTlsV11PlusArr);
static const uint16_t kTlsV12PlusArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
    ::testing::ValuesIn(kTlsV12PlusArr);
static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
    ::testing::ValuesIn(kTlsV13Arr);
static const uint16_t kTlsVAllArr[] = {
#ifndef NSS_DISABLE_TLS_1_3
    SSL_LIBRARY_VERSION_TLS_1_3,
#endif
    SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
    SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
    ::testing::ValuesIn(kTlsVAllArr);

std::string VersionString(uint16_t version) {
  switch (version) {
    case 0:
      return "(no version)";
    case SSL_LIBRARY_VERSION_3_0:
      return "1.0";
    case SSL_LIBRARY_VERSION_TLS_1_0:
      return "1.0";
    case SSL_LIBRARY_VERSION_TLS_1_1:
      return "1.1";
    case SSL_LIBRARY_VERSION_TLS_1_2:
      return "1.2";
    case SSL_LIBRARY_VERSION_TLS_1_3:
      return "1.3";
    default:
      std::cerr << "Invalid version: " << version << std::endl;
      EXPECT_TRUE(false);
      return "";
  }
}

TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
                                       uint16_t version)
    : variant_(variant),
      client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
      server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
      client_model_(nullptr),
      server_model_(nullptr),
      version_(version),
      expected_resumption_mode_(RESUME_NONE),
      expected_resumptions_(0),
      session_ids_(),
      expect_extended_master_secret_(false),
      expect_early_data_accepted_(false),
      skip_version_checks_(false) {
  std::string v;
  if (variant_ == ssl_variant_datagram &&
      version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
    v = "1.0";
  } else {
    v = VersionString(version_);
  }
  std::cerr << "Version: " << variant_ << " " << v << std::endl;
}

TlsConnectTestBase::~TlsConnectTestBase() {}

// Check the group of each of the supported groups
void TlsConnectTestBase::CheckGroups(
    const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
  DuplicateGroupChecker group_set;
  uint32_t tmp = 0;
  EXPECT_TRUE(groups.Read(0, 2, &tmp));
  EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
  for (size_t i = 2; i < groups.len(); i += 2) {
    EXPECT_TRUE(groups.Read(i, 2, &tmp));
    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    group_set.AddAndCheckGroup(group);
    check_group(group);
  }
}

// Check the group of each of the shares
void TlsConnectTestBase::CheckShares(
    const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
  DuplicateGroupChecker group_set;
  uint32_t tmp = 0;
  EXPECT_TRUE(shares.Read(0, 2, &tmp));
  EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
  size_t i;
  for (i = 2; i < shares.len(); i += 4 + tmp) {
    ASSERT_TRUE(shares.Read(i, 2, &tmp));
    SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
    group_set.AddAndCheckGroup(group);
    check_group(group);
    ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
  }
  EXPECT_EQ(shares.len(), i);
}

void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
                                     uint16_t server_epoch) const {
  uint16_t read_epoch = 0;
  uint16_t write_epoch = 0;

  EXPECT_EQ(SECSuccess,
            SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch));
  EXPECT_EQ(server_epoch, read_epoch) << "client read epoch";
  EXPECT_EQ(client_epoch, write_epoch) << "client write epoch";

  EXPECT_EQ(SECSuccess,
            SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch));
  EXPECT_EQ(client_epoch, read_epoch) << "server read epoch";
  EXPECT_EQ(server_epoch, write_epoch) << "server write epoch";
}

void TlsConnectTestBase::ClearStats() {
  // Clear statistics.
  SSL3Statistics* stats = SSL_GetStatistics();
  memset(stats, 0, sizeof(*stats));
}

void TlsConnectTestBase::ClearServerCache() {
  SSL_ShutdownServerSessionIDCache();
  SSLInt_ClearSelfEncryptKey();
  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
}

void TlsConnectTestBase::SetUp() {
  SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
  SSLInt_ClearSelfEncryptKey();
  SSLInt_SetTicketLifetime(30);
  SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3);
  ClearStats();
  Init();
}

void TlsConnectTestBase::TearDown() {
  client_ = nullptr;
  server_ = nullptr;

  SSL_ClearSessionCache();
  SSLInt_ClearSelfEncryptKey();
  SSL_ShutdownServerSessionIDCache();
}

void TlsConnectTestBase::Init() {
  client_->SetPeer(server_);
  server_->SetPeer(client_);

  if (version_) {
    ConfigureVersion(version_);
  }
}

void TlsConnectTestBase::Reset() {
  // Take a copy of the names because they are about to disappear.
  std::string server_name = server_->name();
  std::string client_name = client_->name();
  Reset(server_name, client_name);
}

void TlsConnectTestBase::Reset(const std::string& server_name,
                               const std::string& client_name) {
  auto token = client_->GetResumptionToken();
  client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
  client_->SetResumptionToken(token);
  server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
  if (skip_version_checks_) {
    client_->SkipVersionChecks();
    server_->SkipVersionChecks();
  }

  Init();
}

void TlsConnectTestBase::MakeNewServer() {
  auto replacement = std::make_shared<TlsAgent>(
      server_->name(), TlsAgent::SERVER, server_->variant());
  server_ = replacement;
  if (version_) {
    server_->SetVersionRange(version_, version_);
  }
  client_->SetPeer(server_);
  server_->SetPeer(client_);
  server_->StartConnect();
}

void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
                                          uint8_t num_resumptions) {
  expected_resumption_mode_ = expected;
  if (expected != RESUME_NONE) {
    client_->ExpectResumption();
    server_->ExpectResumption();
    expected_resumptions_ = num_resumptions;
  }
  EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
}

void TlsConnectTestBase::EnsureTlsSetup() {
  EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd()
                                                    : nullptr));
  EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd()
                                                    : nullptr));
}

void TlsConnectTestBase::Handshake() {
  EnsureTlsSetup();
  client_->SetServerKeyBits(server_->server_key_bits());
  client_->Handshake();
  server_->Handshake();

  ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
                       (server_->state() != TlsAgent::STATE_CONNECTING),
                   5000);
}

void TlsConnectTestBase::EnableExtendedMasterSecret() {
  client_->EnableExtendedMasterSecret();
  server_->EnableExtendedMasterSecret();
  ExpectExtendedMasterSecret(true);
}

void TlsConnectTestBase::Connect() {
  server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
  client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
  client_->MaybeSetResumptionToken();
  Handshake();
  CheckConnected();
}

void TlsConnectTestBase::StartConnect() {
  server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
  client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
}

void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
  EnsureTlsSetup();
  client_->EnableSingleCipher(cipher_suite);

  Connect();
  SendReceive();

  // Check that we used the right cipher suite.
  uint16_t actual;
  EXPECT_TRUE(client_->cipher_suite(&actual));
  EXPECT_EQ(cipher_suite, actual);
  EXPECT_TRUE(server_->cipher_suite(&actual));
  EXPECT_EQ(cipher_suite, actual);
}

void TlsConnectTestBase::CheckConnected() {
  // Have the client read handshake twice to make sure we get the
  // NST and the ACK.
  if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
      variant_ == ssl_variant_datagram) {
    client_->Handshake();
    client_->Handshake();
    auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
    // Verify that we dropped the client's retransmission cipher suites.
    EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
    if (suites != 2) {
      SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
    }
  }
  EXPECT_EQ(client_->version(), server_->version());
  if (!skip_version_checks_) {
    // Check the version is as expected
    EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
              client_->version());
  }

  EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
  EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());

  uint16_t cipher_suite1, cipher_suite2;
  bool ret = client_->cipher_suite(&cipher_suite1);
  EXPECT_TRUE(ret);
  ret = server_->cipher_suite(&cipher_suite2);
  EXPECT_TRUE(ret);
  EXPECT_EQ(cipher_suite1, cipher_suite2);

  std::cerr << "Connected with version " << client_->version()
            << " cipher suite " << client_->cipher_suite_name() << std::endl;

  if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
    // Check and store session ids.
    std::vector<uint8_t> sid_c1 = client_->session_id();
    EXPECT_EQ(32U, sid_c1.size());
    std::vector<uint8_t> sid_s1 = server_->session_id();
    EXPECT_EQ(32U, sid_s1.size());
    EXPECT_EQ(sid_c1, sid_s1);
    session_ids_.push_back(sid_c1);
  }

  CheckExtendedMasterSecret();
  CheckEarlyDataAccepted();
  CheckResumption(expected_resumption_mode_);
  client_->CheckSecretsDestroyed();
  server_->CheckSecretsDestroyed();
}

void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
                                   SSLAuthType auth_type,
                                   SSLSignatureScheme sig_scheme) const {
  if (kea_group != ssl_grp_none) {
    client_->CheckKEA(kea_type, kea_group);
    server_->CheckKEA(kea_type, kea_group);
  }
  server_->CheckAuthType(auth_type, sig_scheme);
  client_->CheckAuthType(auth_type, sig_scheme);
}

void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
                                   SSLAuthType auth_type) const {
  SSLNamedGroup group;
  switch (kea_type) {
    case ssl_kea_ecdh:
      group = ssl_grp_ec_curve25519;
      break;
    case ssl_kea_dh:
      group = ssl_grp_ffdhe_2048;
      break;
    case ssl_kea_rsa:
      group = ssl_grp_none;
      break;
    default:
      EXPECT_TRUE(false) << "unexpected KEA";
      group = ssl_grp_none;
      break;
  }

  SSLSignatureScheme scheme;
  switch (auth_type) {
    case ssl_auth_rsa_decrypt:
      scheme = ssl_sig_none;
      break;
    case ssl_auth_rsa_sign:
      if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
        scheme = ssl_sig_rsa_pss_rsae_sha256;
      } else {
        scheme = ssl_sig_rsa_pkcs1_sha256;
      }
      break;
    case ssl_auth_rsa_pss:
      scheme = ssl_sig_rsa_pss_rsae_sha256;
      break;
    case ssl_auth_ecdsa:
      scheme = ssl_sig_ecdsa_secp256r1_sha256;
      break;
    case ssl_auth_dsa:
      scheme = ssl_sig_dsa_sha1;
      break;
    default:
      EXPECT_TRUE(false) << "unexpected auth type";
      scheme = static_cast<SSLSignatureScheme>(0x0100);
      break;
  }
  CheckKeys(kea_type, group, auth_type, scheme);
}

void TlsConnectTestBase::CheckKeys() const {
  CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}

void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
                                             SSLNamedGroup kea_group,
                                             SSLNamedGroup original_kea_group,
                                             SSLAuthType auth_type,
                                             SSLSignatureScheme sig_scheme) {
  CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
  EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
  client_->CheckOriginalKEA(original_kea_group);
  server_->CheckOriginalKEA(original_kea_group);
}

void TlsConnectTestBase::ConnectExpectFail() {
  StartConnect();
  Handshake();
  ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
  ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
}

void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
                                     uint8_t alert) {
  EnsureTlsSetup();
  auto receiver = (sender == client_) ? server_ : client_;
  sender->ExpectSendAlert(alert);
  receiver->ExpectReceiveAlert(alert);
}

void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
                                            uint8_t alert) {
  ExpectAlert(sender, alert);
  ConnectExpectFail();
}

void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
  StartConnect();
  client_->SetServerKeyBits(server_->server_key_bits());
  client_->Handshake();
  server_->Handshake();

  auto failing_agent = server_;
  if (failing_side == TlsAgent::CLIENT) {
    failing_agent = client_;
  }
  ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
}

void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
  version_ = version;
  client_->SetVersionRange(version, version);
  server_->SetVersionRange(version, version);
}

void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
  client_->SetExpectedVersion(version);
  server_->SetExpectedVersion(version);
}

void TlsConnectTestBase::DisableAllCiphers() {
  EnsureTlsSetup();
  client_->DisableAllCiphers();
  server_->DisableAllCiphers();
}

void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
  DisableAllCiphers();

  client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
  server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
}

void TlsConnectTestBase::EnableOnlyDheCiphers() {
  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    DisableAllCiphers();
    client_->EnableCiphersByKeyExchange(ssl_kea_dh);
    server_->EnableCiphersByKeyExchange(ssl_kea_dh);
  } else {
    client_->ConfigNamedGroups(kFFDHEGroups);
    server_->ConfigNamedGroups(kFFDHEGroups);
  }
}

void TlsConnectTestBase::EnableSomeEcdhCiphers() {
  if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
    client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
    server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
    server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
  } else {
    client_->ConfigNamedGroups(kECDHEGroups);
    server_->ConfigNamedGroups(kECDHEGroups);
  }
}

void TlsConnectTestBase::ConfigureSelfEncrypt() {
  ScopedCERTCertificate cert;
  ScopedSECKEYPrivateKey privKey;
  ASSERT_TRUE(
      TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));

  ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
  ASSERT_TRUE(pubKey);

  EXPECT_EQ(SECSuccess,
            SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
}

void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
                                               SessionResumptionMode server) {
  client_->ConfigureSessionCache(client);
  server_->ConfigureSessionCache(server);
  if ((server & RESUME_TICKET) != 0) {
    ConfigureSelfEncrypt();
  }
}

void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
  EXPECT_NE(RESUME_BOTH, expected);

  int resume_count = expected ? expected_resumptions_ : 0;
  int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;

  // Note: hch == server counter; hsh == client counter.
  SSL3Statistics* stats = SSL_GetStatistics();
  EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
  EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);

  EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
  EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);

  if (expected != RESUME_NONE) {
    if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
        client_->GetResumptionToken().size() == 0) {
      // Check that the last two session ids match.
      ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
      EXPECT_EQ(session_ids_[session_ids_.size() - 1],
                session_ids_[session_ids_.size() - 2]);
    } else {
      // We've either chosen TLS 1.3 or are using an external resumption token,
      // both of which only use tickets.
      EXPECT_TRUE(expected & RESUME_TICKET);
    }
  }
}

static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd,
                                         const unsigned char* protos,
                                         unsigned int protos_len,
                                         unsigned char* protoOut,
                                         unsigned int* protoOutLen,
                                         unsigned int protoMaxLen) {
  EXPECT_EQ(protoMaxLen, 255U);
  TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
  // Check that agent->alpn_value_to_use_ is in protos.
  if (protos_len < 1) {
    return SECFailure;
  }
  for (size_t i = 0; i < protos_len;) {
    size_t l = protos[i];
    EXPECT_LT(i + l, protos_len);
    if (i + l >= protos_len) {
      return SECFailure;
    }
    std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l);
    if (protos_s == agent->alpn_value_to_use_) {
      size_t s_len = agent->alpn_value_to_use_.size();
      EXPECT_LE(s_len, 255U);
      memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len);
      *protoOutLen = s_len;
      return SECSuccess;
    }
    i += l + 1;
  }
  return SECFailure;
}

void TlsConnectTestBase::EnableAlpn() {
  client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
  server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
}

void TlsConnectTestBase::EnableAlpnWithCallback(
    const std::vector<uint8_t>& client_vals, std::string server_choice) {
  EnsureTlsSetup();
  server_->alpn_value_to_use_ = server_choice;
  EXPECT_EQ(SECSuccess,
            SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(),
                                 client_vals.size()));
  SECStatus rv = SSL_SetNextProtoCallback(
      server_->ssl_fd(), NextProtoCallbackServer, server_.get());
  EXPECT_EQ(SECSuccess, rv);
}

void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) {
  client_->EnableAlpn(vals.data(), vals.size());
  server_->EnableAlpn(vals.data(), vals.size());
}

void TlsConnectTestBase::EnsureModelSockets() {
  // Make sure models agents are available.
  if (!client_model_) {
    ASSERT_EQ(server_model_, nullptr);
    client_model_.reset(
        new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
    server_model_.reset(
        new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
    if (skip_version_checks_) {
      client_model_->SkipVersionChecks();
      server_model_->SkipVersionChecks();
    }
  }
}

void TlsConnectTestBase::CheckAlpn(const std::string& val) {
  client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
  server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
}

void TlsConnectTestBase::EnableSrtp() {
  client_->EnableSrtp();
  server_->EnableSrtp();
}

void TlsConnectTestBase::CheckSrtp() const {
  client_->CheckSrtp();
  server_->CheckSrtp();
}

void TlsConnectTestBase::SendReceive(size_t total) {
  ASSERT_GT(total, client_->received_bytes());
  ASSERT_GT(total, server_->received_bytes());
  client_->SendData(total - server_->received_bytes());
  server_->SendData(total - client_->received_bytes());
  Receive(total);  // Receive() is cumulative
}

// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
  // If we don't do this, then all 0-RTT attempts will be rejected.
  SSLInt_RolloverAntiReplay();

  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
  ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
  server_->Set0RttEnabled(true);  // So we signal that we allow 0-RTT.
  Connect();
  SendReceive();  // Need to read so that we absorb the session ticket.
  CheckKeys();

  Reset();
  StartConnect();
}

// Do a first connection so we can do resumption
void TlsConnectTestBase::SetupForResume() {
  EnsureTlsSetup();
  ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
  Connect();
  SendReceive();  // Need to read so that we absorb the session ticket.
  CheckKeys();

  Reset();
}

void TlsConnectTestBase::ZeroRttSendReceive(
    bool expect_writable, bool expect_readable,
    std::function<bool()> post_clienthello_check) {
  const char* k0RttData = "ABCDEF";
  const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));

  client_->Handshake();  // Send ClientHello.
  if (post_clienthello_check) {
    if (!post_clienthello_check()) return;
  }
  PRInt32 rv =
      PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen);  // 0-RTT write.
  if (expect_writable) {
    EXPECT_EQ(k0RttDataLen, rv);
  } else {
    EXPECT_EQ(SECFailure, rv);
  }
  server_->Handshake();  // Consume ClientHello

  std::vector<uint8_t> buf(k0RttDataLen);
  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);  // 0-RTT read
  if (expect_readable) {
    std::cerr << "0-RTT read " << rv << " bytes\n";
    EXPECT_EQ(k0RttDataLen, rv);
  } else {
    EXPECT_EQ(SECFailure, rv);
    EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError())
        << "Unexpected error: " << PORT_ErrorToName(PORT_GetError());
  }

  // Do a second read. this should fail.
  rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
  EXPECT_EQ(SECFailure, rv);
  EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
}

void TlsConnectTestBase::Receive(size_t amount) {
  WAIT_(client_->received_bytes() == amount &&
            server_->received_bytes() == amount,
        2000);
  ASSERT_EQ(amount, client_->received_bytes());
  ASSERT_EQ(amount, server_->received_bytes());
}

void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
  expect_extended_master_secret_ = expected;
}

void TlsConnectTestBase::CheckExtendedMasterSecret() {
  client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
  server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
}

void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
  expect_early_data_accepted_ = expected;
}

void TlsConnectTestBase::CheckEarlyDataAccepted() {
  client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
  server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
}

void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
  server_->DisableECDHEServerKeyReuse();
}

void TlsConnectTestBase::SkipVersionChecks() {
  skip_version_checks_ = true;
  client_->SkipVersionChecks();
  server_->SkipVersionChecks();
}

// Shift the DTLS timers, to the minimum time necessary to let the next timer
// run on either client or server.  This allows tests to skip waiting without
// having timers run out of order.
void TlsConnectTestBase::ShiftDtlsTimers() {
  PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
  PRIntervalTime time;
  SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
  if (rv == SECSuccess) {
    time_shift = time;
  }
  rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
  if (rv == SECSuccess &&
      (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
    time_shift = time;
  }

  if (time_shift == PR_INTERVAL_NO_TIMEOUT) {
    return;
  }

  EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
  EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
}

TlsConnectGeneric::TlsConnectGeneric()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectPre12::TlsConnectPre12()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectTls12::TlsConnectTls12()
    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}

TlsConnectTls12Plus::TlsConnectTls12Plus()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

TlsConnectTls13::TlsConnectTls13()
    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}

TlsConnectGenericResumption::TlsConnectGenericResumption()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
      external_cache_(std::get<2>(GetParam())) {}

TlsConnectTls13ResumptionToken::TlsConnectTls13ResumptionToken()
    : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}

TlsConnectGenericResumptionToken::TlsConnectGenericResumptionToken()
    : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}

void TlsKeyExchangeTest::EnsureKeyShareSetup() {
  EnsureTlsSetup();
  groups_capture_ =
      std::make_shared<TlsExtensionCapture>(client_, ssl_supported_groups_xtn);
  shares_capture_ =
      std::make_shared<TlsExtensionCapture>(client_, ssl_tls13_key_share_xtn);
  shares_capture2_ = std::make_shared<TlsExtensionCapture>(
      client_, ssl_tls13_key_share_xtn, true);
  std::vector<std::shared_ptr<PacketFilter>> captures = {
      groups_capture_, shares_capture_, shares_capture2_};
  client_->SetFilter(std::make_shared<ChainedPacketFilter>(captures));
  capture_hrr_ = MakeTlsFilter<TlsHandshakeRecorder>(
      server_, kTlsHandshakeHelloRetryRequest);
}

void TlsKeyExchangeTest::ConfigNamedGroups(
    const std::vector<SSLNamedGroup>& groups) {
  client_->ConfigNamedGroups(groups);
  server_->ConfigNamedGroups(groups);
}

std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
    const std::shared_ptr<TlsExtensionCapture>& capture) {
  EXPECT_TRUE(capture->captured());
  const DataBuffer& ext = capture->extension();

  uint32_t tmp = 0;
  EXPECT_TRUE(ext.Read(0, 2, &tmp));
  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
  EXPECT_TRUE(ext.len() % 2 == 0);

  std::vector<SSLNamedGroup> groups;
  for (size_t i = 1; i < ext.len() / 2; i += 1) {
    EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
    groups.push_back(static_cast<SSLNamedGroup>(tmp));
  }
  return groups;
}

std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
    const std::shared_ptr<TlsExtensionCapture>& capture) {
  EXPECT_TRUE(capture->captured());
  const DataBuffer& ext = capture->extension();

  uint32_t tmp = 0;
  EXPECT_TRUE(ext.Read(0, 2, &tmp));
  EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));

  std::vector<SSLNamedGroup> shares;
  size_t i = 2;
  while (i < ext.len()) {
    EXPECT_TRUE(ext.Read(i, 2, &tmp));
    shares.push_back(static_cast<SSLNamedGroup>(tmp));
    EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
    i += 4 + tmp;
  }
  EXPECT_EQ(ext.len(), i);
  return shares;
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
  std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
  EXPECT_EQ(expected_groups, groups);

  if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
    ASSERT_LT(0U, expected_shares.size());
    std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
    EXPECT_EQ(expected_shares, shares);
  } else {
    EXPECT_FALSE(shares_capture_->captured());
  }

  EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares) {
  CheckKEXDetails(expected_groups, expected_shares, false);
}

void TlsKeyExchangeTest::CheckKEXDetails(
    const std::vector<SSLNamedGroup>& expected_groups,
    const std::vector<SSLNamedGroup>& expected_shares,
    SSLNamedGroup expected_share2) {
  CheckKEXDetails(expected_groups, expected_shares, true);

  for (auto it : expected_shares) {
    EXPECT_NE(expected_share2, it);
  }
  std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
  EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
}
}  // namespace nss_test