diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.cc | 234 |
1 files changed, 194 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index d34b13bcb..aa03cba70 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -30,11 +30,9 @@ void TlsVersioned::WriteStream(std::ostream& stream) const { case SSL_LIBRARY_VERSION_TLS_1_0: stream << "1.0"; break; - case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: case SSL_LIBRARY_VERSION_TLS_1_1: stream << (is_dtls() ? "1.0" : "1.1"); break; - case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: case SSL_LIBRARY_VERSION_TLS_1_2: stream << "1.2"; break; @@ -67,8 +65,14 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, return; } - self->in_sequence_number_ = 0; - self->out_sequence_number_ = 0; + uint64_t seq_no; + if (self->agent()->variant() == ssl_variant_datagram) { + seq_no = static_cast<uint64_t>(SSLInt_CipherSpecToEpoch(newSpec)) << 48; + } else { + seq_no = 0; + } + self->in_sequence_number_ = seq_no; + self->out_sequence_number_ = seq_no; self->dropped_record_ = false; self->cipher_spec_.reset(new TlsCipherSpec()); bool ret = self->cipher_spec_->Init( @@ -77,33 +81,59 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, EXPECT_EQ(true, ret); } +bool TlsRecordFilter::is_dtls13() const { + if (agent()->variant() != ssl_variant_datagram) { + return false; + } + if (agent()->state() == TlsAgent::STATE_CONNECTED) { + return agent()->version() >= SSL_LIBRARY_VERSION_TLS_1_3; + } + SSLPreliminaryChannelInfo info; + EXPECT_EQ(SECSuccess, SSL_GetPreliminaryChannelInfo(agent()->ssl_fd(), &info, + sizeof(info))); + return (info.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) || + info.canSendEarlyData; +} + PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { + // Disable during shutdown. + if (!agent()) { + return KEEP; + } + bool changed = false; size_t offset = 0U; - output->Allocate(input.len()); + output->Allocate(input.len()); TlsParser parser(input); while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(in_sequence_number_, &parser, &record)) { + if (!header.Parse(is_dtls13(), in_sequence_number_, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } - // Track the sequence number, which is necessary for stream mode (the - // sequence number is in the header for datagram). + // Track the sequence number, which is necessary for stream mode when + // decrypting and for TLS 1.3 datagram to recover the sequence number. + // + // We reset the counter when the cipher spec changes, but that notification + // appears before a record is sent. If multiple records are sent with + // different cipher specs, this would fail. This filters out cleartext + // records, so we don't get confused by handshake messages that are sent at + // the same time as encrypted records. Sequence numbers are therefore + // likely to be incorrect for cleartext records. // - // This isn't perfectly robust. If there is a change from an active cipher + // This isn't perfectly robust: if there is a change from an active cipher // spec to another active cipher spec (KeyUpdate for instance) AND writes - // are consolidated across that change AND packets were dropped from the - // older epoch, we will not correctly re-encrypt records in the old epoch to - // update their sequence numbers. - if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) { - ++in_sequence_number_; + // are consolidated across that change, this code could use the wrong + // sequence numbers when re-encrypting records with the old keys. + if (header.content_type() == kTlsApplicationDataType) { + in_sequence_number_ = + (std::max)(in_sequence_number_, header.sequence_number() + 1); } if (FilterRecord(header, record, &offset, output) != KEEP) { @@ -131,11 +161,14 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( DataBuffer plaintext; if (!Unprotect(header, record, &inner_content_type, &plaintext)) { + if (g_ssl_gtest_verbose) { + std::cerr << "unprotect failed: " << header << ":" << record << std::endl; + } return KEEP; } - TlsRecordHeader real_header = {header.version(), inner_content_type, - header.sequence_number()}; + TlsRecordHeader real_header(header.variant(), header.version(), + inner_content_type, header.sequence_number()); PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); // In stream mode, even if something doesn't change we need to re-encrypt if @@ -166,8 +199,8 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( } else { seq_num = out_sequence_number_++; } - TlsRecordHeader out_header = {header.version(), header.content_type(), - seq_num}; + TlsRecordHeader out_header(header.variant(), header.version(), + header.content_type(), seq_num); DataBuffer ciphertext; bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); @@ -179,20 +212,119 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( return CHANGE; } -bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser, +size_t TlsRecordHeader::header_length() const { + // If we have a header, return it's length. + if (header_.len()) { + return header_.len(); + } + + // Otherwise make a dummy header and return the length. + DataBuffer buf; + return WriteHeader(&buf, 0, 0); +} + +uint64_t TlsRecordHeader::RecoverSequenceNumber(uint64_t expected, + uint32_t partial, + size_t partial_bits) { + EXPECT_GE(32U, partial_bits); + uint64_t mask = (1 << partial_bits) - 1; + // First we determine the highest possible value. This is half the + // expressible range above the expected value. + uint64_t cap = expected + (1ULL << (partial_bits - 1)); + // Add the partial piece in. e.g., xxxx789a and 1234 becomes xxxx1234. + uint64_t seq_no = (cap & ~mask) | partial; + // If the partial value is higher than the same partial piece from the cap, + // then the real value has to be lower. e.g., xxxx1234 can't become xxxx5678. + if (partial > (cap & mask)) { + seq_no -= 1ULL << partial_bits; + } + return seq_no; +} + +// Determine the full epoch and sequence number from an expected and raw value. +// The expected and output values are packed as they are in DTLS 1.2 and +// earlier: with 16 bits of epoch and 48 bits of sequence number. +uint64_t TlsRecordHeader::ParseSequenceNumber(uint64_t expected, uint32_t raw, + size_t seq_no_bits, + size_t epoch_bits) { + uint64_t epoch_mask = (1ULL << epoch_bits) - 1; + uint64_t epoch = RecoverSequenceNumber( + expected >> 48, (raw >> seq_no_bits) & epoch_mask, epoch_bits); + if (epoch > (expected >> 48)) { + // If the epoch has changed, reset the expected sequence number. + expected = 0; + } else { + // Otherwise, retain just the sequence number part. + expected &= (1ULL << 48) - 1; + } + uint64_t seq_no_mask = (1ULL << seq_no_bits) - 1; + uint64_t seq_no = + RecoverSequenceNumber(expected, raw & seq_no_mask, seq_no_bits); + return (epoch << 48) | seq_no; +} + +bool TlsRecordHeader::Parse(bool is_dtls13, uint64_t seqno, TlsParser* parser, DataBuffer* body) { + auto mark = parser->consumed(); + if (!parser->Read(&content_type_)) { return false; } - uint32_t version; - if (!parser->Read(&version, 2)) { + if (is_dtls13) { + variant_ = ssl_variant_datagram; + version_ = SSL_LIBRARY_VERSION_TLS_1_3; + +#ifndef UNSAFE_FUZZER_MODE + // Deal with the 7 octet header. + if (content_type_ == kTlsApplicationDataType) { + uint32_t tmp; + if (!parser->Read(&tmp, 4)) { + return false; + } + sequence_number_ = ParseSequenceNumber(seqno, tmp, 30, 2); + if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, + mark)) { + return false; + } + return parser->ReadVariable(body, 2); + } + + // The short, 2 octet header. + if ((content_type_ & 0xe0) == 0x20) { + uint32_t tmp; + if (!parser->Read(&tmp, 1)) { + return false; + } + // Need to use the low 5 bits of the first octet too. + tmp |= (content_type_ & 0x1f) << 8; + content_type_ = kTlsApplicationDataType; + sequence_number_ = ParseSequenceNumber(seqno, tmp, 12, 1); + + if (!parser->ReadFromMark(&header_, parser->consumed() - mark, mark)) { + return false; + } + return parser->Read(body, parser->remaining()); + } + + // The full 13 octet header can only be used for a few types. + EXPECT_TRUE(content_type_ == kTlsAlertType || + content_type_ == kTlsHandshakeType || + content_type_ == kTlsAckType); +#endif + } + + uint32_t ver; + if (!parser->Read(&ver, 2)) { return false; } - version_ = version; + if (!is_dtls13) { + variant_ = IsDtls(ver) ? ssl_variant_datagram : ssl_variant_stream; + } + version_ = NormalizeTlsVersion(ver); - // If this is DTLS, overwrite the sequence number. - if (IsDtls(version)) { + if (is_dtls()) { + // If this is DTLS, read the sequence number. uint32_t tmp; if (!parser->Read(&tmp, 4)) { return false; @@ -203,21 +335,40 @@ bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser, } sequence_number_ |= static_cast<uint64_t>(tmp); } else { - sequence_number_ = sequence_number; + sequence_number_ = seqno; + } + if (!parser->ReadFromMark(&header_, parser->consumed() + 2 - mark, mark)) { + return false; } return parser->ReadVariable(body, 2); } -size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, - const DataBuffer& body) const { +size_t TlsRecordHeader::WriteHeader(DataBuffer* buffer, size_t offset, + size_t body_len) const { offset = buffer->Write(offset, content_type_, 1); - offset = buffer->Write(offset, version_, 2); - if (is_dtls()) { - // write epoch (2 octet), and seqnum (6 octet) - offset = buffer->Write(offset, sequence_number_ >> 32, 4); - offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + if (is_dtls() && version_ >= SSL_LIBRARY_VERSION_TLS_1_3 && + content_type() == kTlsApplicationDataType) { + // application_data records in TLS 1.3 have a different header format. + // Always use the long header here for simplicity. + uint32_t e = (sequence_number_ >> 48) & 0x3; + uint32_t seqno = sequence_number_ & ((1ULL << 30) - 1); + offset = buffer->Write(offset, (e << 30) | seqno, 4); + } else { + uint16_t v = is_dtls() ? TlsVersionToDtlsVersion(version_) : version_; + offset = buffer->Write(offset, v, 2); + if (is_dtls()) { + // write epoch (2 octet), and seqnum (6 octet) + offset = buffer->Write(offset, sequence_number_ >> 32, 4); + offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4); + } } - offset = buffer->Write(offset, body.len(), 2); + offset = buffer->Write(offset, body_len, 2); + return offset; +} + +size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { + offset = WriteHeader(buffer, offset, body.len()); offset = buffer->Write(offset, body); return offset; } @@ -259,7 +410,7 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, bool TlsRecordFilter::Protect(const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, - DataBuffer* ciphertext) { + DataBuffer* ciphertext, size_t padding) { if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) { *ciphertext = plaintext; return true; @@ -267,8 +418,10 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header, if (g_ssl_gtest_verbose) { std::cerr << "protect: " << header.sequence_number() << std::endl; } - DataBuffer padded = plaintext; - padded.Write(padded.len(), inner_content_type, 1); + DataBuffer padded; + padded.Allocate(plaintext.len() + 1 + padding); + size_t offset = padded.Write(0, plaintext.data(), plaintext.len()); + padded.Write(offset, inner_content_type, 1); return cipher_spec_->Protect(header, padded, ciphertext); } @@ -406,6 +559,7 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse( const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { *complete = false; + variant_ = record_header.variant(); version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed @@ -487,10 +641,10 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord( return KEEP; } -PacketFilter::Action TlsHeaderRecorder::FilterRecord( - const TlsRecordHeader& header, const DataBuffer& input, - DataBuffer* output) { - headers_.push_back(header); +PacketFilter::Action TlsHeaderRecorder::FilterRecord(const TlsRecordHeader& hdr, + const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(hdr); return KEEP; } |