diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.h | 141 |
1 files changed, 93 insertions, 48 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index 1db3b90f6..1bbe190ab 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -13,6 +13,7 @@ #include <vector> #include "test_io.h" +#include "tls_agent.h" #include "tls_parser.h" #include "tls_protect.h" @@ -23,7 +24,6 @@ extern "C" { namespace nss_test { class TlsCipherSpec; -class TlsAgent; class TlsVersioned { public: @@ -71,19 +71,27 @@ struct TlsRecord { const DataBuffer buffer; }; +// Make a filter and install it on a TlsAgent. +template <class T, typename... Args> +inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent, + Args&&... args) { + auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...); + agent->SetFilter(filter); + return filter; +} + // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter() - : agent_(nullptr), + TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent) + : agent_(agent), count_(0), cipher_spec_(), dropped_record_(false), in_sequence_number_(0), out_sequence_number_(0) {} - void SetAgent(const TlsAgent* agent) { agent_ = agent; } - const TlsAgent* agent() const { return agent_; } + std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); } // External interface. Overrides PacketFilter. PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output); @@ -126,7 +134,7 @@ class TlsRecordFilter : public PacketFilter { static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); - const TlsAgent* agent_; + std::weak_ptr<TlsAgent> agent_; size_t count_; std::unique_ptr<TlsCipherSpec> cipher_spec_; // Whether we dropped a record since the cipher spec changed. @@ -175,9 +183,13 @@ inline std::ostream& operator<<(std::ostream& stream, // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {} - TlsHandshakeFilter(const std::set<uint8_t>& types) - : handshake_types_(types), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& types) + : TlsRecordFilter(agent), + handshake_types_(types), + preceding_fragment_() {} // This filter can be set to be selective based on handshake message type. If // this function isn't used (or the set is empty), then all handshake messages @@ -229,12 +241,14 @@ class TlsHandshakeFilter : public TlsRecordFilter { }; // Make a copy of the first instance of a handshake message. -class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { +class TlsHandshakeRecorder : public TlsHandshakeFilter { public: - TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : TlsHandshakeFilter({handshake_type}), buffer_() {} - TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types) - : TlsHandshakeFilter(handshake_types), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type) + : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& handshake_types) + : TlsHandshakeFilter(agent, handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -251,9 +265,10 @@ class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: - TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, + TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent, + uint8_t handshake_type, const DataBuffer& replacement) - : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {} + : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -266,9 +281,11 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { // Make a copy of each record of a given type. class TlsRecordRecorder : public TlsRecordFilter { public: - TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {} - TlsRecordRecorder() - : filter_(false), + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct) + : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {} + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), + filter_(false), ct_(content_handshake), // dummy (<optional> is C++14) records_() {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, @@ -289,7 +306,9 @@ class TlsRecordRecorder : public TlsRecordFilter { // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: - TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {} + TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent, + DataBuffer& buffer) + : TlsRecordFilter(agent), buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -302,6 +321,8 @@ class TlsConversationRecorder : public TlsRecordFilter { // Make a copy of the records class TlsHeaderRecorder : public TlsRecordFilter { public: + TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); @@ -338,13 +359,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter() - : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello, + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter(agent, + {kTlsHandshakeClientHello, kTlsHandshakeServerHello, kTlsHandshakeHelloRetryRequest, kTlsHandshakeEncryptedExtensions}) {} - TlsExtensionFilter(const std::set<uint8_t>& types) - : TlsHandshakeFilter(types) {} + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent, + const std::set<uint8_t>& types) + : TlsHandshakeFilter(agent, types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -365,8 +388,13 @@ class TlsExtensionFilter : public TlsHandshakeFilter { class TlsExtensionCapture : public TlsExtensionFilter { public: - TlsExtensionCapture(uint16_t ext, bool last = false) - : extension_(ext), captured_(false), last_(last), data_() {} + TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + bool last = false) + : TlsExtensionFilter(agent), + extension_(ext), + captured_(false), + last_(last), + data_() {} const DataBuffer& extension() const { return data_; } bool captured() const { return captured_; } @@ -385,8 +413,9 @@ class TlsExtensionCapture : public TlsExtensionFilter { class TlsExtensionReplacer : public TlsExtensionFilter { public: - TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) - : extension_(extension), data_(data) {} + TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension, const DataBuffer& data) + : TlsExtensionFilter(agent), extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; @@ -398,7 +427,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter { class TlsExtensionDropper : public TlsExtensionFilter { public: - TlsExtensionDropper(uint16_t extension) : extension_(extension) {} + TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent, + uint16_t extension) + : TlsExtensionFilter(agent), extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; @@ -408,8 +439,9 @@ class TlsExtensionDropper : public TlsExtensionFilter { class TlsExtensionInjector : public TlsHandshakeFilter { public: - TlsExtensionInjector(uint16_t ext, const DataBuffer& data) - : extension_(ext), data_(data) {} + TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + const DataBuffer& data) + : TlsHandshakeFilter(agent), extension_(ext), data_(data) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -426,16 +458,20 @@ typedef std::function<void(void)> VoidFunction; class AfterRecordN : public TlsRecordFilter { public: - AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest, - unsigned int record, VoidFunction func) - : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {} + AfterRecordN(const std::shared_ptr<TlsAgent>& src, + const std::shared_ptr<TlsAgent>& dest, unsigned int record, + VoidFunction func) + : TlsRecordFilter(src), + dest_(dest), + record_(record), + func_(func), + counter_(0) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& body, DataBuffer* out) override; private: - std::weak_ptr<TlsAgent> src_; std::weak_ptr<TlsAgent> dest_; unsigned int record_; VoidFunction func_; @@ -444,10 +480,12 @@ class AfterRecordN : public TlsRecordFilter { // When we see the ClientKeyExchange from |client|, increment the // ClientHelloVersion on |server|. -class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { +class TlsClientHelloVersionChanger : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server) - : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {} + TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client, + const std::shared_ptr<TlsAgent>& server) + : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}), + server_(server) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -477,14 +515,16 @@ class SelectiveDropFilter : public PacketFilter { // datagram, we just drop one. class SelectiveRecordDropFilter : public TlsRecordFilter { public: - SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true) - : pattern_(pattern), counter_(0) { + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + uint32_t pattern, bool enabled = true) + : TlsRecordFilter(agent), pattern_(pattern), counter_(0) { if (!enabled) { Disable(); } } - SelectiveRecordDropFilter(std::initializer_list<size_t> records) - : SelectiveRecordDropFilter(ToPattern(records), true) {} + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + std::initializer_list<size_t> records) + : SelectiveRecordDropFilter(agent, ToPattern(records), true) {} void Reset(uint32_t pattern) { counter_ = 0; @@ -509,10 +549,12 @@ class SelectiveRecordDropFilter : public TlsRecordFilter { }; // Set the version number in the ClientHello. -class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { +class TlsClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsInspectorClientHelloVersionSetter(uint16_t version) - : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {} + TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent, + uint16_t version) + : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}), + version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -525,7 +567,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter { // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: - TlsLastByteDamager(uint8_t type) : type_(type) {} + TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type) + : TlsHandshakeFilter(agent), type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -545,8 +588,10 @@ class TlsLastByteDamager : public TlsHandshakeFilter { class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { public: - SelectedCipherSuiteReplacer(uint16_t suite) - : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {} + SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent, + uint16_t suite) + : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + cipher_suite_(suite) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, |