summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_filter.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_filter.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_filter.h112
1 files changed, 89 insertions, 23 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_filter.h b/security/nss/gtests/ssl_gtest/tls_filter.h
index 2b6e88645..64ee71c89 100644
--- a/security/nss/gtests/ssl_gtest/tls_filter.h
+++ b/security/nss/gtests/ssl_gtest/tls_filter.h
@@ -97,13 +97,7 @@ inline std::shared_ptr<T> MakeTlsFilter(const std::shared_ptr<TlsAgent>& agent,
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter(const std::shared_ptr<TlsAgent>& a)
- : agent_(a),
- count_(0),
- cipher_spec_(),
- dropped_record_(false),
- in_sequence_number_(0),
- out_sequence_number_(0) {}
+ TlsRecordFilter(const std::shared_ptr<TlsAgent>& a);
std::shared_ptr<TlsAgent> agent() const { return agent_.lock(); }
@@ -118,10 +112,11 @@ class TlsRecordFilter : public PacketFilter {
// behavior.
void EnableDecryption();
bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
- uint8_t* inner_content_type, DataBuffer* plaintext);
- bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
- const DataBuffer& plaintext, DataBuffer* ciphertext,
- size_t padding = 0);
+ uint16_t* protection_epoch, uint8_t* inner_content_type,
+ DataBuffer* plaintext);
+ bool Protect(TlsCipherSpec& protection_spec, const TlsRecordHeader& header,
+ uint8_t inner_content_type, const DataBuffer& plaintext,
+ DataBuffer* ciphertext, size_t padding = 0);
protected:
// There are two filter functions which can be overriden. Both are
@@ -146,20 +141,17 @@ class TlsRecordFilter : public PacketFilter {
}
bool is_dtls13() const;
+ TlsCipherSpec& spec(uint16_t epoch);
private:
- static void CipherSpecChanged(void* arg, PRBool sending,
- ssl3CipherSpec* newSpec);
+ static void SecretCallback(PRFileDesc* fd, PRUint16 epoch,
+ SSLSecretDirection dir, PK11SymKey* secret,
+ void* arg);
std::weak_ptr<TlsAgent> agent_;
- size_t count_;
- std::unique_ptr<TlsCipherSpec> cipher_spec_;
- // Whether we dropped a record since the cipher spec changed.
- bool dropped_record_;
- // The sequence number we use for reading records as they are written.
- uint64_t in_sequence_number_;
- // The sequence number we use for writing modified records.
- uint64_t out_sequence_number_;
+ size_t count_ = 0;
+ std::vector<TlsCipherSpec> cipher_specs_;
+ bool decrypting_ = false;
};
inline std::ostream& operator<<(std::ostream& stream, const TlsVersioned& v) {
@@ -449,6 +441,80 @@ class TlsExtensionDropper : public TlsExtensionFilter {
uint16_t extension_;
};
+class TlsHandshakeDropper : public TlsHandshakeFilter {
+ public:
+ TlsHandshakeDropper(const std::shared_ptr<TlsAgent>& a)
+ : TlsHandshakeFilter(a) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ return DROP;
+ }
+};
+
+class TlsEncryptedHandshakeMessageReplacer : public TlsRecordFilter {
+ public:
+ TlsEncryptedHandshakeMessageReplacer(const std::shared_ptr<TlsAgent>& a,
+ uint8_t old_ct, uint8_t new_ct)
+ : TlsRecordFilter(a), old_ct_(old_ct), new_ct_(new_ct) {}
+
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) override {
+ if (header.content_type() != ssl_ct_application_data) {
+ return KEEP;
+ }
+
+ uint16_t protection_epoch = 0;
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+ if (!Unprotect(header, record, &protection_epoch, &inner_content_type,
+ &plaintext) ||
+ !plaintext.len()) {
+ return KEEP;
+ }
+
+ if (inner_content_type != ssl_ct_handshake) {
+ return KEEP;
+ }
+
+ size_t off = 0;
+ uint32_t msg_len = 0;
+ uint32_t msg_type = 255; // Not a real message
+ do {
+ if (!plaintext.Read(off, 1, &msg_type) || msg_type == old_ct_) {
+ break;
+ }
+
+ // Increment and check next messages
+ if (!plaintext.Read(++off, 3, &msg_len)) {
+ break;
+ }
+ off += 3 + msg_len;
+ } while (msg_type != old_ct_);
+
+ if (msg_type == old_ct_) {
+ plaintext.Write(off, new_ct_, 1);
+ }
+
+ DataBuffer ciphertext;
+ bool ok = Protect(spec(protection_epoch), header, inner_content_type,
+ plaintext, &ciphertext, 0);
+ if (!ok) {
+ return KEEP;
+ }
+ *offset = header.Write(output, *offset, ciphertext);
+ return CHANGE;
+ }
+
+ private:
+ uint8_t old_ct_;
+ uint8_t new_ct_;
+};
+
class TlsExtensionInjector : public TlsHandshakeFilter {
public:
TlsExtensionInjector(const std::shared_ptr<TlsAgent>& a, uint16_t ext,
@@ -557,9 +623,9 @@ class SelectiveDropFilter : public PacketFilter {
class SelectiveRecordDropFilter : public TlsRecordFilter {
public:
SelectiveRecordDropFilter(const std::shared_ptr<TlsAgent>& a,
- uint32_t pattern, bool enabled = true)
+ uint32_t pattern, bool on = true)
: TlsRecordFilter(a), pattern_(pattern), counter_(0) {
- if (!enabled) {
+ if (!on) {
Disable();
}
}