diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.h | 137 |
1 files changed, 88 insertions, 49 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h index 78923c930..4bccb9a84 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.h +++ b/security/nss/gtests/ssl_gtest/tls_agent.h @@ -14,9 +14,11 @@ #include <iostream> #include "test_io.h" +#include "tls_filter.h" #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" +#include "scoped_ptrs.h" extern bool g_ssl_gtest_verbose; @@ -42,6 +44,8 @@ const extern std::vector<SSLNamedGroup> kECDHEGroups; const extern std::vector<SSLNamedGroup> kFFDHEGroups; const extern std::vector<SSLNamedGroup> kFasterDHEGroups; +// These functions are called from callbacks. They use bare pointers because +// TlsAgent sets up the callback and it doesn't know who owns it. typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)> AuthCertificateCallbackFunction; @@ -70,25 +74,24 @@ class TlsAgent : public PollTarget { static const std::string kServerEcdhRsa; static const std::string kServerDsa; - TlsAgent(const std::string& name, Role role, Mode mode); + TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant); virtual ~TlsAgent(); - bool Init() { - pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_); - if (!pr_fd_) return false; - - adapter_ = DummyPrSocket::GetAdapter(pr_fd_); - if (!adapter_) return false; - - return true; + void SetPeer(std::shared_ptr<TlsAgent>& peer) { + adapter_->SetPeer(peer->adapter_); } - void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } + void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) { + filter->SetAgent(this); + adapter_->SetPacketFilter(filter); + } - void SetPacketFilter(PacketFilter* filter) { + void SetPacketFilter(std::shared_ptr<PacketFilter> filter) { adapter_->SetPacketFilter(filter); } + void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); } + void StartConnect(PRFileDesc* model = nullptr); void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group, size_t kea_size = 0) const; @@ -107,6 +110,9 @@ class TlsAgent : public PollTarget { void PrepareForRenegotiate(); // Prepares for renegotiation, then actually triggers it. void StartRenegotiate(); + static bool LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv); bool ConfigServerCert(const std::string& name, bool updateKeyBits = false, const SSLExtraServerCertData* serverCertData = nullptr); bool ConfigServerCertWithChain(const std::string& name); @@ -114,13 +120,12 @@ class TlsAgent : public PollTarget { void SetupClientAuth(); void RequestClientAuth(bool requireAuth); - bool GetClientAuthCredentials(CERTCertificate** cert, - SECKEYPrivateKey** priv) const; void ConfigureSessionCache(SessionResumptionMode mode); void SetSessionTicketsEnabled(bool en); void SetSessionCacheEnabled(bool en); void Set0RttEnabled(bool en); + void SetFallbackSCSVEnabled(bool en); void SetShortHeadersEnabled(); void SetVersionRange(uint16_t minver, uint16_t maxver); void GetVersionRange(uint16_t* minver, uint16_t* maxver); @@ -132,6 +137,7 @@ class TlsAgent : public PollTarget { void EnableFalseStart(); void ExpectResumption(); void ExpectShortHeaders(); + void SkipVersionChecks(); void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count); void EnableAlpn(const uint8_t* val, size_t len); void CheckAlpn(SSLNextProtoState expected_state, @@ -145,7 +151,7 @@ class TlsAgent : public PollTarget { void SendBuffer(const DataBuffer& buf); // Send data directly to the underlying socket, skipping the TLS layer. void SendDirect(const DataBuffer& buf); - void ReadBytes(); + void ReadBytes(size_t max = 16384U); void ResetSentBytes(); // Hack to test drops. void EnableExtendedMasterSecret(); void CheckExtendedMasterSecret(bool expected); @@ -157,6 +163,7 @@ class TlsAgent : public PollTarget { void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups); void DisableECDHEServerKeyReuse(); bool GetPeerChainLength(size_t* count); + void CheckCipherSuite(uint16_t cipher_suite); const std::string& name() const { return name_; } @@ -166,15 +173,15 @@ class TlsAgent : public PollTarget { State state() const { return state_; } const CERTCertificate* peer_cert() const { - return SSL_PeerCertificate(ssl_fd_); + return SSL_PeerCertificate(ssl_fd_.get()); } const char* state_str() const { return state_str(state()); } static const char* state_str(State state) { return states[state]; } - PRFileDesc* ssl_fd() { return ssl_fd_; } - DummyPrSocket* adapter() { return adapter_; } + PRFileDesc* ssl_fd() const { return ssl_fd_.get(); } + std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; } bool is_compressed() const { return info_.compressionMethod != ssl_compression_null; @@ -239,6 +246,9 @@ class TlsAgent : public PollTarget { sni_callback_ = sni_callback; } + void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0); + void ExpectSendAlert(uint8_t alert, uint8_t level = 0); + private: const static char* states[]; @@ -320,6 +330,18 @@ class TlsAgent : public PollTarget { return SECSuccess; } + void CheckAlert(bool sent, const SSLAlert* alert); + + static void AlertReceivedCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert); + } + + static void AlertSentCallback(const PRFileDesc* fd, void* arg, + const SSLAlert* alert) { + reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert); + } + static void HandshakeCallback(PRFileDesc* fd, void* arg) { TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); agent->handshake_callback_called_ = true; @@ -336,14 +358,13 @@ class TlsAgent : public PollTarget { void Connected(); const std::string name_; - Mode mode_; - uint16_t server_key_bits_; - PRFileDesc* pr_fd_; - DummyPrSocket* adapter_; - PRFileDesc* ssl_fd_; + SSLProtocolVariant variant_; Role role_; + uint16_t server_key_bits_; + std::shared_ptr<DummyPrSocket> adapter_; + ScopedPRFileDesc ssl_fd_; State state_; - Poller::Timer* timer_handle_; + std::shared_ptr<Poller::Timer> timer_handle_; bool falsestart_enabled_; uint16_t expected_version_; uint16_t expected_cipher_suite_; @@ -352,6 +373,10 @@ class TlsAgent : public PollTarget { bool can_falsestart_hook_called_; bool sni_hook_called_; bool auth_certificate_hook_called_; + uint8_t expected_received_alert_; + uint8_t expected_received_alert_level_; + uint8_t expected_sent_alert_; + uint8_t expected_sent_alert_level_; bool handshake_callback_called_; SSLChannelInfo info_; SSLCipherSuiteInfo csinfo_; @@ -364,6 +389,7 @@ class TlsAgent : public PollTarget { AuthCertificateCallbackFunction auth_certificate_callback_; SniCallbackFunction sni_callback_; bool expect_short_headers_; + bool skip_version_checks_; }; inline std::ostream& operator<<(std::ostream& stream, @@ -375,20 +401,23 @@ class TlsAgentTestBase : public ::testing::Test { public: static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll; - TlsAgentTestBase(TlsAgent::Role role, Mode mode) - : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {} - ~TlsAgentTestBase() { - if (fd_) { - PR_Close(fd_); - } - } + TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant, + uint16_t version = 0) + : agent_(nullptr), + role_(role), + variant_(variant), + version_(version), + sink_adapter_(new DummyPrSocket("sink", variant)) {} + virtual ~TlsAgentTestBase() {} void SetUp(); void TearDown(); - static void MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, DataBuffer* out, - uint64_t seq_num = 0); + void ExpectAlert(uint8_t alert); + + static void MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, size_t len, + DataBuffer* out, uint64_t seq_num = 0); void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num = 0) const; void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, @@ -403,10 +432,6 @@ class TlsAgentTestBase : public ::testing::Test { return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; } - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - void Init(const std::string& server_name = TlsAgent::kServerRsa); void Reset(const std::string& server_name = TlsAgent::kServerRsa); @@ -415,43 +440,57 @@ class TlsAgentTestBase : public ::testing::Test { void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code = 0); - TlsAgent* agent_; - PRFileDesc* fd_; + std::unique_ptr<TlsAgent> agent_; TlsAgent::Role role_; - Mode mode_; + SSLProtocolVariant variant_; + uint16_t version_; + // This adapter is here just to accept packets from this agent. + std::shared_ptr<DummyPrSocket> sink_adapter_; }; -class TlsAgentTest : public TlsAgentTestBase, - public ::testing::WithParamInterface< - std::tuple<std::string, std::string>> { +class TlsAgentTest + : public TlsAgentTestBase, + public ::testing::WithParamInterface< + std::tuple<std::string, SSLProtocolVariant, uint16_t>> { public: TlsAgentTest() : TlsAgentTestBase(ToRole(std::get<0>(GetParam())), - ToMode(std::get<1>(GetParam()))) {} + std::get<1>(GetParam()), std::get<2>(GetParam())) {} }; class TlsAgentTestClient : public TlsAgentTestBase, - public ::testing::WithParamInterface<std::string> { + public ::testing::WithParamInterface< + std::tuple<SSLProtocolVariant, uint16_t>> { public: TlsAgentTestClient() - : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {} + : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()), + std::get<1>(GetParam())) {} }; +class TlsAgentTestClient13 : public TlsAgentTestClient {}; + class TlsAgentStreamTestClient : public TlsAgentTestBase { public: - TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {} + TlsAgentStreamTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {} }; class TlsAgentStreamTestServer : public TlsAgentTestBase { public: - TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {} + TlsAgentStreamTestServer() + : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {} }; class TlsAgentDgramTestClient : public TlsAgentTestBase { public: - TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {} + TlsAgentDgramTestClient() + : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {} }; +inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) { + return vr1.min == vr2.min && vr1.max == vr2.max; +} + } // namespace nss_test #endif |