summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc234
1 files changed, 194 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
index d34b13bcb..aa03cba70 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -30,11 +30,9 @@ void TlsVersioned::WriteStream(std::ostream& stream) const {
case SSL_LIBRARY_VERSION_TLS_1_0:
stream << "1.0";
break;
- case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_1:
stream << (is_dtls() ? "1.0" : "1.1");
break;
- case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE:
case SSL_LIBRARY_VERSION_TLS_1_2:
stream << "1.2";
break;
@@ -67,8 +65,14 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
return;
}
- self->in_sequence_number_ = 0;
- self->out_sequence_number_ = 0;
+ uint64_t seq_no;
+ if (self->agent()->variant() == ssl_variant_datagram) {
+ seq_no = static_cast<uint64_t>(SSLInt_CipherSpecToEpoch(newSpec)) << 48;
+ } else {
+ seq_no = 0;
+ }
+ self->in_sequence_number_ = seq_no;
+ self->out_sequence_number_ = seq_no;
self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
bool ret = self->cipher_spec_->Init(
@@ -77,33 +81,59 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
EXPECT_EQ(true, ret);
}
+bool TlsRecordFilter::is_dtls13() const {
+ if (agent()->variant() != ssl_variant_datagram) {
+ return false;
+ }
+ if (agent()->state() == TlsAgent::STATE_CONNECTED) {
+ return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3;
+ }
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info,
+ sizeof(info)));
+ return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) ||
+ info.canSendEarlyData;
+}
+
PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
+ // Disable during shutdown.
+ if (!agent()) {
+ return KEEP;
+ }
+
bool changed = false;
size_t offset = 0U;
- output->Allocate(input.len());
+ output->Allocate(input.len());
TlsParser parser(input);
while (parser.remaining()) {
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(in_sequence_number_, &parser, &record)) {
+ if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
- // Track the sequence number, which is necessary for stream mode (the
- // sequence number is in the header for datagram).
+ // Track the sequence number, which is necessary for stream mode when
+ // decrypting and for TLS 1.3 datagram to recover the sequence number.
+ //
+ // We reset the counter when the cipher spec changes, but that notification
+ // appears before a record is sent. If multiple records are sent with
+ // different cipher specs, this would fail. This filters out cleartext
+ // records, so we don't get confused by handshake messages that are sent at
+ // the same time as encrypted records. Sequence numbers are therefore
+ // likely to be incorrect for cleartext records.
//
- // This isn't perfectly robust. If there is a change from an active cipher
+ // This isn't perfectly robust: if there is a change from an active cipher
// spec to another active cipher spec (KeyUpdate for instance) AND writes
- // are consolidated across that change AND packets were dropped from the
- // older epoch, we will not correctly re-encrypt records in the old epoch to
- // update their sequence numbers.
- if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
- ++in_sequence_number_;
+ // are consolidated across that change, this code could use the wrong
+ // sequence numbers when re-encrypting records with the old keys.
+ if (header.content_type() == kTlsApplicationDataType) {
+ in_sequence_number_ =
+ (std::max)(in_sequence_number_, header.sequence_number() + 1);
}
if (FilterRecord(header, record, &offset, output) != KEEP) {
@@ -131,11 +161,14 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
DataBuffer plaintext;
if (!Unprotect(header, record, &inner_content_type, &plaintext)) {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "unprotect failed: " << header << ":" << record << std::endl;
+ }
return KEEP;
}
- TlsRecordHeader real_header = {header.version(), inner_content_type,
- header.sequence_number()};
+ TlsRecordHeader real_header(header.variant(), header.version(),
+ inner_content_type, header.sequence_number());
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
// In stream mode, even if something doesn't change we need to re-encrypt if
@@ -166,8 +199,8 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
} else {
seq_num = out_sequence_number_++;
}
- TlsRecordHeader out_header = {header.version(), header.content_type(),
- seq_num};
+ TlsRecordHeader out_header(header.variant(), header.version(),
+ header.content_type(), seq_num);
DataBuffer ciphertext;
bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
@@ -179,20 +212,119 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
return CHANGE;
}
-bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
+size_t TlsRecordHeader::header_length() const {
+ // If we have a header, return it's length.
+ if (header_.len()) {
+ return header_.len();
+ }
+
+ // Otherwise make a dummy header and return the length.
+ DataBuffer buf;
+ return WriteHeader(&buf, 0, 0);
+}
+
+uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected,
+ uint32_t partial,
+ size_t partial_bits) {
+ EXPECT_GE(32U, partial_bits);
+ uint64_t mask = (1 << partial_bits) - 1;
+ // First we determine the highest possible value. This is half the
+ // expressible range above the expected value.
+ uint64_t cap = expected + (1ULL << (partial_bits - 1));
+ // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234.
+ uint64_t seq_no = (cap & ~mask) | partial;
+ // If the partial value is higher than the same partial piece from the cap,
+ // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678.
+ if (partial > (cap & mask)) {
+ seq_no -= 1ULL << partial_bits;
+ }
+ return seq_no;
+}
+
+// Determine the full epoch and sequence number from an expected and raw value.
+// The expected and output values are packed as they are in DTLS 1.2 and
+// earlier: with 16 bits of epoch and 48 bits of sequence number.
+uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw,
+ size_t seq_no_bits,
+ size_t epoch_bits) {
+ uint64_t epoch_mask = (1ULL << epoch_bits) - 1;
+ uint64_t epoch = RecoverSequenceNumber(
+ expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits);
+ if (epoch > (expected >> 48)) {
+ // If the epoch has changed, reset the expected sequence number.
+ expected = 0;
+ } else {
+ // Otherwise, retain just the sequence number part.
+ expected &= (1ULL << 48) - 1;
+ }
+ uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1;
+ uint64_t seq_no =
+ RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits);
+ return (epoch << 48) | seq_no;
+}
+
+bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser,
DataBuffer* body) {
+ auto mark = parser->consumed();
+
if (!parser->Read(&content_type_)) {
return false;
}
- uint32_t version;
- if (!parser->Read(&version, 2)) {
+ if (is_dtls13) {
+ variant_ = ssl_variant_datagram;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_3;
+
+#ifndef UNSAFE_FUZZER_MODE
+ // Deal with the 7 octet header.
+ if (content_type_ == kTlsApplicationDataType) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2);
+ if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark,
+ mark)) {
+ return false;
+ }
+ return parser->ReadVariable(body, 2);
+ }
+
+ // The short, 2 octet header.
+ if ((content_type_ & 0xe0) == 0x20) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 1)) {
+ return false;
+ }
+ // Need to use the low 5 bits of the first octet too.
+ tmp |= (content_type_ & 0x1f) << 8;
+ content_type_ = kTlsApplicationDataType;
+ sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1);
+
+ if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) {
+ return false;
+ }
+ return parser->Read(body, parser->remaining());
+ }
+
+ // The full 13 octet header can only be used for a few types.
+ EXPECT_TRUE(content_type_ == kTlsAlertType ||
+ content_type_ == kTlsHandshakeType ||
+ content_type_ == kTlsAckType);
+#endif
+ }
+
+ uint32_t ver;
+ if (!parser->Read(&ver, 2)) {
return false;
}
- version_ = version;
+ if (!is_dtls13) {
+ variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream;
+ }
+ version_ = NormalizeTlsVersion(ver);
- // If this is DTLS, overwrite the sequence number.
- if (IsDtls(version)) {
+ if (is_dtls()) {
+ // If this is DTLS, read the sequence number.
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
return false;
@@ -203,21 +335,40 @@ bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
}
sequence_number_ |= static_cast<uint64_t>(tmp);
} else {
- sequence_number_ = sequence_number;
+ sequence_number_ = seqno;
+ }
+ if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) {
+ return false;
}
return parser->ReadVariable(body, 2);
}
-size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
- const DataBuffer& body) const {
+size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset,
+ size_t body_len) const {
offset = buffer->Write(offset, content_type_, 1);
- offset = buffer->Write(offset, version_, 2);
- if (is_dtls()) {
- // write epoch (2 octet), and seqnum (6 octet)
- offset = buffer->Write(offset, sequence_number_ >> 32, 4);
- offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ content_type() == kTlsApplicationDataType) {
+ // application_data records in TLS 1.3 have a different header format.
+ // Always use the long header here for simplicity.
+ uint32_t e = (sequence_number_ >> 48) & 0x3;
+ uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1);
+ offset = buffer->Write(offset, (e << 30) | seqno, 4);
+ } else {
+ uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_;
+ offset = buffer->Write(offset, v, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
}
- offset = buffer->Write(offset, body.len(), 2);
+ offset = buffer->Write(offset, body_len, 2);
+ return offset;
+}
+
+size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
+ offset = WriteHeader(buffer, offset, body.len());
offset = buffer->Write(offset, body);
return offset;
}
@@ -259,7 +410,7 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
uint8_t inner_content_type,
const DataBuffer& plaintext,
- DataBuffer* ciphertext) {
+ DataBuffer* ciphertext, size_t padding) {
if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) {
*ciphertext = plaintext;
return true;
@@ -267,8 +418,10 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
if (g_ssl_gtest_verbose) {
std::cerr << "protect: " << header.sequence_number() << std::endl;
}
- DataBuffer padded = plaintext;
- padded.Write(padded.len(), inner_content_type, 1);
+ DataBuffer padded;
+ padded.Allocate(plaintext.len() + 1 + padding);
+ size_t offset = padded.Write(0, plaintext.data(), plaintext.len());
+ padded.Write(offset, inner_content_type, 1);
return cipher_spec_->Protect(header, padded, ciphertext);
}
@@ -406,6 +559,7 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse(
const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
*complete = false;
+ variant_ = record_header.variant();
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
@@ -487,10 +641,10 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord(
return KEEP;
}
-PacketFilter::Action TlsHeaderRecorder::FilterRecord(
- const TlsRecordHeader& header, const DataBuffer& input,
- DataBuffer* output) {
- headers_.push_back(header);
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(hdr);
return KEEP;
}