summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc992
1 files changed, 992 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
new file mode 100644
index 000000000..b75bba567
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -0,0 +1,992 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_agent.h"
+#include "databuffer.h"
+#include "keyhi.h"
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "tls_parser.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
+
+const std::string TlsAgent::kClient = "client"; // both sign and encrypt
+const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
+const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
+const std::string TlsAgent::kServerRsaSign = "rsa_sign";
+const std::string TlsAgent::kServerRsaPss = "rsa_pss";
+const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
+const std::string TlsAgent::kServerRsaChain = "rsa_chain";
+const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
+const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
+const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
+const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
+const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
+const std::string TlsAgent::kServerDsa = "dsa";
+
+TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
+ : name_(name),
+ mode_(mode),
+ server_key_bits_(0),
+ pr_fd_(nullptr),
+ adapter_(nullptr),
+ ssl_fd_(nullptr),
+ role_(role),
+ state_(STATE_INIT),
+ timer_handle_(nullptr),
+ falsestart_enabled_(false),
+ expected_version_(0),
+ expected_cipher_suite_(0),
+ expect_resumption_(false),
+ expect_client_auth_(false),
+ can_falsestart_hook_called_(false),
+ sni_hook_called_(false),
+ auth_certificate_hook_called_(false),
+ handshake_callback_called_(false),
+ error_code_(0),
+ send_ctr_(0),
+ recv_ctr_(0),
+ expect_readwrite_error_(false),
+ handshake_callback_(),
+ auth_certificate_callback_(),
+ sni_callback_(),
+ expect_short_headers_(false) {
+ memset(&info_, 0, sizeof(info_));
+ memset(&csinfo_, 0, sizeof(csinfo_));
+ SECStatus rv = SSL_VersionRangeGetDefault(
+ mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+TlsAgent::~TlsAgent() {
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
+ // The adapter is closed when the FD closes.
+ }
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ }
+
+ if (pr_fd_) {
+ PR_Close(pr_fd_);
+ }
+
+ if (ssl_fd_) {
+ PR_Close(ssl_fd_);
+ }
+}
+
+void TlsAgent::SetState(State state) {
+ if (state_ == state) return;
+
+ LOG("Changing state from " << state_ << " to " << state);
+ state_ = state;
+}
+
+bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
+ const SSLExtraServerCertData* serverCertData) {
+ ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert.get());
+ if (!cert.get()) return false;
+
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
+ if (updateKeyBits) {
+ server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
+ }
+
+ ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr));
+ EXPECT_NE(nullptr, priv.get());
+ if (!priv.get()) return false;
+
+ SECStatus rv =
+ SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null);
+ EXPECT_EQ(SECFailure, rv);
+ rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData,
+ serverCertData ? sizeof(*serverCertData) : 0);
+ return rv == SECSuccess;
+}
+
+bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
+ // Don't set up twice
+ if (ssl_fd_) return true;
+
+ if (adapter_->mode() == STREAM) {
+ ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
+ } else {
+ ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
+ }
+
+ EXPECT_NE(nullptr, ssl_fd_);
+ if (!ssl_fd_) return false;
+ pr_fd_ = nullptr;
+
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ if (role_ == SERVER) {
+ EXPECT_TRUE(ConfigServerCert(name_, true));
+
+ rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ ScopedCERTCertList anchors(CERT_NewCertList());
+ rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get());
+ if (rv != SECSuccess) return false;
+ } else {
+ rv = SSL_SetURL(ssl_fd_, "server");
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ return true;
+}
+
+void TlsAgent::SetupClientAuth() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(CLIENT, role_);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
+ reinterpret_cast<void*>(this)));
+}
+
+bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert,
+ SECKEYPrivateKey** priv) const {
+ *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
+ EXPECT_NE(nullptr, *cert);
+ if (!*cert) return false;
+
+ *priv = PK11_FindKeyByAnyCert(*cert, nullptr);
+ EXPECT_NE(nullptr, *priv);
+ if (!*priv) return false; // Leak cert.
+
+ return true;
+}
+
+SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** cert,
+ SECKEYPrivateKey** privKey) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
+ ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
+ EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+ if (agent->GetClientAuthCredentials(cert, privKey)) {
+ return SECSuccess;
+ }
+ return SECFailure;
+}
+
+bool TlsAgent::GetPeerChainLength(size_t* count) {
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_);
+ if (!chain) return false;
+ *count = 0;
+
+ for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list;
+ cursor = PR_NEXT_LINK(cursor)) {
+ CERTCertListNode* node = (CERTCertListNode*)cursor;
+ std::cerr << node->cert->subjectName << std::endl;
+ ++(*count);
+ }
+
+ CERT_DestroyCertList(chain);
+
+ return true;
+}
+
+void TlsAgent::RequestClientAuth(bool requireAuth) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(SERVER, role_);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
+ requireAuth ? PR_TRUE : PR_FALSE));
+
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
+ ssl_fd_, &TlsAgent::ClientAuthenticated, this));
+ expect_client_auth_ = true;
+}
+
+void TlsAgent::StartConnect(PRFileDesc* model) {
+ EXPECT_TRUE(EnsureTlsSetup(model));
+
+ SECStatus rv;
+ rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::DisableAllCiphers() {
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SECStatus rv =
+ SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+// Not actually all groups, just the onece that we are actually willing
+// to use.
+const std::vector<SSLNamedGroup> kAllDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072,
+ ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+const std::vector<SSLNamedGroup> kECDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+const std::vector<SSLNamedGroup> kFFDHEGroups = {
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
+ ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+// Defined because the big DHE groups are ridiculously slow.
+const std::vector<SSLNamedGroup> kFasterDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072};
+
+void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo), csinfo.length);
+
+ if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
+ switch (kea) {
+ case ssl_kea_dh:
+ ConfigNamedGroups(kFFDHEGroups);
+ break;
+ case ssl_kea_ecdh:
+ ConfigNamedGroups(kECDHEGroups);
+ break;
+ default:
+ break;
+ }
+}
+
+void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
+ if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
+ authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
+ ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+
+ if ((csinfo.authType == authType) ||
+ (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableSingleCipher(uint16_t cipher) {
+ DisableAllCiphers();
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetSessionTicketsEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetSessionCacheEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Set0RttEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv =
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetShortHeadersEnabled() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
+ vrange_.min = minver;
+ vrange_.max = maxver;
+
+ if (ssl_fd_) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
+ *minver = vrange_.min;
+ *maxver = vrange_.max;
+}
+
+void TlsAgent::SetExpectedVersion(uint16_t version) {
+ expected_version_ = version;
+}
+
+void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
+
+void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
+
+void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
+
+void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
+ size_t count) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_LE(count, SSL_SignatureMaxCount());
+ EXPECT_EQ(SECSuccess,
+ SSL_SignatureSchemePrefSet(ssl_fd_, schemes,
+ static_cast<unsigned int>(count)));
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0))
+ << "setting no schemes should fail and do nothing";
+
+ std::vector<SSLSignatureScheme> configuredSchemes(count);
+ unsigned int configuredCount;
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
+ << "get schemes, schemes is nullptr";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0],
+ &configuredCount, 0))
+ << "get schemes, too little space";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr,
+ configuredSchemes.size()))
+ << "get schemes, countOut is nullptr";
+
+ EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
+ ssl_fd_, &configuredSchemes[0], &configuredCount,
+ configuredSchemes.size()));
+ // SignatureSchemePrefSet drops unsupported algorithms silently, so the
+ // number that are configured might be fewer.
+ EXPECT_LE(configuredCount, count);
+ unsigned int i = 0;
+ for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
+ if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
+ ++i;
+ }
+ }
+ EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
+}
+
+void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ size_t kea_size) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(kea_type, info_.keaType);
+ if (kea_size == 0) {
+ switch (kea_group) {
+ case ssl_grp_ec_curve25519:
+ kea_size = 255;
+ break;
+ case ssl_grp_ec_secp256r1:
+ kea_size = 256;
+ break;
+ case ssl_grp_ec_secp384r1:
+ kea_size = 384;
+ break;
+ case ssl_grp_ffdhe_2048:
+ kea_size = 2048;
+ break;
+ case ssl_grp_ffdhe_3072:
+ kea_size = 3072;
+ break;
+ case ssl_grp_ffdhe_custom:
+ break;
+ default:
+ if (kea_type == ssl_kea_rsa) {
+ kea_size = server_key_bits_;
+ } else {
+ EXPECT_TRUE(false) << "need to update group sizes";
+ }
+ }
+ }
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_size, info_.keaKeyBits);
+ EXPECT_EQ(kea_group, info_.keaGroup);
+ }
+}
+
+void TlsAgent::CheckAuthType(SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(auth_type, info_.authType);
+ EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ switch (auth_type) {
+ case ssl_auth_rsa_sign:
+ sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
+ break;
+ case ssl_auth_ecdsa:
+ sig_scheme = ssl_sig_ecdsa_sha1;
+ break;
+ default:
+ break;
+ }
+ }
+ EXPECT_EQ(sig_scheme, info_.signatureScheme);
+
+ if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
+
+ // Check authAlgorithm, which is the old value for authType. This is a second
+ // switch
+ // statement because default label is different.
+ switch (auth_type) {
+ case ssl_auth_rsa_sign:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for RSA is always decrypt";
+ break;
+ case ssl_auth_ecdh_rsa:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
+ break;
+ default:
+ EXPECT_EQ(auth_type, csinfo_.authAlgorithm)
+ << "authAlgorithm is (usually) the same as authType";
+ break;
+ }
+}
+
+void TlsAgent::EnableFalseStart() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess,
+ SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
+}
+
+void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
+
+void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
+}
+
+void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected) const {
+ SSLNextProtoState state;
+ char chosen[10];
+ unsigned int chosen_len;
+ SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
+ reinterpret_cast<unsigned char*>(chosen),
+ &chosen_len, sizeof(chosen));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_state, state);
+ if (state == SSL_NEXT_PROTO_NO_SUPPORT) {
+ EXPECT_EQ("", expected);
+ } else {
+ EXPECT_NE("", expected);
+ EXPECT_EQ(expected, std::string(chosen, chosen_len));
+ }
+}
+
+void TlsAgent::EnableSrtp() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
+ SRTP_AES128_CM_HMAC_SHA1_32};
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers)));
+}
+
+void TlsAgent::CheckSrtp() const {
+ uint16_t actual;
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
+ EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
+}
+
+void TlsAgent::CheckErrorCode(int32_t expected) const {
+ EXPECT_EQ(STATE_ERROR, state_);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
+ ASSERT_EQ(0, error_code_);
+ WAIT_(error_code_ != 0, delay);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::CheckPreliminaryInfo() {
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
+ EXPECT_EQ(sizeof(info), info.length);
+ EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
+ EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
+
+ // A version of 0 is invalid and indicates no expectation. This value is
+ // initialized to 0 so that tests that don't explicitly set an expected
+ // version can negotiate a version.
+ if (!expected_version_) {
+ expected_version_ = info.protocolVersion;
+ }
+ EXPECT_EQ(expected_version_, info.protocolVersion);
+
+ // As with the version; 0 is the null cipher suite (and also invalid).
+ if (!expected_cipher_suite_) {
+ expected_cipher_suite_ = info.cipherSuite;
+ }
+ EXPECT_EQ(expected_cipher_suite_, info.cipherSuite);
+}
+
+// Check that all the expected callbacks have been called.
+void TlsAgent::CheckCallbacks() const {
+ // If false start happens, the handshake is reported as being complete at the
+ // point that false start happens.
+ if (expect_resumption_ || !falsestart_enabled_) {
+ EXPECT_TRUE(handshake_callback_called_);
+ }
+
+ // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
+ if (role_ == SERVER) {
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn);
+ EXPECT_EQ(((!expect_resumption_ && have_sni) ||
+ expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
+ sni_hook_called_);
+ } else {
+ EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_);
+ // Note that this isn't unconditionally called, even with false start on.
+ // But the callback is only skipped if a cipher that is ridiculously weak
+ // (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
+ EXPECT_EQ(falsestart_enabled_ && !expect_resumption_,
+ can_falsestart_hook_called_);
+ }
+}
+
+void TlsAgent::ResetPreliminaryInfo() {
+ expected_version_ = 0;
+ expected_cipher_suite_ = 0;
+}
+
+void TlsAgent::Connected() {
+ LOG("Handshake success");
+ CheckPreliminaryInfo();
+ CheckCallbacks();
+
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(info_), info_.length);
+
+ // Preliminary values are exposed through callbacks during the handshake.
+ // If either expected values were set or the callbacks were called, check
+ // that the final values are correct.
+ EXPECT_EQ(expected_version_, info_.protocolVersion);
+ EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
+
+ rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
+
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_);
+ // We use one ciphersuite in each direction, plus one that's kept around
+ // by DTLS for retransmission.
+ PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2;
+ EXPECT_EQ(expected, cipherSuites);
+ if (expected != cipherSuites) {
+ SSLInt_PrintTls13CipherSpecs(ssl_fd_);
+ }
+ }
+
+ PRBool short_headers;
+ rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers);
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
+ SetState(STATE_CONNECTED);
+}
+
+void TlsAgent::EnableExtendedMasterSecret() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv =
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::CheckExtendedMasterSecret(bool expected) {
+ if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = PR_TRUE;
+ }
+ ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
+ << "unexpected extended master secret state for " << name_;
+}
+
+void TlsAgent::CheckEarlyDataAccepted(bool expected) {
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = false;
+ }
+ ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
+ << "unexpected early data state for " << name_;
+}
+
+void TlsAgent::CheckSecretsDestroyed() {
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_));
+}
+
+void TlsAgent::DisableRollbackDetection() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE);
+
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::EnableCompression() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Handshake() {
+ LOGV("Handshake");
+ SECStatus rv = SSL_ForceHandshake(ssl_fd_);
+ if (rv == SECSuccess) {
+ Connected();
+
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ int32_t err = PR_GetError();
+ if (err == PR_WOULD_BLOCK_ERROR) {
+ LOGV("Would have blocked");
+ if (mode_ == DGRAM) {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ timer_handle_ = nullptr;
+ }
+
+ PRIntervalTime timeout;
+ rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
+ if (rv == SECSuccess) {
+ Poller::Instance()->SetTimer(
+ timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
+ }
+ }
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ error_code_ = err;
+ SetState(STATE_ERROR);
+}
+
+void TlsAgent::PrepareForRenegotiate() {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::StartRenegotiate() {
+ PrepareForRenegotiate();
+
+ SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+ LOG("Send Direct " << buf);
+ adapter_->peer()->PacketReceived(buf);
+}
+
+static bool ErrorIsNonFatal(PRErrorCode code) {
+ return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
+}
+
+void TlsAgent::SendData(size_t bytes, size_t blocksize) {
+ uint8_t block[4096];
+
+ ASSERT_LT(blocksize, sizeof(block));
+
+ while (bytes) {
+ size_t tosend = std::min(blocksize, bytes);
+
+ for (size_t i = 0; i < tosend; ++i) {
+ block[i] = 0xff & send_ctr_;
+ ++send_ctr_;
+ }
+
+ SendBuffer(DataBuffer(block, tosend));
+ bytes -= tosend;
+ }
+}
+
+void TlsAgent::SendBuffer(const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len());
+ if (expect_readwrite_error_) {
+ EXPECT_GT(0, rv);
+ EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
+ error_code_ = PR_GetError();
+ expect_readwrite_error_ = false;
+ } else {
+ ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
+ }
+}
+
+void TlsAgent::ReadBytes() {
+ uint8_t block[1024];
+
+ int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+ LOGV("ReadBytes " << rv);
+ int32_t err;
+
+ if (rv >= 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ } else {
+ err = PR_GetError();
+ LOG("Read error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
+ error_code_ = err;
+ expect_readwrite_error_ = false;
+ }
+ }
+
+ // If closed, then don't bother waiting around.
+ if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) {
+ LOGV("Re-arming");
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ }
+}
+
+void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
+
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
+ mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::DisableECDHEServerKeyReuse() {
+ ASSERT_EQ(TlsAgent::SERVER, role_);
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
+::testing::internal::ParamGenerator<std::string>
+ TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
+
+void TlsAgentTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsAgentTestBase::TearDown() {
+ delete agent_;
+ SSL_ClearSessionCache();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsAgentTestBase::Reset(const std::string& server_name) {
+ delete agent_;
+ Init(server_name);
+}
+
+void TlsAgentTestBase::Init(const std::string& server_name) {
+ agent_ =
+ new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
+ role_, mode_);
+ agent_->Init();
+ fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
+ agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
+ agent_->StartConnect();
+}
+
+void TlsAgentTestBase::EnsureInit() {
+ if (!agent_) {
+ Init();
+ }
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ agent_->ConfigNamedGroups(groups);
+}
+
+void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
+ TlsAgent::State expected_state,
+ int32_t error_code) {
+ std::cerr << "Process message: " << buffer << std::endl;
+ EnsureInit();
+ agent_->adapter()->PacketReceived(buffer);
+ agent_->Handshake();
+
+ ASSERT_EQ(expected_state, agent_->state());
+
+ if (expected_state == TlsAgent::STATE_ERROR) {
+ ASSERT_EQ(error_code, agent_->error_code());
+ }
+}
+
+void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) {
+ size_t index = 0;
+ index = out->Write(index, type, 1);
+ index = out->Write(
+ index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2);
+ if (mode == DGRAM) {
+ index = out->Write(index, seq_num >> 32, 4);
+ index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
+ }
+ index = out->Write(index, len, 2);
+ out->Write(index, buf, len);
+}
+
+void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) const {
+ MakeRecord(mode_, type, version, buf, len, out, seq_num);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
+ const uint8_t* data, size_t hs_len,
+ DataBuffer* out,
+ uint64_t seq_num) const {
+ return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
+ 0);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessageFragment(
+ uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const {
+ size_t index = 0;
+ if (!fragment_length) fragment_length = hs_len;
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ if (mode_ == DGRAM) {
+ index = out->Write(index, seq_num, 2);
+ index = out->Write(index, fragment_offset, 3);
+ index = out->Write(index, fragment_length, 3);
+ }
+ if (data) {
+ index = out->Write(index, data, fragment_length);
+ } else {
+ for (size_t i = 0; i < fragment_length; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+ }
+}
+
+void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
+ size_t hs_len,
+ DataBuffer* out) {
+ size_t index = 0;
+ index = out->Write(index, kTlsHandshakeType, 1); // Content Type
+ index = out->Write(index, 3, 1); // Version high
+ index = out->Write(index, 1, 1); // Version low
+ index = out->Write(index, 4 + hs_len, 2); // Length
+
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ for (size_t i = 0; i < hs_len; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+}
+
+} // namespace nss_test