summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc394
1 files changed, 394 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
new file mode 100644
index 000000000..eda96831c
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -0,0 +1,394 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+#include <iostream>
+
+namespace nss_test {
+
+std::string GetSSLVersionString(uint16_t v) {
+ switch (v) {
+ case SSL_LIBRARY_VERSION_3_0:
+ return "ssl3";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "tls1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "tls1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "tls1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "tls1.3";
+ case SSL_LIBRARY_VERSION_NONE:
+ return "NONE";
+ }
+ if (v < SSL_LIBRARY_VERSION_3_0) {
+ return "undefined-too-low";
+ }
+ return "undefined-too-high";
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const SSLVersionRange& vr) {
+ return stream << GetSSLVersionString(vr.min) << ","
+ << GetSSLVersionString(vr.max);
+}
+
+class VersionRangeWithLabel {
+ public:
+ VersionRangeWithLabel(const std::string& label, const SSLVersionRange& vr)
+ : label_(label), vr_(vr) {}
+ VersionRangeWithLabel(const std::string& label, uint16_t min, uint16_t max)
+ : label_(label) {
+ vr_.min = min;
+ vr_.max = max;
+ }
+ VersionRangeWithLabel(const std::string& label) : label_(label) {
+ vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE;
+ }
+
+ void WriteStream(std::ostream& stream) const {
+ stream << " " << label_ << ": " << vr_;
+ }
+
+ uint16_t min() const { return vr_.min; }
+ uint16_t max() const { return vr_.max; }
+ SSLVersionRange range() const { return vr_; }
+
+ private:
+ std::string label_;
+ SSLVersionRange vr_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const VersionRangeWithLabel& vrwl) {
+ vrwl.WriteStream(stream);
+ return stream;
+}
+
+typedef std::tuple<SSLProtocolVariant, // variant
+ uint16_t, // policy min
+ uint16_t, // policy max
+ uint16_t, // input min
+ uint16_t> // input max
+ PolicyVersionRangeInput;
+
+class TestPolicyVersionRange
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<PolicyVersionRangeInput> {
+ public:
+ TestPolicyVersionRange()
+ : TlsConnectTestBase(std::get<0>(GetParam()), 0),
+ variant_(std::get<0>(GetParam())),
+ policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())),
+ input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())),
+ library_("supported-by-library",
+ ((variant_ == ssl_variant_stream)
+ ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM
+ : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM),
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED) {
+ TlsConnectTestBase::SkipVersionChecks();
+ }
+
+ void SetPolicy(const SSLVersionRange& policy) {
+ NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0);
+
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ void CreateDummySocket(std::shared_ptr<DummyPrSocket>* dummy_socket,
+ ScopedPRFileDesc* ssl_fd) {
+ (*dummy_socket).reset(new DummyPrSocket("dummy", variant_));
+ *ssl_fd = (*dummy_socket)->CreateFD();
+ if (variant_ == ssl_variant_stream) {
+ SSL_ImportFD(nullptr, ssl_fd->get());
+ } else {
+ DTLS_ImportFD(nullptr, ssl_fd->get());
+ }
+ }
+
+ bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2,
+ SSLVersionRange* overlap) {
+ if (r1.min == SSL_LIBRARY_VERSION_NONE ||
+ r1.max == SSL_LIBRARY_VERSION_NONE ||
+ r2.min == SSL_LIBRARY_VERSION_NONE ||
+ r2.max == SSL_LIBRARY_VERSION_NONE) {
+ return false;
+ }
+
+ SSLVersionRange temp;
+ temp.min = PR_MAX(r1.min, r2.min);
+ temp.max = PR_MIN(r1.max, r2.max);
+
+ if (temp.min > temp.max) {
+ return false;
+ }
+
+ *overlap = temp;
+ return true;
+ }
+
+ bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) {
+ if (input_.min() <= SSL_LIBRARY_VERSION_3_0 &&
+ input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // This is always invalid input, independent of policy
+ return false;
+ }
+
+ if (input_.min() < library_.min() || input_.max() > library_.max() ||
+ input_.min() > input_.max()) {
+ // Asking for unsupported ranges is invalid input for VersionRangeSet
+ // APIs, regardless of overlap.
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library;
+ if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) {
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library_and_policy;
+ if (!GetOverlap(overlap_with_library, policy_.range(),
+ &overlap_with_library_and_policy)) {
+ return false;
+ }
+
+ RemoveConflictingVersions(variant_, &overlap_with_library_and_policy);
+ *expectedEffectiveRange = overlap_with_library_and_policy;
+ return true;
+ }
+
+ void RemoveConflictingVersions(SSLProtocolVariant variant,
+ SSLVersionRange* r) {
+ ASSERT_TRUE(r != nullptr);
+ if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ r->min < SSL_LIBRARY_VERSION_TLS_1_0) {
+ r->min = SSL_LIBRARY_VERSION_TLS_1_0;
+ }
+ }
+
+ void SetUp() {
+ SetPolicy(policy_.range());
+ TlsConnectTestBase::SetUp();
+ }
+
+ void TearDown() {
+ TlsConnectTestBase::TearDown();
+ saved_version_policy_.RestoreOriginalPolicy();
+ }
+
+ protected:
+ class VersionPolicy {
+ public:
+ VersionPolicy() { SaveOriginalPolicy(); }
+
+ void RestoreOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ // If it wasn't set initially, clear the bit that we set.
+ if (!(saved_algorithm_policy_ & NSS_USE_POLICY_IN_SSL)) {
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, 0,
+ NSS_USE_POLICY_IN_SSL);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ }
+
+ private:
+ void SaveOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_GetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY,
+ &saved_algorithm_policy_);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ int32_t saved_min_tls_;
+ int32_t saved_max_tls_;
+ int32_t saved_min_dtls_;
+ int32_t saved_max_dtls_;
+ uint32_t saved_algorithm_policy_;
+ };
+
+ VersionPolicy saved_version_policy_;
+
+ SSLProtocolVariant variant_;
+ const VersionRangeWithLabel policy_;
+ const VersionRangeWithLabel input_;
+ const VersionRangeWithLabel library_;
+};
+
+static const uint16_t kExpandedVersionsArr[] = {
+ /* clang-format off */
+ SSL_LIBRARY_VERSION_3_0 - 1,
+ SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2,
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1
+ /* clang-format on */
+};
+static ::testing::internal::ParamGenerator<uint16_t> kExpandedVersions =
+ ::testing::ValuesIn(kExpandedVersionsArr);
+
+TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) {
+ ASSERT_TRUE(variant_ == ssl_variant_stream ||
+ variant_ == ssl_variant_datagram)
+ << "testing unsupported ssl variant";
+
+ std::cerr << "testing: " << variant_ << policy_ << input_ << library_
+ << std::endl;
+
+ SSLVersionRange supported_range;
+ SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range);
+ VersionRangeWithLabel supported("SSL_VersionRangeGetSupported",
+ supported_range);
+
+ std::cerr << supported << std::endl;
+
+ std::shared_ptr<DummyPrSocket> dummy_socket;
+ ScopedPRFileDesc ssl_fd;
+ CreateDummySocket(&dummy_socket, &ssl_fd);
+
+ SECStatus rv_socket;
+ SSLVersionRange overlap_policy_and_lib;
+ if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetSupported to fail with invalid policy";
+
+ SSLVersionRange enabled_range;
+ rv = SSL_VersionRangeGetDefault(variant_, &enabled_range);
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetDefault to fail with invalid policy";
+
+ SSLVersionRange enabled_range_on_socket;
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket);
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected SSL_VersionRangeGet to fail with invalid policy";
+
+ ConnectExpectFail();
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected SSL_VersionRangeGetSupported to succeed with valid policy";
+
+ EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE &&
+ supported_range.max != SSL_LIBRARY_VERSION_NONE)
+ << "expected SSL_VersionRangeGetSupported to return real values with "
+ "valid policy";
+
+ RemoveConflictingVersions(variant_, &overlap_policy_and_lib);
+ VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib);
+
+ EXPECT_TRUE(supported_range == overlap_policy_and_lib)
+ << "expected range from GetSupported to be identical with calculated "
+ "overlap "
+ << overlap_info;
+
+ // We don't know which versions are "enabled by default" by the library,
+ // therefore we don't know if there's overlap between the default
+ // and the policy, and therefore, we don't if TLS connections should
+ // be successful or fail in this combination.
+ // Therefore we don't test if we can connect, without having configured a
+ // version range explicitly.
+
+ // Now start testing with supplied input.
+
+ SSLVersionRange expected_effective_range;
+ bool is_valid_input =
+ IsValidInputForVersionRangeSet(&expected_effective_range);
+
+ SSLVersionRange temp_input = input_.range();
+ rv = SSL_VersionRangeSetDefault(variant_, &temp_input);
+ rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input);
+
+ if (!is_valid_input) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected failure return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected failure return from SSL_VersionRangeSet";
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeSet";
+
+ SSLVersionRange effective;
+ SSLVersionRange effective_socket;
+
+ rv = SSL_VersionRangeGetDefault(variant_, &effective);
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeGetDefault";
+
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket);
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeGet";
+
+ VersionRangeWithLabel expected_info("expectation", expected_effective_range);
+ VersionRangeWithLabel effective_info("effectively-enabled", effective);
+
+ EXPECT_TRUE(expected_effective_range == effective)
+ << "range returned by SSL_VersionRangeGetDefault doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ EXPECT_TRUE(expected_effective_range == effective_socket)
+ << "range returned by SSL_VersionRangeGet doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ // Because we found overlap between policy and supported versions,
+ // and because we have used SetDefault to enable at least one version,
+ // it should be possible to execute an SSL/TLS connection.
+ Connect();
+}
+
+INSTANTIATE_TEST_CASE_P(TLSVersionRanges, TestPolicyVersionRange,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ kExpandedVersions, kExpandedVersions,
+ kExpandedVersions,
+ kExpandedVersions));
+} // namespace nss_test