diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.h | 86 |
1 files changed, 45 insertions, 41 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.h b/security/nss/gtests/ssl_gtest/tls_connect.h index aa4a32d96..73e8dc81a 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.h +++ b/security/nss/gtests/ssl_gtest/tls_connect.h @@ -25,9 +25,12 @@ extern std::string VersionString(uint16_t version); // A generic TLS connection test base. class TlsConnectTestBase : public ::testing::Test { public: - static ::testing::internal::ParamGenerator<std::string> kTlsModesStream; - static ::testing::internal::ParamGenerator<std::string> kTlsModesDatagram; - static ::testing::internal::ParamGenerator<std::string> kTlsModesAll; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsStream; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsDatagram; + static ::testing::internal::ParamGenerator<SSLProtocolVariant> + kTlsVariantsAll; static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; static ::testing::internal::ParamGenerator<uint16_t> kTlsV12; @@ -39,8 +42,7 @@ class TlsConnectTestBase : public ::testing::Test { static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll; - TlsConnectTestBase(Mode mode, uint16_t version); - TlsConnectTestBase(const std::string& mode, uint16_t version); + TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version); virtual ~TlsConnectTestBase(); void SetUp(); @@ -68,6 +70,9 @@ class TlsConnectTestBase : public ::testing::Test { void CheckConnected(); // Connect and expect it to fail. void ConnectExpectFail(); + void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert); + void ConnectExpectFailOneSide(TlsAgent::Role failingSide); void ConnectWithCipherSuite(uint16_t cipher_suite); // Check that the keys used in the handshake match expectations. void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group, @@ -108,13 +113,14 @@ class TlsConnectTestBase : public ::testing::Test { void ExpectExtendedMasterSecret(bool expected); void ExpectEarlyDataAccepted(bool expected); void DisableECDHEServerKeyReuse(); + void SkipVersionChecks(); protected: - Mode mode_; - TlsAgent* client_; - TlsAgent* server_; - TlsAgent* client_model_; - TlsAgent* server_model_; + SSLProtocolVariant variant_; + std::shared_ptr<TlsAgent> client_; + std::shared_ptr<TlsAgent> server_; + std::unique_ptr<TlsAgent> client_model_; + std::unique_ptr<TlsAgent> server_model_; uint16_t version_; SessionResumptionMode expected_resumption_mode_; std::vector<std::vector<uint8_t>> session_ids_; @@ -126,16 +132,13 @@ class TlsConnectTestBase : public ::testing::Test { const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61}; private: - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void CheckResumption(SessionResumptionMode expected); void CheckExtendedMasterSecret(); void CheckEarlyDataAccepted(); bool expect_extended_master_secret_; bool expect_early_data_accepted_; + bool skip_version_checks_; // Track groups and make sure that there are no duplicates. class DuplicateGroupChecker { @@ -154,20 +157,20 @@ class TlsConnectTestBase : public ::testing::Test { // A non-parametrized TLS test base. class TlsConnectTest : public TlsConnectTestBase { public: - TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {} + TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {} }; // A non-parametrized DTLS-only test base. class DtlsConnectTest : public TlsConnectTestBase { public: - DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {} + DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {} }; // A TLS-only test base. class TlsConnectStream : public TlsConnectTestBase, public ::testing::WithParamInterface<uint16_t> { public: - TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} + TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {} }; // A TLS-only test base for tests before 1.3 @@ -177,30 +180,30 @@ class TlsConnectStreamPre13 : public TlsConnectStream {}; class TlsConnectDatagram : public TlsConnectTestBase, public ::testing::WithParamInterface<uint16_t> { public: - TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} + TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {} }; -// A generic test class that can be either STREAM or DGRAM and a single version -// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this -// should use TEST_P(). -class TlsConnectGeneric - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { +// A generic test class that can be either stream or datagram and a single +// version of TLS. This is configured in ssl_loopback_unittest.cc. +class TlsConnectGeneric : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { public: TlsConnectGeneric(); }; // A Pre TLS 1.2 generic test. -class TlsConnectPre12 - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { +class TlsConnectPre12 : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { public: TlsConnectPre12(); }; // A TLS 1.2 only generic test. -class TlsConnectTls12 : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::string> { +class TlsConnectTls12 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { public: TlsConnectTls12(); }; @@ -209,20 +212,21 @@ class TlsConnectTls12 : public TlsConnectTestBase, class TlsConnectStreamTls12 : public TlsConnectTestBase { public: TlsConnectStreamTls12() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {} }; // A TLS 1.2+ generic test. -class TlsConnectTls12Plus - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { +class TlsConnectTls12Plus : public TlsConnectTestBase, + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { public: TlsConnectTls12Plus(); }; // A TLS 1.3 only generic test. -class TlsConnectTls13 : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::string> { +class TlsConnectTls13 + : public TlsConnectTestBase, + public ::testing::WithParamInterface<SSLProtocolVariant> { public: TlsConnectTls13(); }; @@ -231,13 +235,13 @@ class TlsConnectTls13 : public TlsConnectTestBase, class TlsConnectStreamTls13 : public TlsConnectTestBase { public: TlsConnectStreamTls13() - : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {} }; class TlsConnectDatagram13 : public TlsConnectTestBase { public: TlsConnectDatagram13() - : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {} + : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {} }; // A variant that is used only with Pre13. @@ -245,10 +249,10 @@ class TlsConnectGenericPre13 : public TlsConnectGeneric {}; class TlsKeyExchangeTest : public TlsConnectGeneric { protected: - TlsExtensionCapture* groups_capture_; - TlsExtensionCapture* shares_capture_; - TlsExtensionCapture* shares_capture2_; - TlsInspectorRecordHandshakeMessage* capture_hrr_; + std::shared_ptr<TlsExtensionCapture> groups_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture_; + std::shared_ptr<TlsExtensionCapture> shares_capture2_; + std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_; void EnsureKeyShareSetup(); void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); |