diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_filter.h | 149 |
1 files changed, 88 insertions, 61 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h index 1bbe190ab..effda4aa0 100644 --- a/security/nss/gtests/ssl_gtest/tls_filter.h +++ b/security/nss/gtests/ssl_gtest/tls_filter.h @@ -11,7 +11,7 @@ #include <memory> #include <set> #include <vector> - +#include "sslt.h" #include "test_io.h" #include "tls_agent.h" #include "tls_parser.h" @@ -27,43 +27,57 @@ class TlsCipherSpec; class TlsVersioned { public: - TlsVersioned() : version_(0) {} - explicit TlsVersioned(uint16_t version) : version_(version) {} + TlsVersioned() : variant_(ssl_variant_stream), version_(0) {} + TlsVersioned(SSLProtocolVariant var, uint16_t ver) + : variant_(var), version_(ver) {} - bool is_dtls() const { return IsDtls(version_); } + bool is_dtls() const { return variant_ == ssl_variant_datagram; } + SSLProtocolVariant variant() const { return variant_; } uint16_t version() const { return version_; } void WriteStream(std::ostream& stream) const; protected: + SSLProtocolVariant variant_; uint16_t version_; }; class TlsRecordHeader : public TlsVersioned { public: - TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {} - TlsRecordHeader(uint16_t version, uint8_t content_type, - uint64_t sequence_number) - : TlsVersioned(version), - content_type_(content_type), - sequence_number_(sequence_number) {} + TlsRecordHeader() + : TlsVersioned(), content_type_(0), sequence_number_(0), header_() {} + TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct, + uint64_t seqno) + : TlsVersioned(var, ver), + content_type_(ct), + sequence_number_(seqno), + header_() {} uint8_t content_type() const { return content_type_; } uint64_t sequence_number() const { return sequence_number_; } uint16_t epoch() const { return static_cast<uint16_t>(sequence_number_ >> 48); } - size_t header_length() const { return is_dtls() ? 13 : 5; } + size_t header_length() const; + const DataBuffer& header() const { return header_; } // Parse the header; return true if successful; body in an outparam if OK. - bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body); + bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser, + DataBuffer* body); // Write the header and body to a buffer at the given offset. // Return the offset of the end of the write. size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const; + size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const; private: + static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial, + size_t partial_bits); + static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw, + size_t seq_no_bits, size_t epoch_bits); + uint8_t content_type_; uint64_t sequence_number_; + DataBuffer header_; }; struct TlsRecord { @@ -83,8 +97,8 @@ inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent, // Abstract filter that operates on entire (D)TLS records. class TlsRecordFilter : public PacketFilter { public: - TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent) - : agent_(agent), + TlsRecordFilter(const std::shared_ptr<TlsAgent>& a) + : agent_(a), count_(0), cipher_spec_(), dropped_record_(false), @@ -106,7 +120,8 @@ class TlsRecordFilter : public PacketFilter { bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText, uint8_t* inner_content_type, DataBuffer* plaintext); bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type, - const DataBuffer& plaintext, DataBuffer* ciphertext); + const DataBuffer& plaintext, DataBuffer* ciphertext, + size_t padding = 0); protected: // There are two filter functions which can be overriden. Both are @@ -130,6 +145,8 @@ class TlsRecordFilter : public PacketFilter { return KEEP; } + bool is_dtls13() const; + private: static void CipherSpecChanged(void* arg, PRBool sending, ssl3CipherSpec* newSpec); @@ -183,13 +200,11 @@ inline std::ostream& operator<<(std::ostream& stream, // records and that they don't span records or anything crazy like that. class TlsHandshakeFilter : public TlsRecordFilter { public: - TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent) - : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {} - TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent, + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {} + TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a, const std::set<uint8_t>& types) - : TlsRecordFilter(agent), - handshake_types_(types), - preceding_fragment_() {} + : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {} // This filter can be set to be selective based on handshake message type. If // this function isn't used (or the set is empty), then all handshake messages @@ -243,12 +258,12 @@ class TlsHandshakeFilter : public TlsRecordFilter { // Make a copy of the first instance of a handshake message. class TlsHandshakeRecorder : public TlsHandshakeFilter { public: - TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t handshake_type) - : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {} - TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent, + : TlsHandshakeFilter(a, {handshake_type}), buffer_() {} + TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a, const std::set<uint8_t>& handshake_types) - : TlsHandshakeFilter(agent, handshake_types), buffer_() {} + : TlsHandshakeFilter(a, handshake_types), buffer_() {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -265,10 +280,10 @@ class TlsHandshakeRecorder : public TlsHandshakeFilter { // Replace all instances of a handshake message. class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { public: - TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent, + TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a, uint8_t handshake_type, const DataBuffer& replacement) - : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {} + : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -281,10 +296,10 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { // Make a copy of each record of a given type. class TlsRecordRecorder : public TlsRecordFilter { public: - TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct) - : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {} - TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent) - : TlsRecordFilter(agent), + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct) + : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {} + TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a), filter_(false), ct_(content_handshake), // dummy (<optional> is C++14) records_() {} @@ -306,9 +321,9 @@ class TlsRecordRecorder : public TlsRecordFilter { // Make a copy of the complete conversation. class TlsConversationRecorder : public TlsRecordFilter { public: - TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent, + TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a, DataBuffer& buffer) - : TlsRecordFilter(agent), buffer_(buffer) {} + : TlsRecordFilter(a), buffer_(buffer) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -321,8 +336,7 @@ class TlsConversationRecorder : public TlsRecordFilter { // Make a copy of the records class TlsHeaderRecorder : public TlsRecordFilter { public: - TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent) - : TlsRecordFilter(agent) {} + TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {} virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, DataBuffer* output); @@ -359,15 +373,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)> class TlsExtensionFilter : public TlsHandshakeFilter { public: - TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent) - : TlsHandshakeFilter(agent, + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a) + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello, kTlsHandshakeServerHello, kTlsHandshakeHelloRetryRequest, kTlsHandshakeEncryptedExtensions}) {} - TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent, + TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a, const std::set<uint8_t>& types) - : TlsHandshakeFilter(agent, types) {} + : TlsHandshakeFilter(a, types) {} static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header); @@ -388,9 +402,9 @@ class TlsExtensionFilter : public TlsHandshakeFilter { class TlsExtensionCapture : public TlsExtensionFilter { public: - TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext, bool last = false) - : TlsExtensionFilter(agent), + : TlsExtensionFilter(a), extension_(ext), captured_(false), last_(last), @@ -413,9 +427,9 @@ class TlsExtensionCapture : public TlsExtensionFilter { class TlsExtensionReplacer : public TlsExtensionFilter { public: - TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent, - uint16_t extension, const DataBuffer& data) - : TlsExtensionFilter(agent), extension_(extension), data_(data) {} + TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension, + const DataBuffer& data) + : TlsExtensionFilter(a), extension_(extension), data_(data) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer& input, DataBuffer* output) override; @@ -427,9 +441,8 @@ class TlsExtensionReplacer : public TlsExtensionFilter { class TlsExtensionDropper : public TlsExtensionFilter { public: - TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent, - uint16_t extension) - : TlsExtensionFilter(agent), extension_(extension) {} + TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension) + : TlsExtensionFilter(a), extension_(extension) {} PacketFilter::Action FilterExtension(uint16_t extension_type, const DataBuffer&, DataBuffer*) override; @@ -439,9 +452,9 @@ class TlsExtensionDropper : public TlsExtensionFilter { class TlsExtensionInjector : public TlsHandshakeFilter { public: - TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext, + TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext, const DataBuffer& data) - : TlsHandshakeFilter(agent), extension_(ext), data_(data) {} + : TlsHandshakeFilter(a), extension_(ext), data_(data) {} protected: PacketFilter::Action FilterHandshake(const HandshakeHeader& header, @@ -453,7 +466,6 @@ class TlsExtensionInjector : public TlsHandshakeFilter { const DataBuffer data_; }; -class TlsAgent; typedef std::function<void(void)> VoidFunction; class AfterRecordN : public TlsRecordFilter { @@ -495,6 +507,22 @@ class TlsClientHelloVersionChanger : public TlsHandshakeFilter { std::weak_ptr<TlsAgent> server_; }; +// Damage a record. +class TlsRecordLastByteDamager : public TlsRecordFilter { + public: + TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a) + : TlsRecordFilter(a) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + *changed = data; + changed->data()[changed->len() - 1]++; + return CHANGE; + } +}; + // This class selectively drops complete writes. This relies on the fact that // writes in libssl are on record boundaries. class SelectiveDropFilter : public PacketFilter { @@ -515,16 +543,16 @@ class SelectiveDropFilter : public PacketFilter { // datagram, we just drop one. class SelectiveRecordDropFilter : public TlsRecordFilter { public: - SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, uint32_t pattern, bool enabled = true) - : TlsRecordFilter(agent), pattern_(pattern), counter_(0) { + : TlsRecordFilter(a), pattern_(pattern), counter_(0) { if (!enabled) { Disable(); } } - SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent, + SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a, std::initializer_list<size_t> records) - : SelectiveRecordDropFilter(agent, ToPattern(records), true) {} + : SelectiveRecordDropFilter(a, ToPattern(records), true) {} void Reset(uint32_t pattern) { counter_ = 0; @@ -551,10 +579,9 @@ class SelectiveRecordDropFilter : public TlsRecordFilter { // Set the version number in the ClientHello. class TlsClientHelloVersionSetter : public TlsHandshakeFilter { public: - TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent, + TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& a, uint16_t version) - : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}), - version_(version) {} + : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), version_(version) {} virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, @@ -567,8 +594,8 @@ class TlsClientHelloVersionSetter : public TlsHandshakeFilter { // Damages the last byte of a handshake message. class TlsLastByteDamager : public TlsHandshakeFilter { public: - TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type) - : TlsHandshakeFilter(agent), type_(type) {} + TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type) + : TlsHandshakeFilter(a), type_(type) {} PacketFilter::Action FilterHandshake( const TlsHandshakeFilter::HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { @@ -588,9 +615,9 @@ class TlsLastByteDamager : public TlsHandshakeFilter { class SelectedCipherSuiteReplacer : public TlsHandshakeFilter { public: - SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent, + SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t suite) - : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}), + : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}), cipher_suite_(suite) {} protected: |