summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
authorwolfbeast <mcwerewolf@gmail.com>2018-02-23 11:04:39 +0100
committerwolfbeast <mcwerewolf@gmail.com>2018-06-05 22:24:08 +0200
commite10349ab8dda8a3f11be6aa19f2b6e29fe814044 (patch)
tree1a9b078b06a76af06839d407b7267880890afccc /security/nss/gtests/ssl_gtest/tls_filter.h
parent75b3dd4cbffb6e4534128278300ed6c8a3ab7506 (diff)
downloadUXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar
UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.gz
UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.lz
UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.tar.xz
UXP-e10349ab8dda8a3f11be6aa19f2b6e29fe814044.zip
Update NSS to 3.35-RTM
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h192
1 files changed, 169 insertions, 23 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index e4030e23f..1db3b90f6 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned {
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
- size_t header_length() const { return is_dtls() ? 11 : 3; }
+ uint16_t epoch() const {
+ return static_cast<uint16_t>(sequence_number_ >> 48);
+ }
+ size_t header_length() const { return is_dtls() ? 13 : 5; }
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(TlsParser* parser, DataBuffer* body);
+ bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
@@ -63,10 +66,21 @@ class TlsRecordHeader : public TlsVersioned {
uint64_t sequence_number_;
};
+struct TlsRecord {
+ const TlsRecordHeader header;
+ const DataBuffer buffer;
+};
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
+ TlsRecordFilter()
+ : agent_(nullptr),
+ count_(0),
+ cipher_spec_(),
+ dropped_record_(false),
+ in_sequence_number_(0),
+ out_sequence_number_(0) {}
void SetAgent(const TlsAgent* agent) { agent_ = agent; }
const TlsAgent* agent() const { return agent_; }
@@ -115,14 +129,21 @@ class TlsRecordFilter : public PacketFilter {
const TlsAgent* agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
+ // Whether we dropped a record since the cipher spec changed.
+ bool dropped_record_;
+ // The sequence number we use for reading records as they are written.
+ uint64_t in_sequence_number_;
+ // The sequence number we use for writing modified records.
+ uint64_t out_sequence_number_;
};
-inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
+inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
v.WriteStream(stream);
return stream;
}
-inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsRecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
@@ -133,13 +154,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
stream << "Alert";
break;
case kTlsHandshakeType:
+ case kTlsAltHandshakeType:
stream << "Handshake";
break;
case kTlsApplicationDataType:
stream << "Data";
break;
+ case kTlsAckType:
+ stream << "ACK";
+ break;
default:
- stream << '<' << hdr.content_type() << '>';
+ stream << '<' << static_cast<int>(hdr.content_type()) << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
@@ -150,7 +175,16 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() {}
+ TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::set<uint8_t>& types)
+ : handshake_types_(types), preceding_fragment_() {}
+
+ // This filter can be set to be selective based on handshake message type. If
+ // this function isn't used (or the set is empty), then all handshake messages
+ // will be filtered.
+ void SetHandshakeTypes(const std::set<uint8_t>& types) {
+ handshake_types_ = types;
+ }
class HandshakeHeader : public TlsVersioned {
public:
@@ -158,7 +192,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
uint8_t handshake_type() const { return handshake_type_; }
bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
- DataBuffer* body);
+ const DataBuffer& preceding_fragment, DataBuffer* body,
+ bool* complete);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
@@ -169,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter {
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
- uint32_t* length);
+ uint32_t expected_offset, uint32_t* length,
+ bool* last_fragment);
uint8_t handshake_type_;
uint16_t message_seq_;
@@ -185,22 +221,30 @@ class TlsHandshakeFilter : public TlsRecordFilter {
DataBuffer* output) = 0;
private:
+ bool IsFilteredType(const HandshakeHeader& header,
+ const DataBuffer& handshake);
+
+ std::set<uint8_t> handshake_types_;
+ DataBuffer preceding_fragment_;
};
// Make a copy of the first instance of a handshake message.
class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : handshake_type_(handshake_type), buffer_() {}
+ : TlsHandshakeFilter({handshake_type}), buffer_() {}
+ TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
+ void Reset() { buffer_.Truncate(0); }
+
const DataBuffer& buffer() const { return buffer_; }
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
@@ -209,17 +253,39 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
const DataBuffer& replacement)
- : handshake_type_(handshake_type), buffer_(replacement) {}
+ : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- uint8_t handshake_type_;
DataBuffer buffer_;
};
+// Make a copy of each record of a given type.
+class TlsRecordRecorder : public TlsRecordFilter {
+ public:
+ TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder()
+ : filter_(false),
+ ct_(content_handshake), // dummy (<optional> is C++14)
+ records_() {}
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ size_t count() const { return records_.size(); }
+ void Clear() { records_.clear(); }
+
+ const TlsRecord& record(size_t i) const { return records_[i]; }
+
+ private:
+ bool filter_;
+ uint8_t ct_;
+ std::vector<TlsRecord> records_;
+};
+
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
@@ -230,15 +296,31 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer* output);
private:
- DataBuffer& buffer_;
+ DataBuffer buffer_;
};
+// Make a copy of the records
+class TlsHeaderRecorder : public TlsRecordFilter {
+ public:
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ const TlsRecordHeader* header(size_t index);
+
+ private:
+ std::vector<TlsRecordHeader> headers_;
+};
+
+typedef std::initializer_list<std::shared_ptr<PacketFilter>>
+ ChainedPacketFilterInit;
+
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
+ ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
@@ -256,13 +338,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter() : handshake_types_() {
- handshake_types_.insert(kTlsHandshakeClientHello);
- handshake_types_.insert(kTlsHandshakeServerHello);
- }
+ TlsExtensionFilter()
+ : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ kTlsHandshakeHelloRetryRequest,
+ kTlsHandshakeEncryptedExtensions}) {}
TlsExtensionFilter(const std::set<uint8_t>& types)
- : handshake_types_(types) {}
+ : TlsHandshakeFilter(types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -279,8 +361,6 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
-
- std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
@@ -326,6 +406,21 @@ class TlsExtensionDropper : public TlsExtensionFilter {
uint16_t extension_;
};
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
class TlsAgent;
typedef std::function<void(void)> VoidFunction;
@@ -352,7 +447,7 @@ class AfterRecordN : public TlsRecordFilter {
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : server_(server) {}
+ : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -377,10 +472,47 @@ class SelectiveDropFilter : public PacketFilter {
uint8_t counter_;
};
+// This class selectively drops complete records. The difference from
+// SelectiveDropFilter is that if multiple DTLS records are in the same
+// datagram, we just drop one.
+class SelectiveRecordDropFilter : public TlsRecordFilter {
+ public:
+ SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
+ : pattern_(pattern), counter_(0) {
+ if (!enabled) {
+ Disable();
+ }
+ }
+ SelectiveRecordDropFilter(std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(ToPattern(records), true) {}
+
+ void Reset(uint32_t pattern) {
+ counter_ = 0;
+ PacketFilter::Enable();
+ pattern_ = pattern;
+ }
+
+ void Reset(std::initializer_list<size_t> records) {
+ Reset(ToPattern(records));
+ }
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override;
+
+ private:
+ static uint32_t ToPattern(std::initializer_list<size_t> records);
+
+ uint32_t pattern_;
+ uint8_t counter_;
+};
+
// Set the version number in the ClientHello.
class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
+ TlsInspectorClientHelloVersionSetter(uint16_t version)
+ : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -411,6 +543,20 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
uint8_t type_;
};
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(uint16_t suite)
+ : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ uint16_t cipher_suite_;
+};
+
} // namespace nss_test
#endif