summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h141
1 files changed, 93 insertions, 48 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index 1db3b90f6..1bbe190ab 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -13,6 +13,7 @@
#include <vector>
#include "test_io.h"
+#include "tls_agent.h"
#include "tls_parser.h"
#include "tls_protect.h"
@@ -23,7 +24,6 @@ extern "C" {
namespace nss_test {
class TlsCipherSpec;
-class TlsAgent;
class TlsVersioned {
public:
@@ -71,19 +71,27 @@ struct TlsRecord {
const DataBuffer buffer;
};
+// Make a filter and install it on a TlsAgent.
+template <class T, typename... Args>
+inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
+ Args&&... args) {
+ auto filter = std::make_shared<T>(agent, std::forward<Args>(args)...);
+ agent->SetFilter(filter);
+ return filter;
+}
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter()
- : agent_(nullptr),
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent)
+ : agent_(agent),
count_(0),
cipher_spec_(),
dropped_record_(false),
in_sequence_number_(0),
out_sequence_number_(0) {}
- void SetAgent(const TlsAgent* agent) { agent_ = agent; }
- const TlsAgent* agent() const { return agent_; }
+ std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -126,7 +134,7 @@ class TlsRecordFilter : public PacketFilter {
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
- const TlsAgent* agent_;
+ std::weak_ptr<TlsAgent> agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
// Whether we dropped a record since the cipher spec changed.
@@ -175,9 +183,13 @@ inline std::ostream& operator<<(std::ostream& stream,
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
- TlsHandshakeFilter(const std::set<uint8_t>& types)
- : handshake_types_(types), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsRecordFilter(agent),
+ handshake_types_(types),
+ preceding_fragment_() {}
// This filter can be set to be selective based on handshake message type. If
// this function isn't used (or the set is empty), then all handshake messages
@@ -229,12 +241,14 @@ class TlsHandshakeFilter : public TlsRecordFilter {
};
// Make a copy of the first instance of a handshake message.
-class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
- TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : TlsHandshakeFilter({handshake_type}), buffer_() {}
- TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
- : TlsHandshakeFilter(handshake_types), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type)
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& handshake_types)
+ : TlsHandshakeFilter(agent, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -251,9 +265,10 @@ class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
- TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent,
+ uint8_t handshake_type,
const DataBuffer& replacement)
- : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
+ : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -266,9 +281,11 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
public:
- TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
- TlsRecordRecorder()
- : filter_(false),
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct)
+ : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent),
+ filter_(false),
ct_(content_handshake), // dummy (<optional> is C++14)
records_() {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
@@ -289,7 +306,9 @@ class TlsRecordRecorder : public TlsRecordFilter {
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
- TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent,
+ DataBuffer& buffer)
+ : TlsRecordFilter(agent), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -302,6 +321,8 @@ class TlsConversationRecorder : public TlsRecordFilter {
// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
public:
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent)
+ : TlsRecordFilter(agent) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -338,13 +359,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter()
- : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent)
+ : TlsHandshakeFilter(agent,
+ {kTlsHandshakeClientHello, kTlsHandshakeServerHello,
kTlsHandshakeHelloRetryRequest,
kTlsHandshakeEncryptedExtensions}) {}
- TlsExtensionFilter(const std::set<uint8_t>& types)
- : TlsHandshakeFilter(types) {}
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent,
+ const std::set<uint8_t>& types)
+ : TlsHandshakeFilter(agent, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -365,8 +388,13 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
class TlsExtensionCapture : public TlsExtensionFilter {
public:
- TlsExtensionCapture(uint16_t ext, bool last = false)
- : extension_(ext), captured_(false), last_(last), data_() {}
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ bool last = false)
+ : TlsExtensionFilter(agent),
+ extension_(ext),
+ captured_(false),
+ last_(last),
+ data_() {}
const DataBuffer& extension() const { return data_; }
bool captured() const { return captured_; }
@@ -385,8 +413,9 @@ class TlsExtensionCapture : public TlsExtensionFilter {
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
- TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
- : extension_(extension), data_(data) {}
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension, const DataBuffer& data)
+ : TlsExtensionFilter(agent), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
@@ -398,7 +427,9 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
class TlsExtensionDropper : public TlsExtensionFilter {
public:
- TlsExtensionDropper(uint16_t extension) : extension_(extension) {}
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t extension)
+ : TlsExtensionFilter(agent), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
@@ -408,8 +439,9 @@ class TlsExtensionDropper : public TlsExtensionFilter {
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
- TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
- : extension_(ext), data_(data) {}
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ const DataBuffer& data)
+ : TlsHandshakeFilter(agent), extension_(ext), data_(data) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -426,16 +458,20 @@ typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
- unsigned int record, VoidFunction func)
- : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
+ AfterRecordN(const std::shared_ptr<TlsAgent>& src,
+ const std::shared_ptr<TlsAgent>& dest, unsigned int record,
+ VoidFunction func)
+ : TlsRecordFilter(src),
+ dest_(dest),
+ record_(record),
+ func_(func),
+ counter_(0) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- std::weak_ptr<TlsAgent> src_;
std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
@@ -444,10 +480,12 @@ class AfterRecordN : public TlsRecordFilter {
// When we see the ClientKeyExchange from |client|, increment the
// ClientHelloVersion on |server|.
-class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
+class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
+ TlsClientHelloVersionChanger(const std::shared_ptr<TlsAgent>& client,
+ const std::shared_ptr<TlsAgent>& server)
+ : TlsHandshakeFilter(client, {kTlsHandshakeClientKeyExchange}),
+ server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -477,14 +515,16 @@ class SelectiveDropFilter : public PacketFilter {
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
- SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
- : pattern_(pattern), counter_(0) {
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ uint32_t pattern, bool enabled = true)
+ : TlsRecordFilter(agent), pattern_(pattern), counter_(0) {
if (!enabled) {
Disable();
}
}
- SelectiveRecordDropFilter(std::initializer_list<size_t> records)
- : SelectiveRecordDropFilter(ToPattern(records), true) {}
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ std::initializer_list<size_t> records)
+ : SelectiveRecordDropFilter(agent, ToPattern(records), true) {}
void Reset(uint32_t pattern) {
counter_ = 0;
@@ -509,10 +549,12 @@ class SelectiveRecordDropFilter : public TlsRecordFilter {
};
// Set the version number in the ClientHello.
-class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
+class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
+ TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t version)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}),
+ version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -525,7 +567,8 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
- TlsLastByteDamager(uint8_t type) : type_(type) {}
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type)
+ : TlsHandshakeFilter(agent), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
@@ -545,8 +588,10 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
- SelectedCipherSuiteReplacer(uint16_t suite)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent,
+ uint16_t suite)
+ : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ cipher_suite_(suite) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,