diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.h | 231 |
1 files changed, 152 insertions, 79 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index fa2e38785..e4030e23f 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -9,17 +9,67 @@ #include <functional> #include <memory> +#include <set> #include <vector> #include "test_io.h" #include "tls_parser.h" +#include "tls_protect.h" + +extern "C" { +#include "libssl_internals.h" +} namespace nss_test { +class TlsCipherSpec; +class TlsAgent; + +class TlsVersioned { + public: + TlsVersioned() : version_(0) {} + explicit TlsVersioned(uint16_t version) : version_(version) {} + + bool is_dtls() const { return IsDtls(version_); } + uint16_t version() const { return version_; } + + void WriteStream(std::ostream& stream) const; + + protected: + uint16_t version_; +}; + +class TlsRecordHeader : public TlsVersioned { + public: + TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {} + TlsRecordHeader(uint16_t version, uint8_t content_type, + uint64_t sequence_number) + : TlsVersioned(version), + content_type_(content_type), + sequence_number_(sequence_number) {} + + uint8_t content_type() const { return content_type_; } + uint64_t sequence_number() const { return sequence_number_; } + size_t header_length() const { return is_dtls() ? 11 : 3; } + + // Parse the header; return true if successful; body in an outparam if OK. + bool Parse(TlsParser* parser, DataBuffer* body); + // Write the header and body to a buffer at the given offset. + // Return the offset of the end of the write. + size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + + private: + uint8_t content_type_; + uint64_t sequence_number_; +}; + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() : count_(0) {} + TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {} + + void SetAgent(const TlsAgent* agent) { agent_ = agent; } + const TlsAgent* agent() const { return agent_; } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); @@ -27,42 +77,14 @@ class TlsRecordFilter : public PacketFilter { // Report how many packets were altered by the filter. size_t filtered_packets() const { return count_; } - class Versioned { - public: - Versioned() : version_(0) {} - explicit Versioned(uint16_t version) : version_(version) {} - - bool is_dtls() const { return IsDtls(version_); } - uint16_t version() const { return version_; } - - protected: - uint16_t version_; - }; - - class RecordHeader : public Versioned { - public: - RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {} - RecordHeader(uint16_t version, uint8_t content_type, - uint64_t sequence_number) - : Versioned(version), - content_type_(content_type), - sequence_number_(sequence_number) {} - - uint8_t content_type() const { return content_type_; } - uint64_t sequence_number() const { return sequence_number_; } - size_t header_length() const { return is_dtls() ? 11 : 3; } - - // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(TlsParser* parser, DataBuffer* body); - // Write the header and body to a buffer at the given offset. - // Return the offset of the end of the write. - size_t Write(DataBuffer* buffer, size_t offset, - const DataBuffer& body) const; - - private: - uint8_t content_type_; - uint64_t sequence_number_; - }; + // Enable decryption. This only works properly for TLS 1.3 and above. + // Enabling it for lower version tests will cause undefined + // behavior. + void EnableDecryption(); + bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, + uint8_t* inner_content_type, DataBuffer* plaintext); + bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type, + const DataBuffer& plaintext, DataBuffer* ciphertext); protected: // There are two filter functions which can be overriden. Both are @@ -72,7 +94,7 @@ class TlsRecordFilter : public PacketFilter { // just lets you change the record contents. By default, the // outer one calls the inner one, so if you override the outer // one, the inner one is never called unless you call it yourself. - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& record, size_t* offset, DataBuffer* output); @@ -80,16 +102,49 @@ class TlsRecordFilter : public PacketFilter { // sequence number (which is zero for TLS), plus the existing record payload. // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed` // outparam with the new record contents if it chooses to CHANGE the record. - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& data, DataBuffer* changed) { return KEEP; } private: + static void CipherSpecChanged(void* arg, PRBool sending, + ssl3CipherSpec* newSpec); + + const TlsAgent* agent_; size_t count_; + std::unique_ptr<TlsCipherSpec> cipher_spec_; }; +inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) { + v.WriteStream(stream); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) { + hdr.WriteStream(stream); + stream << ' '; + switch (hdr.content_type()) { + case kTlsChangeCipherSpecType: + stream << "CCS"; + break; + case kTlsAlertType: + stream << "Alert"; + break; + case kTlsHandshakeType: + stream << "Handshake"; + break; + case kTlsApplicationDataType: + stream << "Data"; + break; + default: + stream << '<' << hdr.content_type() << '>'; + break; + } + return stream << ' ' << std::hex << hdr.sequence_number() << std::dec; +} + // Abstract filter that operates on handshake messages rather than records. // This assumes that the handshake messages are written in a block as entire // records and that they don't span records or anything crazy like that. @@ -97,20 +152,23 @@ class TlsHandshakeFilter : public TlsRecordFilter { public: TlsHandshakeFilter() {} - class HandshakeHeader : public Versioned { + class HandshakeHeader : public TlsVersioned { public: - HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {} + HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {} uint8_t handshake_type() const { return handshake_type_; } - bool Parse(TlsParser* parser, const RecordHeader& record_header, + bool Parse(TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body); size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + size_t WriteFragment(DataBuffer* buffer, size_t offset, + const DataBuffer& body, size_t fragment_offset, + size_t fragment_length) const; private: // Reads the length from the record header. // This also reads the DTLS fragment information and checks it. - bool ReadLength(TlsParser* parser, const RecordHeader& header, + bool ReadLength(TlsParser* parser, const TlsRecordHeader& header, uint32_t* length); uint8_t handshake_type_; @@ -119,7 +177,7 @@ class TlsHandshakeFilter : public TlsRecordFilter { }; protected: - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -167,7 +225,7 @@ class TlsConversationRecorder : public TlsRecordFilter { public: TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); @@ -175,43 +233,39 @@ class TlsConversationRecorder : public TlsRecordFilter { DataBuffer& buffer_; }; -// Records an alert. If an alert has already been recorded, it won't save the -// new alert unless the old alert is a warning and the new one is fatal. -class TlsAlertRecorder : public TlsRecordFilter { - public: - TlsAlertRecorder() : level_(255), description_(255) {} - - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, - const DataBuffer& input, - DataBuffer* output); - - uint8_t level() const { return level_; } - uint8_t description() const { return description_; } - - private: - uint8_t level_; - uint8_t description_; -}; - // Runs multiple packet filters in series. class ChainedPacketFilter : public PacketFilter { public: ChainedPacketFilter() {} - ChainedPacketFilter(const std::vector<PacketFilter*> filters) + ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters) : filters_(filters.begin(), filters.end()) {} - virtual ~ChainedPacketFilter(); + virtual ~ChainedPacketFilter() {} virtual PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); // Takes ownership of the filter. - void Add(PacketFilter* filter) { filters_.push_back(filter); } + void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); } private: - std::vector<PacketFilter*> filters_; + std::vector<std::shared_ptr<PacketFilter>> filters_; }; +typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> + TlsExtensionFinder; + class TlsExtensionFilter : public TlsHandshakeFilter { + public: + TlsExtensionFilter() : handshake_types_() { + handshake_types_.insert(kTlsHandshakeClientHello); + handshake_types_.insert(kTlsHandshakeServerHello); + } + + TlsExtensionFilter(const std::set<uint8_t>& types) + : handshake_types_(types) {} + + static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); + protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -221,15 +275,12 @@ class TlsExtensionFilter : public TlsHandshakeFilter { const DataBuffer& input, DataBuffer* output) = 0; - public: - static bool FindClientHelloExtensions(TlsParser* parser, - const Versioned& header); - static bool FindServerHelloExtensions(TlsParser* parser); - private: PacketFilter::Action FilterExtensions(TlsParser* parser, const DataBuffer& input, DataBuffer* output); + + std::set<uint8_t> handshake_types_; }; class TlsExtensionCapture : public TlsExtensionFilter { @@ -280,17 +331,17 @@ typedef std::function<void(void)> VoidFunction; class AfterRecordN : public TlsRecordFilter { public: - AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record, - VoidFunction func) + AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest, + unsigned int record, VoidFunction func) : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} - virtual PacketFilter::Action FilterRecord(const RecordHeader& header, + virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: - TlsAgent* src_; - TlsAgent* dest_; + std::weak_ptr<TlsAgent> src_; + std::weak_ptr<TlsAgent> dest_; unsigned int record_; VoidFunction func_; unsigned int counter_; @@ -300,14 +351,15 @@ class AfterRecordN : public TlsRecordFilter { // ClientHelloVersion on |server|. class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {} + TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) + : server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output); private: - TlsAgent* server_; + std::weak_ptr<TlsAgent> server_; }; // This class selectively drops complete writes. This relies on the fact that @@ -338,6 +390,27 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { uint16_t version_; }; +// Damages the last byte of a handshake message. +class TlsLastByteDamager : public TlsHandshakeFilter { + public: + TlsLastByteDamager(uint8_t type) : type_(type) {} + PacketFilter::Action FilterHandshake( + const TlsHandshakeFilter::HandshakeHeader& header, + const DataBuffer& input, DataBuffer* output) override { + if (header.handshake_type() != type_) { + return KEEP; + } + + *output = input; + + output->data()[output->len() - 1]++; + return CHANGE; + } + + private: + uint8_t type_; +}; + } // namespace nss_test #endif |