summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc227
1 files changed, 154 insertions, 73 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
index 25ad606fc..b2917274b 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -45,40 +45,65 @@ void TlsVersioned::WriteStream(std::ostream& stream) const {
}
}
+TlsRecordFilter::TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
+ : agent_(a) {
+ cipher_specs_.emplace_back(a->variant() == ssl_variant_datagram, 0);
+}
+
void TlsRecordFilter::EnableDecryption() {
- SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged,
- (void*)this);
+ EXPECT_EQ(SECSuccess,
+ SSL_SecretCallback(agent()->ssl_fd(), SecretCallback, this));
+ decrypting_ = true;
}
-void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
- ssl3CipherSpec* newSpec) {
+void TlsRecordFilter::SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg) {
TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg);
- PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
-
if (g_ssl_gtest_verbose) {
- std::cerr << (isServer ? "server" : "client") << ": "
- << (sending ? "send" : "receive")
- << " cipher spec changed: " << newSpec->epoch << " ("
- << newSpec->phase << ")" << std::endl;
+ std::cerr << self->agent()->role_str() << ": " << dir
+ << " secret changed for epoch " << epoch << std::endl;
}
- if (!sending) {
+
+ if (dir == ssl_secret_read) {
return;
}
- uint64_t seq_no;
- if (self->agent()->variant() == ssl_variant_datagram) {
- seq_no = static_cast<uint64_t>(SSLInt_CipherSpecToEpoch(newSpec)) << 48;
+ for (auto& spec : self->cipher_specs_) {
+ ASSERT_NE(spec.epoch(), epoch) << "duplicate spec for epoch " << epoch;
+ }
+
+ SSLPreliminaryChannelInfo preinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(self->agent()->ssl_fd(), &preinfo,
+ sizeof(preinfo)));
+ EXPECT_EQ(sizeof(preinfo), preinfo.length);
+
+ // Check the version.
+ if (preinfo.valuesSet & ssl_preinfo_version) {
+ EXPECT_EQ(SSL_LIBRARY_VERSION_TLS_1_3, preinfo.protocolVersion);
+ } else {
+ EXPECT_EQ(1U, epoch);
+ }
+
+ uint16_t suite;
+ if (epoch == 1) {
+ // 0-RTT
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_0rtt_cipher_suite);
+ suite = preinfo.zeroRttCipherSuite;
} else {
- seq_no = 0;
+ EXPECT_TRUE(preinfo.valuesSet & ssl_preinfo_cipher_suite);
+ suite = preinfo.cipherSuite;
}
- self->in_sequence_number_ = seq_no;
- self->out_sequence_number_ = seq_no;
- self->dropped_record_ = false;
- self->cipher_spec_.reset(new TlsCipherSpec());
- bool ret = self->cipher_spec_->Init(
- SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec),
- SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec));
- EXPECT_EQ(true, ret);
+
+ SSLCipherSuiteInfo cipherinfo;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetCipherSuiteInfo(suite, &cipherinfo, sizeof(cipherinfo)));
+ EXPECT_EQ(sizeof(cipherinfo), cipherinfo.length);
+
+ bool is_dtls = self->agent()->variant() == ssl_variant_datagram;
+ self->cipher_specs_.emplace_back(is_dtls, epoch);
+ EXPECT_TRUE(self->cipher_specs_.back().SetKeys(&cipherinfo, secret));
}
bool TlsRecordFilter::is_dtls13() const {
@@ -95,6 +120,23 @@ bool TlsRecordFilter::is_dtls13() const {
info.canSendEarlyData;
}
+// Gets the cipher spec that matches the specified epoch.
+TlsCipherSpec& TlsRecordFilter::spec(uint16_t write_epoch) {
+ for (auto& sp : cipher_specs_) {
+ if (sp.epoch() == write_epoch) {
+ return sp;
+ }
+ }
+
+ // If we aren't decrypting, provide a cipher spec that does nothing other than
+ // count sequence numbers.
+ EXPECT_FALSE(decrypting_) << "No spec available for epoch " << write_epoch;
+ ;
+ bool is_dtls = agent()->variant() == ssl_variant_datagram;
+ cipher_specs_.emplace_back(is_dtls, write_epoch);
+ return cipher_specs_.back();
+}
+
PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
// Disable during shutdown.
@@ -108,34 +150,28 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
output->Allocate(input.len());
TlsParser parser(input);
+ // This uses the current write spec for the purposes of parsing the epoch and
+ // sequence number from the header. This might be wrong because we can
+ // receive records from older specs, but guessing is good enough:
+ // - In DTLS, parsing the sequence number corrects any errors.
+ // - In TLS, we don't use the sequence number unless decrypting, where we use
+ // trial decryption to get the right epoch.
+ uint16_t write_epoch = 0;
+ SECStatus rv = SSL_GetCurrentEpoch(agent()->ssl_fd(), nullptr, &write_epoch);
+ if (rv != SECSuccess) {
+ ADD_FAILURE() << "unable to read epoch";
+ return KEEP;
+ }
+ uint64_t guess_seqno = static_cast<uint64_t>(write_epoch) << 48;
+
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
-
- if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) {
+ if (!header.Parse(is_dtls13(), guess_seqno, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
- // Track the sequence number, which is necessary for stream mode when
- // decrypting and for TLS 1.3 datagram to recover the sequence number.
- //
- // We reset the counter when the cipher spec changes, but that notification
- // appears before a record is sent. If multiple records are sent with
- // different cipher specs, this would fail. This filters out cleartext
- // records, so we don't get confused by handshake messages that are sent at
- // the same time as encrypted records. Sequence numbers are therefore
- // likely to be incorrect for cleartext records.
- //
- // This isn't perfectly robust: if there is a change from an active cipher
- // spec to another active cipher spec (KeyUpdate for instance) AND writes
- // are consolidated across that change, this code could use the wrong
- // sequence numbers when re-encrypting records with the old keys.
- if (header.content_type() == ssl_ct_application_data) {
- in_sequence_number_ =
- (std::max)(in_sequence_number_, header.sequence_number() + 1);
- }
-
if (FilterRecord(header, record, &offset, output) != KEEP) {
changed = true;
} else {
@@ -159,14 +195,16 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
DataBuffer filtered;
uint8_t inner_content_type;
DataBuffer plaintext;
+ uint16_t protection_epoch = 0;
- if (!Unprotect(header, record, &inner_content_type, &plaintext)) {
- if (g_ssl_gtest_verbose) {
- std::cerr << "unprotect failed: " << header << ":" << record << std::endl;
- }
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext)) {
+ std::cerr << agent()->role_str() << ": unprotect failed: " << header << ":"
+ << record << std::endl;
return KEEP;
}
+ auto& protection_spec = spec(protection_epoch);
TlsRecordHeader real_header(header.variant(), header.version(),
inner_content_type, header.sequence_number());
@@ -174,7 +212,9 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
// In stream mode, even if something doesn't change we need to re-encrypt if
// previous packets were dropped.
if (action == KEEP) {
- if (header.is_dtls() || !dropped_record_) {
+ if (header.is_dtls() || !protection_spec.record_dropped()) {
+ // Count every outgoing packet.
+ protection_spec.RecordProtected();
return KEEP;
}
filtered = plaintext;
@@ -182,7 +222,7 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
if (action == DROP) {
std::cerr << "record drop: " << header << ":" << record << std::endl;
- dropped_record_ = true;
+ protection_spec.RecordDropped();
return DROP;
}
@@ -192,19 +232,18 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
std::cerr << "record new: " << filtered << std::endl;
}
- uint64_t seq_num;
- if (header.is_dtls() || !cipher_spec_ ||
- header.content_type() != ssl_ct_application_data) {
- seq_num = header.sequence_number();
- } else {
- seq_num = out_sequence_number_++;
+ uint64_t seq_num = protection_spec.next_out_seqno();
+ if (!decrypting_ && header.is_dtls()) {
+ // Copy over the epoch, which isn't tracked when not decrypting.
+ seq_num |= header.sequence_number() & (0xffffULL << 48);
}
+
TlsRecordHeader out_header(header.variant(), header.version(),
header.content_type(), seq_num);
DataBuffer ciphertext;
- bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
- EXPECT_TRUE(rv);
+ bool rv = Protect(protection_spec, out_header, inner_content_type, filtered,
+ &ciphertext);
if (!rv) {
return KEEP;
}
@@ -227,15 +266,20 @@ uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected,
uint32_t partial,
size_t partial_bits) {
EXPECT_GE(32U, partial_bits);
- uint64_t mask = (1 << partial_bits) - 1;
+ uint64_t mask = (1ULL << partial_bits) - 1;
// First we determine the highest possible value. This is half the
- // expressible range above the expected value.
- uint64_t cap = expected + (1ULL << (partial_bits - 1));
+ // expressible range above the expected value, less 1.
+ //
+ // We subtract the extra 1 from the cap so that when given a choice between
+ // the equidistant expected+N and expected-N we want to chose the lower. With
+ // 0-RTT, we sometimes have to recover an epoch of 1 when we expect an epoch
+ // of 3 and with 2 partial bits, the alternative result of 5 is wrong.
+ uint64_t cap = expected + (1ULL << (partial_bits - 1)) - 1;
// Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
uint64_t seq_no = (cap & ~mask) | partial;
// If the partial value is higher than the same partial piece from the cap,
// then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678.
- if (partial > (cap & mask)) {
+ if (partial > (cap & mask) && (seq_no >= (1ULL << partial_bits))) {
seq_no -= 1ULL << partial_bits;
}
return seq_no;
@@ -375,16 +419,41 @@ size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
const DataBuffer& ciphertext,
+ uint16_t* protection_epoch,
uint8_t* inner_content_type,
DataBuffer* plaintext) {
- if (!cipher_spec_ || header.content_type() != ssl_ct_application_data) {
+ if (!decrypting_ || header.content_type() != ssl_ct_application_data) {
+ // Maintain the epoch and sequence number for plaintext records.
+ uint16_t ep = 0;
+ if (agent()->variant() == ssl_variant_datagram) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ }
+ spec(ep).RecordUnprotected(header.sequence_number());
+ *protection_epoch = ep;
*inner_content_type = header.content_type();
*plaintext = ciphertext;
return true;
}
- if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) {
- return false;
+ uint16_t ep = 0;
+ if (agent()->variant() == ssl_variant_datagram) {
+ ep = static_cast<uint16_t>(header.sequence_number() >> 48);
+ if (!spec(ep).Unprotect(header, ciphertext, plaintext)) {
+ return false;
+ }
+ } else {
+ // In TLS, records aren't clearly labelled with their epoch, and we
+ // can't just use the newest keys because the same flight of messages can
+ // contain multiple epochs. So... trial decrypt!
+ for (size_t i = cipher_specs_.size() - 1; i > 0; --i) {
+ if (cipher_specs_[i].Unprotect(header, ciphertext, plaintext)) {
+ ep = cipher_specs_[i].epoch();
+ break;
+ }
+ }
+ if (!ep) {
+ return false;
+ }
}
size_t len = plaintext->len();
@@ -396,33 +465,45 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
return false;
}
+ *protection_epoch = ep;
*inner_content_type = plaintext->data()[len - 1];
plaintext->Truncate(len - 1);
if (g_ssl_gtest_verbose) {
- std::cerr << "unprotect: " << std::hex << header.sequence_number()
- << std::dec << " type=" << static_cast<int>(*inner_content_type)
+ std::cerr << agent()->role_str() << ": unprotect: epoch=" << ep
+ << " seq=" << std::hex << header.sequence_number() << std::dec
<< " " << *plaintext << std::endl;
}
return true;
}
-bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
+bool TlsRecordFilter::Protect(TlsCipherSpec& protection_spec,
+ const TlsRecordHeader& header,
uint8_t inner_content_type,
const DataBuffer& plaintext,
DataBuffer* ciphertext, size_t padding) {
- if (!cipher_spec_ || header.content_type() != ssl_ct_application_data) {
+ if (!protection_spec.is_protected()) {
+ // Not protected, just keep the sequence numbers updated.
+ protection_spec.RecordProtected();
*ciphertext = plaintext;
return true;
}
- if (g_ssl_gtest_verbose) {
- std::cerr << "protect: " << header.sequence_number() << std::endl;
- }
+
DataBuffer padded;
padded.Allocate(plaintext.len() + 1 + padding);
size_t offset = padded.Write(0, plaintext.data(), plaintext.len());
padded.Write(offset, inner_content_type, 1);
- return cipher_spec_->Protect(header, padded, ciphertext);
+
+ bool ok = protection_spec.Protect(header, padded, ciphertext);
+ if (!ok) {
+ ADD_FAILURE() << "protect fail";
+ } else if (g_ssl_gtest_verbose) {
+ std::cerr << agent()->role_str()
+ << ": protect: epoch=" << protection_spec.epoch()
+ << " seq=" << std::hex << header.sequence_number() << std::dec
+ << " " << *ciphertext << std::endl;
+ }
+ return ok;
}
bool IsHelloRetry(const DataBuffer& body) {