diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.h')
-rw-r--r-- | security/nss/gtests/ssl_gtest/test_io.h | 97 |
1 files changed, 57 insertions, 40 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.h b/security/nss/gtests/ssl_gtest/test_io.h index b78db0dc6..ac2497222 100644 --- a/security/nss/gtests/ssl_gtest/test_io.h +++ b/security/nss/gtests/ssl_gtest/test_io.h @@ -14,12 +14,15 @@ #include <queue> #include <string> +#include "databuffer.h" +#include "dummy_io.h" #include "prio.h" +#include "scoped_ptrs.h" +#include "sslt.h" namespace nss_test { class DataBuffer; -class Packet; class DummyPrSocket; // Fwd decl. // Allow us to inspect a packet before it is written. @@ -42,49 +45,59 @@ class PacketFilter { virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0; }; -enum Mode { STREAM, DGRAM }; - -inline std::ostream& operator<<(std::ostream& os, Mode m) { - return os << ((m == STREAM) ? "TLS" : "DTLS"); -} - -class DummyPrSocket { +class DummyPrSocket : public DummyIOLayerMethods { public: - ~DummyPrSocket(); + DummyPrSocket(const std::string& name, SSLProtocolVariant variant) + : name_(name), + variant_(variant), + peer_(), + input_(), + filter_(nullptr), + writeable_(true) {} + virtual ~DummyPrSocket() {} - static PRFileDesc* CreateFD(const std::string& name, - Mode mode); // Returns an FD. - static DummyPrSocket* GetAdapter(PRFileDesc* fd); + // Create a file descriptor that will reference this object. The fd must not + // live longer than this adapter; call PR_Close() before. + ScopedPRFileDesc CreateFD(); - DummyPrSocket* peer() const { return peer_; } - void SetPeer(DummyPrSocket* peer) { peer_ = peer; } - void SetPacketFilter(PacketFilter* filter); + std::weak_ptr<DummyPrSocket>& peer() { return peer_; } + void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; } + void SetPacketFilter(std::shared_ptr<PacketFilter> filter); // Drops peer, packet filter and any outstanding packets. void Reset(); void PacketReceived(const DataBuffer& data); - int32_t Read(void* data, int32_t len); - int32_t Recv(void* buf, int32_t buflen); - int32_t Write(const void* buf, int32_t length); + int32_t Read(PRFileDesc* f, void* data, int32_t len) override; + int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags, + PRIntervalTime to) override; + int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override; void CloseWrites() { writeable_ = false; } - Mode mode() const { return mode_; } + SSLProtocolVariant variant() const { return variant_; } bool readable() const { return !input_.empty(); } private: - DummyPrSocket(const std::string& name, Mode mode) - : name_(name), - mode_(mode), - peer_(nullptr), - input_(), - filter_(nullptr), - writeable_(true) {} + class Packet : public DataBuffer { + public: + Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} + + void Advance(size_t delta) { + PR_ASSERT(offset_ + delta <= len()); + offset_ = std::min(len(), offset_ + delta); + } + + size_t offset() const { return offset_; } + size_t remaining() const { return len() - offset_; } + + private: + size_t offset_; + }; const std::string name_; - Mode mode_; - DummyPrSocket* peer_; - std::queue<Packet*> input_; - PacketFilter* filter_; + SSLProtocolVariant variant_; + std::weak_ptr<DummyPrSocket> peer_; + std::queue<Packet> input_; + std::shared_ptr<PacketFilter> filter_; bool writeable_; }; @@ -111,40 +124,44 @@ class Poller { PollCallback callback_; }; - void Wait(Event event, DummyPrSocket* adapter, PollTarget* target, - PollCallback cb); - void Cancel(Event event, DummyPrSocket* adapter); + void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter, + PollTarget* target, PollCallback cb); + void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter); void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, - Timer** handle); + std::shared_ptr<Timer>* handle); bool Poll(); private: Poller() : waiters_(), timers_() {} - ~Poller(); + ~Poller() {} class Waiter { public: - Waiter(DummyPrSocket* io) : io_(io) { + Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) { + memset(&targets_[0], 0, sizeof(targets_)); memset(&callbacks_[0], 0, sizeof(callbacks_)); } void WaitFor(Event event, PollCallback callback); - DummyPrSocket* io_; + std::shared_ptr<DummyPrSocket> io_; PollTarget* targets_[TIMER_EVENT]; PollCallback callbacks_[TIMER_EVENT]; }; class TimerComparator { public: - bool operator()(const Timer* lhs, const Timer* rhs) { + bool operator()(const std::shared_ptr<Timer> lhs, + const std::shared_ptr<Timer> rhs) { return lhs->deadline_ > rhs->deadline_; } }; static Poller* instance; - std::map<DummyPrSocket*, Waiter*> waiters_; - std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_; + std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_; + std::priority_queue<std::shared_ptr<Timer>, + std::vector<std::shared_ptr<Timer>>, TimerComparator> + timers_; }; } // end of namespace |