diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc | 304 |
1 files changed, 248 insertions, 56 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc index 77703dd8e..f1b78f52f 100644 --- a/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc +++ b/security/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc @@ -6,6 +6,7 @@ #include <functional> #include <memory> +#include <vector> #include "secerr.h" #include "ssl.h" #include "sslerr.h" @@ -55,7 +56,8 @@ TEST_P(TlsConnectGeneric, CipherSuiteMismatch) { class TlsAlertRecorder : public TlsRecordFilter { public: - TlsAlertRecorder() : level_(255), description_(255) {} + TlsAlertRecorder(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent), level_(255), description_(255) {} PacketFilter::Action FilterRecord(const TlsRecordHeader& header, const DataBuffer& input, @@ -84,13 +86,13 @@ class TlsAlertRecorder : public TlsRecordFilter { }; class HelloTruncator : public TlsHandshakeFilter { + public: + HelloTruncator(const std::shared_ptr<TlsAgent>& agent) + : TlsHandshakeFilter( + agent, {kTlsHandshakeClientHello, kTlsHandshakeServerHello}) {} PacketFilter::Action FilterHandshake(const HandshakeHeader& header, const DataBuffer& input, DataBuffer* output) override { - if (header.handshake_type() != kTlsHandshakeClientHello && - header.handshake_type() != kTlsHandshakeServerHello) { - return KEEP; - } output->Assign(input.data(), input.len() - 1); return CHANGE; } @@ -98,19 +100,17 @@ class HelloTruncator : public TlsHandshakeFilter { // Verify that when NSS reports that an alert is sent, it is actually sent. TEST_P(TlsConnectGeneric, CaptureAlertServer) { - client_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - server_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(client_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(server_); - ConnectExpectAlert(server_, kTlsAlertIllegalParameter); + ConnectExpectAlert(server_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description()); + EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description()); } TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); ConnectExpectAlert(client_, kTlsAlertDecodeError); EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); @@ -119,12 +119,10 @@ TEST_P(TlsConnectGenericPre13, CaptureAlertClient) { // In TLS 1.3, the server can't read the client alert. TEST_P(TlsConnectTls13, CaptureAlertClient) { - server_->SetPacketFilter(std::make_shared<HelloTruncator>()); - auto alert_recorder = std::make_shared<TlsAlertRecorder>(); - client_->SetPacketFilter(alert_recorder); + MakeTlsFilter<HelloTruncator>(server_); + auto alert_recorder = MakeTlsFilter<TlsAlertRecorder>(client_); - server_->StartConnect(); - client_->StartConnect(); + StartConnect(); client_->Handshake(); client_->ExpectSendAlert(kTlsAlertDecodeError); @@ -166,26 +164,111 @@ TEST_P(TlsConnectDatagram, ConnectSrtp) { SendReceive(); } -// 1.3 is disabled in the next few tests because we don't -// presently support resumption in 1.3. -TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) { +TEST_P(TlsConnectGeneric, ConnectSendReceive) { Connect(); - server_->PrepareForRenegotiate(); - client_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); } -TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) { +class SaveTlsRecord : public TlsRecordFilter { + public: + SaveTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0), contents_() {} + + const DataBuffer& contents() const { return contents_; } + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + contents_ = data; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; + DataBuffer contents_; +}; + +// Check that decrypting filters work and can read any record. +// This test (currently) only works in TLS 1.3 where we can decrypt. +TEST_F(TlsConnectStreamTls13, DecryptRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(client_, 3); + saved->EnableDecryption(); Connect(); - client_->PrepareForRenegotiate(); - server_->StartRenegotiate(); - Handshake(); - CheckConnected(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xdc}; + DataBuffer buf(data, sizeof(data)); + client_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); } -TEST_P(TlsConnectGeneric, ConnectSendReceive) { +TEST_F(TlsConnectStreamTls13, DecryptRecordServer) { + EnsureTlsSetup(); + // Disable tickets so that we are sure to not get NewSessionTicket. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + // 0 = ServerHello, 1 = other handshake, 2 = SendReceive, 3 = SendBuffer + auto saved = MakeTlsFilter<SaveTlsRecord>(server_, 3); + saved->EnableDecryption(); + Connect(); + SendReceive(); + + static const uint8_t data[] = {0xde, 0xad, 0xd5}; + DataBuffer buf(data, sizeof(data)); + server_->SendBuffer(buf); + EXPECT_EQ(buf, saved->contents()); +} + +class DropTlsRecord : public TlsRecordFilter { + public: + DropTlsRecord(const std::shared_ptr<TlsAgent>& agent, size_t index) + : TlsRecordFilter(agent), index_(index), count_(0) {} + + protected: + PacketFilter::Action FilterRecord(const TlsRecordHeader& header, + const DataBuffer& data, + DataBuffer* changed) override { + if (count_++ == index_) { + return DROP; + } + return KEEP; + } + + private: + const size_t index_; + size_t count_; +}; + +// Test that decrypting filters work correctly and are able to drop records. +TEST_F(TlsConnectStreamTls13, DropRecordServer) { + EnsureTlsSetup(); + // Disable session tickets so that the server doesn't send an extra record. + EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(), + SSL_ENABLE_SESSION_TICKETS, PR_FALSE)); + + // 0 = ServerHello, 1 = other handshake, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(server_, 2); + filter->EnableDecryption(); + Connect(); + server_->SendData(23, 23); // This should be dropped, so it won't be counted. + server_->ResetSentBytes(); + SendReceive(); +} + +TEST_F(TlsConnectStreamTls13, DropRecordClient) { + EnsureTlsSetup(); + // 0 = ClientHello, 1 = Finished, 2 = first write + auto filter = MakeTlsFilter<DropTlsRecord>(client_, 2); + filter->EnableDecryption(); Connect(); + client_->SendData(26, 26); // This should be dropped, so it won't be counted. + client_->ResetSentBytes(); SendReceive(); } @@ -224,32 +307,74 @@ TEST_P(TlsConnectStream, ShortRead) { ASSERT_EQ(50U, client_->received_bytes()); } -TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) { +// We enable compression via the API but it's disabled internally, +// so we should never get it. +TEST_P(TlsConnectGeneric, ConnectWithCompressionEnabled) { EnsureTlsSetup(); - client_->EnableCompression(); - server_->EnableCompression(); + client_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); + server_->SetOption(SSL_ENABLE_DEFLATE, PR_TRUE); Connect(); - EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && - variant_ != ssl_variant_datagram, - client_->is_compressed()); + EXPECT_FALSE(client_->is_compressed()); SendReceive(); } -TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) { +class TlsHolddownTest : public TlsConnectDatagram { + protected: + // This causes all timers to run to completion. It advances the clock and + // handshakes on both peers until both peers have no more timers pending, + // which should happen at the end of a handshake. This is necessary to ensure + // that the relatively long holddown timer expires, but that any other timers + // also expire and run correctly. + void RunAllTimersDown() { + while (true) { + PRIntervalTime time; + SECStatus rv = DTLS_GetHandshakeTimeout(client_->ssl_fd(), &time); + if (rv != SECSuccess) { + rv = DTLS_GetHandshakeTimeout(server_->ssl_fd(), &time); + if (rv != SECSuccess) { + break; // Neither peer has an outstanding timer. + } + } + + if (g_ssl_gtest_verbose) { + std::cerr << "Shifting timers" << std::endl; + } + ShiftDtlsTimers(); + Handshake(); + } + } +}; + +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiry) { Connect(); - std::cerr << "Expiring holddown timer\n"; - SSLInt_ForceTimerExpiry(client_->ssl_fd()); - SSLInt_ForceTimerExpiry(server_->ssl_fd()); + std::cerr << "Expiring holddown timer" << std::endl; + RunAllTimersDown(); SendReceive(); if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) { // One for send, one for receive. - EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd())); + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); } } +TEST_P(TlsHolddownTest, TestDtlsHolddownExpiryResumption) { + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + Connect(); + SendReceive(); + + Reset(); + ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); + ExpectResumption(RESUME_TICKET); + Connect(); + RunAllTimersDown(); + SendReceive(); + // One for send, one for receive. + EXPECT_EQ(2, SSLInt_CountCipherSpecs(client_->ssl_fd())); +} + class TlsPreCCSHeaderInjector : public TlsRecordFilter { public: - TlsPreCCSHeaderInjector() {} + TlsPreCCSHeaderInjector(const std::shared_ptr<TlsAgent>& agent) + : TlsRecordFilter(agent) {} virtual PacketFilter::Action FilterRecord( const TlsRecordHeader& record_header, const DataBuffer& input, size_t* offset, DataBuffer* output) override { @@ -266,16 +391,15 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter { }; TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) { - client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); + MakeTlsFilter<TlsPreCCSHeaderInjector>(client_); ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage); client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT); server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER); } TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) { - server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>()); - client_->StartConnect(); - server_->StartConnect(); + MakeTlsFilter<TlsPreCCSHeaderInjector>(server_); + StartConnect(); ExpectAlert(client_, kTlsAlertUnexpectedMessage); Handshake(); EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state()); @@ -306,21 +430,64 @@ TEST_P(TlsConnectTls13, AlertWrongLevel) { TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) { EnsureTlsSetup(); - client_->StartConnect(); - server_->StartConnect(); + StartConnect(); client_->Handshake(); server_->Handshake(); // Send first flight. - client_->adapter()->CloseWrites(); + client_->adapter()->SetWriteError(PR_IO_ERROR); client_->Handshake(); // This will get an error, but shouldn't crash. client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE); } -TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) { - client_->SetShortHeadersEnabled(); - server_->SetShortHeadersEnabled(); - client_->ExpectShortHeaders(); - server_->ExpectShortHeaders(); +TEST_P(TlsConnectDatagram, BlockedWrite) { Connect(); + + // Mark the socket as blocked. + client_->adapter()->SetWriteError(PR_WOULD_BLOCK_ERROR); + static const uint8_t data[] = {1, 2, 3}; + int32_t rv = PR_Write(client_->ssl_fd(), data, sizeof(data)); + EXPECT_GT(0, rv); + EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError()); + + // Remove the write error and though the previous write failed, future reads + // and writes should just work as if it never happened. + client_->adapter()->SetWriteError(0); + SendReceive(); +} + +TEST_F(TlsConnectTest, ConnectSSLv3) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +TEST_F(TlsConnectTest, ConnectSSLv3ClientAuth) { + ConfigureVersion(SSL_LIBRARY_VERSION_3_0); + EnableOnlyStaticRsaCiphers(); + client_->SetupClientAuth(); + server_->RequestClientAuth(true); + Connect(); + CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none); +} + +static size_t ExpectedCbcLen(size_t in, size_t hmac = 20, size_t block = 16) { + // MAC-then-Encrypt expansion formula: + return ((in + hmac + (block - 1)) / block) * block; +} + +TEST_F(TlsConnectTest, OneNRecordSplitting) { + ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_0); + EnsureTlsSetup(); + ConnectWithCipherSuite(TLS_RSA_WITH_AES_128_CBC_SHA); + auto records = MakeTlsFilter<TlsRecordRecorder>(server_); + // This should be split into 1, 16384 and 20. + DataBuffer big_buffer; + big_buffer.Allocate(1 + 16384 + 20); + server_->SendBuffer(big_buffer); + ASSERT_EQ(3U, records->count()); + EXPECT_EQ(ExpectedCbcLen(1), records->record(0).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(16384), records->record(1).buffer.len()); + EXPECT_EQ(ExpectedCbcLen(20), records->record(2).buffer.len()); } INSTANTIATE_TEST_CASE_P( @@ -336,6 +503,8 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream, TlsConnectTestBase::kTlsVAll); INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram, TlsConnectTestBase::kTlsV11Plus); +INSTANTIATE_TEST_CASE_P(DatagramHolddown, TlsHolddownTest, + TlsConnectTestBase::kTlsV11Plus); INSTANTIATE_TEST_CASE_P( Pre12Stream, TlsConnectPre12, @@ -368,4 +537,27 @@ INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus, ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus)); -} // namespace nspr_test +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll, + ::testing::Values(true, false))); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumption, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus, + ::testing::Values(true, false))); + +INSTANTIATE_TEST_CASE_P( + GenericStream, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream, + TlsConnectTestBase::kTlsVAll)); +INSTANTIATE_TEST_CASE_P( + GenericDatagram, TlsConnectGenericResumptionToken, + ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram, + TlsConnectTestBase::kTlsV11Plus)); + +INSTANTIATE_TEST_CASE_P(GenericDatagram, TlsConnectTls13ResumptionToken, + TlsConnectTestBase::kTlsVariantsAll); + +} // namespace nss_test |