/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #ifndef tls_filter_h_ #define tls_filter_h_ #include #include #include #include #include "test_io.h" #include "tls_parser.h" #include "tls_protect.h" extern "C" { #include "libssl_internals.h" } namespace nss_test { class TlsCipherSpec; class TlsAgent; class TlsVersioned { public: TlsVersioned() : version_(0) {} explicit TlsVersioned(uint16_t version) : version_(version) {} bool is_dtls() const { return IsDtls(version_); } uint16_t version() const { return version_; } void WriteStream(std::ostream& stream) const; protected: uint16_t version_; }; class TlsRecordHeader : public TlsVersioned { public: TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {} TlsRecordHeader(uint16_t version, uint8_t content_type, uint64_t sequence_number) : TlsVersioned(version), content_type_(content_type), sequence_number_(sequence_number) {} uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } uint16_t epoch() const { return static_cast(sequence_number_ >> 48); } size_t header_length() const { return is_dtls() ? 13 : 5; } // Parse the header; return true if successful; body in an outparam if OK. bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; private: uint8_t content_type_; uint64_t sequence_number_; }; struct TlsRecord { const TlsRecordHeader header; const DataBuffer buffer; }; // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_(), dropped_record_(false), in_sequence_number_(0), out_sequence_number_(0) {} void SetAgent(const TlsAgent* agent) { agent_ = agent; } const TlsAgent* agent() const { return agent_; } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Report how many packets were altered by the filter. size_t filtered_packets() const { return count_; } // Enable decryption. This only works properly for TLS 1.3 and above. // Enabling it for lower version tests will cause undefined // behavior. void EnableDecryption(); bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, uint8_t* inner_content_type, DataBuffer* plaintext); bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type, const DataBuffer& plaintext, DataBuffer* ciphertext); protected: // There are two filter functions which can be overriden. Both are // called with the header and the record but the outer one is called // with a raw pointer to let you write into the buffer and lets you // do anything with this section of the stream. The inner one // just lets you change the record contents. By default, the // outer one calls the inner one, so if you override the outer // one, the inner one is never called unless you call it yourself. virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output); // The record filter receives the record contentType, version and DTLS // sequence number (which is zero for TLS), plus the existing record payload. // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` // outparam with the new record contents if it chooses to CHANGE the record. virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { return KEEP; } private: static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); const TlsAgent* agent_; size_t count_; std::unique_ptr cipher_spec_; // Whether we dropped a record since the cipher spec changed. bool dropped_record_; // The sequence number we use for reading records as they are written. uint64_t in_sequence_number_; // The sequence number we use for writing modified records. uint64_t out_sequence_number_; }; inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) { v.WriteStream(stream); return stream; } inline std::ostream& operator<<(std::ostream& stream, const TlsRecordHeader& hdr) { hdr.WriteStream(stream); stream << ' '; switch (hdr.content_type()) { case kTlsChangeCipherSpecType: stream << "CCS"; break; case kTlsAlertType: stream << "Alert"; break; case kTlsHandshakeType: case kTlsAltHandshakeType: stream << "Handshake"; break; case kTlsApplicationDataType: stream << "Data"; break; case kTlsAckType: stream << "ACK"; break; default: stream << '<' << static_cast(hdr.content_type()) << '>'; break; } return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; } // Abstract filter that operates on handshake messages rather than records. // This assumes that the handshake messages are written in a block as entire // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} TlsHandshakeFilter(const std::set& types) : handshake_types_(types), preceding_fragment_() {} // This filter can be set to be selective based on handshake message type. If // this function isn't used (or the set is empty), then all handshake messages // will be filtered. void SetHandshakeTypes(const std::set& types) { handshake_types_ = types; } class HandshakeHeader : public TlsVersioned { public: HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} uint8_t handshake_type() const { return handshake_type_; } bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, const DataBuffer& preceding_fragment, DataBuffer* body, bool* complete); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; size_t WriteFragment(DataBuffer* buffer, size_t offset, const DataBuffer& body, size_t fragment_offset, size_t fragment_length) const; private: // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, uint32_t expected_offset, uint32_t* length, bool* last_fragment); uint8_t handshake_type_; uint16_t message_seq_; // fragment_offset is always zero in these tests. }; protected: virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) = 0; private: bool IsFilteredType(const HandshakeHeader& header, const DataBuffer& handshake); std::set handshake_types_; DataBuffer preceding_fragment_; }; // Make a copy of the first instance of a handshake message. class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) : TlsHandshakeFilter({handshake_type}), buffer_() {} TlsInspectorRecordHandshakeMessage(const std::set& handshake_types) : TlsHandshakeFilter(handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); void Reset() { buffer_.Truncate(0); } const DataBuffer& buffer() const { return buffer_; } private: DataBuffer buffer_; }; // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, const DataBuffer& replacement) : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: DataBuffer buffer_; }; // Make a copy of each record of a given type. class TlsRecordRecorder : public TlsRecordFilter { public: TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} TlsRecordRecorder() : filter_(false), ct_(content_handshake), // dummy ( is C++14) records_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); size_t count() const { return records_.size(); } void Clear() { records_.clear(); } const TlsRecord& record(size_t i) const { return records_[i]; } private: bool filter_; uint8_t ct_; std::vector records_; }; // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); private: DataBuffer buffer_; }; // Make a copy of the records class TlsHeaderRecorder : public TlsRecordFilter { public: virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); const TlsRecordHeader* header(size_t index); private: std::vector headers_; }; typedef std::initializer_list> ChainedPacketFilterInit; // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} ChainedPacketFilter(const std::vector> filters) : filters_(filters.begin(), filters.end()) {} ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {} virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Takes ownership of the filter. void Add(std::shared_ptr filter) { filters_.push_back(filter); } private: std::vector> filters_; }; typedef std::function TlsExtensionFinder; class TlsExtensionFilter : public TlsHandshakeFilter { public: TlsExtensionFilter() : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, kTlsHandshakeHelloRetryRequest, kTlsHandshakeEncryptedExtensions}) {} TlsExtensionFilter(const std::set& types) : TlsHandshakeFilter(types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; virtual PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) = 0; private: PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); }; class TlsExtensionCapture : public TlsExtensionFilter { public: TlsExtensionCapture(uint16_t ext, bool last = false) : extension_(ext), captured_(false), last_(last), data_() {} const DataBuffer& extension() const { return data_; } bool captured() const { return captured_; } protected: PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; bool captured_; bool last_; DataBuffer data_; }; class TlsExtensionReplacer : public TlsExtensionFilter { public: TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) : extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; const DataBuffer data_; }; class TlsExtensionDropper : public TlsExtensionFilter { public: TlsExtensionDropper(uint16_t extension) : extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; private: uint16_t extension_; }; class TlsExtensionInjector : public TlsHandshakeFilter { public: TlsExtensionInjector(uint16_t ext, const DataBuffer& data) : extension_(ext), data_(data) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: const uint16_t extension_; const DataBuffer data_; }; class TlsAgent; typedef std::function VoidFunction; class AfterRecordN : public TlsRecordFilter { public: AfterRecordN(std::shared_ptr& src, std::shared_ptr& dest, unsigned int record, VoidFunction func) : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: std::weak_ptr src_; std::weak_ptr dest_; unsigned int record_; VoidFunction func_; unsigned int counter_; }; // When we see the ClientKeyExchange from |client|, increment the // ClientHelloVersion on |server|. class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: TlsInspectorClientHelloVersionChanger(std::shared_ptr& server) : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: std::weak_ptr server_; }; // This class selectively drops complete writes. This relies on the fact that // writes in libssl are on record boundaries. class SelectiveDropFilter : public PacketFilter { public: SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {} protected: virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output) override; private: const uint32_t pattern_; uint8_t counter_; }; // This class selectively drops complete records. The difference from // SelectiveDropFilter is that if multiple DTLS records are in the same // datagram, we just drop one. class SelectiveRecordDropFilter : public TlsRecordFilter { public: SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) : pattern_(pattern), counter_(0) { if (!enabled) { Disable(); } } SelectiveRecordDropFilter(std::initializer_list records) : SelectiveRecordDropFilter(ToPattern(records), true) {} void Reset(uint32_t pattern) { counter_ = 0; PacketFilter::Enable(); pattern_ = pattern; } void Reset(std::initializer_list records) { Reset(ToPattern(records)); } protected: PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) override; private: static uint32_t ToPattern(std::initializer_list records); uint32_t pattern_; uint8_t counter_; }; // Set the version number in the ClientHello. class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { public: TlsInspectorClientHelloVersionSetter(uint16_t version) : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: uint16_t version_; }; // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: TlsLastByteDamager(uint8_t type) : type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { if (header.handshake_type() != type_) { return KEEP; } *output = input; output->data()[output->len() - 1]++; return CHANGE; } private: uint8_t type_; }; class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { public: SelectedCipherSuiteReplacer(uint16_t suite) : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override; private: uint16_t cipher_suite_; }; } // namespace nss_test #endif