From f017b749ea9f1586d2308504553d40bf4cc5439d Mon Sep 17 00:00:00 2001 From: wolfbeast Date: Tue, 6 Feb 2018 11:46:26 +0100 Subject: Update NSS to 3.32.1-RTM --- security/nss/gtests/ssl_gtest/tls_filter.cc | 326 ++++++++++++++++++++-------- 1 file changed, 240 insertions(+), 86 deletions(-) (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc') diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc index 4f7d195d0..76d9aaaff 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.cc +++ b/security/nss/gtests/ssl_gtest/tls_filter.cc @@ -15,9 +15,62 @@ extern "C" { #include #include "gtest_utils.h" #include "tls_agent.h" +#include "tls_filter.h" +#include "tls_protect.h" namespace nss_test { +void TlsVersioned::WriteStream(std::ostream& stream) const { + stream << (is_dtls() ? "DTLS " : "TLS "); + switch (version()) { + case 0: + stream << "(no version)"; + break; + case SSL_LIBRARY_VERSION_TLS_1_0: + stream << "1.0"; + break; + case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE: + case SSL_LIBRARY_VERSION_TLS_1_1: + stream << (is_dtls() ? "1.0" : "1.1"); + break; + case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE: + case SSL_LIBRARY_VERSION_TLS_1_2: + stream << "1.2"; + break; + case SSL_LIBRARY_VERSION_TLS_1_3: + stream << "1.3"; + break; + default: + stream << "Invalid version: " << version(); + break; + } +} + +void TlsRecordFilter::EnableDecryption() { + SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged, + (void*)this); +} + +void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec) { + TlsRecordFilter* self = static_cast(arg); + PRBool isServer = self->agent()->role() == TlsAgent::SERVER; + + if (g_ssl_gtest_verbose) { + std::cerr << "Cipher spec changed. Role=" + << (isServer ? "server" : "client") + << " direction=" << (sending ? "send" : "receive") << std::endl; + } + if (!sending) return; + + self->cipher_spec_.reset(new TlsCipherSpec()); + bool ret = + self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec), + SSLInt_CipherSpecToKey(isServer, newSpec), + SSLInt_CipherSpecToIv(isServer, newSpec)); + EXPECT_EQ(true, ret); +} + PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { bool changed = false; @@ -25,10 +78,13 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, output->Allocate(input.len()); TlsParser parser(input); + while (parser.remaining()) { - RecordHeader header; + TlsRecordHeader header; DataBuffer record; + if (!header.Parse(&parser, &record)) { + ADD_FAILURE() << "not a valid record"; return KEEP; } @@ -49,12 +105,21 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input, return KEEP; } -PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header, - const DataBuffer& record, - size_t* offset, - DataBuffer* output) { +PacketFilter::Action TlsRecordFilter::FilterRecord( + const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, + DataBuffer* output) { DataBuffer filtered; - PacketFilter::Action action = FilterRecord(header, record, &filtered); + uint8_t inner_content_type; + DataBuffer plaintext; + + if (!Unprotect(header, record, &inner_content_type, &plaintext)) { + return KEEP; + } + + TlsRecordHeader real_header = {header.version(), inner_content_type, + header.sequence_number()}; + + PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered); if (action == KEEP) { return KEEP; } @@ -64,19 +129,21 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header, return DROP; } - const DataBuffer* source = &record; - if (action == CHANGE) { - EXPECT_GT(0x10000U, filtered.len()); - std::cerr << "record old: " << record << std::endl; - std::cerr << "record new: " << filtered << std::endl; - source = &filtered; - } + EXPECT_GT(0x10000U, filtered.len()); + std::cerr << "record old: " << plaintext << std::endl; + std::cerr << "record new: " << filtered << std::endl; - *offset = header.Write(output, *offset, *source); + DataBuffer ciphertext; + bool rv = Protect(header, inner_content_type, filtered, &ciphertext); + EXPECT_TRUE(rv); + if (!rv) { + return KEEP; + } + *offset = header.Write(output, *offset, ciphertext); return CHANGE; } -bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) { +bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) { if (!parser->Read(&content_type_)) { return false; } @@ -102,8 +169,8 @@ bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) { return parser->ReadVariable(body, 2); } -size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset, - const DataBuffer& body) const { +size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset, + const DataBuffer& body) const { offset = buffer->Write(offset, content_type_, 1); offset = buffer->Write(offset, version_, 2); if (is_dtls()) { @@ -116,8 +183,48 @@ size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset, return offset; } +bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header, + const DataBuffer& ciphertext, + uint8_t* inner_content_type, + DataBuffer* plaintext) { + if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) { + *inner_content_type = header.content_type(); + *plaintext = ciphertext; + return true; + } + + if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false; + + size_t len = plaintext->len(); + while (len > 0 && !plaintext->data()[len - 1]) { + --len; + } + if (!len) { + // Bogus padding. + return false; + } + + *inner_content_type = plaintext->data()[len - 1]; + plaintext->Truncate(len - 1); + + return true; +} + +bool TlsRecordFilter::Protect(const TlsRecordHeader& header, + uint8_t inner_content_type, + const DataBuffer& plaintext, + DataBuffer* ciphertext) { + if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) { + *ciphertext = plaintext; + return true; + } + DataBuffer padded = plaintext; + padded.Write(padded.len(), inner_content_type, 1); + return cipher_spec_->Protect(header, padded, ciphertext); +} + PacketFilter::Action TlsHandshakeFilter::FilterRecord( - const RecordHeader& record_header, const DataBuffer& input, + const TlsRecordHeader& record_header, const DataBuffer& input, DataBuffer* output) { // Check that the first byte is as requested. if (record_header.content_type() != kTlsHandshakeType) { @@ -159,9 +266,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord( return changed ? (offset ? CHANGE : DROP) : KEEP; } -bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser, - const RecordHeader& header, - uint32_t* length) { +bool TlsHandshakeFilter::HandshakeHeader::ReadLength( + TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) { if (!parser->Read(length, 3)) { return false; // malformed } @@ -192,7 +298,7 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser, } bool TlsHandshakeFilter::HandshakeHeader::Parse( - TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) { + TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) { version_ = record_header.version(); if (!parser->Read(&handshake_type_)) { return false; // malformed @@ -205,15 +311,28 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse( return parser->Read(body, length); } -size_t TlsHandshakeFilter::HandshakeHeader::Write( - DataBuffer* buffer, size_t offset, const DataBuffer& body) const { +size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment( + DataBuffer* buffer, size_t offset, const DataBuffer& body, + size_t fragment_offset, size_t fragment_length) const { + EXPECT_TRUE(is_dtls()); + EXPECT_GE(body.len(), fragment_offset + fragment_length); offset = buffer->Write(offset, handshake_type(), 1); offset = buffer->Write(offset, body.len(), 3); + offset = buffer->Write(offset, message_seq_, 2); + offset = buffer->Write(offset, fragment_offset, 3); + offset = buffer->Write(offset, fragment_length, 3); + offset = + buffer->Write(offset, body.data() + fragment_offset, fragment_length); + return offset; +} + +size_t TlsHandshakeFilter::HandshakeHeader::Write( + DataBuffer* buffer, size_t offset, const DataBuffer& body) const { if (is_dtls()) { - offset = buffer->Write(offset, message_seq_, 2); - offset = buffer->Write(offset, 0U, 3); // fragment_offset - offset = buffer->Write(offset, body.len(), 3); + return WriteFragment(buffer, offset, body, 0U, body.len()); } + offset = buffer->Write(offset, handshake_type(), 1); + offset = buffer->Write(offset, body.len(), 3); offset = buffer->Write(offset, body); return offset; } @@ -244,42 +363,12 @@ PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake( } PacketFilter::Action TlsConversationRecorder::FilterRecord( - const RecordHeader& header, const DataBuffer& input, DataBuffer* output) { + const TlsRecordHeader& header, const DataBuffer& input, + DataBuffer* output) { buffer_.Append(input); return KEEP; } -PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header, - const DataBuffer& input, - DataBuffer* output) { - if (level_ == kTlsAlertFatal) { // already fatal - return KEEP; - } - if (header.content_type() != kTlsAlertType) { - return KEEP; - } - - std::cerr << "Alert: " << input << std::endl; - - TlsParser parser(input); - uint8_t lvl; - if (!parser.Read(&lvl)) { - return KEEP; - } - if (lvl == kTlsAlertWarning) { // not strong enough - return KEEP; - } - level_ = lvl; - (void)parser.Read(&description_); - return KEEP; -} - -ChainedPacketFilter::~ChainedPacketFilter() { - for (auto it = filters_.begin(); it != filters_.end(); ++it) { - delete *it; - } -} - PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { DataBuffer in(input); @@ -297,28 +386,7 @@ PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input, return changed ? CHANGE : KEEP; } -PacketFilter::Action TlsExtensionFilter::FilterHandshake( - const HandshakeHeader& header, const DataBuffer& input, - DataBuffer* output) { - if (header.handshake_type() == kTlsHandshakeClientHello) { - TlsParser parser(input); - if (!FindClientHelloExtensions(&parser, header)) { - return KEEP; - } - return FilterExtensions(&parser, input, output); - } - if (header.handshake_type() == kTlsHandshakeServerHello) { - TlsParser parser(input); - if (!FindServerHelloExtensions(&parser)) { - return KEEP; - } - return FilterExtensions(&parser, input, output); - } - return KEEP; -} - -bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser, - const Versioned& header) { +bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) { if (!parser->Skip(2 + 32)) { // version + random return false; } @@ -337,7 +405,7 @@ bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser, return true; } -bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) { +bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) { uint32_t vtmp; if (!parser->Read(&vtmp, 2)) { return false; @@ -362,6 +430,92 @@ bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) { return true; } +static bool FindHelloRetryExtensions(TlsParser* parser, + const TlsVersioned& header) { + // TODO for -19 add cipher suite + if (!parser->Skip(2)) { // version + return false; + } + return true; +} + +bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) { + return true; +} + +static bool FindCertReqExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + // TODO remove the next two for -19 + if (!parser->SkipVariable(2)) { // signature_algorithms + return false; + } + if (!parser->SkipVariable(2)) { // certificate_authorities + return false; + } + return true; +} + +// Only look at the EE cert for this one. +static bool FindCertificateExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->SkipVariable(1)) { // request context + return false; + } + if (!parser->Skip(3)) { // length of certificate list + return false; + } + if (!parser->SkipVariable(3)) { // ASN1Cert + return false; + } + return true; +} + +static bool FindNewSessionTicketExtensions(TlsParser* parser, + const TlsVersioned& header) { + if (!parser->Skip(8)) { // lifetime, age add + return false; + } + if (!parser->SkipVariable(2)) { // ticket + return false; + } + return true; +} + +static const std::map kExtensionFinders = { + {kTlsHandshakeClientHello, FindClientHelloExtensions}, + {kTlsHandshakeServerHello, FindServerHelloExtensions}, + {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions}, + {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions}, + {kTlsHandshakeCertificateRequest, FindCertReqExtensions}, + {kTlsHandshakeCertificate, FindCertificateExtensions}, + {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}}; + +bool TlsExtensionFilter::FindExtensions(TlsParser* parser, + const HandshakeHeader& header) { + auto it = kExtensionFinders.find(header.handshake_type()); + if (it == kExtensionFinders.end()) { + return false; + } + return (it->second)(parser, header); +} + +PacketFilter::Action TlsExtensionFilter::FilterHandshake( + const HandshakeHeader& header, const DataBuffer& input, + DataBuffer* output) { + if (handshake_types_.count(header.handshake_type()) == 0) { + return KEEP; + } + + TlsParser parser(input); + if (!FindExtensions(&parser, header)) { + return KEEP; + } + return FilterExtensions(&parser, input, output); +} + PacketFilter::Action TlsExtensionFilter::FilterExtensions( TlsParser* parser, const DataBuffer& input, DataBuffer* output) { size_t length_offset = parser->consumed(); @@ -456,14 +610,14 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension( return KEEP; } -PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header, +PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) { if (counter_++ == record_) { DataBuffer buf; header.Write(&buf, 0, body); - src_->SendDirect(buf); - dest_->Handshake(); + src_.lock()->SendDirect(buf); + dest_.lock()->Handshake(); func_(); return DROP; } @@ -476,7 +630,7 @@ PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake( DataBuffer* output) { if (header.handshake_type() == kTlsHandshakeClientKeyExchange) { EXPECT_EQ(SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd())); + SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd())); } return KEEP; } -- cgit v1.2.3