diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.h | 192 |
1 files changed, 169 insertions, 23 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index e4030e23f..1db3b90f6 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -50,10 +50,13 @@ class TlsRecordHeader : public TlsVersioned { uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } - size_t header_length() const { return is_dtls() ? 11 : 3; } + uint16_t epoch() const { + return static_cast<uint16_t>(sequence_number_ >> 48); + } + size_t header_length() const { return is_dtls() ? 13 : 5; } // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(TlsParser* parser, DataBuffer* body); + bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; @@ -63,10 +66,21 @@ class TlsRecordHeader : public TlsVersioned { uint64_t sequence_number_; }; +struct TlsRecord { + const TlsRecordHeader header; + const DataBuffer buffer; +}; + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} + TlsRecordFilter() + : agent_(nullptr), + count_(0), + cipher_spec_(), + dropped_record_(false), + in_sequence_number_(0), + out_sequence_number_(0) {} void SetAgent(const TlsAgent* agent) { agent_ = agent; } const TlsAgent* agent() const { return agent_; } @@ -115,14 +129,21 @@ class TlsRecordFilter : public PacketFilter { const TlsAgent* agent_; size_t count_; std::unique_ptr<TlsCipherSpec> cipher_spec_; + // Whether we dropped a record since the cipher spec changed. + bool dropped_record_; + // The sequence number we use for reading records as they are written. + uint64_t in_sequence_number_; + // The sequence number we use for writing modified records. + uint64_t out_sequence_number_; }; -inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { +inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { v.WriteStream(stream); return stream; } -inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { +inline std::ostream& operator<<(std::ostream& stream, + const TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { @@ -133,13 +154,17 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { stream << "Alert"; break; case kTlsHandshakeType: + case kTlsAltHandshakeType: stream << "Handshake"; break; case kTlsApplicationDataType: stream << "Data"; break; + case kTlsAckType: + stream << "ACK"; + break; default: - stream << '<' << hdr.content_type() << '>'; + stream << '<' << static_cast<int>(hdr.content_type()) << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; @@ -150,7 +175,16 @@ inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() {} + TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::set<uint8_t>& types) + : handshake_types_(types), preceding_fragment_() {} + + // This filter can be set to be selective based on handshake message type. If + // this function isn't used (or the set is empty), then all handshake messages + // will be filtered. + void SetHandshakeTypes(const std::set<uint8_t>& types) { + handshake_types_ = types; + } class HandshakeHeader : public TlsVersioned { public: @@ -158,7 +192,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, - DataBuffer* body); + const DataBuffer& preceding_fragment, DataBuffer* body, + bool* complete); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, @@ -169,7 +204,8 @@ class TlsHandshakeFilter : public TlsRecordFilter { // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, - uint32_t* length); + uint32_t expected_offset, uint32_t* length, + bool* last_fragment); uint8_t handshake_type_; uint16_t message_seq_; @@ -185,22 +221,30 @@ class TlsHandshakeFilter : public TlsRecordFilter { DataBuffer* output) = 0; private: + bool IsFilteredType(const HandshakeHeader& header, + const DataBuffer& handshake); + + std::set<uint8_t> handshake_types_; + DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : handshake_type_(handshake_type), buffer_() {} + : TlsHandshakeFilter({handshake_type}), buffer_() {} + TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types) + : TlsHandshakeFilter(handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); + void Reset() { buffer_.Truncate(0); } + const DataBuffer& buffer() const { return buffer_; } private: - uint8_t handshake_type_; DataBuffer buffer_; }; @@ -209,17 +253,39 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, const DataBuffer& replacement) - : handshake_type_(handshake_type), buffer_(replacement) {} + : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: - uint8_t handshake_type_; DataBuffer buffer_; }; +// Make a copy of each record of a given type. +class TlsRecordRecorder : public TlsRecordFilter { + public: + TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} + TlsRecordRecorder() + : filter_(false), + ct_(content_handshake), // dummy (<optional> is C++14) + records_() {} + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + + size_t count() const { return records_.size(); } + void Clear() { records_.clear(); } + + const TlsRecord& record(size_t i) const { return records_[i]; } + + private: + bool filter_; + uint8_t ct_; + std::vector<TlsRecord> records_; +}; + // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: @@ -230,15 +296,31 @@ class TlsConversationRecorder : public TlsRecordFilter { DataBuffer* output); private: - DataBuffer& buffer_; + DataBuffer buffer_; }; +// Make a copy of the records +class TlsHeaderRecorder : public TlsRecordFilter { + public: + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& input, + DataBuffer* output); + const TlsRecordHeader* header(size_t index); + + private: + std::vector<TlsRecordHeader> headers_; +}; + +typedef std::initializer_list<std::shared_ptr<PacketFilter>> + ChainedPacketFilterInit; + // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) : filters_(filters.begin(), filters.end()) {} + ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, @@ -256,13 +338,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() : handshake_types_() { - handshake_types_.insert(kTlsHandshakeClientHello); - handshake_types_.insert(kTlsHandshakeServerHello); - } + TlsExtensionFilter() + : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, + kTlsHandshakeHelloRetryRequest, + kTlsHandshakeEncryptedExtensions}) {} TlsExtensionFilter(const std::set<uint8_t>& types) - : handshake_types_(types) {} + : TlsHandshakeFilter(types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -279,8 +361,6 @@ class TlsExtensionFilter : public TlsHandshakeFilter { PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); - - std::set<uint8_t> handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { @@ -326,6 +406,21 @@ class TlsExtensionDropper : public TlsExtensionFilter { uint16_t extension_; }; +class TlsExtensionInjector : public TlsHandshakeFilter { + public: + TlsExtensionInjector(uint16_t ext, const DataBuffer& data) + : extension_(ext), data_(data) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + const uint16_t extension_; + const DataBuffer data_; +}; + class TlsAgent; typedef std::function<void(void)> VoidFunction; @@ -352,7 +447,7 @@ class AfterRecordN : public TlsRecordFilter { class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) - : server_(server) {} + : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -377,10 +472,47 @@ class SelectiveDropFilter : public PacketFilter { uint8_t counter_; }; +// This class selectively drops complete records. The difference from +// SelectiveDropFilter is that if multiple DTLS records are in the same +// datagram, we just drop one. +class SelectiveRecordDropFilter : public TlsRecordFilter { + public: + SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) + : pattern_(pattern), counter_(0) { + if (!enabled) { + Disable(); + } + } + SelectiveRecordDropFilter(std::initializer_list<size_t> records) + : SelectiveRecordDropFilter(ToPattern(records), true) {} + + void Reset(uint32_t pattern) { + counter_ = 0; + PacketFilter::Enable(); + pattern_ = pattern; + } + + void Reset(std::initializer_list<size_t> records) { + Reset(ToPattern(records)); + } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override; + + private: + static uint32_t ToPattern(std::initializer_list<size_t> records); + + uint32_t pattern_; + uint8_t counter_; +}; + // Set the version number in the ClientHello. class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {} + TlsInspectorClientHelloVersionSetter(uint16_t version) + : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -411,6 +543,20 @@ class TlsLastByteDamager : public TlsHandshakeFilter { uint8_t type_; }; +class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { + public: + SelectedCipherSuiteReplacer(uint16_t suite) + : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} + + protected: + PacketFilter::Action FilterHandshake(const HandshakeHeader& header, + const DataBuffer& input, + DataBuffer* output) override; + + private: + uint16_t cipher_suite_; +}; + } // namespace nss_test #endif |