/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_agent_h_
#define tls_agent_h_

#include "prio.h"
#include "ssl.h"

#include <functional>
#include <iostream>

#include "test_io.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "nss_scoped_ptrs.h"
#include "scoped_ptrs_ssl.h"

extern bool g_ssl_gtest_verbose;

namespace nss_test {

#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
#define LOGV(msg)                      \
  do {                                 \
    if (g_ssl_gtest_verbose) LOG(msg); \
  } while (false)

enum SessionResumptionMode {
  RESUME_NONE = 0,
  RESUME_SESSIONID = 1,
  RESUME_TICKET = 2,
  RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};

class PacketFilter;
class TlsAgent;
class TlsCipherSpec;
struct TlsRecord;

const extern std::vector<SSLNamedGroup> kAllDHEGroups;
const extern std::vector<SSLNamedGroup> kECDHEGroups;
const extern std::vector<SSLNamedGroup> kFFDHEGroups;
const extern std::vector<SSLNamedGroup> kFasterDHEGroups;

// These functions are called from callbacks.  They use bare pointers because
// TlsAgent sets up the callback and it doesn't know who owns it.
typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
    AuthCertificateCallbackFunction;

typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;

typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
                              PRUint32 srvNameArrSize)>
    SniCallbackFunction;

class TlsAgent : public PollTarget {
 public:
  enum Role { CLIENT, SERVER };
  enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };

  static const std::string kClient;     // the client key is sign only
  static const std::string kRsa2048;    // bigger sign and encrypt for either
  static const std::string kRsa8192;    // biggest sign and encrypt for either
  static const std::string kServerRsa;  // both sign and encrypt
  static const std::string kServerRsaSign;
  static const std::string kServerRsaPss;
  static const std::string kServerRsaDecrypt;
  static const std::string kServerEcdsa256;
  static const std::string kServerEcdsa384;
  static const std::string kServerEcdsa521;
  static const std::string kServerEcdhEcdsa;
  static const std::string kServerEcdhRsa;
  static const std::string kServerDsa;

  TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
  virtual ~TlsAgent();

  void SetPeer(std::shared_ptr<TlsAgent>& peer) {
    adapter_->SetPeer(peer->adapter_);
  }

  void SetFilter(std::shared_ptr<PacketFilter> filter) {
    adapter_->SetPacketFilter(filter);
  }
  void ClearFilter() { adapter_->SetPacketFilter(nullptr); }

  void StartConnect(PRFileDesc* model = nullptr);
  void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
                size_t kea_size = 0) const;
  void CheckOriginalKEA(SSLNamedGroup kea_group) const;
  void CheckAuthType(SSLAuthType auth_type,
                     SSLSignatureScheme sig_scheme) const;

  void DisableAllCiphers();
  void EnableCiphersByAuthType(SSLAuthType authType);
  void EnableCiphersByKeyExchange(SSLKEAType kea);
  void EnableGroupsByKeyExchange(SSLKEAType kea);
  void EnableGroupsByAuthType(SSLAuthType authType);
  void EnableSingleCipher(uint16_t cipher);

  void Handshake();
  // Marks the internal state as CONNECTING in anticipation of renegotiation.
  void PrepareForRenegotiate();
  // Prepares for renegotiation, then actually triggers it.
  void StartRenegotiate();
  static bool LoadCertificate(const std::string& name,
                              ScopedCERTCertificate* cert,
                              ScopedSECKEYPrivateKey* priv);
  bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
                        const SSLExtraServerCertData* serverCertData = nullptr);
  bool ConfigServerCertWithChain(const std::string& name);
  bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);

  void SetupClientAuth();
  void RequestClientAuth(bool requireAuth);

  void SetOption(int32_t option, int value);
  void ConfigureSessionCache(SessionResumptionMode mode);
  void Set0RttEnabled(bool en);
  void SetFallbackSCSVEnabled(bool en);
  void SetVersionRange(uint16_t minver, uint16_t maxver);
  void GetVersionRange(uint16_t* minver, uint16_t* maxver);
  void CheckPreliminaryInfo();
  void ResetPreliminaryInfo();
  void SetExpectedVersion(uint16_t version);
  void SetServerKeyBits(uint16_t bits);
  void ExpectReadWriteError();
  void EnableFalseStart();
  void ExpectResumption();
  void SkipVersionChecks();
  void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
  void EnableAlpn(const uint8_t* val, size_t len);
  void CheckAlpn(SSLNextProtoState expected_state,
                 const std::string& expected = "") const;
  void EnableSrtp();
  void CheckSrtp() const;
  void CheckErrorCode(int32_t expected) const;
  void WaitForErrorCode(int32_t expected, uint32_t delay) const;
  // Send data on the socket, encrypting it.
  void SendData(size_t bytes, size_t blocksize = 1024);
  void SendBuffer(const DataBuffer& buf);
  bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
                           uint64_t seq, uint8_t ct, const DataBuffer& buf);
  // Send data directly to the underlying socket, skipping the TLS layer.
  void SendDirect(const DataBuffer& buf);
  void SendRecordDirect(const TlsRecord& record);
  void ReadBytes(size_t max = 16384U);
  void ResetSentBytes();  // Hack to test drops.
  void EnableExtendedMasterSecret();
  void CheckExtendedMasterSecret(bool expected);
  void CheckEarlyDataAccepted(bool expected);
  void SetDowngradeCheckVersion(uint16_t version);
  void CheckSecretsDestroyed();
  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
  void DisableECDHEServerKeyReuse();
  bool GetPeerChainLength(size_t* count);
  void CheckCipherSuite(uint16_t cipher_suite);
  void SetResumptionTokenCallback();
  bool MaybeSetResumptionToken();
  void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
    resumption_token_ = resumption_token;
  }
  const std::vector<uint8_t>& GetResumptionToken() const {
    return resumption_token_;
  }
  void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
    SECStatus rv = SSL_GetResumptionTokenInfo(
        resumption_token_.data(), resumption_token_.size(), token.get(),
        sizeof(SSLResumptionTokenInfo));
    ASSERT_EQ(SECSuccess, rv);
  }
  void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
  bool resumption_callback_called() const {
    return resumption_callback_called_;
  }

  const std::string& name() const { return name_; }

  Role role() const { return role_; }
  std::string role_str() const { return role_ == SERVER ? "server" : "client"; }

  SSLProtocolVariant variant() const { return variant_; }

  State state() const { return state_; }

  const CERTCertificate* peer_cert() const {
    return SSL_PeerCertificate(ssl_fd_.get());
  }

  const char* state_str() const { return state_str(state()); }

  static const char* state_str(State state) { return states[state]; }

  PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
  std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }

  bool is_compressed() const {
    return info_.compressionMethod != ssl_compression_null;
  }
  uint16_t server_key_bits() const { return server_key_bits_; }
  uint16_t min_version() const { return vrange_.min; }
  uint16_t max_version() const { return vrange_.max; }
  uint16_t version() const {
    EXPECT_EQ(STATE_CONNECTED, state_);
    return info_.protocolVersion;
  }

  bool cipher_suite(uint16_t* suite) const {
    if (state_ != STATE_CONNECTED) return false;

    *suite = info_.cipherSuite;
    return true;
  }

  std::string cipher_suite_name() const {
    if (state_ != STATE_CONNECTED) return "UNKNOWN";

    return csinfo_.cipherSuiteName;
  }

  std::vector<uint8_t> session_id() const {
    return std::vector<uint8_t>(info_.sessionID,
                                info_.sessionID + info_.sessionIDLength);
  }

  bool auth_type(SSLAuthType* a) const {
    if (state_ != STATE_CONNECTED) return false;

    *a = info_.authType;
    return true;
  }

  bool kea_type(SSLKEAType* k) const {
    if (state_ != STATE_CONNECTED) return false;

    *k = info_.keaType;
    return true;
  }

  size_t received_bytes() const { return recv_ctr_; }
  PRErrorCode error_code() const { return error_code_; }

  bool can_falsestart_hook_called() const {
    return can_falsestart_hook_called_;
  }

  void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
    handshake_callback_ = handshake_callback;
  }

  void SetAuthCertificateCallback(
      AuthCertificateCallbackFunction auth_certificate_callback) {
    auth_certificate_callback_ = auth_certificate_callback;
  }

  void SetSniCallback(SniCallbackFunction sni_callback) {
    sni_callback_ = sni_callback;
  }

  void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
  void ExpectSendAlert(uint8_t alert, uint8_t level = 0);

  std::string alpn_value_to_use_ = "";

 private:
  const static char* states[];

  void SetState(State state);
  void ValidateCipherSpecs();

  // Dummy auth certificate hook.
  static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->auth_certificate_hook_called_ = true;
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  // Client auth certificate hook.
  static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    EXPECT_TRUE(agent->expect_client_auth_);
    EXPECT_EQ(PR_TRUE, isServer);
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
                                         CERTDistNames* caNames,
                                         CERTCertificate** cert,
                                         SECKEYPrivateKey** privKey);

  static void ReadableCallback(PollTarget* self, Event event) {
    TlsAgent* agent = static_cast<TlsAgent*>(self);
    if (event == TIMER_EVENT) {
      agent->timer_handle_ = nullptr;
    }
    agent->ReadableCallback_int();
  }

  void ReadableCallback_int() {
    LOGV("Readable");
    switch (state_) {
      case STATE_CONNECTING:
        Handshake();
        break;
      case STATE_CONNECTED:
        ReadBytes();
        break;
      default:
        break;
    }
  }

  static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
                         PRUint32 srvNameArrSize, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->sni_hook_called_ = true;
    EXPECT_EQ(1UL, srvNameArrSize);
    if (agent->sni_callback_) {
      return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
    }
    return 0;  // First configuration.
  }

  static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
                                         PRBool* canFalseStart) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    EXPECT_TRUE(agent->falsestart_enabled_);
    EXPECT_FALSE(agent->can_falsestart_hook_called_);
    agent->can_falsestart_hook_called_ = true;
    *canFalseStart = true;
    return SECSuccess;
  }

  void CheckAlert(bool sent, const SSLAlert* alert);

  static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
                                    const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
  }

  static void AlertSentCallback(const PRFileDesc* fd, void* arg,
                                const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
  }

  static void HandshakeCallback(PRFileDesc* fd, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->handshake_callback_called_ = true;
    agent->Connected();
    if (agent->handshake_callback_) {
      agent->handshake_callback_(agent);
    }
  }

  void DisableLameGroups();
  void ConfigStrongECGroups(bool en);
  void ConfigAllDHGroups(bool en);
  void CheckCallbacks() const;
  void Connected();

  const std::string name_;
  SSLProtocolVariant variant_;
  Role role_;
  uint16_t server_key_bits_;
  std::shared_ptr<DummyPrSocket> adapter_;
  ScopedPRFileDesc ssl_fd_;
  State state_;
  std::shared_ptr<Poller::Timer> timer_handle_;
  bool falsestart_enabled_;
  uint16_t expected_version_;
  uint16_t expected_cipher_suite_;
  bool expect_resumption_;
  bool expect_client_auth_;
  bool can_falsestart_hook_called_;
  bool sni_hook_called_;
  bool auth_certificate_hook_called_;
  uint8_t expected_received_alert_;
  uint8_t expected_received_alert_level_;
  uint8_t expected_sent_alert_;
  uint8_t expected_sent_alert_level_;
  bool handshake_callback_called_;
  bool resumption_callback_called_;
  SSLChannelInfo info_;
  SSLCipherSuiteInfo csinfo_;
  SSLVersionRange vrange_;
  PRErrorCode error_code_;
  size_t send_ctr_;
  size_t recv_ctr_;
  bool expect_readwrite_error_;
  HandshakeCallbackFunction handshake_callback_;
  AuthCertificateCallbackFunction auth_certificate_callback_;
  SniCallbackFunction sni_callback_;
  bool skip_version_checks_;
  std::vector<uint8_t> resumption_token_;
};

inline std::ostream& operator<<(std::ostream& stream,
                                const TlsAgent::State& state) {
  return stream << TlsAgent::state_str(state);
}

class TlsAgentTestBase : public ::testing::Test {
 public:
  static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;

  TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
                   uint16_t version = 0)
      : agent_(nullptr),
        role_(role),
        variant_(variant),
        version_(version),
        sink_adapter_(new DummyPrSocket("sink", variant)) {}
  virtual ~TlsAgentTestBase() {}

  void SetUp();
  void TearDown();

  void ExpectAlert(uint8_t alert);

  static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
                         uint16_t version, const uint8_t* buf, size_t len,
                         DataBuffer* out, uint64_t seq_num = 0);
  void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
                  size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
                            DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
                                    size_t hs_len, DataBuffer* out,
                                    uint64_t seq_num, uint32_t fragment_offset,
                                    uint32_t fragment_length) const;
  DataBuffer MakeCannedTls13ServerHello();
  static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
                                         DataBuffer* out);
  static inline TlsAgent::Role ToRole(const std::string& str) {
    return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
  }

  void Init(const std::string& server_name = TlsAgent::kServerRsa);
  void Reset(const std::string& server_name = TlsAgent::kServerRsa);

 protected:
  void EnsureInit();
  void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
                      int32_t error_code = 0);

  std::shared_ptr<TlsAgent> agent_;
  TlsAgent::Role role_;
  SSLProtocolVariant variant_;
  uint16_t version_;
  // This adapter is here just to accept packets from this agent.
  std::shared_ptr<DummyPrSocket> sink_adapter_;
};

class TlsAgentTest
    : public TlsAgentTestBase,
      public ::testing::WithParamInterface<
          std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTest()
      : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
                         std::get<1>(GetParam()), std::get<2>(GetParam())) {}
};

class TlsAgentTestClient : public TlsAgentTestBase,
                           public ::testing::WithParamInterface<
                               std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
                         std::get<1>(GetParam())) {}
};

class TlsAgentTestClient13 : public TlsAgentTestClient {};

class TlsAgentStreamTestClient : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
};

class TlsAgentStreamTestServer : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestServer()
      : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
};

class TlsAgentDgramTestClient : public TlsAgentTestBase {
 public:
  TlsAgentDgramTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
};

inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
  return vr1.min == vr2.min && vr1.max == vr2.max;
}

}  // namespace nss_test

#endif