summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_connect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc190
1 files changed, 50 insertions, 140 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
index 0af5123e9..c8de5a1fe 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -5,7 +5,6 @@
* You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "tls_connect.h"
-#include "sslexp.h"
extern "C" {
#include "libssl_internals.h"
}
@@ -89,8 +88,6 @@ std::string VersionString(uint16_t version) {
switch (version) {
case 0:
return "(no version)";
- case SSL_LIBRARY_VERSION_3_0:
- return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_0:
return "1.0";
case SSL_LIBRARY_VERSION_TLS_1_1:
@@ -115,7 +112,6 @@ TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
- expected_resumptions_(0),
session_ids_(),
expect_extended_master_secret_(false),
expect_early_data_accepted_(false),
@@ -165,22 +161,6 @@ void TlsConnectTestBase::CheckShares(
EXPECT_EQ(shares.len(), i);
}
-void TlsConnectTestBase::CheckEpochs(uint16_t client_epoch,
- uint16_t server_epoch) const {
- uint16_t read_epoch = 0;
- uint16_t write_epoch = 0;
-
- EXPECT_EQ(SECSuccess,
- SSLInt_GetEpochs(client_->ssl_fd(), &read_epoch, &write_epoch));
- EXPECT_EQ(server_epoch, read_epoch) << "client read epoch";
- EXPECT_EQ(client_epoch, write_epoch) << "client write epoch";
-
- EXPECT_EQ(SECSuccess,
- SSLInt_GetEpochs(server_->ssl_fd(), &read_epoch, &write_epoch));
- EXPECT_EQ(client_epoch, read_epoch) << "server read epoch";
- EXPECT_EQ(server_epoch, write_epoch) << "server write epoch";
-}
-
void TlsConnectTestBase::ClearStats() {
// Clear statistics.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -198,7 +178,6 @@ void TlsConnectTestBase::SetUp() {
SSLInt_ClearSelfEncryptKey();
SSLInt_SetTicketLifetime(30);
SSLInt_SetMaxEarlyDataSize(1024);
- SSL_SetupAntiReplay(1 * PR_USEC_PER_SEC, 1, 3);
ClearStats();
Init();
}
@@ -240,27 +219,12 @@ void TlsConnectTestBase::Reset(const std::string& server_name,
Init();
}
-void TlsConnectTestBase::MakeNewServer() {
- auto replacement = std::make_shared<TlsAgent>(
- server_->name(), TlsAgent::SERVER, server_->variant());
- server_ = replacement;
- if (version_) {
- server_->SetVersionRange(version_, version_);
- }
- client_->SetPeer(server_);
- server_->SetPeer(client_);
- server_->StartConnect();
-}
-
-void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected,
- uint8_t num_resumptions) {
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
expected_resumption_mode_ = expected;
if (expected != RESUME_NONE) {
client_->ExpectResumption();
server_->ExpectResumption();
- expected_resumptions_ = num_resumptions;
}
- EXPECT_EQ(expected_resumptions_ == 0, expected == RESUME_NONE);
}
void TlsConnectTestBase::EnsureTlsSetup() {
@@ -294,11 +258,6 @@ void TlsConnectTestBase::Connect() {
CheckConnected();
}
-void TlsConnectTestBase::StartConnect() {
- server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
- client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
-}
-
void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
EnsureTlsSetup();
client_->EnableSingleCipher(cipher_suite);
@@ -315,19 +274,6 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
- // Have the client read handshake twice to make sure we get the
- // NST and the ACK.
- if (client_->version() >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- variant_ == ssl_variant_datagram) {
- client_->Handshake();
- client_->Handshake();
- auto suites = SSLInt_CountCipherSpecs(client_->ssl_fd());
- // Verify that we dropped the client's retransmission cipher suites.
- EXPECT_EQ(2, suites) << "Client has the wrong number of suites";
- if (suites != 2) {
- SSLInt_PrintCipherSpecs("client", client_->ssl_fd());
- }
- }
EXPECT_EQ(client_->version(), server_->version());
if (!skip_version_checks_) {
// Check the version is as expected
@@ -368,12 +314,10 @@ void TlsConnectTestBase::CheckConnected() {
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const {
- if (kea_group != ssl_grp_none) {
- client_->CheckKEA(kea_type, kea_group);
- server_->CheckKEA(kea_type, kea_group);
- }
- server_->CheckAuthType(auth_type, sig_scheme);
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
client_->CheckAuthType(auth_type, sig_scheme);
+ server_->CheckAuthType(auth_type, sig_scheme);
}
void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
@@ -428,19 +372,9 @@ void TlsConnectTestBase::CheckKeys() const {
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
}
-void TlsConnectTestBase::CheckKeysResumption(SSLKEAType kea_type,
- SSLNamedGroup kea_group,
- SSLNamedGroup original_kea_group,
- SSLAuthType auth_type,
- SSLSignatureScheme sig_scheme) {
- CheckKeys(kea_type, kea_group, auth_type, sig_scheme);
- EXPECT_TRUE(expected_resumption_mode_ != RESUME_NONE);
- client_->CheckOriginalKEA(original_kea_group);
- server_->CheckOriginalKEA(original_kea_group);
-}
-
void TlsConnectTestBase::ConnectExpectFail() {
- StartConnect();
+ server_->StartConnect();
+ client_->StartConnect();
Handshake();
ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
@@ -461,7 +395,8 @@ void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
}
void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
- StartConnect();
+ server_->StartConnect();
+ client_->StartConnect();
client_->SetServerKeyBits(server_->server_key_bits());
client_->Handshake();
server_->Handshake();
@@ -520,33 +455,29 @@ void TlsConnectTestBase::EnableSomeEcdhCiphers() {
}
}
-void TlsConnectTestBase::ConfigureSelfEncrypt() {
- ScopedCERTCertificate cert;
- ScopedSECKEYPrivateKey privKey;
- ASSERT_TRUE(
- TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert, &privKey));
-
- ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
- ASSERT_TRUE(pubKey);
-
- EXPECT_EQ(SECSuccess,
- SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
-}
-
void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
SessionResumptionMode server) {
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- ConfigureSelfEncrypt();
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
+ &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
}
}
void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
EXPECT_NE(RESUME_BOTH, expected);
- int resume_count = expected ? expected_resumptions_ : 0;
- int stateless_count = (expected & RESUME_TICKET) ? expected_resumptions_ : 0;
+ int resume_count = expected ? 1 : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
// Note: hch == server counter; hsh == client counter.
SSL3Statistics* stats = SSL_GetStatistics();
@@ -559,7 +490,7 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
if (expected != RESUME_NONE) {
if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
// Check that the last two session ids match.
- ASSERT_EQ(1U + expected_resumptions_, session_ids_.size());
+ ASSERT_EQ(2U, session_ids_.size());
EXPECT_EQ(session_ids_[session_ids_.size() - 1],
session_ids_[session_ids_.size() - 2]);
} else {
@@ -609,28 +540,31 @@ void TlsConnectTestBase::CheckSrtp() const {
server_->CheckSrtp();
}
-void TlsConnectTestBase::SendReceive(size_t total) {
- ASSERT_GT(total, client_->received_bytes());
- ASSERT_GT(total, server_->received_bytes());
- client_->SendData(total - server_->received_bytes());
- server_->SendData(total - client_->received_bytes());
- Receive(total); // Receive() is cumulative
+void TlsConnectTestBase::SendReceive() {
+ client_->SendData(50);
+ server_->SendData(50);
+ Receive(50);
}
// Do a first connection so we can do 0-RTT on the second one.
void TlsConnectTestBase::SetupForZeroRtt() {
- // If we don't do this, then all 0-RTT attempts will be rejected.
- SSLInt_RolloverAntiReplay();
-
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
Connect();
SendReceive(); // Need to read so that we absorb the session ticket.
CheckKeys();
Reset();
- StartConnect();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->StartConnect();
+ client_->StartConnect();
}
// Do a first connection so we can do resumption
@@ -650,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ if (expect_writable && expect_readable) {
+ ExpectAlert(client_, kTlsAlertEndOfEarlyData);
+ }
+
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -661,7 +599,7 @@ void TlsConnectTestBase::ZeroRttSendReceive(
} else {
EXPECT_EQ(SECFailure, rv);
}
- server_->Handshake(); // Consume ClientHello
+ server_->Handshake(); // Consume ClientHello, EE, Finished.
std::vector<uint8_t> buf(k0RttDataLen);
rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
@@ -715,30 +653,6 @@ void TlsConnectTestBase::SkipVersionChecks() {
server_->SkipVersionChecks();
}
-// Shift the DTLS timers, to the minimum time necessary to let the next timer
-// run on either client or server. This allows tests to skip waiting without
-// having timers run out of order.
-void TlsConnectTestBase::ShiftDtlsTimers() {
- PRIntervalTime time_shift = PR_INTERVAL_NO_TIMEOUT;
- PRIntervalTime time;
- SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time);
- if (rv == SECSuccess) {
- time_shift = time;
- }
- rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time);
- if (rv == SECSuccess &&
- (time < time_shift || time_shift == PR_INTERVAL_NO_TIMEOUT)) {
- time_shift = time;
- }
-
- if (time_shift == PR_INTERVAL_NO_TIMEOUT) {
- return;
- }
-
- EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(client_->ssl_fd(), time_shift));
- EXPECT_EQ(SECSuccess, SSLInt_ShiftDtlsTimers(server_->ssl_fd(), time_shift));
-}
-
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -777,15 +691,11 @@ void TlsKeyExchangeTest::ConfigNamedGroups(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
- const std::shared_ptr<TlsExtensionCapture>& capture) {
- EXPECT_TRUE(capture->captured());
- const DataBuffer& ext = capture->extension();
-
+ const DataBuffer& ext) {
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
EXPECT_TRUE(ext.len() % 2 == 0);
-
std::vector<SSLNamedGroup> groups;
for (size_t i = 1; i < ext.len() / 2; i += 1) {
EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
@@ -795,14 +705,10 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
}
std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
- const std::shared_ptr<TlsExtensionCapture>& capture) {
- EXPECT_TRUE(capture->captured());
- const DataBuffer& ext = capture->extension();
-
+ const DataBuffer& ext) {
uint32_t tmp = 0;
EXPECT_TRUE(ext.Read(0, 2, &tmp));
EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
-
std::vector<SSLNamedGroup> shares;
size_t i = 2;
while (i < ext.len()) {
@@ -818,15 +724,17 @@ std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
void TlsKeyExchangeTest::CheckKEXDetails(
const std::vector<SSLNamedGroup>& expected_groups,
const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
- std::vector<SSLNamedGroup> groups = GetGroupDetails(groups_capture_);
+ std::vector<SSLNamedGroup> groups =
+ GetGroupDetails(groups_capture_->extension());
EXPECT_EQ(expected_groups, groups);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
ASSERT_LT(0U, expected_shares.size());
- std::vector<SSLNamedGroup> shares = GetShareDetails(shares_capture_);
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture_->extension());
EXPECT_EQ(expected_shares, shares);
} else {
- EXPECT_FALSE(shares_capture_->captured());
+ EXPECT_EQ(0U, shares_capture_->extension().len());
}
EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
@@ -848,6 +756,8 @@ void TlsKeyExchangeTest::CheckKEXDetails(
EXPECT_NE(expected_share2, it);
}
std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
- EXPECT_EQ(expected_shares2, GetShareDetails(shares_capture2_));
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture2_->extension());
+ EXPECT_EQ(expected_shares2, shares);
}
} // namespace nss_test