summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_connect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_connect.cc49
1 files changed, 46 insertions, 3 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_connect.cc b/security/nss/gtests/ssl_gtest/tls_connect.cc
index 8567b392f..68f6d21e9 100644
--- a/security/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/security/nss/gtests/ssl_gtest/tls_connect.cc
@@ -571,14 +571,57 @@ void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
}
}
+static SECStatus NextProtoCallbackServer(void* arg, PRFileDesc* fd,
+ const unsigned char* protos,
+ unsigned int protos_len,
+ unsigned char* protoOut,
+ unsigned int* protoOutLen,
+ unsigned int protoMaxLen) {
+ EXPECT_EQ(protoMaxLen, 255U);
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ // Check that agent->alpn_value_to_use_ is in protos.
+ if (protos_len < 1) {
+ return SECFailure;
+ }
+ for (size_t i = 0; i < protos_len;) {
+ size_t l = protos[i];
+ EXPECT_LT(i + l, protos_len);
+ if (i + l >= protos_len) {
+ return SECFailure;
+ }
+ std::string protos_s(reinterpret_cast<const char*>(protos + i + 1), l);
+ if (protos_s == agent->alpn_value_to_use_) {
+ size_t s_len = agent->alpn_value_to_use_.size();
+ EXPECT_LE(s_len, 255U);
+ memcpy(protoOut, &agent->alpn_value_to_use_[0], s_len);
+ *protoOutLen = s_len;
+ return SECSuccess;
+ }
+ i += l + 1;
+ }
+ return SECFailure;
+}
+
void TlsConnectTestBase::EnableAlpn() {
client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
}
-void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) {
- client_->EnableAlpn(val, len);
- server_->EnableAlpn(val, len);
+void TlsConnectTestBase::EnableAlpnWithCallback(
+ const std::vector<uint8_t>& client_vals, std::string server_choice) {
+ EnsureTlsSetup();
+ server_->alpn_value_to_use_ = server_choice;
+ EXPECT_EQ(SECSuccess,
+ SSL_SetNextProtoNego(client_->ssl_fd(), client_vals.data(),
+ client_vals.size()));
+ SECStatus rv = SSL_SetNextProtoCallback(
+ server_->ssl_fd(), NextProtoCallbackServer, server_.get());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsConnectTestBase::EnableAlpn(const std::vector<uint8_t>& vals) {
+ client_->EnableAlpn(vals.data(), vals.size());
+ server_->EnableAlpn(vals.data(), vals.size());
}
void TlsConnectTestBase::EnsureModelSockets() {