diff options
author | wolfbeast <mcwerewolf@gmail.com> | 2018-02-06 11:46:26 +0100 |
---|---|---|
committer | wolfbeast <mcwerewolf@gmail.com> | 2018-02-06 11:46:26 +0100 |
commit | f017b749ea9f1586d2308504553d40bf4cc5439d (patch) | |
tree | c6033924a0de9be1ab140596e305898c651bf57e /security/nss/gtests/ssl_gtest/tls_agent.cc | |
parent | 7c728b3c7680662fc4e92b5d03697b8339560b08 (diff) | |
download | UXP-f017b749ea9f1586d2308504553d40bf4cc5439d.tar UXP-f017b749ea9f1586d2308504553d40bf4cc5439d.tar.gz UXP-f017b749ea9f1586d2308504553d40bf4cc5439d.tar.lz UXP-f017b749ea9f1586d2308504553d40bf4cc5439d.tar.xz UXP-f017b749ea9f1586d2308504553d40bf4cc5439d.zip |
Update NSS to 3.32.1-RTM
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.cc | 375 |
1 files changed, 240 insertions, 135 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index b75bba567..d6d91f7f7 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -43,14 +43,14 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; -TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) +TlsAgent::TlsAgent(const std::string& name, Role role, + SSLProtocolVariant variant) : name_(name), - mode_(mode), + variant_(variant), + role_(role), server_key_bits_(0), - pr_fd_(nullptr), - adapter_(nullptr), + adapter_(new DummyPrSocket(role_str(), variant)), ssl_fd_(nullptr), - role_(role), state_(STATE_INIT), timer_handle_(nullptr), falsestart_enabled_(false), @@ -61,6 +61,10 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) can_falsestart_hook_called_(false), sni_hook_called_(false), auth_certificate_hook_called_(false), + expected_received_alert_(kTlsAlertCloseNotify), + expected_received_alert_level_(kTlsAlertWarning), + expected_sent_alert_(kTlsAlertCloseNotify), + expected_sent_alert_level_(kTlsAlertWarning), handshake_callback_called_(false), error_code_(0), send_ctr_(0), @@ -69,29 +73,31 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode) handshake_callback_(), auth_certificate_callback_(), sni_callback_(), - expect_short_headers_(false) { + expect_short_headers_(false), + skip_version_checks_(false) { memset(&info_, 0, sizeof(info_)); memset(&csinfo_, 0, sizeof(csinfo_)); - SECStatus rv = SSL_VersionRangeGetDefault( - mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_); + SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_); EXPECT_EQ(SECSuccess, rv); } TlsAgent::~TlsAgent() { - if (adapter_) { - Poller::Instance()->Cancel(READABLE_EVENT, adapter_); - // The adapter is closed when the FD closes. - } if (timer_handle_) { timer_handle_->Cancel(); } - if (pr_fd_) { - PR_Close(pr_fd_); + if (adapter_) { + Poller::Instance()->Cancel(READABLE_EVENT, adapter_); } - if (ssl_fd_) { - PR_Close(ssl_fd_); + // Add failures manually, if any, so we don't throw in a destructor. + if (expected_received_alert_ != kTlsAlertCloseNotify || + expected_received_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_received_alert status"; + } + if (expected_sent_alert_ != kTlsAlertCloseNotify || + expected_sent_alert_level_ != kTlsAlertWarning) { + ADD_FAILURE() << "Wrong expected_sent_alert status"; } } @@ -102,27 +108,39 @@ void TlsAgent::SetState(State state) { state_ = state; } +/*static*/ bool TlsAgent::LoadCertificate(const std::string& name, + ScopedCERTCertificate* cert, + ScopedSECKEYPrivateKey* priv) { + cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr)); + EXPECT_NE(nullptr, cert->get()); + if (!cert->get()) return false; + + priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr)); + EXPECT_NE(nullptr, priv->get()); + if (!priv->get()) return false; + + return true; +} + bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { - ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr)); - EXPECT_NE(nullptr, cert.get()); - if (!cert.get()) return false; + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(name, &cert, &priv)) { + return false; + } - ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); - EXPECT_NE(nullptr, pub.get()); - if (!pub.get()) return false; if (updateKeyBits) { + ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get())); + EXPECT_NE(nullptr, pub.get()); + if (!pub.get()) return false; server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get()); } - ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr)); - EXPECT_NE(nullptr, priv.get()); - if (!priv.get()) return false; - SECStatus rv = - SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null); + SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null); EXPECT_EQ(SECFailure, rv); - rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData, + rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData, serverCertData ? sizeof(*serverCertData) : 0); return rv == SECSuccess; } @@ -131,41 +149,59 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { // Don't set up twice if (ssl_fd_) return true; - if (adapter_->mode() == STREAM) { - ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_); + ScopedPRFileDesc dummy_fd(adapter_->CreateFD()); + EXPECT_NE(nullptr, dummy_fd); + if (!dummy_fd) { + return false; + } + if (adapter_->variant() == ssl_variant_stream) { + ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get())); } else { - ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_); + ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get())); } EXPECT_NE(nullptr, ssl_fd_); - if (!ssl_fd_) return false; - pr_fd_ = nullptr; + if (!ssl_fd_) { + return false; + } + dummy_fd.release(); // Now subsumed by ssl_fd_. - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; + SECStatus rv; + if (!skip_version_checks_) { + rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + } if (role_ == SERVER) { EXPECT_TRUE(ConfigServerCert(name_, true)); - rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this); + rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; ScopedCERTCertList anchors(CERT_NewCertList()); - rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get()); + rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); if (rv != SECSuccess) return false; } else { - rv = SSL_SetURL(ssl_fd_, "server"); + rv = SSL_SetURL(ssl_fd(), "server"); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; } - rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this); + rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; - rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); + rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this); + EXPECT_EQ(SECSuccess, rv); + if (rv != SECSuccess) return false; + + rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; @@ -177,38 +213,31 @@ void TlsAgent::SetupClientAuth() { ASSERT_EQ(CLIENT, role_); EXPECT_EQ(SECSuccess, - SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook, + SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook, reinterpret_cast<void*>(this))); } -bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert, - SECKEYPrivateKey** priv) const { - *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); - EXPECT_NE(nullptr, *cert); - if (!*cert) return false; - - *priv = PK11_FindKeyByAnyCert(*cert, nullptr); - EXPECT_NE(nullptr, *priv); - if (!*priv) return false; // Leak cert. - - return true; -} - SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, - CERTCertificate** cert, - SECKEYPrivateKey** privKey) { + CERTCertificate** clientCert, + SECKEYPrivateKey** clientKey) { TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; - if (agent->GetClientAuthCredentials(cert, privKey)) { - return SECSuccess; + + ScopedCERTCertificate cert; + ScopedSECKEYPrivateKey priv; + if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { + return SECFailure; } - return SECFailure; + + *clientCert = cert.release(); + *clientKey = priv.release(); + return SECSuccess; } bool TlsAgent::GetPeerChainLength(size_t* count) { - CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_); + CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd()); if (!chain) return false; *count = 0; @@ -224,17 +253,21 @@ bool TlsAgent::GetPeerChainLength(size_t* count) { return true; } +void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { + EXPECT_EQ(csinfo_.cipherSuite, cipher_suite); +} + void TlsAgent::RequestClientAuth(bool requireAuth) { EXPECT_TRUE(EnsureTlsSetup()); ASSERT_EQ(SERVER, role_); EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, + SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE, requireAuth ? PR_TRUE : PR_FALSE)); EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook( - ssl_fd_, &TlsAgent::ClientAuthenticated, this)); + ssl_fd(), &TlsAgent::ClientAuthenticated, this)); expect_client_auth_ = true; } @@ -242,7 +275,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) { EXPECT_TRUE(EnsureTlsSetup(model)); SECStatus rv; - rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); + rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); SetState(STATE_CONNECTING); } @@ -250,7 +283,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) { void TlsAgent::DisableAllCiphers() { for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { SECStatus rv = - SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE); + SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE); EXPECT_EQ(SECSuccess, rv); } } @@ -287,7 +320,7 @@ void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) { EXPECT_EQ(sizeof(csinfo), csinfo.length); if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) { - rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } @@ -325,7 +358,7 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { if ((csinfo.authType == authType) || (csinfo.keaType == ssl_kea_tls13_any)) { - rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE); + rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE); EXPECT_EQ(SECSuccess, rv); } } @@ -333,20 +366,20 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) { void TlsAgent::EnableSingleCipher(uint16_t cipher) { DisableAllCiphers(); - SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE); + SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size()); + SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size()); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetSessionTicketsEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } @@ -354,7 +387,7 @@ void TlsAgent::SetSessionTicketsEnabled(bool en) { void TlsAgent::SetSessionCacheEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); } @@ -362,14 +395,22 @@ void TlsAgent::Set0RttEnabled(bool en) { EXPECT_TRUE(EnsureTlsSetup()); SECStatus rv = - SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsAgent::SetFallbackSCSVEnabled(bool en) { + EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup()); + + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV, + en ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SetShortHeadersEnabled() { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_); + SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd()); EXPECT_EQ(SECSuccess, rv); } @@ -377,8 +418,8 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { vrange_.min = minver; vrange_.max = maxver; - if (ssl_fd_) { - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); + if (ssl_fd()) { + SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_); EXPECT_EQ(SECSuccess, rv); } } @@ -398,32 +439,34 @@ void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; } void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; } +void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; } + void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count) { EXPECT_TRUE(EnsureTlsSetup()); EXPECT_LE(count, SSL_SignatureMaxCount()); EXPECT_EQ(SECSuccess, - SSL_SignatureSchemePrefSet(ssl_fd_, schemes, + SSL_SignatureSchemePrefSet(ssl_fd(), schemes, static_cast<unsigned int>(count))); - EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0)) + EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0)) << "setting no schemes should fail and do nothing"; std::vector<SSLSignatureScheme> configuredSchemes(count); unsigned int configuredCount; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1)) + SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1)) << "get schemes, schemes is nullptr"; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], &configuredCount, 0)) << "get schemes, too little space"; EXPECT_EQ(SECFailure, - SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr, + SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr, configuredSchemes.size())) << "get schemes, countOut is nullptr"; EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet( - ssl_fd_, &configuredSchemes[0], &configuredCount, + ssl_fd(), &configuredSchemes[0], &configuredCount, configuredSchemes.size())); // SignatureSchemePrefSet drops unsupported algorithms silently, so the // number that are configured might be fewer. @@ -524,10 +567,10 @@ void TlsAgent::EnableFalseStart() { EXPECT_TRUE(EnsureTlsSetup()); falsestart_enabled_ = true; + EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback( + ssl_fd(), CanFalseStartCallback, this)); EXPECT_EQ(SECSuccess, - SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE)); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE)); } void TlsAgent::ExpectResumption() { expect_resumption_ = true; } @@ -535,8 +578,8 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); + EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE)); + EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, @@ -544,7 +587,7 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, SSLNextProtoState state; char chosen[10]; unsigned int chosen_len; - SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, + SECStatus rv = SSL_GetNextProto(ssl_fd(), &state, reinterpret_cast<unsigned char*>(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); @@ -562,12 +605,12 @@ void TlsAgent::EnableSrtp() { const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}; EXPECT_EQ(SECSuccess, - SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers))); + SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers))); } void TlsAgent::CheckSrtp() const { uint16_t actual; - EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); + EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual)); EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); } @@ -578,6 +621,55 @@ void TlsAgent::CheckErrorCode(int32_t expected) const { << PORT_ErrorToName(expected) << std::endl; } +static uint8_t GetExpectedAlertLevel(uint8_t alert) { + switch (alert) { + case kTlsAlertCloseNotify: + case kTlsAlertEndOfEarlyData: + return kTlsAlertWarning; + default: + break; + } + return kTlsAlertFatal; +} + +void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) { + expected_received_alert_ = alert; + if (level == 0) { + expected_received_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_received_alert_level_ = level; + } +} + +void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) { + expected_sent_alert_ = alert; + if (level == 0) { + expected_sent_alert_level_ = GetExpectedAlertLevel(alert); + } else { + expected_sent_alert_level_ = level; + } +} + +void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) { + LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal") + << " alert " << (sent ? "sent" : "received") << ": " + << static_cast<int>(alert->description)); + + auto& expected = sent ? expected_sent_alert_ : expected_received_alert_; + auto& expected_level = + sent ? expected_sent_alert_level_ : expected_received_alert_level_; + /* Silently pass close_notify in case the test has already ended. */ + if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning && + alert->description == expected && alert->level == expected_level) { + return; + } + + EXPECT_EQ(expected, alert->description); + EXPECT_EQ(expected_level, alert->level); + expected = kTlsAlertCloseNotify; + expected_level = kTlsAlertWarning; +} + void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { ASSERT_EQ(0, error_code_); WAIT_(error_code_ != 0, delay); @@ -589,7 +681,7 @@ void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const { void TlsAgent::CheckPreliminaryInfo() { SSLPreliminaryChannelInfo info; EXPECT_EQ(SECSuccess, - SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info))); + SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info))); EXPECT_EQ(sizeof(info), info.length); EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); @@ -619,7 +711,7 @@ void TlsAgent::CheckCallbacks() const { // These callbacks shouldn't fire if we are resuming, except on TLS 1.3. if (role_ == SERVER) { - PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn); + PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn); EXPECT_EQ(((!expect_resumption_ && have_sni) || expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3), sni_hook_called_); @@ -639,11 +731,15 @@ void TlsAgent::ResetPreliminaryInfo() { } void TlsAgent::Connected() { + if (state_ == STATE_CONNECTED) { + return; + } + LOG("Handshake success"); CheckPreliminaryInfo(); CheckCallbacks(); - SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); + SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_)); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ(sizeof(info_), info_.length); @@ -658,18 +754,19 @@ void TlsAgent::Connected() { EXPECT_EQ(sizeof(csinfo_), csinfo_.length); if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { - PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_); + PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd()); // We use one ciphersuite in each direction, plus one that's kept around // by DTLS for retransmission. - PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2; + PRInt32 expected = + ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2; EXPECT_EQ(expected, cipherSuites); if (expected != cipherSuites) { - SSLInt_PrintTls13CipherSpecs(ssl_fd_); + SSLInt_PrintTls13CipherSpecs(ssl_fd()); } } PRBool short_headers; - rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers); + rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers); EXPECT_EQ(SECSuccess, rv); EXPECT_EQ((PRBool)expect_short_headers_, short_headers); SetState(STATE_CONNECTED); @@ -679,7 +776,7 @@ void TlsAgent::EnableExtendedMasterSecret() { ASSERT_TRUE(EnsureTlsSetup()); SECStatus rv = - SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); + SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } @@ -701,13 +798,13 @@ void TlsAgent::CheckEarlyDataAccepted(bool expected) { } void TlsAgent::CheckSecretsDestroyed() { - ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_)); + ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } void TlsAgent::DisableRollbackDetection() { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE); ASSERT_EQ(SECSuccess, rv); } @@ -715,23 +812,22 @@ void TlsAgent::DisableRollbackDetection() { void TlsAgent::EnableCompression() { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version); + SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version); ASSERT_EQ(SECSuccess, rv); } void TlsAgent::Handshake() { LOGV("Handshake"); - SECStatus rv = SSL_ForceHandshake(ssl_fd_); + SECStatus rv = SSL_ForceHandshake(ssl_fd()); if (rv == SECSuccess) { Connected(); - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); return; @@ -740,14 +836,14 @@ void TlsAgent::Handshake() { int32_t err = PR_GetError(); if (err == PR_WOULD_BLOCK_ERROR) { LOGV("Would have blocked"); - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { if (timer_handle_) { timer_handle_->Cancel(); timer_handle_ = nullptr; } PRIntervalTime timeout; - rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout); + rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout); if (rv == SECSuccess) { Poller::Instance()->SetTimer( timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_); @@ -773,13 +869,18 @@ void TlsAgent::PrepareForRenegotiate() { void TlsAgent::StartRenegotiate() { PrepareForRenegotiate(); - SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE); + SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::SendDirect(const DataBuffer& buf) { LOG("Send Direct " << buf); - adapter_->peer()->PacketReceived(buf); + auto peer = adapter_->peer().lock(); + if (peer) { + peer->PacketReceived(buf); + } else { + LOG("Send Direct peer absent"); + } } static bool ErrorIsNonFatal(PRErrorCode code) { @@ -806,7 +907,7 @@ void TlsAgent::SendData(size_t bytes, size_t blocksize) { void TlsAgent::SendBuffer(const DataBuffer& buf) { LOGV("Writing " << buf.len() << " bytes"); - int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len()); + int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len()); if (expect_readwrite_error_) { EXPECT_GT(0, rv); EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_); @@ -817,10 +918,10 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { } } -void TlsAgent::ReadBytes() { - uint8_t block[1024]; +void TlsAgent::ReadBytes(size_t amount) { + uint8_t block[16384]; - int32_t rv = PR_Read(ssl_fd_, block, sizeof(block)); + int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); LOGV("ReadBytes " << rv); int32_t err; @@ -853,18 +954,19 @@ void TlsAgent::ResetSentBytes() { send_ctr_ = 0; } void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { EXPECT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE); EXPECT_EQ(SECSuccess, rv); - rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, + rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS, mode & RESUME_TICKET ? PR_TRUE : PR_FALSE); EXPECT_EQ(SECSuccess, rv); } void TlsAgent::DisableECDHEServerKeyReuse() { + ASSERT_TRUE(EnsureTlsSetup()); ASSERT_EQ(TlsAgent::SERVER, role_); - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); + SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); EXPECT_EQ(SECSuccess, rv); } @@ -877,29 +979,25 @@ void TlsAgentTestBase::SetUp() { } void TlsAgentTestBase::TearDown() { - delete agent_; + agent_ = nullptr; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsAgentTestBase::Reset(const std::string& server_name) { - delete agent_; - Init(server_name); -} - -void TlsAgentTestBase::Init(const std::string& server_name) { - agent_ = + agent_.reset( new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name, - role_, mode_); - agent_->Init(); - fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_); - agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_)); + role_, variant_)); + if (version_) { + agent_->SetVersionRange(version_, version_); + } + agent_->adapter()->SetPeer(sink_adapter_); agent_->StartConnect(); } void TlsAgentTestBase::EnsureInit() { if (!agent_) { - Init(); + Reset(); } const std::vector<SSLNamedGroup> groups = { ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, @@ -907,6 +1005,11 @@ void TlsAgentTestBase::EnsureInit() { agent_->ConfigNamedGroups(groups); } +void TlsAgentTestBase::ExpectAlert(uint8_t alert) { + EnsureInit(); + agent_->ExpectSendAlert(alert); +} + void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state, int32_t error_code) { @@ -922,14 +1025,16 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, } } -void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, - const uint8_t* buf, size_t len, - DataBuffer* out, uint64_t seq_num) { +void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, + uint16_t version, const uint8_t* buf, + size_t len, DataBuffer* out, + uint64_t seq_num) { size_t index = 0; index = out->Write(index, type, 1); - index = out->Write( - index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2); - if (mode == DGRAM) { + if (variant == ssl_variant_stream) { + index = out->Write(index, version, 2); + } else { + index = out->Write(index, TlsVersionToDtlsVersion(version), 2); index = out->Write(index, seq_num >> 32, 4); index = out->Write(index, seq_num & PR_UINT32_MAX, 4); } @@ -940,7 +1045,7 @@ void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version, void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, uint64_t seq_num) const { - MakeRecord(mode_, type, version, buf, len, out, seq_num); + MakeRecord(variant_, type, version, buf, len, out, seq_num); } void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type, @@ -959,7 +1064,7 @@ void TlsAgentTestBase::MakeHandshakeMessageFragment( if (!fragment_length) fragment_length = hs_len; index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length - if (mode_ == DGRAM) { + if (variant_ == ssl_variant_datagram) { index = out->Write(index, seq_num, 2); index = out->Write(index, fragment_offset, 3); index = out->Write(index, fragment_length, 3); |