diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_connect.cc | 49 |
1 files changed, 46 insertions, 3 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc index 8567b392f..68f6d21e9 100644 --- a/security/nss/gtests/ssl_gtest/tls_connect.cc +++ b/security/nss/gtests/ssl_gtest/tls_connect.cc @@ -571,14 +571,57 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { } } +static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd, + const unsigned char* protos, + unsigned int protos_len, + unsigned char* protoOut, + unsigned int* protoOutLen, + unsigned int protoMaxLen) { + EXPECT_EQ(protoMaxLen, 255U); + TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); + // Check that agent->alpn_value_to_use_ is in protos. + if (protos_len < 1) { + return SECFailure; + } + for (size_t i = 0; i < protos_len;) { + size_t l = protos[i]; + EXPECT_LT(i + l, protos_len); + if (i + l >= protos_len) { + return SECFailure; + } + std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l); + if (protos_s == agent->alpn_value_to_use_) { + size_t s_len = agent->alpn_value_to_use_.size(); + EXPECT_LE(s_len, 255U); + memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len); + *protoOutLen = s_len; + return SECSuccess; + } + i += l + 1; + } + return SECFailure; +} + void TlsConnectTestBase::EnableAlpn() { client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_)); } -void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) { - client_->EnableAlpn(val, len); - server_->EnableAlpn(val, len); +void TlsConnectTestBase::EnableAlpnWithCallback( + const std::vector<uint8_t>& client_vals, std::string server_choice) { + EnsureTlsSetup(); + server_->alpn_value_to_use_ = server_choice; + EXPECT_EQ(SECSuccess, + SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(), + client_vals.size())); + SECStatus rv = SSL_SetNextProtoCallback( + server_->ssl_fd(), NextProtoCallbackServer, server_.get()); + EXPECT_EQ(SECSuccess, rv); +} + +void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) { + client_->EnableAlpn(vals.data(), vals.size()); + server_->EnableAlpn(vals.data(), vals.size()); } void TlsConnectTestBase::EnsureModelSockets() { |