/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_agent.h" #include "databuffer.h" #include "keyhi.h" #include "pk11func.h" #include "ssl.h" #include "sslerr.h" #include "sslproto.h" #include "tls_parser.h" extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" } #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" #include "gtest_utils.h" #include "scoped_ptrs.h" extern std::string g_working_dir_path; namespace nss_test { const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; const std::string TlsAgent::kClient = "client"; // both sign and encrypt const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt const std::string TlsAgent::kServerRsaSign = "rsa_sign"; const std::string TlsAgent::kServerRsaPss = "rsa_pss"; const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt"; const std::string TlsAgent::kServerRsaChain = "rsa_chain"; const std::string TlsAgent::kServerEcdsa256 = "ecdsa256"; const std::string TlsAgent::kServerEcdsa384 = "ecdsa384"; const std::string TlsAgent::kServerEcdsa521 = "ecdsa521"; const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; TlsAgent::TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant) : name_(name), variant_(variant), role_(role), server_key_bits_(0), adapter_(new DummyPrSocket(role_str(), variant)), ssl_fd_(nullptr), state_(STATE_INIT), timer_handle_(nullptr), falsestart_enabled_(false), expected_version_(0), expected_cipher_suite_(0), expect_resumption_(false), expect_client_auth_(false), can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), expected_received_alert_(kTlsAlertCloseNotify), expected_received_alert_level_(kTlsAlertWarning), expected_sent_alert_(kTlsAlertCloseNotify), expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), error_code_(0), send_ctr_(0), recv_ctr_(0), expect_readwrite_error_(false), handshake_callback_(), auth_certificate_callback_(), sni_callback_(), expect_short_headers_(false), skip_version_checks_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); EXPECT_EQ(SECSuccess, rv); } TlsAgent::~TlsAgent() { if (timer_handle_) { timer_handle_->Cancel(); } if (adapter_) { Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } // Add failures manually, if any, so we don't throw in a destructor. if (expected_received_alert_ != kTlsAlertCloseNotify || expected_received_alert_level_ != kTlsAlertWarning) { ADD_FAILURE() << "Wrong expected_received_alert status"; } if (expected_sent_alert_ != kTlsAlertCloseNotify || expected_sent_alert_level_ != kTlsAlertWarning) { ADD_FAILURE() << "Wrong expected_sent_alert status"; } } void TlsAgent::SetState(State state) { if (state_ == state) return; LOG("Changing state from " << state_ << " to " << state); state_ = state; } /*static*/ bool TlsAgent::LoadCertificate(const std::string& name, ScopedCERTCertificate* cert, ScopedSECKEYPrivateKey* priv) { cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); EXPECT_NE(nullptr, cert->get()); if (!cert->get()) return false; priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); EXPECT_NE(nullptr, priv->get()); if (!priv->get()) return false; return true; } bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (!TlsAgent::LoadCertificate(name, &cert, &priv)) { return false; } if (updateKeyBits) { ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); EXPECT_NE(nullptr, pub.get()); if (!pub.get()) return false; server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); } SECStatus rv = SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); EXPECT_EQ(SECFailure, rv); rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, serverCertData ? sizeof(*serverCertData) : 0); return rv == SECSuccess; } bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { // Don't set up twice if (ssl_fd_) return true; ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); EXPECT_NE(nullptr, dummy_fd); if (!dummy_fd) { return false; } if (adapter_->variant() == ssl_variant_stream) { ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); } else { ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); } EXPECT_NE(nullptr, ssl_fd_); if (!ssl_fd_) { return false; } dummy_fd.release(); // Now subsumed by ssl_fd_. SECStatus rv; if (!skip_version_checks_) { rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } if (role_ == SERVER) { EXPECT_TRUE(ConfigServerCert(name_, true)); rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; ScopedCERTCertList anchors(CERT_NewCertList()); rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; } else { rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; return true; } void TlsAgent::SetupClientAuth() { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(CLIENT, role_); EXPECT_EQ(SECSuccess, SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, reinterpret_cast<void*>(this))); } SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** clientCert, SECKEYPrivateKey** clientKey) { TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { return SECFailure; } *clientCert = cert.release(); *clientKey = priv.release(); return SECSuccess; } bool TlsAgent::GetPeerChainLength(size_t* count) { CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); if (!chain) return false; *count = 0; for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list; cursor = PR_NEXT_LINK(cursor)) { CERTCertListNode* node = (CERTCertListNode*)cursor; std::cerr << node->cert->subjectName << std::endl; ++(*count); } CERT_DestroyCertList(chain); return true; } void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { EXPECT_EQ(csinfo_.cipherSuite, cipher_suite); } void TlsAgent::RequestClientAuth(bool requireAuth) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE)); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE)); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( ssl_fd(), &TlsAgent::ClientAuthenticated, this)); expect_client_auth_ = true; } void TlsAgent::StartConnect(PRFileDesc* model) { EXPECT_TRUE(EnsureTlsSetup(model)); SECStatus rv; rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTING); } void TlsAgent::DisableAllCiphers() { for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SECStatus rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); EXPECT_EQ(SECSuccess, rv); } } // Not actually all groups, just the onece that we are actually willing // to use. const std::vector<SSLNamedGroup> kAllDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; const std::vector<SSLNamedGroup> kECDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1}; const std::vector<SSLNamedGroup> kFFDHEGroups = { ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192}; // Defined because the big DHE groups are ridiculously slow. const std::vector<SSLNamedGroup> kFasterDHEGroups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072}; void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { EXPECT_TRUE(EnsureTlsSetup()); for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); ASSERT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo), csinfo.length); if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } } void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) { switch (kea) { case ssl_kea_dh: ConfigNamedGroups(kFFDHEGroups); break; case ssl_kea_ecdh: ConfigNamedGroups(kECDHEGroups); break; default: break; } } void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) { if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa || authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) { ConfigNamedGroups(kECDHEGroups); } } void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { EXPECT_TRUE(EnsureTlsSetup()); for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SSLCipherSuiteInfo csinfo; SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo, sizeof(csinfo)); ASSERT_EQ(SECSuccess, rv); if ((csinfo.authType == authType) || (csinfo.keaType == ssl_kea_tls13_any)) { rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } } void TlsAgent::EnableSingleCipher(uint16_t cipher) { DisableAllCiphers(); SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetSessionTicketsEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetSessionCacheEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::Set0RttEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetFallbackSCSVEnabled(bool en) { EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetShortHeadersEnabled() { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd()); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { vrange_.min = minver; vrange_.max = maxver; if (ssl_fd()) { SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); } } void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { *minver = vrange_.min; *maxver = vrange_.max; } void TlsAgent::SetExpectedVersion(uint16_t version) { expected_version_ = version; } void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_LE(count, SSL_SignatureMaxCount()); EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, static_cast<unsigned int>(count))); EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) << "setting no schemes should fail and do nothing"; std::vector<SSLSignatureScheme> configuredSchemes(count); unsigned int configuredCount; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) << "get schemes, schemes is nullptr"; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], &configuredCount, 0)) << "get schemes, too little space"; EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, configuredSchemes.size())) << "get schemes, countOut is nullptr"; EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( ssl_fd(), &configuredSchemes[0], &configuredCount, configuredSchemes.size())); // SignatureSchemePrefSet drops unsupported algorithms silently, so the // number that are configured might be fewer. EXPECT_LE(configuredCount, count); unsigned int i = 0; for (unsigned int j = 0; j < count && i < configuredCount; ++j) { if (i < configuredCount && schemes[j] == configuredSchemes[i]) { ++i; } } EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; } void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, size_t kea_size) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(kea_type, info_.keaType); if (kea_size == 0) { switch (kea_group) { case ssl_grp_ec_curve25519: kea_size = 255; break; case ssl_grp_ec_secp256r1: kea_size = 256; break; case ssl_grp_ec_secp384r1: kea_size = 384; break; case ssl_grp_ffdhe_2048: kea_size = 2048; break; case ssl_grp_ffdhe_3072: kea_size = 3072; break; case ssl_grp_ffdhe_custom: break; default: if (kea_type == ssl_kea_rsa) { kea_size = server_key_bits_; } else { EXPECT_TRUE(false) << "need to update group sizes"; } } } if (kea_group != ssl_grp_ffdhe_custom) { EXPECT_EQ(kea_size, info_.keaKeyBits); EXPECT_EQ(kea_group, info_.keaGroup); } } void TlsAgent::CheckAuthType(SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const { EXPECT_EQ(STATE_CONNECTED, state_); EXPECT_EQ(auth_type, info_.authType); EXPECT_EQ(server_key_bits_, info_.authKeyBits); if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { switch (auth_type) { case ssl_auth_rsa_sign: sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; break; case ssl_auth_ecdsa: sig_scheme = ssl_sig_ecdsa_sha1; break; default: break; } } EXPECT_EQ(sig_scheme, info_.signatureScheme); if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) { return; } // Check authAlgorithm, which is the old value for authType. This is a second // switch // statement because default label is different. switch (auth_type) { case ssl_auth_rsa_sign: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for RSA is always decrypt"; break; case ssl_auth_ecdh_rsa: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)"; break; case ssl_auth_ecdh_ecdsa: EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm) << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; break; default: EXPECT_EQ(auth_type, csinfo_.authAlgorithm) << "authAlgorithm is (usually) the same as authType"; break; } } void TlsAgent::EnableFalseStart() { EXPECT_TRUE(EnsureTlsSetup()); falsestart_enabled_ = true; EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( ssl_fd(), CanFalseStartCallback, this)); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE)); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE)); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, const std::string& expected) const { SSLNextProtoState state; char chosen[10]; unsigned int chosen_len; SECStatus rv = SSL_GetNextProto(ssl_fd(), &state, reinterpret_cast<unsigned char*>(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(expected_state, state); if (state == SSL_NEXT_PROTO_NO_SUPPORT) { EXPECT_EQ("", expected); } else { EXPECT_NE("", expected); EXPECT_EQ(expected, std::string(chosen, chosen_len)); } } void TlsAgent::EnableSrtp() { EXPECT_TRUE(EnsureTlsSetup()); const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}; EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); } void TlsAgent::CheckSrtp() const { uint16_t actual; EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } void TlsAgent::CheckErrorCode(int32_t expected) const { EXPECT_EQ(STATE_ERROR, state_); EXPECT_EQ(expected, error_code_) << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " << PORT_ErrorToName(expected) << std::endl; } static uint8_t GetExpectedAlertLevel(uint8_t alert) { switch (alert) { case kTlsAlertCloseNotify: case kTlsAlertEndOfEarlyData: return kTlsAlertWarning; default: break; } return kTlsAlertFatal; } void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { expected_received_alert_ = alert; if (level == 0) { expected_received_alert_level_ = GetExpectedAlertLevel(alert); } else { expected_received_alert_level_ = level; } } void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { expected_sent_alert_ = alert; if (level == 0) { expected_sent_alert_level_ = GetExpectedAlertLevel(alert); } else { expected_sent_alert_level_ = level; } } void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") << " alert " << (sent ? "sent" : "received") << ": " << static_cast<int>(alert->description)); auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; auto& expected_level = sent ? expected_sent_alert_level_ : expected_received_alert_level_; /* Silently pass close_notify in case the test has already ended. */ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && alert->description == expected && alert->level == expected_level) { return; } EXPECT_EQ(expected, alert->description); EXPECT_EQ(expected_level, alert->level); expected = kTlsAlertCloseNotify; expected_level = kTlsAlertWarning; } void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { ASSERT_EQ(0, error_code_); WAIT_(error_code_ != 0, delay); EXPECT_EQ(expected, error_code_) << "Got error code " << PORT_ErrorToName(error_code_) << " expecting " << PORT_ErrorToName(expected) << std::endl; } void TlsAgent::CheckPreliminaryInfo() { SSLPreliminaryChannelInfo info; EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info))); EXPECT_EQ(sizeof(info), info.length); EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); // A version of 0 is invalid and indicates no expectation. This value is // initialized to 0 so that tests that don't explicitly set an expected // version can negotiate a version. if (!expected_version_) { expected_version_ = info.protocolVersion; } EXPECT_EQ(expected_version_, info.protocolVersion); // As with the version; 0 is the null cipher suite (and also invalid). if (!expected_cipher_suite_) { expected_cipher_suite_ = info.cipherSuite; } EXPECT_EQ(expected_cipher_suite_, info.cipherSuite); } // Check that all the expected callbacks have been called. void TlsAgent::CheckCallbacks() const { // If false start happens, the handshake is reported as being complete at the // point that false start happens. if (expect_resumption_ || !falsestart_enabled_) { EXPECT_TRUE(handshake_callback_called_); } // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. if (role_ == SERVER) { PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); EXPECT_EQ(((!expect_resumption_ && have_sni) || expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), sni_hook_called_); } else { EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_); // Note that this isn't unconditionally called, even with false start on. // But the callback is only skipped if a cipher that is ridiculously weak // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. EXPECT_EQ(falsestart_enabled_ && !expect_resumption_, can_falsestart_hook_called_); } } void TlsAgent::ResetPreliminaryInfo() { expected_version_ = 0; expected_cipher_suite_ = 0; } void TlsAgent::Connected() { if (state_ == STATE_CONNECTED) { return; } LOG("Handshake success"); CheckPreliminaryInfo(); CheckCallbacks(); SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); // Preliminary values are exposed through callbacks during the handshake. // If either expected values were set or the callbacks were called, check // that the final values are correct. EXPECT_EQ(expected_version_, info_.protocolVersion); EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(csinfo_), csinfo_.length); if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); // We use one ciphersuite in each direction, plus one that's kept around // by DTLS for retransmission. PRInt32 expected = ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; EXPECT_EQ(expected, cipherSuites); if (expected != cipherSuites) { SSLInt_PrintTls13CipherSpecs(ssl_fd()); } } PRBool short_headers; rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ((PRBool)expect_short_headers_, short_headers); SetState(STATE_CONNECTED); } void TlsAgent::EnableExtendedMasterSecret() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::CheckExtendedMasterSecret(bool expected) { if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) { expected = PR_TRUE; } ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) << "unexpected extended master secret state for " << name_; } void TlsAgent::CheckEarlyDataAccepted(bool expected) { if (version() < SSL_LIBRARY_VERSION_TLS_1_3) { expected = false; } ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE) << "unexpected early data state for " << name_; } void TlsAgent::CheckSecretsDestroyed() { ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } void TlsAgent::DisableRollbackDetection() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::EnableCompression() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::Handshake() { LOGV("Handshake"); SECStatus rv = SSL_ForceHandshake(ssl_fd()); if (rv == SECSuccess) { Connected(); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; } int32_t err = PR_GetError(); if (err == PR_WOULD_BLOCK_ERROR) { LOGV("Would have blocked"); if (variant_ == ssl_variant_datagram) { if (timer_handle_) { timer_handle_->Cancel(); timer_handle_ = nullptr; } PRIntervalTime timeout; rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); if (rv == SECSuccess) { Poller::Instance()->SetTimer( timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); } } Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; } LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": " << PORT_ErrorToString(err)); error_code_ = err; SetState(STATE_ERROR); } void TlsAgent::PrepareForRenegotiate() { EXPECT_EQ(STATE_CONNECTED, state_); SetState(STATE_CONNECTING); } void TlsAgent::StartRenegotiate() { PrepareForRenegotiate(); SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SendDirect(const DataBuffer& buf) { LOG("Send Direct " << buf); auto peer = adapter_->peer().lock(); if (peer) { peer->PacketReceived(buf); } else { LOG("Send Direct peer absent"); } } static bool ErrorIsNonFatal(PRErrorCode code) { return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ; } void TlsAgent::SendData(size_t bytes, size_t blocksize) { uint8_t block[4096]; ASSERT_LT(blocksize, sizeof(block)); while (bytes) { size_t tosend = std::min(blocksize, bytes); for (size_t i = 0; i < tosend; ++i) { block[i] = 0xff & send_ctr_; ++send_ctr_; } SendBuffer(DataBuffer(block, tosend)); bytes -= tosend; } } void TlsAgent::SendBuffer(const DataBuffer& buf) { LOGV("Writing " << buf.len() << " bytes"); int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); if (expect_readwrite_error_) { EXPECT_GT(0, rv); EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); error_code_ = PR_GetError(); expect_readwrite_error_ = false; } else { ASSERT_EQ(buf.len(), static_cast<size_t>(rv)); } } void TlsAgent::ReadBytes(size_t amount) { uint8_t block[16384]; int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); LOGV("ReadBytes " << rv); int32_t err; if (rv >= 0) { size_t count = static_cast<size_t>(rv); for (size_t i = 0; i < count; ++i) { ASSERT_EQ(recv_ctr_ & 0xff, block[i]); recv_ctr_++; } } else { err = PR_GetError(); LOG("Read error " << PORT_ErrorToName(err) << ": " << PORT_ErrorToString(err)); if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { error_code_ = err; expect_readwrite_error_ = false; } } // If closed, then don't bother waiting around. if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) { LOGV("Re-arming"); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); } } void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::DisableECDHEServerKeyReuse() { ASSERT_TRUE(EnsureTlsSetup()); ASSERT_EQ(TlsAgent::SERVER, role_); SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); EXPECT_EQ(SECSuccess, rv); } static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; ::testing::internal::ParamGenerator<std::string> TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); void TlsAgentTestBase::SetUp() { SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); } void TlsAgentTestBase::TearDown() { agent_ = nullptr; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsAgentTestBase::Reset(const std::string& server_name) { agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, role_, variant_)); if (version_) { agent_->SetVersionRange(version_, version_); } agent_->adapter()->SetPeer(sink_adapter_); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { Reset(); } const std::vector<SSLNamedGroup> groups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048}; agent_->ConfigNamedGroups(groups); } void TlsAgentTestBase::ExpectAlert(uint8_t alert) { EnsureInit(); agent_->ExpectSendAlert(alert); } void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code) { std::cerr << "Process message: " << buffer << std::endl; EnsureInit(); agent_->adapter()->PacketReceived(buffer); agent_->Handshake(); ASSERT_EQ(expected_state, agent_->state()); if (expected_state == TlsAgent::STATE_ERROR) { ASSERT_EQ(error_code, agent_->error_code()); } } void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) { size_t index = 0; index = out->Write(index, type, 1); if (variant == ssl_variant_stream) { index = out->Write(index, version, 2); } else { index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, seq_num >> 32, 4); index = out->Write(index, seq_num & PR_UINT32_MAX, 4); } index = out->Write(index, len, 2); out->Write(index, buf, len); } void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) const { MakeRecord(variant_, type, version, buf, len, out, seq_num); } void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, uint64_t seq_num) const { return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0, 0); } void TlsAgentTestBase::MakeHandshakeMessageFragment( uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out, uint64_t seq_num, uint32_t fragment_offset, uint32_t fragment_length) const { size_t index = 0; if (!fragment_length) fragment_length = hs_len; index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length if (variant_ == ssl_variant_datagram) { index = out->Write(index, seq_num, 2); index = out->Write(index, fragment_offset, 3); index = out->Write(index, fragment_length, 3); } if (data) { index = out->Write(index, data, fragment_length); } else { for (size_t i = 0; i < fragment_length; ++i) { index = out->Write(index, 1, 1); } } } void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, DataBuffer* out) { size_t index = 0; index = out->Write(index, kTlsHandshakeType, 1); // Content Type index = out->Write(index, 3, 1); // Version high index = out->Write(index, 1, 1); // Version low index = out->Write(index, 4 + hs_len, 2); // Length index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length for (size_t i = 0; i < hs_len; ++i) { index = out->Write(index, 1, 1); } } } // namespace nss_test