summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.cc321
1 files changed, 250 insertions, 71 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.cc b/security/nss/gtests/ssl_gtest/tls_filter.cc
index 76d9aaaff..89f201295 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/security/nss/gtests/ssl_gtest/tls_filter.cc
@@ -12,6 +12,7 @@ extern "C" {
#include "libssl_internals.h"
}
+#include <cassert>
#include <iostream>
#include "gtest_utils.h"
#include "tls_agent.h"
@@ -57,17 +58,22 @@ void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
if (g_ssl_gtest_verbose) {
- std::cerr << "Cipher spec changed. Role="
- << (isServer ? "server" : "client")
- << " direction=" << (sending ? "send" : "receive") << std::endl;
+ std::cerr << (isServer ? "server" : "client") << ": "
+ << (sending ? "send" : "receive")
+ << " cipher spec changed: " << newSpec->epoch << " ("
+ << newSpec->phase << ")" << std::endl;
+ }
+ if (!sending) {
+ return;
}
- if (!sending) return;
+ self->in_sequence_number_ = 0;
+ self->out_sequence_number_ = 0;
+ self->dropped_record_ = false;
self->cipher_spec_.reset(new TlsCipherSpec());
- bool ret =
- self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
- SSLInt_CipherSpecToKey(isServer, newSpec),
- SSLInt_CipherSpecToIv(isServer, newSpec));
+ bool ret = self->cipher_spec_->Init(
+ SSLInt_CipherSpecToEpoch(newSpec), SSLInt_CipherSpecToAlgorithm(newSpec),
+ SSLInt_CipherSpecToKey(newSpec), SSLInt_CipherSpecToIv(newSpec));
EXPECT_EQ(true, ret);
}
@@ -83,11 +89,23 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
TlsRecordHeader header;
DataBuffer record;
- if (!header.Parse(&parser, &record)) {
+ if (!header.Parse(in_sequence_number_, &parser, &record)) {
ADD_FAILURE() << "not a valid record";
return KEEP;
}
+ // Track the sequence number, which is necessary for stream mode (the
+ // sequence number is in the header for datagram).
+ //
+ // This isn't perfectly robust. If there is a change from an active cipher
+ // spec to another active cipher spec (KeyUpdate for instance) AND writes
+ // are consolidated across that change AND packets were dropped from the
+ // older epoch, we will not correctly re-encrypt records in the old epoch to
+ // update their sequence numbers.
+ if (cipher_spec_ && header.content_type() == kTlsApplicationDataType) {
+ ++in_sequence_number_;
+ }
+
if (FilterRecord(header, record, &offset, output) != KEEP) {
changed = true;
} else {
@@ -120,30 +138,49 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(
header.sequence_number()};
PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
+ // In stream mode, even if something doesn't change we need to re-encrypt if
+ // previous packets were dropped.
if (action == KEEP) {
- return KEEP;
+ if (header.is_dtls() || !dropped_record_) {
+ return KEEP;
+ }
+ filtered = plaintext;
}
if (action == DROP) {
- std::cerr << "record drop: " << record << std::endl;
+ std::cerr << "record drop: " << header << ":" << record << std::endl;
+ dropped_record_ = true;
return DROP;
}
EXPECT_GT(0x10000U, filtered.len());
- std::cerr << "record old: " << plaintext << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
+ if (action != KEEP) {
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ }
+
+ uint64_t seq_num;
+ if (header.is_dtls() || !cipher_spec_ ||
+ header.content_type() != kTlsApplicationDataType) {
+ seq_num = header.sequence_number();
+ } else {
+ seq_num = out_sequence_number_++;
+ }
+ TlsRecordHeader out_header = {header.version(), header.content_type(),
+ seq_num};
DataBuffer ciphertext;
- bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
+ bool rv = Protect(out_header, inner_content_type, filtered, &ciphertext);
EXPECT_TRUE(rv);
if (!rv) {
return KEEP;
}
- *offset = header.Write(output, *offset, ciphertext);
+ *offset = out_header.Write(output, *offset, ciphertext);
return CHANGE;
}
-bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+bool TlsRecordHeader::Parse(uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
@@ -154,7 +191,7 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
}
version_ = version;
- sequence_number_ = 0;
+ // If this is DTLS, overwrite the sequence number.
if (IsDtls(version)) {
uint32_t tmp;
if (!parser->Read(&tmp, 4)) {
@@ -165,6 +202,8 @@ bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return false;
}
sequence_number_ |= static_cast<uint64_t>(tmp);
+ } else {
+ sequence_number_ = sequence_number;
}
return parser->ReadVariable(body, 2);
}
@@ -193,7 +232,9 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
return true;
}
- if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;
+ if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) {
+ return false;
+ }
size_t len = plaintext->len();
while (len > 0 && !plaintext->data()[len - 1]) {
@@ -206,6 +247,11 @@ bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
*inner_content_type = plaintext->data()[len - 1];
plaintext->Truncate(len - 1);
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "unprotect: " << std::hex << header.sequence_number()
+ << std::dec << " type=" << static_cast<int>(*inner_content_type)
+ << " " << *plaintext << std::endl;
+ }
return true;
}
@@ -218,16 +264,44 @@ bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
*ciphertext = plaintext;
return true;
}
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "protect: " << header.sequence_number() << std::endl;
+ }
DataBuffer padded = plaintext;
padded.Write(padded.len(), inner_content_type, 1);
return cipher_spec_->Protect(header, padded, ciphertext);
}
+bool IsHelloRetry(const DataBuffer& body) {
+ static const uint8_t ssl_hello_retry_random[] = {
+ 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11, 0xBE, 0x1D, 0x8C,
+ 0x02, 0x1E, 0x65, 0xB8, 0x91, 0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB,
+ 0x8C, 0x5E, 0x07, 0x9E, 0x09, 0xE2, 0xC8, 0xA8, 0x33, 0x9C};
+ return memcmp(body.data() + 2, ssl_hello_retry_random,
+ sizeof(ssl_hello_retry_random)) == 0;
+}
+
+bool TlsHandshakeFilter::IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& body) {
+ if (handshake_types_.empty()) {
+ return true;
+ }
+
+ uint8_t type = header.handshake_type();
+ if (type == kTlsHandshakeServerHello) {
+ if (IsHelloRetry(body)) {
+ type = kTlsHandshakeHelloRetryRequest;
+ }
+ }
+ return handshake_types_.count(type) > 0U;
+}
+
PacketFilter::Action TlsHandshakeFilter::FilterRecord(
const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
// Check that the first byte is as requested.
- if (record_header.content_type() != kTlsHandshakeType) {
+ if ((record_header.content_type() != kTlsHandshakeType) &&
+ (record_header.content_type() != kTlsAltHandshakeType)) {
return KEEP;
}
@@ -239,12 +313,29 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
while (parser.remaining()) {
HandshakeHeader header;
DataBuffer handshake;
- if (!header.Parse(&parser, record_header, &handshake)) {
+ bool complete = false;
+ if (!header.Parse(&parser, record_header, preceding_fragment_, &handshake,
+ &complete)) {
return KEEP;
}
+ if (!complete) {
+ EXPECT_TRUE(record_header.is_dtls());
+ // Save the fragment and drop it from this record. Fragments are
+ // coalesced with the last fragment of the handshake message.
+ changed = true;
+ preceding_fragment_.Assign(handshake);
+ continue;
+ }
+ preceding_fragment_.Truncate(0);
+
DataBuffer filtered;
- PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+ PacketFilter::Action action;
+ if (!IsFilteredType(header, handshake)) {
+ action = KEEP;
+ } else {
+ action = FilterHandshake(header, handshake, &filtered);
+ }
if (action == DROP) {
changed = true;
std::cerr << "handshake drop: " << handshake << std::endl;
@@ -258,6 +349,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
std::cerr << "handshake old: " << handshake << std::endl;
std::cerr << "handshake new: " << filtered << std::endl;
source = &filtered;
+ } else if (preceding_fragment_.len()) {
+ changed = true;
}
offset = header.Write(output, offset, *source);
@@ -267,12 +360,16 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
}
bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
- TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) {
- if (!parser->Read(length, 3)) {
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset,
+ uint32_t* length, bool* last_fragment) {
+ uint32_t message_length;
+ if (!parser->Read(&message_length, 3)) {
return false; // malformed
}
if (!header.is_dtls()) {
+ *last_fragment = true;
+ *length = message_length;
return true; // nothing left to do
}
@@ -283,32 +380,50 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
}
message_seq_ = message_seq_tmp;
- uint32_t fragment_offset;
- if (!parser->Read(&fragment_offset, 3)) {
+ uint32_t offset = 0;
+ if (!parser->Read(&offset, 3)) {
+ return false;
+ }
+ // We only parse if the fragments are all complete and in order.
+ if (offset != expected_offset) {
+ EXPECT_NE(0U, header.epoch())
+ << "Received out of order handshake fragment for epoch 0";
return false;
}
- uint32_t fragment_length;
- if (!parser->Read(&fragment_length, 3)) {
+ // For DTLS, we return the length of just this fragment.
+ if (!parser->Read(length, 3)) {
return false;
}
- // All current tests where we are using this code don't fragment.
- return (fragment_offset == 0 && fragment_length == *length);
+ // It's a fragment if the entire message is longer than what we have.
+ *last_fragment = message_length == (*length + offset);
+ return true;
}
bool TlsHandshakeFilter::HandshakeHeader::Parse(
- TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) {
+ TlsParser* parser, const TlsRecordHeader& record_header,
+ const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete) {
+ *complete = false;
+
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
}
+
uint32_t length;
- if (!ReadLength(parser, record_header, &length)) {
+ if (!ReadLength(parser, record_header, preceding_fragment.len(), &length,
+ complete)) {
return false;
}
- return parser->Read(body, length);
+ if (!parser->Read(body, length)) {
+ return false;
+ }
+ if (preceding_fragment.len()) {
+ body->Splice(preceding_fragment, 0);
+ }
+ return true;
}
size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
@@ -345,20 +460,23 @@ PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
return KEEP;
}
- if (header.handshake_type() == handshake_type_) {
- buffer_ = input;
- }
+ buffer_ = input;
return KEEP;
}
PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == handshake_type_) {
- *output = buffer_;
- return CHANGE;
- }
+ *output = buffer_;
+ return CHANGE;
+}
+PacketFilter::Action TlsRecordRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (!filter_ || (header.content_type() == ct_)) {
+ records_.push_back({header, input});
+ }
return KEEP;
}
@@ -369,15 +487,30 @@ PacketFilter::Action TlsConversationRecorder::FilterRecord(
return KEEP;
}
+PacketFilter::Action TlsHeaderRecorder::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ headers_.push_back(header);
+ return KEEP;
+}
+
+const TlsRecordHeader* TlsHeaderRecorder::header(size_t index) {
+ if (index > headers_.size() + 1) {
+ return nullptr;
+ }
+ return &headers_[index];
+}
+
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
bool changed = false;
for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- PacketFilter::Action action = (*it)->Filter(in, output);
+ PacketFilter::Action action = (*it)->Process(in, output);
if (action == DROP) {
return DROP;
}
+
if (action == CHANGE) {
in = *output;
changed = true;
@@ -430,15 +563,6 @@ bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
-static bool FindHelloRetryExtensions(TlsParser* parser,
- const TlsVersioned& header) {
- // TODO for -19 add cipher suite
- if (!parser->Skip(2)) { // version
- return false;
- }
- return true;
-}
-
bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
return true;
}
@@ -448,13 +572,6 @@ static bool FindCertReqExtensions(TlsParser* parser,
if (!parser->SkipVariable(1)) { // request context
return false;
}
- // TODO remove the next two for -19
- if (!parser->SkipVariable(2)) { // signature_algorithms
- return false;
- }
- if (!parser->SkipVariable(2)) { // certificate_authorities
- return false;
- }
return true;
}
@@ -478,6 +595,9 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
if (!parser->Skip(8)) { // lifetime, age add
return false;
}
+ if (!parser->SkipVariable(1)) { // ticket_nonce
+ return false;
+ }
if (!parser->SkipVariable(2)) { // ticket
return false;
}
@@ -487,7 +607,6 @@ static bool FindNewSessionTicketExtensions(TlsParser* parser,
static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
{kTlsHandshakeClientHello, FindClientHelloExtensions},
{kTlsHandshakeServerHello, FindServerHelloExtensions},
- {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
{kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
{kTlsHandshakeCertificateRequest, FindCertReqExtensions},
{kTlsHandshakeCertificate, FindCertificateExtensions},
@@ -505,10 +624,6 @@ bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
PacketFilter::Action TlsExtensionFilter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (handshake_types_.count(header.handshake_type()) == 0) {
- return KEEP;
- }
-
TlsParser parser(input);
if (!FindExtensions(&parser, header)) {
return KEEP;
@@ -610,6 +725,38 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension(
return KEEP;
}
+PacketFilter::Action TlsExtensionInjector::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ size_t offset = parser.consumed();
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+}
+
PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
@@ -628,10 +775,8 @@ PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
- EXPECT_EQ(SECSuccess,
- SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
- }
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
return KEEP;
}
@@ -643,15 +788,49 @@ PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
return ((1 << counter_++) & pattern_) ? DROP : KEEP;
}
+PacketFilter::Action SelectiveRecordDropFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& data,
+ DataBuffer* changed) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+/* static */ uint32_t SelectiveRecordDropFilter::ToPattern(
+ std::initializer_list<size_t> records) {
+ uint32_t pattern = 0;
+ for (auto it = records.begin(); it != records.end(); ++it) {
+ EXPECT_GT(32U, *it);
+ assert(*it < 32U);
+ pattern |= 1 << *it;
+ }
+ return pattern;
+}
+
PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
const HandshakeHeader& header, const DataBuffer& input,
DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- *output = input;
- output->Write(0, version_, 2);
- return CHANGE;
- }
- return KEEP;
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+}
+
+PacketFilter::Action SelectedCipherSuiteReplacer::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ // Cipher suite is after version(2) and random(32).
+ size_t pos = 34;
+ if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In old versions, we have to skip a session_id too.
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+ }
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
}
} // namespace nss_test