diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.h | 85 |
1 files changed, 77 insertions, 8 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index 73e8dc81a..7dffe7f8a 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -45,8 +45,8 @@ class TlsConnectTestBase : public ::testing::Test { TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); - void SetUp(); - void TearDown(); + virtual void SetUp(); + virtual void TearDown(); // Initialize client and server. void Init(); @@ -55,13 +55,17 @@ class TlsConnectTestBase : public ::testing::Test { // Clear the server session cache. void ClearServerCache(); // Make sure TLS is configured for a connection. - void EnsureTlsSetup(); + virtual void EnsureTlsSetup(); // Reset and keep the same certificate names void Reset(); // Reset, and update the certificate names on both peers void Reset(const std::string& server_name, const std::string& client_name = "client"); + // Replace the server. + void MakeNewServer(); + // Set up + void StartConnect(); // Run the handshake. void Handshake(); // Connect and check that it works. @@ -81,20 +85,28 @@ class TlsConnectTestBase : public ::testing::Test { void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const; // This version assumes defaults. void CheckKeys() const; + // Check that keys on resumed sessions. + void CheckKeysResumption(SSLKEAType kea_type, SSLNamedGroup kea_group, + SSLNamedGroup original_kea_group, + SSLAuthType auth_type, + SSLSignatureScheme sig_scheme); void CheckGroups(const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group); void CheckShares(const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group); + void CheckEpochs(uint16_t client_epoch, uint16_t server_epoch) const; void ConfigureVersion(uint16_t version); void SetExpectedVersion(uint16_t version); // Expect resumption of a particular type. - void ExpectResumption(SessionResumptionMode expected); + void ExpectResumption(SessionResumptionMode expected, + uint8_t num_resumed = 1); void DisableAllCiphers(); void EnableOnlyStaticRsaCiphers(); void EnableOnlyDheCiphers(); void EnableSomeEcdhCiphers(); void EnableExtendedMasterSecret(); + void ConfigureSelfEncrypt(); void ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server); void EnableAlpn(); @@ -103,7 +115,7 @@ class TlsConnectTestBase : public ::testing::Test { void CheckAlpn(const std::string& val); void EnableSrtp(); void CheckSrtp() const; - void SendReceive(); + void SendReceive(size_t total = 50); void SetupForZeroRtt(); void SetupForResume(); void ZeroRttSendReceive( @@ -115,6 +127,9 @@ class TlsConnectTestBase : public ::testing::Test { void DisableECDHEServerKeyReuse(); void SkipVersionChecks(); + // Move the DTLS timers for both endpoints to pop the next timer. + void ShiftDtlsTimers(); + protected: SSLProtocolVariant variant_; std::shared_ptr<TlsAgent> client_; @@ -123,6 +138,7 @@ class TlsConnectTestBase : public ::testing::Test { std::unique_ptr<TlsAgent> server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; + uint8_t expected_resumptions_; std::vector<std::vector<uint8_t>> session_ids_; // A simple value of "a", "b". Note that the preferred value of "a" is placed @@ -192,6 +208,52 @@ class TlsConnectGeneric : public TlsConnectTestBase, TlsConnectGeneric(); }; +class TlsConnectGenericResumption + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t, bool>> { + private: + bool external_cache_; + + public: + TlsConnectGenericResumption(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + // Enable external resumption token cache. + if (external_cache_) { + client_->SetResumptionTokenCallback(); + } + } + + bool use_external_cache() const { return external_cache_; } +}; + +class TlsConnectTls13ResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { + public: + TlsConnectTls13ResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + +class TlsConnectGenericResumptionToken + : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { + public: + TlsConnectGenericResumptionToken(); + + virtual void EnsureTlsSetup() { + TlsConnectTestBase::EnsureTlsSetup(); + client_->SetResumptionTokenCallback(); + } +}; + // A Pre TLS 1.2 generic test. class TlsConnectPre12 : public TlsConnectTestBase, public ::testing::WithParamInterface< @@ -244,6 +306,11 @@ class TlsConnectDatagram13 : public TlsConnectTestBase { : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; +class TlsConnectDatagramPre13 : public TlsConnectDatagram { + public: + TlsConnectDatagramPre13() {} +}; + // A variant that is used only with Pre13. class TlsConnectGenericPre13 : public TlsConnectGeneric {}; @@ -252,12 +319,14 @@ class TlsKeyExchangeTest : public TlsConnectGeneric { std::shared_ptr<TlsExtensionCapture> groups_capture_; std::shared_ptr<TlsExtensionCapture> shares_capture_; std::shared_ptr<TlsExtensionCapture> shares_capture2_; - std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_; + std::shared_ptr<TlsHandshakeRecorder> capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); - std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext); - std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext); + std::vector<SSLNamedGroup> GetGroupDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); + std::vector<SSLNamedGroup> GetShareDetails( + const std::shared_ptr<TlsExtensionCapture>& capture); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, const std::vector<SSLNamedGroup>& expectedShares); void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups, |