summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc131
1 files changed, 82 insertions, 49 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
index 2f71caedb..9bed1ce1b 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -33,6 +33,7 @@ const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
const std::string TlsAgent::kClient = "client"; // both sign and encrypt
const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
+const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed
const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
const std::string TlsAgent::kServerRsaSign = "rsa_sign";
const std::string TlsAgent::kServerRsaPss = "rsa_pss";
@@ -44,13 +45,22 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
const std::string TlsAgent::kServerDsa = "dsa";
-TlsAgent::TlsAgent(const std::string& name, Role role,
- SSLProtocolVariant variant)
- : name_(name),
- variant_(variant),
- role_(role),
+static const uint8_t kCannedTls13ServerHello[] = {
+ 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3,
+ 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b,
+ 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76,
+ 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24,
+ 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
+ 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
+ 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13};
+
+TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
+ : name_(nm),
+ variant_(var),
+ role_(rl),
server_key_bits_(0),
- adapter_(new DummyPrSocket(role_str(), variant)),
+ adapter_(new DummyPrSocket(role_str(), var)),
ssl_fd_(nullptr),
state_(STATE_INIT),
timer_handle_(nullptr),
@@ -103,11 +113,11 @@ TlsAgent::~TlsAgent() {
}
}
-void TlsAgent::SetState(State state) {
- if (state_ == state) return;
+void TlsAgent::SetState(State s) {
+ if (state_ == s) return;
- LOG("Changing state from " << state_ << " to " << state);
- state_ = state;
+ LOG("Changing state from " << state_ << " to " << s);
+ state_ = s;
}
/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
@@ -124,11 +134,11 @@ void TlsAgent::SetState(State state) {
return true;
}
-bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
+bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits,
const SSLExtraServerCertData* serverCertData) {
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey priv;
- if (!TlsAgent::LoadCertificate(name, &cert, &priv)) {
+ if (!TlsAgent::LoadCertificate(id, &cert, &priv)) {
return false;
}
@@ -175,6 +185,10 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
if (rv != SECSuccess) return false;
}
+ ScopedCERTCertList anchors(CERT_NewCertList());
+ rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
+ if (rv != SECSuccess) return false;
+
if (role_ == SERVER) {
EXPECT_TRUE(ConfigServerCert(name_, true));
@@ -182,10 +196,6 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
- ScopedCERTCertList anchors(CERT_NewCertList());
- rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
- if (rv != SECSuccess) return false;
-
rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
@@ -246,6 +256,17 @@ void TlsAgent::SetupClientAuth() {
reinterpret_cast<void*>(this)));
}
+void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) {
+ ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr));
+
+ ASSERT_EQ(expected->nnames, caNames->nnames);
+
+ for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) {
+ EXPECT_EQ(SECEqual,
+ SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i])));
+ }
+}
+
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
CERTDistNames* caNames,
CERTCertificate** clientCert,
@@ -254,6 +275,9 @@ SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+ // See bug 1457716
+ // CheckCertReqAgainstDefaultCAs(caNames);
+
ScopedCERTCertificate cert;
ScopedSECKEYPrivateKey priv;
if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
@@ -282,8 +306,8 @@ bool TlsAgent::GetPeerChainLength(size_t* count) {
return true;
}
-void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) {
- EXPECT_EQ(csinfo_.cipherSuite, cipher_suite);
+void TlsAgent::CheckCipherSuite(uint16_t suite) {
+ EXPECT_EQ(csinfo_.cipherSuite, suite);
}
void TlsAgent::RequestClientAuth(bool requireAuth) {
@@ -442,9 +466,7 @@ void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
*maxver = vrange_.max;
}
-void TlsAgent::SetExpectedVersion(uint16_t version) {
- expected_version_ = version;
-}
+void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; }
void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
@@ -491,10 +513,10 @@ void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
}
-void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
+void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group,
size_t kea_size) const {
EXPECT_EQ(STATE_CONNECTED, state_);
- EXPECT_EQ(kea_type, info_.keaType);
+ EXPECT_EQ(kea, info_.keaType);
if (kea_size == 0) {
switch (kea_group) {
case ssl_grp_ec_curve25519:
@@ -515,7 +537,7 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
case ssl_grp_ffdhe_custom:
break;
default:
- if (kea_type == ssl_kea_rsa) {
+ if (kea == ssl_kea_rsa) {
kea_size = server_key_bits_;
} else {
EXPECT_TRUE(false) << "need to update group sizes";
@@ -534,13 +556,13 @@ void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const {
}
}
-void TlsAgent::CheckAuthType(SSLAuthType auth_type,
+void TlsAgent::CheckAuthType(SSLAuthType auth,
SSLSignatureScheme sig_scheme) const {
EXPECT_EQ(STATE_CONNECTED, state_);
- EXPECT_EQ(auth_type, info_.authType);
+ EXPECT_EQ(auth, info_.authType);
EXPECT_EQ(server_key_bits_, info_.authKeyBits);
if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
- switch (auth_type) {
+ switch (auth) {
case ssl_auth_rsa_sign:
sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
break;
@@ -558,9 +580,8 @@ void TlsAgent::CheckAuthType(SSLAuthType auth_type,
}
// Check authAlgorithm, which is the old value for authType. This is a second
- // switch
- // statement because default label is different.
- switch (auth_type) {
+ // switch statement because default label is different.
+ switch (auth) {
case ssl_auth_rsa_sign:
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
<< "authAlgorithm for RSA is always decrypt";
@@ -574,7 +595,7 @@ void TlsAgent::CheckAuthType(SSLAuthType auth_type,
<< "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
break;
default:
- EXPECT_EQ(auth_type, csinfo_.authAlgorithm)
+ EXPECT_EQ(auth, csinfo_.authAlgorithm)
<< "authAlgorithm is (usually) the same as authType";
break;
}
@@ -593,22 +614,20 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
-
- SetOption(SSL_ENABLE_ALPN, PR_TRUE);
EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
const std::string& expected) const {
- SSLNextProtoState state;
+ SSLNextProtoState alpn_state;
char chosen[10];
unsigned int chosen_len;
- SECStatus rv = SSL_GetNextProto(ssl_fd(), &state,
+ SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state,
reinterpret_cast<unsigned char*>(chosen),
&chosen_len, sizeof(chosen));
EXPECT_EQ(SECSuccess, rv);
- EXPECT_EQ(expected_state, state);
- if (state == SSL_NEXT_PROTO_NO_SUPPORT) {
+ EXPECT_EQ(expected_state, alpn_state);
+ if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) {
EXPECT_EQ("", expected);
} else {
EXPECT_NE("", expected);
@@ -840,10 +859,10 @@ void TlsAgent::CheckSecretsDestroyed() {
ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
-void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
+void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version);
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver);
ASSERT_EQ(SECSuccess, rv);
}
@@ -920,9 +939,9 @@ static bool ErrorIsNonFatal(PRErrorCode code) {
}
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
- uint8_t block[4096];
+ uint8_t block[16385]; // One larger than the maximum record size.
- ASSERT_LT(blocksize, sizeof(block));
+ ASSERT_LE(blocksize, sizeof(block));
while (bytes) {
size_t tosend = std::min(blocksize, bytes);
@@ -951,12 +970,13 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
}
bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
- uint16_t wireVersion, uint64_t seq,
- uint8_t ct, const DataBuffer& buf) {
- LOGV("Writing " << buf.len() << " bytes");
- // Ensure we are a TLS 1.3 cipher agent.
+ uint64_t seq, uint8_t ct,
+ const DataBuffer& buf) {
+ LOGV("Encrypting " << buf.len() << " bytes");
+ // Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
- TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq);
+ TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType,
+ seq);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
DataBuffer ciphertext;
@@ -1078,15 +1098,20 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out,
- uint64_t seq_num) {
+ uint64_t sequence_number) {
size_t index = 0;
index = out->Write(index, type, 1);
if (variant == ssl_variant_stream) {
index = out->Write(index, version, 2);
+ } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ type == kTlsApplicationDataType) {
+ uint32_t epoch = (sequence_number >> 48) & 0x3;
+ uint32_t seqno = sequence_number & ((1ULL << 30) - 1);
+ index = out->Write(index, (epoch << 30) | seqno, 4);
} else {
index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
- index = out->Write(index, seq_num >> 32, 4);
- index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
+ index = out->Write(index, sequence_number >> 32, 4);
+ index = out->Write(index, sequence_number & PR_UINT32_MAX, 4);
}
index = out->Write(index, len, 2);
out->Write(index, buf, len);
@@ -1144,4 +1169,12 @@ void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
}
}
+DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
+ DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
+ if (variant_ == ssl_variant_datagram) {
+ sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
+ }
+ return sh;
+}
+
} // namespace nss_test