summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc375
1 files changed, 240 insertions, 135 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
index b75bba567..d6d91f7f7 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -43,14 +43,14 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
const std::string TlsAgent::kServerDsa = "dsa";
-TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
+TlsAgent::TlsAgent(const std::string& name, Role role,
+ SSLProtocolVariant variant)
: name_(name),
- mode_(mode),
+ variant_(variant),
+ role_(role),
server_key_bits_(0),
- pr_fd_(nullptr),
- adapter_(nullptr),
+ adapter_(new DummyPrSocket(role_str(), variant)),
ssl_fd_(nullptr),
- role_(role),
state_(STATE_INIT),
timer_handle_(nullptr),
falsestart_enabled_(false),
@@ -61,6 +61,10 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
can_falsestart_hook_called_(false),
sni_hook_called_(false),
auth_certificate_hook_called_(false),
+ expected_received_alert_(kTlsAlertCloseNotify),
+ expected_received_alert_level_(kTlsAlertWarning),
+ expected_sent_alert_(kTlsAlertCloseNotify),
+ expected_sent_alert_level_(kTlsAlertWarning),
handshake_callback_called_(false),
error_code_(0),
send_ctr_(0),
@@ -69,29 +73,31 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
handshake_callback_(),
auth_certificate_callback_(),
sni_callback_(),
- expect_short_headers_(false) {
+ expect_short_headers_(false),
+ skip_version_checks_(false) {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
- SECStatus rv = SSL_VersionRangeGetDefault(
- mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_);
+ SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
TlsAgent::~TlsAgent() {
- if (adapter_) {
- Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
- // The adapter is closed when the FD closes.
- }
if (timer_handle_) {
timer_handle_->Cancel();
}
- if (pr_fd_) {
- PR_Close(pr_fd_);
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
}
- if (ssl_fd_) {
- PR_Close(ssl_fd_);
+ // Add failures manually, if any, so we don't throw in a destructor.
+ if (expected_received_alert_ != kTlsAlertCloseNotify ||
+ expected_received_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_received_alert status";
+ }
+ if (expected_sent_alert_ != kTlsAlertCloseNotify ||
+ expected_sent_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_sent_alert status";
}
}
@@ -102,27 +108,39 @@ void TlsAgent::SetState(State state) {
state_ = state;
}
+/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv) {
+ cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert->get());
+ if (!cert->get()) return false;
+
+ priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
+ EXPECT_NE(nullptr, priv->get());
+ if (!priv->get()) return false;
+
+ return true;
+}
+
bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
const SSLExtraServerCertData* serverCertData) {
- ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr));
- EXPECT_NE(nullptr, cert.get());
- if (!cert.get()) return false;
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(name, &cert, &priv)) {
+ return false;
+ }
- ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
- EXPECT_NE(nullptr, pub.get());
- if (!pub.get()) return false;
if (updateKeyBits) {
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
}
- ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr));
- EXPECT_NE(nullptr, priv.get());
- if (!priv.get()) return false;
-
SECStatus rv =
- SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null);
+ SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
EXPECT_EQ(SECFailure, rv);
- rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData,
+ rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
serverCertData ? sizeof(*serverCertData) : 0);
return rv == SECSuccess;
}
@@ -131,41 +149,59 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
// Don't set up twice
if (ssl_fd_) return true;
- if (adapter_->mode() == STREAM) {
- ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
+ ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
+ EXPECT_NE(nullptr, dummy_fd);
+ if (!dummy_fd) {
+ return false;
+ }
+ if (adapter_->variant() == ssl_variant_stream) {
+ ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
} else {
- ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
+ ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
}
EXPECT_NE(nullptr, ssl_fd_);
- if (!ssl_fd_) return false;
- pr_fd_ = nullptr;
+ if (!ssl_fd_) {
+ return false;
+ }
+ dummy_fd.release(); // Now subsumed by ssl_fd_.
- SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
- EXPECT_EQ(SECSuccess, rv);
- if (rv != SECSuccess) return false;
+ SECStatus rv;
+ if (!skip_version_checks_) {
+ rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
if (role_ == SERVER) {
EXPECT_TRUE(ConfigServerCert(name_, true));
- rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
+ rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
ScopedCERTCertList anchors(CERT_NewCertList());
- rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get());
+ rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
if (rv != SECSuccess) return false;
} else {
- rv = SSL_SetURL(ssl_fd_, "server");
+ rv = SSL_SetURL(ssl_fd(), "server");
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
}
- rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
+ rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
- rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
+ rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
@@ -177,38 +213,31 @@ void TlsAgent::SetupClientAuth() {
ASSERT_EQ(CLIENT, role_);
EXPECT_EQ(SECSuccess,
- SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
+ SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
reinterpret_cast<void*>(this)));
}
-bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert,
- SECKEYPrivateKey** priv) const {
- *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
- EXPECT_NE(nullptr, *cert);
- if (!*cert) return false;
-
- *priv = PK11_FindKeyByAnyCert(*cert, nullptr);
- EXPECT_NE(nullptr, *priv);
- if (!*priv) return false; // Leak cert.
-
- return true;
-}
-
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
CERTDistNames* caNames,
- CERTCertificate** cert,
- SECKEYPrivateKey** privKey) {
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
- if (agent->GetClientAuthCredentials(cert, privKey)) {
- return SECSuccess;
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
+ return SECFailure;
}
- return SECFailure;
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
}
bool TlsAgent::GetPeerChainLength(size_t* count) {
- CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_);
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
if (!chain) return false;
*count = 0;
@@ -224,17 +253,21 @@ bool TlsAgent::GetPeerChainLength(size_t* count) {
return true;
}
+void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) {
+ EXPECT_EQ(csinfo_.cipherSuite, cipher_suite);
+}
+
void TlsAgent::RequestClientAuth(bool requireAuth) {
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(SERVER, role_);
EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
+ SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE,
requireAuth ? PR_TRUE : PR_FALSE));
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
- ssl_fd_, &TlsAgent::ClientAuthenticated, this));
+ ssl_fd(), &TlsAgent::ClientAuthenticated, this));
expect_client_auth_ = true;
}
@@ -242,7 +275,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) {
EXPECT_TRUE(EnsureTlsSetup(model));
SECStatus rv;
- rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
+ rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
SetState(STATE_CONNECTING);
}
@@ -250,7 +283,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) {
void TlsAgent::DisableAllCiphers() {
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
SECStatus rv =
- SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
+ SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -287,7 +320,7 @@ void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
EXPECT_EQ(sizeof(csinfo), csinfo.length);
if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
- rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -325,7 +358,7 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
if ((csinfo.authType == authType) ||
(csinfo.keaType == ssl_kea_tls13_any)) {
- rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -333,20 +366,20 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
void TlsAgent::EnableSingleCipher(uint16_t cipher) {
DisableAllCiphers();
- SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE);
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetSessionTicketsEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
en ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -354,7 +387,7 @@ void TlsAgent::SetSessionTicketsEnabled(bool en) {
void TlsAgent::SetSessionCacheEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -362,14 +395,22 @@ void TlsAgent::Set0RttEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv =
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetFallbackSCSVEnabled(bool en) {
+ EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV,
+ en ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetShortHeadersEnabled() {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_);
+ SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd());
EXPECT_EQ(SECSuccess, rv);
}
@@ -377,8 +418,8 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
vrange_.min = minver;
vrange_.max = maxver;
- if (ssl_fd_) {
- SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ if (ssl_fd()) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -398,32 +439,34 @@ void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
+void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
+
void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
size_t count) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_LE(count, SSL_SignatureMaxCount());
EXPECT_EQ(SECSuccess,
- SSL_SignatureSchemePrefSet(ssl_fd_, schemes,
+ SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
static_cast<unsigned int>(count)));
- EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0))
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
<< "setting no schemes should fail and do nothing";
std::vector<SSLSignatureScheme> configuredSchemes(count);
unsigned int configuredCount;
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
+ SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
<< "get schemes, schemes is nullptr";
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0],
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
&configuredCount, 0))
<< "get schemes, too little space";
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
configuredSchemes.size()))
<< "get schemes, countOut is nullptr";
EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
- ssl_fd_, &configuredSchemes[0], &configuredCount,
+ ssl_fd(), &configuredSchemes[0], &configuredCount,
configuredSchemes.size()));
// SignatureSchemePrefSet drops unsupported algorithms silently, so the
// number that are configured might be fewer.
@@ -524,10 +567,10 @@ void TlsAgent::EnableFalseStart() {
EXPECT_TRUE(EnsureTlsSetup());
falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
+ ssl_fd(), CanFalseStartCallback, this));
EXPECT_EQ(SECSuccess,
- SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE));
}
void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
@@ -535,8 +578,8 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
@@ -544,7 +587,7 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
SSLNextProtoState state;
char chosen[10];
unsigned int chosen_len;
- SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
+ SECStatus rv = SSL_GetNextProto(ssl_fd(), &state,
reinterpret_cast<unsigned char*>(chosen),
&chosen_len, sizeof(chosen));
EXPECT_EQ(SECSuccess, rv);
@@ -562,12 +605,12 @@ void TlsAgent::EnableSrtp() {
const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
SRTP_AES128_CM_HMAC_SHA1_32};
EXPECT_EQ(SECSuccess,
- SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers)));
+ SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
}
void TlsAgent::CheckSrtp() const {
uint16_t actual;
- EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}
@@ -578,6 +621,55 @@ void TlsAgent::CheckErrorCode(int32_t expected) const {
<< PORT_ErrorToName(expected) << std::endl;
}
+static uint8_t GetExpectedAlertLevel(uint8_t alert) {
+ switch (alert) {
+ case kTlsAlertCloseNotify:
+ case kTlsAlertEndOfEarlyData:
+ return kTlsAlertWarning;
+ default:
+ break;
+ }
+ return kTlsAlertFatal;
+}
+
+void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
+ expected_received_alert_ = alert;
+ if (level == 0) {
+ expected_received_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_received_alert_level_ = level;
+ }
+}
+
+void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
+ expected_sent_alert_ = alert;
+ if (level == 0) {
+ expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_sent_alert_level_ = level;
+ }
+}
+
+void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
+ LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
+ << " alert " << (sent ? "sent" : "received") << ": "
+ << static_cast<int>(alert->description));
+
+ auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
+ auto& expected_level =
+ sent ? expected_sent_alert_level_ : expected_received_alert_level_;
+ /* Silently pass close_notify in case the test has already ended. */
+ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
+ alert->description == expected && alert->level == expected_level) {
+ return;
+ }
+
+ EXPECT_EQ(expected, alert->description);
+ EXPECT_EQ(expected_level, alert->level);
+ expected = kTlsAlertCloseNotify;
+ expected_level = kTlsAlertWarning;
+}
+
void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
ASSERT_EQ(0, error_code_);
WAIT_(error_code_ != 0, delay);
@@ -589,7 +681,7 @@ void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
void TlsAgent::CheckPreliminaryInfo() {
SSLPreliminaryChannelInfo info;
EXPECT_EQ(SECSuccess,
- SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info)));
EXPECT_EQ(sizeof(info), info.length);
EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
@@ -619,7 +711,7 @@ void TlsAgent::CheckCallbacks() const {
// These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
if (role_ == SERVER) {
- PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn);
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
EXPECT_EQ(((!expect_resumption_ && have_sni) ||
expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
sni_hook_called_);
@@ -639,11 +731,15 @@ void TlsAgent::ResetPreliminaryInfo() {
}
void TlsAgent::Connected() {
+ if (state_ == STATE_CONNECTED) {
+ return;
+ }
+
LOG("Handshake success");
CheckPreliminaryInfo();
CheckCallbacks();
- SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
@@ -658,18 +754,19 @@ void TlsAgent::Connected() {
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_);
+ PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd());
// We use one ciphersuite in each direction, plus one that's kept around
// by DTLS for retransmission.
- PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2;
+ PRInt32 expected =
+ ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2;
EXPECT_EQ(expected, cipherSuites);
if (expected != cipherSuites) {
- SSLInt_PrintTls13CipherSpecs(ssl_fd_);
+ SSLInt_PrintTls13CipherSpecs(ssl_fd());
}
}
PRBool short_headers;
- rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers);
+ rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers);
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
SetState(STATE_CONNECTED);
@@ -679,7 +776,7 @@ void TlsAgent::EnableExtendedMasterSecret() {
ASSERT_TRUE(EnsureTlsSetup());
SECStatus rv =
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
}
@@ -701,13 +798,13 @@ void TlsAgent::CheckEarlyDataAccepted(bool expected) {
}
void TlsAgent::CheckSecretsDestroyed() {
- ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_));
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
void TlsAgent::DisableRollbackDetection() {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
}
@@ -715,23 +812,22 @@ void TlsAgent::DisableRollbackDetection() {
void TlsAgent::EnableCompression() {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
}
void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version);
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version);
ASSERT_EQ(SECSuccess, rv);
}
void TlsAgent::Handshake() {
LOGV("Handshake");
- SECStatus rv = SSL_ForceHandshake(ssl_fd_);
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
if (rv == SECSuccess) {
Connected();
-
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
return;
@@ -740,14 +836,14 @@ void TlsAgent::Handshake() {
int32_t err = PR_GetError();
if (err == PR_WOULD_BLOCK_ERROR) {
LOGV("Would have blocked");
- if (mode_ == DGRAM) {
+ if (variant_ == ssl_variant_datagram) {
if (timer_handle_) {
timer_handle_->Cancel();
timer_handle_ = nullptr;
}
PRIntervalTime timeout;
- rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
+ rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
if (rv == SECSuccess) {
Poller::Instance()->SetTimer(
timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
@@ -773,13 +869,18 @@ void TlsAgent::PrepareForRenegotiate() {
void TlsAgent::StartRenegotiate() {
PrepareForRenegotiate();
- SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
+ SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SendDirect(const DataBuffer& buf) {
LOG("Send Direct " << buf);
- adapter_->peer()->PacketReceived(buf);
+ auto peer = adapter_->peer().lock();
+ if (peer) {
+ peer->PacketReceived(buf);
+ } else {
+ LOG("Send Direct peer absent");
+ }
}
static bool ErrorIsNonFatal(PRErrorCode code) {
@@ -806,7 +907,7 @@ void TlsAgent::SendData(size_t bytes, size_t blocksize) {
void TlsAgent::SendBuffer(const DataBuffer& buf) {
LOGV("Writing " << buf.len() << " bytes");
- int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len());
+ int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
if (expect_readwrite_error_) {
EXPECT_GT(0, rv);
EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
@@ -817,10 +918,10 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
}
}
-void TlsAgent::ReadBytes() {
- uint8_t block[1024];
+void TlsAgent::ReadBytes(size_t amount) {
+ uint8_t block[16384];
- int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+ int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
LOGV("ReadBytes " << rv);
int32_t err;
@@ -853,18 +954,19 @@ void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE,
mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
- rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::DisableECDHEServerKeyReuse() {
+ ASSERT_TRUE(EnsureTlsSetup());
ASSERT_EQ(TlsAgent::SERVER, role_);
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -877,29 +979,25 @@ void TlsAgentTestBase::SetUp() {
}
void TlsAgentTestBase::TearDown() {
- delete agent_;
+ agent_ = nullptr;
SSL_ClearSessionCache();
SSL_ShutdownServerSessionIDCache();
}
void TlsAgentTestBase::Reset(const std::string& server_name) {
- delete agent_;
- Init(server_name);
-}
-
-void TlsAgentTestBase::Init(const std::string& server_name) {
- agent_ =
+ agent_.reset(
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
- role_, mode_);
- agent_->Init();
- fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
- agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
+ role_, variant_));
+ if (version_) {
+ agent_->SetVersionRange(version_, version_);
+ }
+ agent_->adapter()->SetPeer(sink_adapter_);
agent_->StartConnect();
}
void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
- Init();
+ Reset();
}
const std::vector<SSLNamedGroup> groups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
@@ -907,6 +1005,11 @@ void TlsAgentTestBase::EnsureInit() {
agent_->ConfigNamedGroups(groups);
}
+void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
+ EnsureInit();
+ agent_->ExpectSendAlert(alert);
+}
+
void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
TlsAgent::State expected_state,
int32_t error_code) {
@@ -922,14 +1025,16 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
}
}
-void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
- const uint8_t* buf, size_t len,
- DataBuffer* out, uint64_t seq_num) {
+void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out,
+ uint64_t seq_num) {
size_t index = 0;
index = out->Write(index, type, 1);
- index = out->Write(
- index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2);
- if (mode == DGRAM) {
+ if (variant == ssl_variant_stream) {
+ index = out->Write(index, version, 2);
+ } else {
+ index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
index = out->Write(index, seq_num >> 32, 4);
index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
}
@@ -940,7 +1045,7 @@ void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
const uint8_t* buf, size_t len,
DataBuffer* out, uint64_t seq_num) const {
- MakeRecord(mode_, type, version, buf, len, out, seq_num);
+ MakeRecord(variant_, type, version, buf, len, out, seq_num);
}
void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
@@ -959,7 +1064,7 @@ void TlsAgentTestBase::MakeHandshakeMessageFragment(
if (!fragment_length) fragment_length = hs_len;
index = out->Write(index, hs_type, 1); // Handshake record type.
index = out->Write(index, hs_len, 3); // Handshake length
- if (mode_ == DGRAM) {
+ if (variant_ == ssl_variant_datagram) {
index = out->Write(index, seq_num, 2);
index = out->Write(index, fragment_offset, 3);
index = out->Write(index, fragment_length, 3);