summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h149
1 files changed, 88 insertions, 61 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index 1bbe190ab..effda4aa0 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -11,7 +11,7 @@
#include <memory>
#include <set>
#include <vector>
-
+#include "sslt.h"
#include "test_io.h"
#include "tls_agent.h"
#include "tls_parser.h"
@@ -27,43 +27,57 @@ class TlsCipherSpec;
class TlsVersioned {
public:
- TlsVersioned() : version_(0) {}
- explicit TlsVersioned(uint16_t version) : version_(version) {}
+ TlsVersioned() : variant_(ssl_variant_stream), version_(0) {}
+ TlsVersioned(SSLProtocolVariant var, uint16_t ver)
+ : variant_(var), version_(ver) {}
- bool is_dtls() const { return IsDtls(version_); }
+ bool is_dtls() const { return variant_ == ssl_variant_datagram; }
+ SSLProtocolVariant variant() const { return variant_; }
uint16_t version() const { return version_; }
void WriteStream(std::ostream& stream) const;
protected:
+ SSLProtocolVariant variant_;
uint16_t version_;
};
class TlsRecordHeader : public TlsVersioned {
public:
- TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {}
- TlsRecordHeader(uint16_t version, uint8_t content_type,
- uint64_t sequence_number)
- : TlsVersioned(version),
- content_type_(content_type),
- sequence_number_(sequence_number) {}
+ TlsRecordHeader()
+ : TlsVersioned(), content_type_(0), sequence_number_(0), header_() {}
+ TlsRecordHeader(SSLProtocolVariant var, uint16_t ver, uint8_t ct,
+ uint64_t seqno)
+ : TlsVersioned(var, ver),
+ content_type_(ct),
+ sequence_number_(seqno),
+ header_() {}
uint8_t content_type() const { return content_type_; }
uint64_t sequence_number() const { return sequence_number_; }
uint16_t epoch() const {
return static_cast<uint16_t>(sequence_number_ >> 48);
}
- size_t header_length() const { return is_dtls() ? 13 : 5; }
+ size_t header_length() const;
+ const DataBuffer& header() const { return header_; }
// Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(uint64_t sequence_number, TlsParser* parser, DataBuffer* body);
+ bool Parse(bool is_dtls13, uint64_t sequence_number, TlsParser* parser,
+ DataBuffer* body);
// Write the header and body to a buffer at the given offset.
// Return the offset of the end of the write.
size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+ size_t WriteHeader(DataBuffer* buffer, size_t offset, size_t body_len) const;
private:
+ static uint64_t RecoverSequenceNumber(uint64_t expected, uint32_t partial,
+ size_t partial_bits);
+ static uint64_t ParseSequenceNumber(uint64_t expected, uint32_t raw,
+ size_t seq_no_bits, size_t epoch_bits);
+
uint8_t content_type_;
uint64_t sequence_number_;
+ DataBuffer header_;
};
struct TlsRecord {
@@ -83,8 +97,8 @@ inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter(const std::shared_ptr<TlsAgent>& agent)
- : agent_(agent),
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
+ : agent_(a),
count_(0),
cipher_spec_(),
dropped_record_(false),
@@ -106,7 +120,8 @@ class TlsRecordFilter : public PacketFilter {
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
uint8_t* inner_content_type, DataBuffer* plaintext);
bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
- const DataBuffer& plaintext, DataBuffer* ciphertext);
+ const DataBuffer& plaintext, DataBuffer* ciphertext,
+ size_t padding = 0);
protected:
// There are two filter functions which can be overriden. Both are
@@ -130,6 +145,8 @@ class TlsRecordFilter : public PacketFilter {
return KEEP;
}
+ bool is_dtls13() const;
+
private:
static void CipherSpecChanged(void* arg, PRBool sending,
ssl3CipherSpec* newSpec);
@@ -183,13 +200,11 @@ inline std::ostream& operator<<(std::ostream& stream,
// records and that they don't span records or anything crazy like that.
class TlsHandshakeFilter : public TlsRecordFilter {
public:
- TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent)
- : TlsRecordFilter(agent), handshake_types_(), preceding_fragment_() {}
- TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& agent,
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a), handshake_types_(), preceding_fragment_() {}
+ TlsHandshakeFilter(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& types)
- : TlsRecordFilter(agent),
- handshake_types_(types),
- preceding_fragment_() {}
+ : TlsRecordFilter(a), handshake_types_(types), preceding_fragment_() {}
// This filter can be set to be selective based on handshake message type. If
// this function isn't used (or the set is empty), then all handshake messages
@@ -243,12 +258,12 @@ class TlsHandshakeFilter : public TlsRecordFilter {
// Make a copy of the first instance of a handshake message.
class TlsHandshakeRecorder : public TlsHandshakeFilter {
public:
- TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
uint8_t handshake_type)
- : TlsHandshakeFilter(agent, {handshake_type}), buffer_() {}
- TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& agent,
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_() {}
+ TlsHandshakeRecorder(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& handshake_types)
- : TlsHandshakeFilter(agent, handshake_types), buffer_() {}
+ : TlsHandshakeFilter(a, handshake_types), buffer_() {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -265,10 +280,10 @@ class TlsHandshakeRecorder : public TlsHandshakeFilter {
// Replace all instances of a handshake message.
class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
public:
- TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& agent,
+ TlsInspectorReplaceHandshakeMessage(const std::shared_ptr<TlsAgent>& a,
uint8_t handshake_type,
const DataBuffer& replacement)
- : TlsHandshakeFilter(agent, {handshake_type}), buffer_(replacement) {}
+ : TlsHandshakeFilter(a, {handshake_type}), buffer_(replacement) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -281,10 +296,10 @@ class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
// Make a copy of each record of a given type.
class TlsRecordRecorder : public TlsRecordFilter {
public:
- TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent, uint8_t ct)
- : TlsRecordFilter(agent), filter_(true), ct_(ct), records_() {}
- TlsRecordRecorder(const std::shared_ptr<TlsAgent>& agent)
- : TlsRecordFilter(agent),
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a, uint8_t ct)
+ : TlsRecordFilter(a), filter_(true), ct_(ct), records_() {}
+ TlsRecordRecorder(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a),
filter_(false),
ct_(content_handshake), // dummy (<optional> is C++14)
records_() {}
@@ -306,9 +321,9 @@ class TlsRecordRecorder : public TlsRecordFilter {
// Make a copy of the complete conversation.
class TlsConversationRecorder : public TlsRecordFilter {
public:
- TlsConversationRecorder(const std::shared_ptr<TlsAgent>& agent,
+ TlsConversationRecorder(const std::shared_ptr<TlsAgent>& a,
DataBuffer& buffer)
- : TlsRecordFilter(agent), buffer_(buffer) {}
+ : TlsRecordFilter(a), buffer_(buffer) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
@@ -321,8 +336,7 @@ class TlsConversationRecorder : public TlsRecordFilter {
// Make a copy of the records
class TlsHeaderRecorder : public TlsRecordFilter {
public:
- TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& agent)
- : TlsRecordFilter(agent) {}
+ TlsHeaderRecorder(const std::shared_ptr<TlsAgent>& a) : TlsRecordFilter(a) {}
virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -359,15 +373,15 @@ typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
class TlsExtensionFilter : public TlsHandshakeFilter {
public:
- TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent)
- : TlsHandshakeFilter(agent,
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a,
{kTlsHandshakeClientHello, kTlsHandshakeServerHello,
kTlsHandshakeHelloRetryRequest,
kTlsHandshakeEncryptedExtensions}) {}
- TlsExtensionFilter(const std::shared_ptr<TlsAgent>& agent,
+ TlsExtensionFilter(const std::shared_ptr<TlsAgent>& a,
const std::set<uint8_t>& types)
- : TlsHandshakeFilter(agent, types) {}
+ : TlsHandshakeFilter(a, types) {}
static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
@@ -388,9 +402,9 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
class TlsExtensionCapture : public TlsExtensionFilter {
public:
- TlsExtensionCapture(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ TlsExtensionCapture(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
bool last = false)
- : TlsExtensionFilter(agent),
+ : TlsExtensionFilter(a),
extension_(ext),
captured_(false),
last_(last),
@@ -413,9 +427,9 @@ class TlsExtensionCapture : public TlsExtensionFilter {
class TlsExtensionReplacer : public TlsExtensionFilter {
public:
- TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& agent,
- uint16_t extension, const DataBuffer& data)
- : TlsExtensionFilter(agent), extension_(extension), data_(data) {}
+ TlsExtensionReplacer(const std::shared_ptr<TlsAgent>& a, uint16_t extension,
+ const DataBuffer& data)
+ : TlsExtensionFilter(a), extension_(extension), data_(data) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer& input,
DataBuffer* output) override;
@@ -427,9 +441,8 @@ class TlsExtensionReplacer : public TlsExtensionFilter {
class TlsExtensionDropper : public TlsExtensionFilter {
public:
- TlsExtensionDropper(const std::shared_ptr<TlsAgent>& agent,
- uint16_t extension)
- : TlsExtensionFilter(agent), extension_(extension) {}
+ TlsExtensionDropper(const std::shared_ptr<TlsAgent>& a, uint16_t extension)
+ : TlsExtensionFilter(a), extension_(extension) {}
PacketFilter::Action FilterExtension(uint16_t extension_type,
const DataBuffer&, DataBuffer*) override;
@@ -439,9 +452,9 @@ class TlsExtensionDropper : public TlsExtensionFilter {
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
- TlsExtensionInjector(const std::shared_ptr<TlsAgent>& agent, uint16_t ext,
+ TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
const DataBuffer& data)
- : TlsHandshakeFilter(agent), extension_(ext), data_(data) {}
+ : TlsHandshakeFilter(a), extension_(ext), data_(data) {}
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -453,7 +466,6 @@ class TlsExtensionInjector : public TlsHandshakeFilter {
const DataBuffer data_;
};
-class TlsAgent;
typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
@@ -495,6 +507,22 @@ class TlsClientHelloVersionChanger : public TlsHandshakeFilter {
std::weak_ptr<TlsAgent> server_;
};
+// Damage a record.
+class TlsRecordLastByteDamager : public TlsRecordFilter {
+ public:
+ TlsRecordLastByteDamager(const std::shared_ptr<TlsAgent>& a)
+ : TlsRecordFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) override {
+ *changed = data;
+ changed->data()[changed->len() - 1]++;
+ return CHANGE;
+ }
+};
+
// This class selectively drops complete writes. This relies on the fact that
// writes in libssl are on record boundaries.
class SelectiveDropFilter : public PacketFilter {
@@ -515,16 +543,16 @@ class SelectiveDropFilter : public PacketFilter {
// datagram, we just drop one.
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
- SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
uint32_t pattern, bool enabled = true)
- : TlsRecordFilter(agent), pattern_(pattern), counter_(0) {
+ : TlsRecordFilter(a), pattern_(pattern), counter_(0) {
if (!enabled) {
Disable();
}
}
- SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& agent,
+ SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
std::initializer_list<size_t> records)
- : SelectiveRecordDropFilter(agent, ToPattern(records), true) {}
+ : SelectiveRecordDropFilter(a, ToPattern(records), true) {}
void Reset(uint32_t pattern) {
counter_ = 0;
@@ -551,10 +579,9 @@ class SelectiveRecordDropFilter : public TlsRecordFilter {
// Set the version number in the ClientHello.
class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
public:
- TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& agent,
+ TlsClientHelloVersionSetter(const std::shared_ptr<TlsAgent>& a,
uint16_t version)
- : TlsHandshakeFilter(agent, {kTlsHandshakeClientHello}),
- version_(version) {}
+ : TlsHandshakeFilter(a, {kTlsHandshakeClientHello}), version_(version) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -567,8 +594,8 @@ class TlsClientHelloVersionSetter : public TlsHandshakeFilter {
// Damages the last byte of a handshake message.
class TlsLastByteDamager : public TlsHandshakeFilter {
public:
- TlsLastByteDamager(const std::shared_ptr<TlsAgent>& agent, uint8_t type)
- : TlsHandshakeFilter(agent), type_(type) {}
+ TlsLastByteDamager(const std::shared_ptr<TlsAgent>& a, uint8_t type)
+ : TlsHandshakeFilter(a), type_(type) {}
PacketFilter::Action FilterHandshake(
const TlsHandshakeFilter::HandshakeHeader& header,
const DataBuffer& input, DataBuffer* output) override {
@@ -588,9 +615,9 @@ class TlsLastByteDamager : public TlsHandshakeFilter {
class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
public:
- SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& agent,
+ SelectedCipherSuiteReplacer(const std::shared_ptr<TlsAgent>& a,
uint16_t suite)
- : TlsHandshakeFilter(agent, {kTlsHandshakeServerHello}),
+ : TlsHandshakeFilter(a, {kTlsHandshakeServerHello}),
cipher_suite_(suite) {}
protected: