diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.cc | 321 |
1 files changed, 250 insertions, 71 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 76d9aaaff..89f201295 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -12,6 +12,7 @@ extern "C" { #include "libssl_internals.h" } +#include <cassert> #include <iostream> #include "gtest_utils.h" #include "tls_agent.h" @@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, PRBool isServer = self->agent()->role() == TlsAgent::SERVER; if (g_ssl_gtest_verbose) { - std::cerr << "Cipher spec changed. Role=" - << (isServer ? "server" : "client") - << " direction=" << (sending ? "send" : "receive") << std::endl; + std::cerr << (isServer ? "server" : "client") << ": " + << (sending ? "send" : "receive") + << " cipher spec changed: " << newSpec->epoch << " (" + << newSpec->phase << ")" << std::endl; + } + if (!sending) { + return; } - if (!sending) return; + self->in_sequence_number_ = 0; + self->out_sequence_number_ = 0; + self->dropped_record_ = false; self->cipher_spec_.reset(new TlsCipherSpec()); - bool ret = - self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec), - SSLInt_CipherSpecToKey(isServer, newSpec), - SSLInt_CipherSpecToIv(isServer, newSpec)); + bool ret = self->cipher_spec_->Init( + SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec), + SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec)); EXPECT_EQ(true, ret); } @@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, TlsRecordHeader header; DataBuffer record; - if (!header.Parse(&parser, &record)) { + if (!header.Parse(in_sequence_number_, &parser, &record)) { ADD_FAILURE() << "not a valid record"; return KEEP; } + // Track the sequence number, which is necessary for stream mode (the + // sequence number is in the header for datagram). + // + // This isn't perfectly robust. If there is a change from an active cipher + // spec to another active cipher spec (KeyUpdate for instance) AND writes + // are consolidated across that change AND packets were dropped from the + // older epoch, we will not correctly re-encrypt records in the old epoch to + // update their sequence numbers. + if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) { + ++in_sequence_number_; + } + if (FilterRecord(header, record, &offset, output) != KEEP) { changed = true; } else { @@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord( header.sequence_number()}; PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); + // In stream mode, even if something doesn't change we need to re-encrypt if + // previous packets were dropped. if (action == KEEP) { - return KEEP; + if (header.is_dtls() || !dropped_record_) { + return KEEP; + } + filtered = plaintext; } if (action == DROP) { - std::cerr << "record drop: " << record << std::endl; + std::cerr << "record drop: " << header << ":" << record << std::endl; + dropped_record_ = true; return DROP; } EXPECT_GT(0x10000U, filtered.len()); - std::cerr << "record old: " << plaintext << std::endl; - std::cerr << "record new: " << filtered << std::endl; + if (action != KEEP) { + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; + } + + uint64_t seq_num; + if (header.is_dtls() || !cipher_spec_ || + header.content_type() != kTlsApplicationDataType) { + seq_num = header.sequence_number(); + } else { + seq_num = out_sequence_number_++; + } + TlsRecordHeader out_header = {header.version(), header.content_type(), + seq_num}; DataBuffer ciphertext; - bool rv = Protect(header, inner_content_type, filtered, &ciphertext); + bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext); EXPECT_TRUE(rv); if (!rv) { return KEEP; } - *offset = header.Write(output, *offset, ciphertext); + *offset = out_header.Write(output, *offset, ciphertext); return CHANGE; } -bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { +bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser, + DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } @@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { } version_ = version; - sequence_number_ = 0; + // If this is DTLS, overwrite the sequence number. if (IsDtls(version)) { uint32_t tmp; if (!parser->Read(&tmp, 4)) { @@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { return false; } sequence_number_ |= static_cast<uint64_t>(tmp); + } else { + sequence_number_ = sequence_number; } return parser->ReadVariable(body, 2); } @@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, return true; } - if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false; + if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) { + return false; + } size_t len = plaintext->len(); while (len > 0 && !plaintext->data()[len - 1]) { @@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, *inner_content_type = plaintext->data()[len - 1]; plaintext->Truncate(len - 1); + if (g_ssl_gtest_verbose) { + std::cerr << "unprotect: " << std::hex << header.sequence_number() + << std::dec << " type=" << static_cast<int>(*inner_content_type) + << " " << *plaintext << std::endl; + } return true; } @@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header, *ciphertext = plaintext; return true; } + if (g_ssl_gtest_verbose) { + std::cerr << "protect: " << header.sequence_number() << std::endl; + } DataBuffer padded = plaintext; padded.Write(padded.len(), inner_content_type, 1); return cipher_spec_->Protect(header, padded, ciphertext); } +bool IsHelloRetry(const DataBuffer& body) { + static const uint8_t ssl_hello_retry_random[] = { + 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C, + 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, + 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C}; + return memcmp(body.data() + 2, ssl_hello_retry_random, + sizeof(ssl_hello_retry_random)) == 0; +} + +bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header, + const DataBuffer& body) { + if (handshake_types_.empty()) { + return true; + } + + uint8_t type = header.handshake_type(); + if (type == kTlsHandshakeServerHello) { + if (IsHelloRetry(body)) { + type = kTlsHandshakeHelloRetryRequest; + } + } + return handshake_types_.count(type) > 0U; +} + PacketFilter::Action TlsHandshakeFilter::FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. - if (record_header.content_type() != kTlsHandshakeType) { + if ((record_header.content_type() != kTlsHandshakeType) && + (record_header.content_type() != kTlsAltHandshakeType)) { return KEEP; } @@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( while (parser.remaining()) { HandshakeHeader header; DataBuffer handshake; - if (!header.Parse(&parser, record_header, &handshake)) { + bool complete = false; + if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake, + &complete)) { return KEEP; } + if (!complete) { + EXPECT_TRUE(record_header.is_dtls()); + // Save the fragment and drop it from this record. Fragments are + // coalesced with the last fragment of the handshake message. + changed = true; + preceding_fragment_.Assign(handshake); + continue; + } + preceding_fragment_.Truncate(0); + DataBuffer filtered; - PacketFilter::Action action = FilterHandshake(header, handshake, &filtered); + PacketFilter::Action action; + if (!IsFilteredType(header, handshake)) { + action = KEEP; + } else { + action = FilterHandshake(header, handshake, &filtered); + } if (action == DROP) { changed = true; std::cerr << "handshake drop: " << handshake << std::endl; @@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( std::cerr << "handshake old: " << handshake << std::endl; std::cerr << "handshake new: " << filtered << std::endl; source = &filtered; + } else if (preceding_fragment_.len()) { + changed = true; } offset = header.Write(output, offset, *source); @@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( } bool TlsHandshakeFilter::HandshakeHeader::ReadLength( - TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) { - if (!parser->Read(length, 3)) { + TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, + uint32_t* length, bool* last_fragment) { + uint32_t message_length; + if (!parser->Read(&message_length, 3)) { return false; // malformed } if (!header.is_dtls()) { + *last_fragment = true; + *length = message_length; return true; // nothing left to do } @@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength( } message_seq_ = message_seq_tmp; - uint32_t fragment_offset; - if (!parser->Read(&fragment_offset, 3)) { + uint32_t offset = 0; + if (!parser->Read(&offset, 3)) { + return false; + } + // We only parse if the fragments are all complete and in order. + if (offset != expected_offset) { + EXPECT_NE(0U, header.epoch()) + << "Received out of order handshake fragment for epoch 0"; return false; } - uint32_t fragment_length; - if (!parser->Read(&fragment_length, 3)) { + // For DTLS, we return the length of just this fragment. + if (!parser->Read(length, 3)) { return false; } - // All current tests where we are using this code don't fragment. - return (fragment_offset == 0 && fragment_length == *length); + // It's a fragment if the entire message is longer than what we have. + *last_fragment = message_length == (*length + offset); + return true; } bool TlsHandshakeFilter::HandshakeHeader::Parse( - TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) { + TlsParser* parser, const TlsRecordHeader& record_header, + const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) { + *complete = false; + version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed } + uint32_t length; - if (!ReadLength(parser, record_header, &length)) { + if (!ReadLength(parser, record_header, preceding_fragment.len(), &length, + complete)) { return false; } - return parser->Read(body, length); + if (!parser->Read(body, length)) { + return false; + } + if (preceding_fragment.len()) { + body->Splice(preceding_fragment, 0); + } + return true; } size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( @@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake( return KEEP; } - if (header.handshake_type() == handshake_type_) { - buffer_ = input; - } + buffer_ = input; return KEEP; } PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == handshake_type_) { - *output = buffer_; - return CHANGE; - } + *output = buffer_; + return CHANGE; +} +PacketFilter::Action TlsRecordRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (!filter_ || (header.content_type() == ct_)) { + records_.push_back({header, input}); + } return KEEP; } @@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord( return KEEP; } +PacketFilter::Action TlsHeaderRecorder::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { + headers_.push_back(header); + return KEEP; +} + +const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) { + if (index > headers_.size() + 1) { + return nullptr; + } + return &headers_[index]; +} + PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); bool changed = false; for (auto it = filters_.begin(); it != filters_.end(); ++it) { - PacketFilter::Action action = (*it)->Filter(in, output); + PacketFilter::Action action = (*it)->Process(in, output); if (action == DROP) { return DROP; } + if (action == CHANGE) { in = *output; changed = true; @@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } -static bool FindHelloRetryExtensions(TlsParser* parser, - const TlsVersioned& header) { - // TODO for -19 add cipher suite - if (!parser->Skip(2)) { // version - return false; - } - return true; -} - bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { return true; } @@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser, if (!parser->SkipVariable(1)) { // request context return false; } - // TODO remove the next two for -19 - if (!parser->SkipVariable(2)) { // signature_algorithms - return false; - } - if (!parser->SkipVariable(2)) { // certificate_authorities - return false; - } return true; } @@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, if (!parser->Skip(8)) { // lifetime, age add return false; } + if (!parser->SkipVariable(1)) { // ticket_nonce + return false; + } if (!parser->SkipVariable(2)) { // ticket return false; } @@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser, static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = { {kTlsHandshakeClientHello, FindClientHelloExtensions}, {kTlsHandshakeServerHello, FindServerHelloExtensions}, - {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions}, {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, {kTlsHandshakeCertificate, FindCertificateExtensions}, @@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser, PacketFilter::Action TlsExtensionFilter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (handshake_types_.count(header.handshake_type()) == 0) { - return KEEP; - } - TlsParser parser(input); if (!FindExtensions(&parser, header)) { return KEEP; @@ -610,6 +725,38 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension( return KEEP; } +PacketFilter::Action TlsExtensionInjector::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + TlsParser parser(input); + if (!TlsExtensionFilter::FindExtensions(&parser, header)) { + return KEEP; + } + size_t offset = parser.consumed(); + + *output = input; + + // Increase the size of the extensions. + uint16_t ext_len; + memcpy(&ext_len, output->data() + offset, sizeof(ext_len)); + ext_len = htons(ntohs(ext_len) + data_.len() + 4); + memcpy(output->data() + offset, &ext_len, sizeof(ext_len)); + + // Insert the extension type and length. + DataBuffer type_length; + type_length.Allocate(4); + type_length.Write(0, extension_, 2); + type_length.Write(2, data_.len(), 2); + output->Splice(type_length, offset + 2); + + // Insert the payload. + if (data_.len() > 0) { + output->Splice(data_, offset + 6); + } + + return CHANGE; +} + PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { @@ -628,10 +775,8 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { - EXPECT_EQ(SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); - } + EXPECT_EQ(SECSuccess, + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); return KEEP; } @@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input, return ((1 << counter_++) & pattern_) ? DROP : KEEP; } +PacketFilter::Action SelectiveRecordDropFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& data, + DataBuffer* changed) { + if (counter_ >= 32) { + return KEEP; + } + return ((1 << counter_++) & pattern_) ? DROP : KEEP; +} + +/* static */ uint32_t SelectiveRecordDropFilter::ToPattern( + std::initializer_list<size_t> records) { + uint32_t pattern = 0; + for (auto it = records.begin(); it != records.end(); ++it) { + EXPECT_GT(32U, *it); + assert(*it < 32U); + pattern |= 1 << *it; + } + return pattern; +} + PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake( const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientHello) { - *output = input; - output->Write(0, version_, 2); - return CHANGE; - } - return KEEP; + *output = input; + output->Write(0, version_, 2); + return CHANGE; +} + +PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + *output = input; + uint32_t temp = 0; + EXPECT_TRUE(input.Read(0, 2, &temp)); + // Cipher suite is after version(2) and random(32). + size_t pos = 34; + if (temp < SSL_LIBRARY_VERSION_TLS_1_3) { + // In old versions, we have to skip a session_id too. + EXPECT_TRUE(input.Read(pos, 1, &temp)); + pos += 1 + temp; + } + output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2); + return CHANGE; } } // namespace nss_test |