summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.h53
1 files changed, 35 insertions, 18 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.h b/security/nss/gtests/ssl_gtest/tls_agent.h
index 4bccb9a84..6cd6d5073 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.h
+++ b/security/nss/gtests/ssl_gtest/tls_agent.h
@@ -14,7 +14,6 @@
#include <iostream>
#include "test_io.h"
-#include "tls_filter.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
@@ -37,7 +36,10 @@ enum SessionResumptionMode {
RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};
+class PacketFilter;
class TlsAgent;
+class TlsCipherSpec;
+struct TlsRecord;
const extern std::vector<SSLNamedGroup> kAllDHEGroups;
const extern std::vector<SSLNamedGroup> kECDHEGroups;
@@ -66,7 +68,6 @@ class TlsAgent : public PollTarget {
static const std::string kServerRsaSign;
static const std::string kServerRsaPss;
static const std::string kServerRsaDecrypt;
- static const std::string kServerRsaChain; // A cert that requires a chain.
static const std::string kServerEcdsa256;
static const std::string kServerEcdsa384;
static const std::string kServerEcdsa521;
@@ -81,20 +82,15 @@ class TlsAgent : public PollTarget {
adapter_->SetPeer(peer->adapter_);
}
- void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
- filter->SetAgent(this);
+ void SetFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
}
-
- void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
- adapter_->SetPacketFilter(filter);
- }
-
- void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+ void ClearFilter() { adapter_->SetPacketFilter(nullptr); }
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
+ void CheckOriginalKEA(SSLNamedGroup kea_group) const;
void CheckAuthType(SSLAuthType auth_type,
SSLSignatureScheme sig_scheme) const;
@@ -121,12 +117,10 @@ class TlsAgent : public PollTarget {
void SetupClientAuth();
void RequestClientAuth(bool requireAuth);
+ void SetOption(int32_t option, int value);
void ConfigureSessionCache(SessionResumptionMode mode);
- void SetSessionTicketsEnabled(bool en);
- void SetSessionCacheEnabled(bool en);
void Set0RttEnabled(bool en);
void SetFallbackSCSVEnabled(bool en);
- void SetShortHeadersEnabled();
void SetVersionRange(uint16_t minver, uint16_t maxver);
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
void CheckPreliminaryInfo();
@@ -136,7 +130,6 @@ class TlsAgent : public PollTarget {
void ExpectReadWriteError();
void EnableFalseStart();
void ExpectResumption();
- void ExpectShortHeaders();
void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
@@ -149,27 +142,49 @@ class TlsAgent : public PollTarget {
// Send data on the socket, encrypting it.
void SendData(size_t bytes, size_t blocksize = 1024);
void SendBuffer(const DataBuffer& buf);
+ bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
+ uint16_t wireVersion, uint64_t seq, uint8_t ct,
+ const DataBuffer& buf);
// Send data directly to the underlying socket, skipping the TLS layer.
void SendDirect(const DataBuffer& buf);
+ void SendRecordDirect(const TlsRecord& record);
void ReadBytes(size_t max = 16384U);
void ResetSentBytes(); // Hack to test drops.
void EnableExtendedMasterSecret();
void CheckExtendedMasterSecret(bool expected);
void CheckEarlyDataAccepted(bool expected);
- void DisableRollbackDetection();
- void EnableCompression();
void SetDowngradeCheckVersion(uint16_t version);
void CheckSecretsDestroyed();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
void DisableECDHEServerKeyReuse();
bool GetPeerChainLength(size_t* count);
void CheckCipherSuite(uint16_t cipher_suite);
+ void SetResumptionTokenCallback();
+ bool MaybeSetResumptionToken();
+ void SetResumptionToken(const std::vector<uint8_t>& resumption_token) {
+ resumption_token_ = resumption_token;
+ }
+ const std::vector<uint8_t>& GetResumptionToken() const {
+ return resumption_token_;
+ }
+ void GetTokenInfo(ScopedSSLResumptionTokenInfo& token) {
+ SECStatus rv = SSL_GetResumptionTokenInfo(
+ resumption_token_.data(), resumption_token_.size(), token.get(),
+ sizeof(SSLResumptionTokenInfo));
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ void SetResumptionCallbackCalled() { resumption_callback_called_ = true; }
+ bool resumption_callback_called() const {
+ return resumption_callback_called_;
+ }
const std::string& name() const { return name_; }
Role role() const { return role_; }
std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+ SSLProtocolVariant variant() const { return variant_; }
+
State state() const { return state_; }
const CERTCertificate* peer_cert() const {
@@ -253,6 +268,7 @@ class TlsAgent : public PollTarget {
const static char* states[];
void SetState(State state);
+ void ValidateCipherSpecs();
// Dummy auth certificate hook.
static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
@@ -378,6 +394,7 @@ class TlsAgent : public PollTarget {
uint8_t expected_sent_alert_;
uint8_t expected_sent_alert_level_;
bool handshake_callback_called_;
+ bool resumption_callback_called_;
SSLChannelInfo info_;
SSLCipherSuiteInfo csinfo_;
SSLVersionRange vrange_;
@@ -388,8 +405,8 @@ class TlsAgent : public PollTarget {
HandshakeCallbackFunction handshake_callback_;
AuthCertificateCallbackFunction auth_certificate_callback_;
SniCallbackFunction sni_callback_;
- bool expect_short_headers_;
bool skip_version_checks_;
+ std::vector<uint8_t> resumption_token_;
};
inline std::ostream& operator<<(std::ostream& stream,
@@ -440,7 +457,7 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- std::unique_ptr<TlsAgent> agent_;
+ std::shared_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
SSLProtocolVariant variant_;
uint16_t version_;