summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h192
1 files changed, 23 insertions, 169 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index 1db3b90f6..e4030e23f 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -50,13 +50,10 @@ class TlsRecordHeader : public TlsVersioned {
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
- uint16_t epoch() const {
- return static_cast<uint16_t>(sequence_number_ >> 48);
- }
- size_t header_length() const { return is_dtls() ? 13 : 5; }
+ size_t header_length() const { return is_dtls() ? 11 : 3; }
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
+ bool Parse(TlsParser* parser, DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
@@ -66,21 +63,10 @@ class TlsRecordHeader : public TlsVersioned {
uint64_t sequence_number_;
};
-struct TlsRecord {
- const TlsRecordHeader header;
- const DataBuffer buffer;
-};
-
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter()
- : agent_(nullptr),
- count_(0),
- cipher_spec_(),
- dropped_record_(false),
- in_sequence_number_(0),
- out_sequence_number_(0) {}
+ TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
void SetAgent(const TlsAgent* agent) { agent_ = agent; }
const TlsAgent* agent() const { return agent_; }
@@ -129,21 +115,14 @@ class TlsRecordFilter : public PacketFilter {
const TlsAgent* agent_;
size_t count_;
std::unique_ptr<TlsCipherSpec> cipher_spec_;
- // Whether we dropped a record since the cipher spec changed.
- bool dropped_record_;
- // The sequence number we use for reading records as they are written.
- uint64_t in_sequence_number_;
- // The sequence number we use for writing modified records.
- uint64_t out_sequence_number_;
};
-inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
+inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
v.WriteStream(stream);
return stream;
}
-inline std::ostream& operator<<(std::ostream& stream,
- const TlsRecordHeader& hdr) {
+inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
hdr.WriteStream(stream);
stream << ' ';
switch (hdr.content_type()) {
@@ -154,17 +133,13 @@ inline std::ostream& operator<<(std::ostream& stream,
stream << "Alert";
break;
case kTlsHandshakeType:
- case kTlsAltHandshakeType:
stream << "Handshake";
break;
case kTlsApplicationDataType:
stream << "Data";
break;
- case kTlsAckType:
- stream << "ACK";
- break;
default:
- stream << '<' << static_cast<int>(hdr.content_type()) << '>';
+ stream << '<' << hdr.content_type() << '>';
break;
}
return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
@@ -175,16 +150,7 @@ inline std::ostream& operator<<(std::ostream& stream,
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter() : handshake_types_(), preceding_fragment_() {}
- TlsHandshakeFilter(const std::set<uint8_t>& types)
- : handshake_types_(types), preceding_fragment_() {}
-
- // This filter can be set to be selective based on handshake message type. If
- // this function isn't used (or the set is empty), then all handshake messages
- // will be filtered.
- void SetHandshakeTypes(const std::set<uint8_t>& types) {
- handshake_types_ = types;
- }
+ TlsHandshakeFilter() {}
class HandshakeHeader : public TlsVersioned {
public:
@@ -192,8 +158,7 @@ class TlsHandshakeFilter : public TlsRecordFilter {
uint8_t handshake_type() const { return handshake_type_; }
bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
- const DataBuffer& preceding_fragment, DataBuffer* body,
- bool* complete);
+ DataBuffer* body);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
size_t WriteFragment(DataBuffer* buffer, size_t offset,
@@ -204,8 +169,7 @@ class TlsHandshakeFilter : public TlsRecordFilter {
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
- uint32_t expected_offset, uint32_t* length,
- bool* last_fragment);
+ uint32_t* length);
uint8_t handshake_type_;
uint16_t message_seq_;
@@ -221,30 +185,22 @@ class TlsHandshakeFilter : public TlsRecordFilter {
DataBuffer* output) = 0;
private:
- bool IsFilteredType(const HandshakeHeader& header,
- const DataBuffer& handshake);
-
- std::set<uint8_t> handshake_types_;
- DataBuffer preceding_fragment_;
};
// Make a copy of the first instance of a handshake message.
class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
- : TlsHandshakeFilter({handshake_type}), buffer_() {}
- TlsInspectorRecordHandshakeMessage(const std::set<uint8_t>& handshake_types)
- : TlsHandshakeFilter(handshake_types), buffer_() {}
+ : handshake_type_(handshake_type), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
- void Reset() { buffer_.Truncate(0); }
-
const DataBuffer& buffer() const { return buffer_; }
private:
+ uint8_t handshake_type_;
DataBuffer buffer_;
};
@@ -253,39 +209,17 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
const DataBuffer& replacement)
- : TlsHandshakeFilter({handshake_type}), buffer_(replacement) {}
+ : handshake_type_(handshake_type), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
+ uint8_t handshake_type_;
DataBuffer buffer_;
};
-// Make a copy of each record of a given type.
-class TlsRecordRecorder : public TlsRecordFilter {
- public:
- TlsRecordRecorder(uint8_t ct) : filter_(true), ct_(ct), records_() {}
- TlsRecordRecorder()
- : filter_(false),
- ct_(content_handshake), // dummy (<optional> is C++14)
- records_() {}
- virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output);
-
- size_t count() const { return records_.size(); }
- void Clear() { records_.clear(); }
-
- const TlsRecord& record(size_t i) const { return records_[i]; }
-
- private:
- bool filter_;
- uint8_t ct_;
- std::vector<TlsRecord> records_;
-};
-
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
@@ -296,31 +230,15 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer* output);
private:
- DataBuffer buffer_;
+ DataBuffer& buffer_;
};
-// Make a copy of the records
-class TlsHeaderRecorder : public TlsRecordFilter {
- public:
- virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output);
- const TlsRecordHeader* header(size_t index);
-
- private:
- std::vector<TlsRecordHeader> headers_;
-};
-
-typedef std::initializer_list<std::shared_ptr<PacketFilter>>
- ChainedPacketFilterInit;
-
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
- ChainedPacketFilter(ChainedPacketFilterInit il) : filters_(il) {}
virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
@@ -338,13 +256,13 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter()
- : TlsHandshakeFilter({kTlsHandshakeClientHello, kTlsHandshakeServerHello,
- kTlsHandshakeHelloRetryRequest,
- kTlsHandshakeEncryptedExtensions}) {}
+ TlsExtensionFilter() : handshake_types_() {
+ handshake_types_.insert(kTlsHandshakeClientHello);
+ handshake_types_.insert(kTlsHandshakeServerHello);
+ }
TlsExtensionFilter(const std::set<uint8_t>& types)
- : TlsHandshakeFilter(types) {}
+ : handshake_types_(types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -361,6 +279,8 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
+
+ std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
@@ -406,21 +326,6 @@ class TlsExtensionDropper : public TlsExtensionFilter {
uint16_t extension_;
};
-class TlsExtensionInjector : public TlsHandshakeFilter {
- public:
- TlsExtensionInjector(uint16_t ext, const DataBuffer& data)
- : extension_(ext), data_(data) {}
-
- protected:
- PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) override;
-
- private:
- const uint16_t extension_;
- const DataBuffer data_;
-};
-
class TlsAgent;
typedef std::function<void(void)> VoidFunction;
@@ -447,7 +352,7 @@ class AfterRecordN : public TlsRecordFilter {
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
- : TlsHandshakeFilter({kTlsHandshakeClientKeyExchange}), server_(server) {}
+ : server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -472,47 +377,10 @@ class SelectiveDropFilter : public PacketFilter {
uint8_t counter_;
};
-// This class selectively drops complete records. The difference from
-// SelectiveDropFilter is that if multiple DTLS records are in the same
-// datagram, we just drop one.
-class SelectiveRecordDropFilter : public TlsRecordFilter {
- public:
- SelectiveRecordDropFilter(uint32_t pattern, bool enabled = true)
- : pattern_(pattern), counter_(0) {
- if (!enabled) {
- Disable();
- }
- }
- SelectiveRecordDropFilter(std::initializer_list<size_t> records)
- : SelectiveRecordDropFilter(ToPattern(records), true) {}
-
- void Reset(uint32_t pattern) {
- counter_ = 0;
- PacketFilter::Enable();
- pattern_ = pattern;
- }
-
- void Reset(std::initializer_list<size_t> records) {
- Reset(ToPattern(records));
- }
-
- protected:
- PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
- const DataBuffer& data,
- DataBuffer* changed) override;
-
- private:
- static uint32_t ToPattern(std::initializer_list<size_t> records);
-
- uint32_t pattern_;
- uint8_t counter_;
-};
-
// Set the version number in the ClientHello.
class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionSetter(uint16_t version)
- : TlsHandshakeFilter({kTlsHandshakeClientHello}), version_(version) {}
+ TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -543,20 +411,6 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
uint8_t type_;
};
-class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
- public:
- SelectedCipherSuiteReplacer(uint16_t suite)
- : TlsHandshakeFilter({kTlsHandshakeServerHello}), cipher_suite_(suite) {}
-
- protected:
- PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
- const DataBuffer& input,
- DataBuffer* output) override;
-
- private:
- uint16_t cipher_suite_;
-};
-
} // namespace nss_test
#endif