diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.h | 38 |
1 files changed, 26 insertions, 12 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index b3fd892ae..6cd6d5073 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -14,7 +14,6 @@ #include <iostream> #include "test_io.h" -#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" @@ -37,7 +36,10 @@ enum SessionResumptionMode { RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET }; +class PacketFilter; class TlsAgent; +class TlsCipherSpec; +struct TlsRecord; const extern std::vector<SSLNamedGroup> kAllDHEGroups; const extern std::vector<SSLNamedGroup> kECDHEGroups; @@ -80,18 +82,10 @@ class TlsAgent : public PollTarget { adapter_->SetPeer(peer->adapter_); } - // Set a filter that can access plaintext (TLS 1.3 only). - void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) { - filter->SetAgent(this); + void SetFilter(std::shared_ptr<PacketFilter> filter) { adapter_->SetPacketFilter(filter); - filter->EnableDecryption(); } - - void SetPacketFilter(std::shared_ptr<PacketFilter> filter) { - adapter_->SetPacketFilter(filter); - } - - void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void ClearFilter() { adapter_->SetPacketFilter(nullptr); } void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, @@ -165,6 +159,24 @@ class TlsAgent : public PollTarget { void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); void CheckCipherSuite(uint16_t cipher_suite); + void SetResumptionTokenCallback(); + bool MaybeSetResumptionToken(); + void SetResumptionToken(const std::vector<uint8_t>& resumption_token) { + resumption_token_ = resumption_token; + } + const std::vector<uint8_t>& GetResumptionToken() const { + return resumption_token_; + } + void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) { + SECStatus rv = SSL_GetResumptionTokenInfo( + resumption_token_.data(), resumption_token_.size(), token.get(), + sizeof(SSLResumptionTokenInfo)); + ASSERT_EQ(SECSuccess, rv); + } + void SetResumptionCallbackCalled() { resumption_callback_called_ = true; } + bool resumption_callback_called() const { + return resumption_callback_called_; + } const std::string& name() const { return name_; } @@ -382,6 +394,7 @@ class TlsAgent : public PollTarget { uint8_t expected_sent_alert_; uint8_t expected_sent_alert_level_; bool handshake_callback_called_; + bool resumption_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; SSLVersionRange vrange_; @@ -393,6 +406,7 @@ class TlsAgent : public PollTarget { AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; bool skip_version_checks_; + std::vector<uint8_t> resumption_token_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -443,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - std::unique_ptr<TlsAgent> agent_; + std::shared_ptr<TlsAgent> agent_; TlsAgent::Role role_; SSLProtocolVariant variant_; uint16_t version_; |