diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.h | 192 |
1 files changed, 23 insertions, 169 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index 1db3b90f6..e4030e23f 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -50,13 +50,10 @@ class TlsRecordHeader : public TlsVersioned { uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } - uint16_t epoch() const { - return static_cast<uint16_t>(sequence_number_ >> 48); - } - size_t header_length() const { return is_dtls() ? 13 : 5; } + size_t header_length() const { return is_dtls() ? 11 : 3; } // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); + bool Parse(TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; @@ -66,21 +63,10 @@ class TlsRecordHeader : public TlsVersioned { uint64_t sequence_number_; }; -struct TlsRecord { - const TlsRecordHeader header; - const DataBuffer buffer; -}; - // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() - : agent_(nullptr), - count_(0), - cipher_spec_(), - dropped_record_(false), - in_sequence_number_(0), - out_sequence_number_(0) {} + TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} void SetAgent(const TlsAgent* agent) { agent_ = agent; } const TlsAgent* agent() const { return agent_; } @@ -129,21 +115,14 @@ class TlsRecordFilter : public PacketFilter { const TlsAgent* agent_; size_t count_; std::unique_ptr<TlsCipherSpec> cipher_spec_; - // Whether we dropped a record since the cipher spec changed. - bool dropped_record_; - // The sequence number we use for reading records as they are written. - uint64_t in_sequence_number_; - // The sequence number we use for writing modified records. - uint64_t out_sequence_number_; }; -inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { +inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { v.WriteStream(stream); return stream; } -inline std::ostream& operator<<(std::ostream& stream, - const TlsRecordHeader& hdr) { +inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { @@ -154,17 +133,13 @@ inline std::ostream& operator<<(std::ostream& stream, stream << "Alert"; break; case kTlsHandshakeType: - case kTlsAltHandshakeType: stream << "Handshake"; break; case kTlsApplicationDataType: stream << "Data"; break; - case kTlsAckType: - stream << "ACK"; - break; default: - stream << '<' << static_cast<int>(hdr.content_type()) << '>'; + stream << '<' << hdr.content_type() << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; @@ -175,16 +150,7 @@ inline std::ostream& operator<<(std::ostream& stream, // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} - TlsHandshakeFilter(const std::set<uint8_t>& types) - : handshake_types_(types), preceding_fragment_() {} - - // This filter can be set to be selective based on handshake message type. If - // this function isn't used (or the set is empty), then all handshake messages - // will be filtered. - void SetHandshakeTypes(const std::set<uint8_t>& types) { - handshake_types_ = types; - } + TlsHandshakeFilter() {} class HandshakeHeader : public TlsVersioned { public: @@ -192,8 +158,7 @@ class TlsHandshakeFilter : public TlsRecordFilter { uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, - const DataBuffer& preceding_fragment, DataBuffer* body, - bool* complete); + DataBuffer* body); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, @@ -204,8 +169,7 @@ class TlsHandshakeFilter : public TlsRecordFilter { // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, - uint32_t expected_offset, uint32_t* length, - bool* last_fragment); + uint32_t* length); uint8_t handshake_type_; uint16_t message_seq_; @@ -221,30 +185,22 @@ class TlsHandshakeFilter : public TlsRecordFilter { DataBuffer* output) = 0; private: - bool IsFilteredType(const HandshakeHeader& header, - const DataBuffer& handshake); - - std::set<uint8_t> handshake_types_; - DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : TlsHandshakeFilter({handshake_type}), buffer_() {} - TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types) - : TlsHandshakeFilter(handshake_types), buffer_() {} + : handshake_type_(handshake_type), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); - void Reset() { buffer_.Truncate(0); } - const DataBuffer& buffer() const { return buffer_; } private: + uint8_t handshake_type_; DataBuffer buffer_; }; @@ -253,39 +209,17 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, const DataBuffer& replacement) - : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} + : handshake_type_(handshake_type), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: + uint8_t handshake_type_; DataBuffer buffer_; }; -// Make a copy of each record of a given type. -class TlsRecordRecorder : public TlsRecordFilter { - public: - TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} - TlsRecordRecorder() - : filter_(false), - ct_(content_handshake), // dummy (<optional> is C++14) - records_() {} - virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, - const DataBuffer& input, - DataBuffer* output); - - size_t count() const { return records_.size(); } - void Clear() { records_.clear(); } - - const TlsRecord& record(size_t i) const { return records_[i]; } - - private: - bool filter_; - uint8_t ct_; - std::vector<TlsRecord> records_; -}; - // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: @@ -296,31 +230,15 @@ class TlsConversationRecorder : public TlsRecordFilter { DataBuffer* output); private: - DataBuffer buffer_; + DataBuffer& buffer_; }; -// Make a copy of the records -class TlsHeaderRecorder : public TlsRecordFilter { - public: - virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, - const DataBuffer& input, - DataBuffer* output); - const TlsRecordHeader* header(size_t index); - - private: - std::vector<TlsRecordHeader> headers_; -}; - -typedef std::initializer_list<std::shared_ptr<PacketFilter>> - ChainedPacketFilterInit; - // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) : filters_(filters.begin(), filters.end()) {} - ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, @@ -338,13 +256,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() - : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, - kTlsHandshakeHelloRetryRequest, - kTlsHandshakeEncryptedExtensions}) {} + TlsExtensionFilter() : handshake_types_() { + handshake_types_.insert(kTlsHandshakeClientHello); + handshake_types_.insert(kTlsHandshakeServerHello); + } TlsExtensionFilter(const std::set<uint8_t>& types) - : TlsHandshakeFilter(types) {} + : handshake_types_(types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -361,6 +279,8 @@ class TlsExtensionFilter : public TlsHandshakeFilter { PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); + + std::set<uint8_t> handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { @@ -406,21 +326,6 @@ class TlsExtensionDropper : public TlsExtensionFilter { uint16_t extension_; }; -class TlsExtensionInjector : public TlsHandshakeFilter { - public: - TlsExtensionInjector(uint16_t ext, const DataBuffer& data) - : extension_(ext), data_(data) {} - - protected: - PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) override; - - private: - const uint16_t extension_; - const DataBuffer data_; -}; - class TlsAgent; typedef std::function<void(void)> VoidFunction; @@ -447,7 +352,7 @@ class AfterRecordN : public TlsRecordFilter { class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) - : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} + : server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -472,47 +377,10 @@ class SelectiveDropFilter : public PacketFilter { uint8_t counter_; }; -// This class selectively drops complete records. The difference from -// SelectiveDropFilter is that if multiple DTLS records are in the same -// datagram, we just drop one. -class SelectiveRecordDropFilter : public TlsRecordFilter { - public: - SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) - : pattern_(pattern), counter_(0) { - if (!enabled) { - Disable(); - } - } - SelectiveRecordDropFilter(std::initializer_list<size_t> records) - : SelectiveRecordDropFilter(ToPattern(records), true) {} - - void Reset(uint32_t pattern) { - counter_ = 0; - PacketFilter::Enable(); - pattern_ = pattern; - } - - void Reset(std::initializer_list<size_t> records) { - Reset(ToPattern(records)); - } - - protected: - PacketFilter::Action FilterRecord(const TlsRecordHeader& header, - const DataBuffer& data, - DataBuffer* changed) override; - - private: - static uint32_t ToPattern(std::initializer_list<size_t> records); - - uint32_t pattern_; - uint8_t counter_; -}; - // Set the version number in the ClientHello. class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) - : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} + TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -543,20 +411,6 @@ class TlsLastByteDamager : public TlsHandshakeFilter { uint8_t type_; }; -class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { - public: - SelectedCipherSuiteReplacer(uint16_t suite) - : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} - - protected: - PacketFilter::Action FilterHandshake(const HandshakeHeader& header, - const DataBuffer& input, - DataBuffer* output) override; - - private: - uint16_t cipher_suite_; -}; - } // namespace nss_test #endif |