summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/test_io.h
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.h')
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.h97
1 files changed, 57 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h
index b78db0dc6..ac2497222 100644
--- a/security/nss/gtests/ssl_gtest/test_io.h
+++ b/security/nss/gtests/ssl_gtest/test_io.h
@@ -14,12 +14,15 @@
#include <queue>
#include <string>
+#include "databuffer.h"
+#include "dummy_io.h"
#include "prio.h"
+#include "scoped_ptrs.h"
+#include "sslt.h"
namespace nss_test {
class DataBuffer;
-class Packet;
class DummyPrSocket; // Fwd decl.
// Allow us to inspect a packet before it is written.
@@ -42,49 +45,59 @@ class PacketFilter {
virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
};
-enum Mode { STREAM, DGRAM };
-
-inline std::ostream& operator<<(std::ostream& os, Mode m) {
- return os << ((m == STREAM) ? "TLS" : "DTLS");
-}
-
-class DummyPrSocket {
+class DummyPrSocket : public DummyIOLayerMethods {
public:
- ~DummyPrSocket();
+ DummyPrSocket(const std::string& name, SSLProtocolVariant variant)
+ : name_(name),
+ variant_(variant),
+ peer_(),
+ input_(),
+ filter_(nullptr),
+ writeable_(true) {}
+ virtual ~DummyPrSocket() {}
- static PRFileDesc* CreateFD(const std::string& name,
- Mode mode); // Returns an FD.
- static DummyPrSocket* GetAdapter(PRFileDesc* fd);
+ // Create a file descriptor that will reference this object. The fd must not
+ // live longer than this adapter; call PR_Close() before.
+ ScopedPRFileDesc CreateFD();
- DummyPrSocket* peer() const { return peer_; }
- void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
- void SetPacketFilter(PacketFilter* filter);
+ std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
+ void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
+ void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
// Drops peer, packet filter and any outstanding packets.
void Reset();
void PacketReceived(const DataBuffer& data);
- int32_t Read(void* data, int32_t len);
- int32_t Recv(void* buf, int32_t buflen);
- int32_t Write(const void* buf, int32_t length);
+ int32_t Read(PRFileDesc* f, void* data, int32_t len) override;
+ int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
+ PRIntervalTime to) override;
+ int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
void CloseWrites() { writeable_ = false; }
- Mode mode() const { return mode_; }
+ SSLProtocolVariant variant() const { return variant_; }
bool readable() const { return !input_.empty(); }
private:
- DummyPrSocket(const std::string& name, Mode mode)
- : name_(name),
- mode_(mode),
- peer_(nullptr),
- input_(),
- filter_(nullptr),
- writeable_(true) {}
+ class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
+
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
+ }
+
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
+ };
const std::string name_;
- Mode mode_;
- DummyPrSocket* peer_;
- std::queue<Packet*> input_;
- PacketFilter* filter_;
+ SSLProtocolVariant variant_;
+ std::weak_ptr<DummyPrSocket> peer_;
+ std::queue<Packet> input_;
+ std::shared_ptr<PacketFilter> filter_;
bool writeable_;
};
@@ -111,40 +124,44 @@ class Poller {
PollCallback callback_;
};
- void Wait(Event event, DummyPrSocket* adapter, PollTarget* target,
- PollCallback cb);
- void Cancel(Event event, DummyPrSocket* adapter);
+ void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
+ PollTarget* target, PollCallback cb);
+ void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
- Timer** handle);
+ std::shared_ptr<Timer>* handle);
bool Poll();
private:
Poller() : waiters_(), timers_() {}
- ~Poller();
+ ~Poller() {}
class Waiter {
public:
- Waiter(DummyPrSocket* io) : io_(io) {
+ Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
+ memset(&targets_[0], 0, sizeof(targets_));
memset(&callbacks_[0], 0, sizeof(callbacks_));
}
void WaitFor(Event event, PollCallback callback);
- DummyPrSocket* io_;
+ std::shared_ptr<DummyPrSocket> io_;
PollTarget* targets_[TIMER_EVENT];
PollCallback callbacks_[TIMER_EVENT];
};
class TimerComparator {
public:
- bool operator()(const Timer* lhs, const Timer* rhs) {
+ bool operator()(const std::shared_ptr<Timer> lhs,
+ const std::shared_ptr<Timer> rhs) {
return lhs->deadline_ > rhs->deadline_;
}
};
static Poller* instance;
- std::map<DummyPrSocket*, Waiter*> waiters_;
- std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_;
+ std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
+ std::priority_queue<std::shared_ptr<Timer>,
+ std::vector<std::shared_ptr<Timer>>, TimerComparator>
+ timers_;
};
} // end of namespace