diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc index f4940bf28..92947c2c7 100644 --- a/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc @@ -20,14 +20,16 @@ namespace nss_test { // This class cuts every unencrypted handshake record into two parts. class RecordFragmenter : public PacketFilter { public: - RecordFragmenter() : sequence_number_(0), splitting_(true) {} + RecordFragmenter(bool is_dtls13) + : is_dtls13_(is_dtls13), sequence_number_(0), splitting_(true) {} private: class HandshakeSplitter { public: - HandshakeSplitter(const DataBuffer& input, DataBuffer* output, - uint64_t* sequence_number) - : input_(input), + HandshakeSplitter(bool is_dtls13, const DataBuffer& input, + DataBuffer* output, uint64_t* sequence_number) + : is_dtls13_(is_dtls13), + input_(input), output_(output), cursor_(0), sequence_number_(sequence_number) {} @@ -35,9 +37,9 @@ class RecordFragmenter : public PacketFilter { private: void WriteRecord(TlsRecordHeader& record_header, DataBuffer& record_fragment) { - TlsRecordHeader fragment_header(record_header.version(), - record_header.content_type(), - *sequence_number_); + TlsRecordHeader fragment_header( + record_header.variant(), record_header.version(), + record_header.content_type(), *sequence_number_); ++*sequence_number_; if (::g_ssl_gtest_verbose) { std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment @@ -88,7 +90,7 @@ class RecordFragmenter : public PacketFilter { while (parser.remaining()) { TlsRecordHeader header; DataBuffer record; - if (!header.Parse(0, &parser, &record)) { + if (!header.Parse(is_dtls13_, 0, &parser, &record)) { ADD_FAILURE() << "bad record header"; return false; } @@ -118,6 +120,7 @@ class RecordFragmenter : public PacketFilter { } private: + bool is_dtls13_; const DataBuffer& input_; DataBuffer* output_; size_t cursor_; @@ -132,7 +135,7 @@ class RecordFragmenter : public PacketFilter { } output->Allocate(input.len()); - HandshakeSplitter splitter(input, output, &sequence_number_); + HandshakeSplitter splitter(is_dtls13_, input, output, &sequence_number_); if (!splitter.Split()) { // If splitting fails, we obviously reached encrypted packets. // Stop splitting from that point onward. @@ -144,18 +147,21 @@ class RecordFragmenter : public PacketFilter { } private: + bool is_dtls13_; uint64_t sequence_number_; bool splitting_; }; TEST_P(TlsConnectDatagram, FragmentClientPackets) { - client_->SetFilter(std::make_shared<RecordFragmenter>()); + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + client_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); Connect(); SendReceive(); } TEST_P(TlsConnectDatagram, FragmentServerPackets) { - server_->SetFilter(std::make_shared<RecordFragmenter>()); + bool is_dtls13 = version_ >= SSL_LIBRARY_VERSION_TLS_1_3; + server_->SetFilter(std::make_shared<RecordFragmenter>(is_dtls13)); Connect(); SendReceive(); } |