/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef tls_connect_h_ #define tls_connect_h_ #include #include "sslproto.h" #include "sslt.h" #include "tls_agent.h" #include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" namespace nss_test { extern std::string VersionString(uint16_t version); // A generic TLS connection test base. class TlsConnectTestBase : public ::testing::Test { public: static ::testing::internal::ParamGenerator kTlsVariantsStream; static ::testing::internal::ParamGenerator kTlsVariantsDatagram; static ::testing::internal::ParamGenerator kTlsVariantsAll; static ::testing::internal::ParamGenerator kTlsV10; static ::testing::internal::ParamGenerator kTlsV11; static ::testing::internal::ParamGenerator kTlsV12; static ::testing::internal::ParamGenerator kTlsV10V11; static ::testing::internal::ParamGenerator kTlsV11V12; static ::testing::internal::ParamGenerator kTlsV10ToV12; static ::testing::internal::ParamGenerator kTlsV13; static ::testing::internal::ParamGenerator kTlsV11Plus; static ::testing::internal::ParamGenerator kTlsV12Plus; static ::testing::internal::ParamGenerator kTlsVAll; TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); virtual void SetUp(); virtual void TearDown(); PRTime now() const { return now_; } // Initialize client and server. void Init(); // Clear the statistics. void ClearStats(); // Clear the server session cache. void ClearServerCache(); // Make sure TLS is configured for a connection. virtual void EnsureTlsSetup(); // Reset and keep the same certificate names void Reset(); // Reset, and update the certificate names on both peers void Reset(const std::string& server_name, const std::string& client_name = "client"); // Replace the server. void MakeNewServer(); // Set up void StartConnect(); // Run the handshake. void Handshake(); // Connect and check that it works. void Connect(); // Check that the connection was successfully established. void CheckConnected(); // Connect and expect it to fail. void ConnectExpectFail(); void ExpectAlert(std::shared_ptr& sender, uint8_t alert); void ConnectExpectAlert(std::shared_ptr& sender, uint8_t alert); void ConnectExpectFailOneSide(TlsAgent::Role failingSide); void ConnectWithCipherSuite(uint16_t cipher_suite); void CheckEarlyDataLimit(const std::shared_ptr& agent, size_t expected_size); // Check that the keys used in the handshake match expectations. void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const; // This version guesses some of the values. void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; // This version assumes defaults. void CheckKeys() const; // Check that keys on resumed sessions. void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, SSLNamedGroup original_kea_group, SSLAuthType auth_type, SSLSignatureScheme sig_scheme); void CheckGroups(const DataBuffer& groups, std::function check_group); void CheckShares(const DataBuffer& shares, std::function check_group); void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; void ConfigureVersion(uint16_t version); void SetExpectedVersion(uint16_t version); // Expect resumption of a particular type. void ExpectResumption(SessionResumptionMode expected, uint8_t num_resumed = 1); void DisableAllCiphers(); void EnableOnlyStaticRsaCiphers(); void EnableOnlyDheCiphers(); void EnableSomeEcdhCiphers(); void EnableExtendedMasterSecret(); void ConfigureSelfEncrypt(); void ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server); void EnableAlpn(); void EnableAlpnWithCallback(const std::vector& client, std::string server_choice); void EnableAlpn(const std::vector& vals); void EnsureModelSockets(); void CheckAlpn(const std::string& val); void EnableSrtp(); void CheckSrtp() const; void SendReceive(size_t total = 50); void AddPsk(const ScopedPK11SymKey& psk, std::string label, SSLHashType hash, uint16_t zeroRttSuite = TLS_NULL_WITH_NULL_NULL); void RemovePsk(std::string label); void SetupForZeroRtt(); void SetupForResume(); void ZeroRttSendReceive( bool expect_writable, bool expect_readable, std::function post_clienthello_check = nullptr); void Receive(size_t amount); void ExpectExtendedMasterSecret(bool expected); void ExpectEarlyDataAccepted(bool expected); void DisableECDHEServerKeyReuse(); void SkipVersionChecks(); // Move the DTLS timers for both endpoints to pop the next timer. void ShiftDtlsTimers(); void AdvanceTime(PRTime time_shift); void ResetAntiReplay(PRTime window); void RolloverAntiReplay(); void SaveAlgorithmPolicy(); void RestoreAlgorithmPolicy(); protected: SSLProtocolVariant variant_; std::shared_ptr client_; std::shared_ptr server_; std::unique_ptr client_model_; std::unique_ptr server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; uint8_t expected_resumptions_; std::vector> session_ids_; ScopedSSLAntiReplayContext anti_replay_; // A simple value of "a", "b". Note that the preferred value of "a" is placed // at the end, because the NSS API follows the now defunct NPN specification, // which places the preferred (and default) entry at the end of the list. // NSS will move this final entry to the front when used with ALPN. const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; // A list of algorithm IDs whose policies need to be preserved // around test cases. In particular, DSA is checked in // ssl_extension_unittest.cc. const std::vector algorithms_ = {SEC_OID_APPLY_SSL_POLICY, SEC_OID_ANSIX9_DSA_SIGNATURE, SEC_OID_CURVE25519, SEC_OID_SHA1}; std::vector> saved_policies_; private: void CheckResumption(SessionResumptionMode expected); void CheckExtendedMasterSecret(); void CheckEarlyDataAccepted(); static PRTime TimeFunc(void* arg); bool expect_extended_master_secret_; bool expect_early_data_accepted_; bool skip_version_checks_; PRTime now_; // Track groups and make sure that there are no duplicates. class DuplicateGroupChecker { public: void AddAndCheckGroup(SSLNamedGroup group) { EXPECT_EQ(groups_.end(), groups_.find(group)) << "Group " << group << " should not be duplicated"; groups_.insert(group); } private: std::set groups_; }; }; // A non-parametrized TLS test base. class TlsConnectTest : public TlsConnectTestBase { public: TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} }; // A non-parametrized DTLS-only test base. class DtlsConnectTest : public TlsConnectTestBase { public: DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} }; // A TLS-only test base. class TlsConnectStream : public TlsConnectTestBase, public ::testing::WithParamInterface { public: TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} }; // A TLS-only test base for tests before 1.3 class TlsConnectStreamPre13 : public TlsConnectStream {}; // A DTLS-only test base. class TlsConnectDatagram : public TlsConnectTestBase, public ::testing::WithParamInterface { public: TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} }; // A generic test class that can be either stream or datagram and a single // version of TLS. This is configured in ssl_loopback_unittest.cc. class TlsConnectGeneric : public TlsConnectTestBase, public ::testing::WithParamInterface< std::tuple> { public: TlsConnectGeneric(); }; class TlsConnectGenericResumption : public TlsConnectTestBase, public ::testing::WithParamInterface< std::tuple> { private: bool external_cache_; public: TlsConnectGenericResumption(); virtual void EnsureTlsSetup() { TlsConnectTestBase::EnsureTlsSetup(); // Enable external resumption token cache. if (external_cache_) { client_->SetResumptionTokenCallback(); } } bool use_external_cache() const { return external_cache_; } }; class TlsConnectTls13ResumptionToken : public TlsConnectTestBase, public ::testing::WithParamInterface { public: TlsConnectTls13ResumptionToken(); virtual void EnsureTlsSetup() { TlsConnectTestBase::EnsureTlsSetup(); client_->SetResumptionTokenCallback(); } }; class TlsConnectGenericResumptionToken : public TlsConnectTestBase, public ::testing::WithParamInterface< std::tuple> { public: TlsConnectGenericResumptionToken(); virtual void EnsureTlsSetup() { TlsConnectTestBase::EnsureTlsSetup(); client_->SetResumptionTokenCallback(); } }; // A Pre TLS 1.2 generic test. class TlsConnectPre12 : public TlsConnectTestBase, public ::testing::WithParamInterface< std::tuple> { public: TlsConnectPre12(); }; // A TLS 1.2 only generic test. class TlsConnectTls12 : public TlsConnectTestBase, public ::testing::WithParamInterface { public: TlsConnectTls12(); }; // A TLS 1.2 only stream test. class TlsConnectStreamTls12 : public TlsConnectTestBase { public: TlsConnectStreamTls12() : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} }; // A TLS 1.2+ generic test. class TlsConnectTls12Plus : public TlsConnectTestBase, public ::testing::WithParamInterface< std::tuple> { public: TlsConnectTls12Plus(); }; // A TLS 1.3 only generic test. class TlsConnectTls13 : public TlsConnectTestBase, public ::testing::WithParamInterface { public: TlsConnectTls13(); }; // A TLS 1.3 only stream test. class TlsConnectStreamTls13 : public TlsConnectTestBase { public: TlsConnectStreamTls13() : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; class TlsConnectDatagram13 : public TlsConnectTestBase { public: TlsConnectDatagram13() : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; class TlsConnectDatagramPre13 : public TlsConnectDatagram { public: TlsConnectDatagramPre13() {} }; // A variant that is used only with Pre13. class TlsConnectGenericPre13 : public TlsConnectGeneric {}; class TlsKeyExchangeTest : public TlsConnectGeneric { protected: std::shared_ptr groups_capture_; std::shared_ptr shares_capture_; std::shared_ptr shares_capture2_; std::shared_ptr capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector& groups); std::vector GetGroupDetails( const std::shared_ptr& capture); std::vector GetShareDetails( const std::shared_ptr& capture); void CheckKEXDetails(const std::vector& expectedGroups, const std::vector& expectedShares); void CheckKEXDetails(const std::vector& expectedGroups, const std::vector& expectedShares, SSLNamedGroup expectedShare2); private: void CheckKEXDetails(const std::vector& expectedGroups, const std::vector& expectedShares, bool expect_hrr); }; class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {}; class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {}; } // namespace nss_test #endif