summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.h457
1 files changed, 457 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h
new file mode 100644
index 000000000..78923c930
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_agent.h
@@ -0,0 +1,457 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_agent_h_
+#define tls_agent_h_
+
+#include "prio.h"
+#include "ssl.h"
+
+#include <functional>
+#include <iostream>
+
+#include "test_io.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
+#define LOGV(msg) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(msg); \
+ } while (false)
+
+enum SessionResumptionMode {
+ RESUME_NONE = 0,
+ RESUME_SESSIONID = 1,
+ RESUME_TICKET = 2,
+ RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
+};
+
+class TlsAgent;
+
+const extern std::vector<SSLNamedGroup> kAllDHEGroups;
+const extern std::vector<SSLNamedGroup> kECDHEGroups;
+const extern std::vector<SSLNamedGroup> kFFDHEGroups;
+const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
+
+typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
+ AuthCertificateCallbackFunction;
+
+typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;
+
+typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize)>
+ SniCallbackFunction;
+
+class TlsAgent : public PollTarget {
+ public:
+ enum Role { CLIENT, SERVER };
+ enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };
+
+ static const std::string kClient; // the client key is sign only
+ static const std::string kRsa2048; // bigger sign and encrypt for either
+ static const std::string kServerRsa; // both sign and encrypt
+ static const std::string kServerRsaSign;
+ static const std::string kServerRsaPss;
+ static const std::string kServerRsaDecrypt;
+ static const std::string kServerRsaChain; // A cert that requires a chain.
+ static const std::string kServerEcdsa256;
+ static const std::string kServerEcdsa384;
+ static const std::string kServerEcdsa521;
+ static const std::string kServerEcdhEcdsa;
+ static const std::string kServerEcdhRsa;
+ static const std::string kServerDsa;
+
+ TlsAgent(const std::string& name, Role role, Mode mode);
+ virtual ~TlsAgent();
+
+ bool Init() {
+ pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
+ if (!pr_fd_) return false;
+
+ adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
+ if (!adapter_) return false;
+
+ return true;
+ }
+
+ void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
+
+ void SetPacketFilter(PacketFilter* filter) {
+ adapter_->SetPacketFilter(filter);
+ }
+
+ void StartConnect(PRFileDesc* model = nullptr);
+ void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
+ size_t kea_size = 0) const;
+ void CheckAuthType(SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const;
+
+ void DisableAllCiphers();
+ void EnableCiphersByAuthType(SSLAuthType authType);
+ void EnableCiphersByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByAuthType(SSLAuthType authType);
+ void EnableSingleCipher(uint16_t cipher);
+
+ void Handshake();
+ // Marks the internal state as CONNECTING in anticipation of renegotiation.
+ void PrepareForRenegotiate();
+ // Prepares for renegotiation, then actually triggers it.
+ void StartRenegotiate();
+ bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
+ const SSLExtraServerCertData* serverCertData = nullptr);
+ bool ConfigServerCertWithChain(const std::string& name);
+ bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);
+
+ void SetupClientAuth();
+ void RequestClientAuth(bool requireAuth);
+ bool GetClientAuthCredentials(CERTCertificate** cert,
+ SECKEYPrivateKey** priv) const;
+
+ void ConfigureSessionCache(SessionResumptionMode mode);
+ void SetSessionTicketsEnabled(bool en);
+ void SetSessionCacheEnabled(bool en);
+ void Set0RttEnabled(bool en);
+ void SetShortHeadersEnabled();
+ void SetVersionRange(uint16_t minver, uint16_t maxver);
+ void GetVersionRange(uint16_t* minver, uint16_t* maxver);
+ void CheckPreliminaryInfo();
+ void ResetPreliminaryInfo();
+ void SetExpectedVersion(uint16_t version);
+ void SetServerKeyBits(uint16_t bits);
+ void ExpectReadWriteError();
+ void EnableFalseStart();
+ void ExpectResumption();
+ void ExpectShortHeaders();
+ void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
+ void EnableAlpn(const uint8_t* val, size_t len);
+ void CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected = "") const;
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void CheckErrorCode(int32_t expected) const;
+ void WaitForErrorCode(int32_t expected, uint32_t delay) const;
+ // Send data on the socket, encrypting it.
+ void SendData(size_t bytes, size_t blocksize = 1024);
+ void SendBuffer(const DataBuffer& buf);
+ // Send data directly to the underlying socket, skipping the TLS layer.
+ void SendDirect(const DataBuffer& buf);
+ void ReadBytes();
+ void ResetSentBytes(); // Hack to test drops.
+ void EnableExtendedMasterSecret();
+ void CheckExtendedMasterSecret(bool expected);
+ void CheckEarlyDataAccepted(bool expected);
+ void DisableRollbackDetection();
+ void EnableCompression();
+ void SetDowngradeCheckVersion(uint16_t version);
+ void CheckSecretsDestroyed();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ void DisableECDHEServerKeyReuse();
+ bool GetPeerChainLength(size_t* count);
+
+ const std::string& name() const { return name_; }
+
+ Role role() const { return role_; }
+ std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+
+ State state() const { return state_; }
+
+ const CERTCertificate* peer_cert() const {
+ return SSL_PeerCertificate(ssl_fd_);
+ }
+
+ const char* state_str() const { return state_str(state()); }
+
+ static const char* state_str(State state) { return states[state]; }
+
+ PRFileDesc* ssl_fd() { return ssl_fd_; }
+ DummyPrSocket* adapter() { return adapter_; }
+
+ bool is_compressed() const {
+ return info_.compressionMethod != ssl_compression_null;
+ }
+ uint16_t server_key_bits() const { return server_key_bits_; }
+ uint16_t min_version() const { return vrange_.min; }
+ uint16_t max_version() const { return vrange_.max; }
+ uint16_t version() const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ return info_.protocolVersion;
+ }
+
+ bool cipher_suite(uint16_t* cipher_suite) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *cipher_suite = info_.cipherSuite;
+ return true;
+ }
+
+ std::string cipher_suite_name() const {
+ if (state_ != STATE_CONNECTED) return "UNKNOWN";
+
+ return csinfo_.cipherSuiteName;
+ }
+
+ std::vector<uint8_t> session_id() const {
+ return std::vector<uint8_t>(info_.sessionID,
+ info_.sessionID + info_.sessionIDLength);
+ }
+
+ bool auth_type(SSLAuthType* auth_type) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *auth_type = info_.authType;
+ return true;
+ }
+
+ bool kea_type(SSLKEAType* kea_type) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *kea_type = info_.keaType;
+ return true;
+ }
+
+ size_t received_bytes() const { return recv_ctr_; }
+ PRErrorCode error_code() const { return error_code_; }
+
+ bool can_falsestart_hook_called() const {
+ return can_falsestart_hook_called_;
+ }
+
+ void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
+ handshake_callback_ = handshake_callback;
+ }
+
+ void SetAuthCertificateCallback(
+ AuthCertificateCallbackFunction auth_certificate_callback) {
+ auth_certificate_callback_ = auth_certificate_callback;
+ }
+
+ void SetSniCallback(SniCallbackFunction sni_callback) {
+ sni_callback_ = sni_callback;
+ }
+
+ private:
+ const static char* states[];
+
+ void SetState(State state);
+
+ // Dummy auth certificate hook.
+ static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->auth_certificate_hook_called_ = true;
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ // Client auth certificate hook.
+ static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ EXPECT_TRUE(agent->expect_client_auth_);
+ EXPECT_EQ(PR_TRUE, isServer);
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** cert,
+ SECKEYPrivateKey** privKey);
+
+ static void ReadableCallback(PollTarget* self, Event event) {
+ TlsAgent* agent = static_cast<TlsAgent*>(self);
+ if (event == TIMER_EVENT) {
+ agent->timer_handle_ = nullptr;
+ }
+ agent->ReadableCallback_int();
+ }
+
+ void ReadableCallback_int() {
+ LOGV("Readable");
+ switch (state_) {
+ case STATE_CONNECTING:
+ Handshake();
+ break;
+ case STATE_CONNECTED:
+ ReadBytes();
+ break;
+ default:
+ break;
+ }
+ }
+
+ static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->sni_hook_called_ = true;
+ EXPECT_EQ(1UL, srvNameArrSize);
+ if (agent->sni_callback_) {
+ return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
+ }
+ return 0; // First configuration.
+ }
+
+ static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
+ PRBool* canFalseStart) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ EXPECT_TRUE(agent->falsestart_enabled_);
+ EXPECT_FALSE(agent->can_falsestart_hook_called_);
+ agent->can_falsestart_hook_called_ = true;
+ *canFalseStart = true;
+ return SECSuccess;
+ }
+
+ static void HandshakeCallback(PRFileDesc* fd, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->handshake_callback_called_ = true;
+ agent->Connected();
+ if (agent->handshake_callback_) {
+ agent->handshake_callback_(agent);
+ }
+ }
+
+ void DisableLameGroups();
+ void ConfigStrongECGroups(bool en);
+ void ConfigAllDHGroups(bool en);
+ void CheckCallbacks() const;
+ void Connected();
+
+ const std::string name_;
+ Mode mode_;
+ uint16_t server_key_bits_;
+ PRFileDesc* pr_fd_;
+ DummyPrSocket* adapter_;
+ PRFileDesc* ssl_fd_;
+ Role role_;
+ State state_;
+ Poller::Timer* timer_handle_;
+ bool falsestart_enabled_;
+ uint16_t expected_version_;
+ uint16_t expected_cipher_suite_;
+ bool expect_resumption_;
+ bool expect_client_auth_;
+ bool can_falsestart_hook_called_;
+ bool sni_hook_called_;
+ bool auth_certificate_hook_called_;
+ bool handshake_callback_called_;
+ SSLChannelInfo info_;
+ SSLCipherSuiteInfo csinfo_;
+ SSLVersionRange vrange_;
+ PRErrorCode error_code_;
+ size_t send_ctr_;
+ size_t recv_ctr_;
+ bool expect_readwrite_error_;
+ HandshakeCallbackFunction handshake_callback_;
+ AuthCertificateCallbackFunction auth_certificate_callback_;
+ SniCallbackFunction sni_callback_;
+ bool expect_short_headers_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsAgent::State& state) {
+ return stream << TlsAgent::state_str(state);
+}
+
+class TlsAgentTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
+
+ TlsAgentTestBase(TlsAgent::Role role, Mode mode)
+ : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {}
+ ~TlsAgentTestBase() {
+ if (fd_) {
+ PR_Close(fd_);
+ }
+ }
+
+ void SetUp();
+ void TearDown();
+
+ static void MakeRecord(Mode mode, uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len, DataBuffer* out,
+ uint64_t seq_num = 0);
+ void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
+ DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
+ size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const;
+ static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
+ DataBuffer* out);
+ static inline TlsAgent::Role ToRole(const std::string& str) {
+ return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
+ }
+
+ static inline Mode ToMode(const std::string& str) {
+ return str == "TLS" ? STREAM : DGRAM;
+ }
+
+ void Init(const std::string& server_name = TlsAgent::kServerRsa);
+ void Reset(const std::string& server_name = TlsAgent::kServerRsa);
+
+ protected:
+ void EnsureInit();
+ void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
+ int32_t error_code = 0);
+
+ TlsAgent* agent_;
+ PRFileDesc* fd_;
+ TlsAgent::Role role_;
+ Mode mode_;
+};
+
+class TlsAgentTest : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<std::string, std::string>> {
+ public:
+ TlsAgentTest()
+ : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
+ ToMode(std::get<1>(GetParam()))) {}
+};
+
+class TlsAgentTestClient : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsAgentTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {}
+};
+
+class TlsAgentStreamTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {}
+};
+
+class TlsAgentStreamTestServer : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {}
+};
+
+class TlsAgentDgramTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {}
+};
+
+} // namespace nss_test
+
+#endif