summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_connect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc190
1 files changed, 140 insertions, 50 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
index c8de5a1fe..0af5123e9 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -5,6 +5,7 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "tls_connect.h"
+#include "sslexp.h"
extern "C" {
#include "libssl_internals.h"
}
@@ -88,6 +89,8 @@ std::string VersionString(uint16_t version) {
switch (version) {
case 0:
return "(no version)";
+ case SSL_LIBRARY_VERSION_3_0:
+ return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_0:
return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_1:
@@ -112,6 +115,7 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
+ expected_resumptions_(0),
session_ids_(),
expect_extended_master_secret_(false),
expect_early_data_accepted_(false),
@@ -161,6 +165,22 @@ void TlsConnectTestBase::CheckShares(
EXPECT_EQ(shares.len(), i);
}
+void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
+ uint16_t server_epoch) const {
+ uint16_t read_epoch = 0;
+ uint16_t write_epoch = 0;
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(server_epoch, read_epoch) << "client read epoch";
+ EXPECT_EQ(client_epoch, write_epoch) << "client write epoch";
+
+ EXPECT_EQ(SECSuccess,
+ SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch));
+ EXPECT_EQ(client_epoch, read_epoch) << "server read epoch";
+ EXPECT_EQ(server_epoch, write_epoch) << "server write epoch";
+}
+
void TlsConnectTestBase::ClearStats() {
// Clear statistics.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -178,6 +198,7 @@ void TlsConnectTestBase::SetUp() {
SSLInt_ClearSelfEncryptKey();
SSLInt_SetTicketLifetime(30);
SSLInt_SetMaxEarlyDataSize(1024);
+ SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3);
ClearStats();
Init();
}
@@ -219,12 +240,27 @@ void TlsConnectTestBase::Reset(const std::string& server_name,
Init();
}
-void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
+void TlsConnectTestBase::MakeNewServer() {
+ auto replacement = std::make_shared<TlsAgent>(
+ server_->name(), TlsAgent::SERVER, server_->variant());
+ server_ = replacement;
+ if (version_) {
+ server_->SetVersionRange(version_, version_);
+ }
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->StartConnect();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
+ uint8_t num_resumptions) {
expected_resumption_mode_ = expected;
if (expected != RESUME_NONE) {
client_->ExpectResumption();
server_->ExpectResumption();
+ expected_resumptions_ = num_resumptions;
}
+ EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
}
void TlsConnectTestBase::EnsureTlsSetup() {
@@ -258,6 +294,11 @@ void TlsConnectTestBase::Connect() {
CheckConnected();
}
+void TlsConnectTestBase::StartConnect() {
+ server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
+ client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+}
+
void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
EnsureTlsSetup();
client_->EnableSingleCipher(cipher_suite);
@@ -274,6 +315,19 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
+ // Have the client read handshake twice to make sure we get the
+ // NST and the ACK.
+ if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ variant_ == ssl_variant_datagram) {
+ client_->Handshake();
+ client_->Handshake();
+ auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
+ // Verify that we dropped the client's retransmission cipher suites.
+ EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
+ if (suites != 2) {
+ SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
+ }
+ }
EXPECT_EQ(client_->version(), server_->version());
if (!skip_version_checks_) {
// Check the version is as expected
@@ -314,10 +368,12 @@ void TlsConnectTestBase::CheckConnected() {
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
- client_->CheckKEA(kea_type, kea_group);
- server_->CheckKEA(kea_type, kea_group);
- client_->CheckAuthType(auth_type, sig_scheme);
+ if (kea_group != ssl_grp_none) {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ }
server_->CheckAuthType(auth_type, sig_scheme);
+ client_->CheckAuthType(auth_type, sig_scheme);
}
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
@@ -372,9 +428,19 @@ void TlsConnectTestBase::CheckKeys() const {
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}
+void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
+ SSLNamedGroup kea_group,
+ SSLNamedGroup original_kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) {
+ CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
+ EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
+ client_->CheckOriginalKEA(original_kea_group);
+ server_->CheckOriginalKEA(original_kea_group);
+}
+
void TlsConnectTestBase::ConnectExpectFail() {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
Handshake();
ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
@@ -395,8 +461,7 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
}
void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
client_->SetServerKeyBits(server_->server_key_bits());
client_->Handshake();
server_->Handshake();
@@ -455,29 +520,33 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() {
}
}
+void TlsConnectTestBase::ConfigureSelfEncrypt() {
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(
+ TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+}
+
void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server) {
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- ScopedCERTCertificate cert;
- ScopedSECKEYPrivateKey privKey;
- ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
- &privKey));
-
- ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
- ASSERT_TRUE(pubKey);
-
- EXPECT_EQ(SECSuccess,
- SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
+ ConfigureSelfEncrypt();
}
}
void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
EXPECT_NE(RESUME_BOTH, expected);
- int resume_count = expected ? 1 : 0;
- int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
+ int resume_count = expected ? expected_resumptions_ : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
// Note: hch == server counter; hsh == client counter.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -490,7 +559,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
if (expected != RESUME_NONE) {
if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
// Check that the last two session ids match.
- ASSERT_EQ(2U, session_ids_.size());
+ ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
EXPECT_EQ(session_ids_[session_ids_.size() - 1],
session_ids_[session_ids_.size() - 2]);
} else {
@@ -540,31 +609,28 @@ void TlsConnectTestBase::CheckSrtp() const {
server_->CheckSrtp();
}
-void TlsConnectTestBase::SendReceive() {
- client_->SendData(50);
- server_->SendData(50);
- Receive(50);
+void TlsConnectTestBase::SendReceive(size_t total) {
+ ASSERT_GT(total, client_->received_bytes());
+ ASSERT_GT(total, server_->received_bytes());
+ client_->SendData(total - server_->received_bytes());
+ server_->SendData(total - client_->received_bytes());
+ Receive(total); // Receive() is cumulative
}
// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
+ // If we don't do this, then all 0-RTT attempts will be rejected.
+ SSLInt_RolloverAntiReplay();
+
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys();
Reset();
- client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
- SSL_LIBRARY_VERSION_TLS_1_3);
- server_->StartConnect();
- client_->StartConnect();
+ StartConnect();
}
// Do a first connection so we can do resumption
@@ -584,10 +650,6 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
- if (expect_writable && expect_readable) {
- ExpectAlert(client_, kTlsAlertEndOfEarlyData);
- }
-
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -599,7 +661,7 @@ void TlsConnectTestBase::ZeroRttSendReceive(
} else {
EXPECT_EQ(SECFailure, rv);
}
- server_->Handshake(); // Consume ClientHello, EE, Finished.
+ server_->Handshake(); // Consume ClientHello
std::vector<uint8_t> buf(k0RttDataLen);
rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
@@ -653,6 +715,30 @@ void TlsConnectTestBase::SkipVersionChecks() {
server_->SkipVersionChecks();
}
+// Shift the DTLS timers, to the minimum time necessary to let the next timer
+// run on either client or server. This allows tests to skip waiting without
+// having timers run out of order.
+void TlsConnectTestBase::ShiftDtlsTimers() {
+ PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
+ PRIntervalTime time;
+ SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
+ if (rv == SECSuccess) {
+ time_shift = time;
+ }
+ rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
+ if (rv == SECSuccess &&
+ (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
+ time_shift = time;
+ }
+
+ if (time_shift == PR_INTERVAL_NO_TIMEOUT) {
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
+ EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
+}
+
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -691,11 +777,15 @@ void TlsKeyExchangeTest::ConfigNamedGroups(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
EXPECT_TRUE(ext.len() % 2 == 0);
+
std::vector<SSLNamedGroup> groups;
for (size_t i = 1; i < ext.len() / 2; i += 1) {
EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
@@ -705,10 +795,14 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
- const DataBuffer& ext) {
+ const std::shared_ptr<TlsExtensionCapture>& capture) {
+ EXPECT_TRUE(capture->captured());
+ const DataBuffer& ext = capture->extension();
+
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+
std::vector<SSLNamedGroup> shares;
size_t i = 2;
while (i < ext.len()) {
@@ -724,17 +818,15 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
void TlsKeyExchangeTest::CheckKEXDetails(
const std::vector<SSLNamedGroup>& expected_groups,
const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
- std::vector<SSLNamedGroup> groups =
- GetGroupDetails(groups_capture_->extension());
+ std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
EXPECT_EQ(expected_groups, groups);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
ASSERT_LT(0U, expected_shares.size());
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture_->extension());
+ std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
EXPECT_EQ(expected_shares, shares);
} else {
- EXPECT_EQ(0U, shares_capture_->extension().len());
+ EXPECT_FALSE(shares_capture_->captured());
}
EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
@@ -756,8 +848,6 @@ void TlsKeyExchangeTest::CheckKEXDetails(
EXPECT_NE(expected_share2, it);
}
std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
- std::vector<SSLNamedGroup> shares =
- GetShareDetails(shares_capture2_->extension());
- EXPECT_EQ(expected_shares2, shares);
+ EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
}
} // namespace nss_test