summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/tls_agent.h
blob: b3fd892ae46f322082c82848178761bf70639bbb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_agent_h_
#define tls_agent_h_

#include "prio.h"
#include "ssl.h"

#include <functional>
#include <iostream>

#include "test_io.h"
#include "tls_filter.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "scoped_ptrs.h"

extern bool g_ssl_gtest_verbose;

namespace nss_test {

#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
#define LOGV(msg)                      \
  do {                                 \
    if (g_ssl_gtest_verbose) LOG(msg); \
  } while (false)

enum SessionResumptionMode {
  RESUME_NONE = 0,
  RESUME_SESSIONID = 1,
  RESUME_TICKET = 2,
  RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
};

class TlsAgent;

const extern std::vector<SSLNamedGroup> kAllDHEGroups;
const extern std::vector<SSLNamedGroup> kECDHEGroups;
const extern std::vector<SSLNamedGroup> kFFDHEGroups;
const extern std::vector<SSLNamedGroup> kFasterDHEGroups;

// These functions are called from callbacks.  They use bare pointers because
// TlsAgent sets up the callback and it doesn't know who owns it.
typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
    AuthCertificateCallbackFunction;

typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;

typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
                              PRUint32 srvNameArrSize)>
    SniCallbackFunction;

class TlsAgent : public PollTarget {
 public:
  enum Role { CLIENT, SERVER };
  enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };

  static const std::string kClient;     // the client key is sign only
  static const std::string kRsa2048;    // bigger sign and encrypt for either
  static const std::string kServerRsa;  // both sign and encrypt
  static const std::string kServerRsaSign;
  static const std::string kServerRsaPss;
  static const std::string kServerRsaDecrypt;
  static const std::string kServerEcdsa256;
  static const std::string kServerEcdsa384;
  static const std::string kServerEcdsa521;
  static const std::string kServerEcdhEcdsa;
  static const std::string kServerEcdhRsa;
  static const std::string kServerDsa;

  TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
  virtual ~TlsAgent();

  void SetPeer(std::shared_ptr<TlsAgent>& peer) {
    adapter_->SetPeer(peer->adapter_);
  }

  // Set a filter that can access plaintext (TLS 1.3 only).
  void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
    filter->SetAgent(this);
    adapter_->SetPacketFilter(filter);
    filter->EnableDecryption();
  }

  void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
    adapter_->SetPacketFilter(filter);
  }

  void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }

  void StartConnect(PRFileDesc* model = nullptr);
  void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
                size_t kea_size = 0) const;
  void CheckOriginalKEA(SSLNamedGroup kea_group) const;
  void CheckAuthType(SSLAuthType auth_type,
                     SSLSignatureScheme sig_scheme) const;

  void DisableAllCiphers();
  void EnableCiphersByAuthType(SSLAuthType authType);
  void EnableCiphersByKeyExchange(SSLKEAType kea);
  void EnableGroupsByKeyExchange(SSLKEAType kea);
  void EnableGroupsByAuthType(SSLAuthType authType);
  void EnableSingleCipher(uint16_t cipher);

  void Handshake();
  // Marks the internal state as CONNECTING in anticipation of renegotiation.
  void PrepareForRenegotiate();
  // Prepares for renegotiation, then actually triggers it.
  void StartRenegotiate();
  static bool LoadCertificate(const std::string& name,
                              ScopedCERTCertificate* cert,
                              ScopedSECKEYPrivateKey* priv);
  bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
                        const SSLExtraServerCertData* serverCertData = nullptr);
  bool ConfigServerCertWithChain(const std::string& name);
  bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);

  void SetupClientAuth();
  void RequestClientAuth(bool requireAuth);

  void SetOption(int32_t option, int value);
  void ConfigureSessionCache(SessionResumptionMode mode);
  void Set0RttEnabled(bool en);
  void SetFallbackSCSVEnabled(bool en);
  void SetVersionRange(uint16_t minver, uint16_t maxver);
  void GetVersionRange(uint16_t* minver, uint16_t* maxver);
  void CheckPreliminaryInfo();
  void ResetPreliminaryInfo();
  void SetExpectedVersion(uint16_t version);
  void SetServerKeyBits(uint16_t bits);
  void ExpectReadWriteError();
  void EnableFalseStart();
  void ExpectResumption();
  void SkipVersionChecks();
  void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
  void EnableAlpn(const uint8_t* val, size_t len);
  void CheckAlpn(SSLNextProtoState expected_state,
                 const std::string& expected = "") const;
  void EnableSrtp();
  void CheckSrtp() const;
  void CheckErrorCode(int32_t expected) const;
  void WaitForErrorCode(int32_t expected, uint32_t delay) const;
  // Send data on the socket, encrypting it.
  void SendData(size_t bytes, size_t blocksize = 1024);
  void SendBuffer(const DataBuffer& buf);
  bool SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
                           uint16_t wireVersion, uint64_t seq, uint8_t ct,
                           const DataBuffer& buf);
  // Send data directly to the underlying socket, skipping the TLS layer.
  void SendDirect(const DataBuffer& buf);
  void SendRecordDirect(const TlsRecord& record);
  void ReadBytes(size_t max = 16384U);
  void ResetSentBytes();  // Hack to test drops.
  void EnableExtendedMasterSecret();
  void CheckExtendedMasterSecret(bool expected);
  void CheckEarlyDataAccepted(bool expected);
  void SetDowngradeCheckVersion(uint16_t version);
  void CheckSecretsDestroyed();
  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
  void DisableECDHEServerKeyReuse();
  bool GetPeerChainLength(size_t* count);
  void CheckCipherSuite(uint16_t cipher_suite);

  const std::string& name() const { return name_; }

  Role role() const { return role_; }
  std::string role_str() const { return role_ == SERVER ? "server" : "client"; }

  SSLProtocolVariant variant() const { return variant_; }

  State state() const { return state_; }

  const CERTCertificate* peer_cert() const {
    return SSL_PeerCertificate(ssl_fd_.get());
  }

  const char* state_str() const { return state_str(state()); }

  static const char* state_str(State state) { return states[state]; }

  PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
  std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }

  bool is_compressed() const {
    return info_.compressionMethod != ssl_compression_null;
  }
  uint16_t server_key_bits() const { return server_key_bits_; }
  uint16_t min_version() const { return vrange_.min; }
  uint16_t max_version() const { return vrange_.max; }
  uint16_t version() const {
    EXPECT_EQ(STATE_CONNECTED, state_);
    return info_.protocolVersion;
  }

  bool cipher_suite(uint16_t* cipher_suite) const {
    if (state_ != STATE_CONNECTED) return false;

    *cipher_suite = info_.cipherSuite;
    return true;
  }

  std::string cipher_suite_name() const {
    if (state_ != STATE_CONNECTED) return "UNKNOWN";

    return csinfo_.cipherSuiteName;
  }

  std::vector<uint8_t> session_id() const {
    return std::vector<uint8_t>(info_.sessionID,
                                info_.sessionID + info_.sessionIDLength);
  }

  bool auth_type(SSLAuthType* auth_type) const {
    if (state_ != STATE_CONNECTED) return false;

    *auth_type = info_.authType;
    return true;
  }

  bool kea_type(SSLKEAType* kea_type) const {
    if (state_ != STATE_CONNECTED) return false;

    *kea_type = info_.keaType;
    return true;
  }

  size_t received_bytes() const { return recv_ctr_; }
  PRErrorCode error_code() const { return error_code_; }

  bool can_falsestart_hook_called() const {
    return can_falsestart_hook_called_;
  }

  void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
    handshake_callback_ = handshake_callback;
  }

  void SetAuthCertificateCallback(
      AuthCertificateCallbackFunction auth_certificate_callback) {
    auth_certificate_callback_ = auth_certificate_callback;
  }

  void SetSniCallback(SniCallbackFunction sni_callback) {
    sni_callback_ = sni_callback;
  }

  void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
  void ExpectSendAlert(uint8_t alert, uint8_t level = 0);

 private:
  const static char* states[];

  void SetState(State state);
  void ValidateCipherSpecs();

  // Dummy auth certificate hook.
  static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->auth_certificate_hook_called_ = true;
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  // Client auth certificate hook.
  static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
                                       PRBool checksig, PRBool isServer) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    EXPECT_TRUE(agent->expect_client_auth_);
    EXPECT_EQ(PR_TRUE, isServer);
    if (agent->auth_certificate_callback_) {
      return agent->auth_certificate_callback_(agent, checksig ? true : false,
                                               isServer ? true : false);
    }
    return SECSuccess;
  }

  static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
                                         CERTDistNames* caNames,
                                         CERTCertificate** cert,
                                         SECKEYPrivateKey** privKey);

  static void ReadableCallback(PollTarget* self, Event event) {
    TlsAgent* agent = static_cast<TlsAgent*>(self);
    if (event == TIMER_EVENT) {
      agent->timer_handle_ = nullptr;
    }
    agent->ReadableCallback_int();
  }

  void ReadableCallback_int() {
    LOGV("Readable");
    switch (state_) {
      case STATE_CONNECTING:
        Handshake();
        break;
      case STATE_CONNECTED:
        ReadBytes();
        break;
      default:
        break;
    }
  }

  static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
                         PRUint32 srvNameArrSize, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    agent->sni_hook_called_ = true;
    EXPECT_EQ(1UL, srvNameArrSize);
    if (agent->sni_callback_) {
      return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
    }
    return 0;  // First configuration.
  }

  static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
                                         PRBool* canFalseStart) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->CheckPreliminaryInfo();
    EXPECT_TRUE(agent->falsestart_enabled_);
    EXPECT_FALSE(agent->can_falsestart_hook_called_);
    agent->can_falsestart_hook_called_ = true;
    *canFalseStart = true;
    return SECSuccess;
  }

  void CheckAlert(bool sent, const SSLAlert* alert);

  static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
                                    const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
  }

  static void AlertSentCallback(const PRFileDesc* fd, void* arg,
                                const SSLAlert* alert) {
    reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
  }

  static void HandshakeCallback(PRFileDesc* fd, void* arg) {
    TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
    agent->handshake_callback_called_ = true;
    agent->Connected();
    if (agent->handshake_callback_) {
      agent->handshake_callback_(agent);
    }
  }

  void DisableLameGroups();
  void ConfigStrongECGroups(bool en);
  void ConfigAllDHGroups(bool en);
  void CheckCallbacks() const;
  void Connected();

  const std::string name_;
  SSLProtocolVariant variant_;
  Role role_;
  uint16_t server_key_bits_;
  std::shared_ptr<DummyPrSocket> adapter_;
  ScopedPRFileDesc ssl_fd_;
  State state_;
  std::shared_ptr<Poller::Timer> timer_handle_;
  bool falsestart_enabled_;
  uint16_t expected_version_;
  uint16_t expected_cipher_suite_;
  bool expect_resumption_;
  bool expect_client_auth_;
  bool can_falsestart_hook_called_;
  bool sni_hook_called_;
  bool auth_certificate_hook_called_;
  uint8_t expected_received_alert_;
  uint8_t expected_received_alert_level_;
  uint8_t expected_sent_alert_;
  uint8_t expected_sent_alert_level_;
  bool handshake_callback_called_;
  SSLChannelInfo info_;
  SSLCipherSuiteInfo csinfo_;
  SSLVersionRange vrange_;
  PRErrorCode error_code_;
  size_t send_ctr_;
  size_t recv_ctr_;
  bool expect_readwrite_error_;
  HandshakeCallbackFunction handshake_callback_;
  AuthCertificateCallbackFunction auth_certificate_callback_;
  SniCallbackFunction sni_callback_;
  bool skip_version_checks_;
};

inline std::ostream& operator<<(std::ostream& stream,
                                const TlsAgent::State& state) {
  return stream << TlsAgent::state_str(state);
}

class TlsAgentTestBase : public ::testing::Test {
 public:
  static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;

  TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
                   uint16_t version = 0)
      : agent_(nullptr),
        role_(role),
        variant_(variant),
        version_(version),
        sink_adapter_(new DummyPrSocket("sink", variant)) {}
  virtual ~TlsAgentTestBase() {}

  void SetUp();
  void TearDown();

  void ExpectAlert(uint8_t alert);

  static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
                         uint16_t version, const uint8_t* buf, size_t len,
                         DataBuffer* out, uint64_t seq_num = 0);
  void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
                  size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
                            DataBuffer* out, uint64_t seq_num = 0) const;
  void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
                                    size_t hs_len, DataBuffer* out,
                                    uint64_t seq_num, uint32_t fragment_offset,
                                    uint32_t fragment_length) const;
  static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
                                         DataBuffer* out);
  static inline TlsAgent::Role ToRole(const std::string& str) {
    return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
  }

  void Init(const std::string& server_name = TlsAgent::kServerRsa);
  void Reset(const std::string& server_name = TlsAgent::kServerRsa);

 protected:
  void EnsureInit();
  void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
                      int32_t error_code = 0);

  std::unique_ptr<TlsAgent> agent_;
  TlsAgent::Role role_;
  SSLProtocolVariant variant_;
  uint16_t version_;
  // This adapter is here just to accept packets from this agent.
  std::shared_ptr<DummyPrSocket> sink_adapter_;
};

class TlsAgentTest
    : public TlsAgentTestBase,
      public ::testing::WithParamInterface<
          std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTest()
      : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
                         std::get<1>(GetParam()), std::get<2>(GetParam())) {}
};

class TlsAgentTestClient : public TlsAgentTestBase,
                           public ::testing::WithParamInterface<
                               std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsAgentTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
                         std::get<1>(GetParam())) {}
};

class TlsAgentTestClient13 : public TlsAgentTestClient {};

class TlsAgentStreamTestClient : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
};

class TlsAgentStreamTestServer : public TlsAgentTestBase {
 public:
  TlsAgentStreamTestServer()
      : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
};

class TlsAgentDgramTestClient : public TlsAgentTestBase {
 public:
  TlsAgentDgramTestClient()
      : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
};

inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
  return vr1.min == vr2.min && vr1.max == vr2.max;
}

}  // namespace nss_test

#endif