diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/test_io.cc | 386 |
1 files changed, 50 insertions, 336 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc index f3fd0b24c..b9f0c672e 100644 --- a/security/nss/gtests/ssl_gtest/test_io.cc +++ b/security/nss/gtests/ssl_gtest/test_io.cc @@ -15,314 +15,33 @@ #include "prlog.h" #include "prthread.h" -#include "databuffer.h" - extern bool g_ssl_gtest_verbose; namespace nss_test { -static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER; - -#define UNIMPLEMENTED() \ - std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \ - PR_ASSERT(PR_FALSE); \ - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0) - #define LOG(a) std::cerr << name_ << ": " << a << std::endl #define LOGV(a) \ do { \ if (g_ssl_gtest_verbose) LOG(a); \ } while (false) -class Packet : public DataBuffer { - public: - Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {} - - void Advance(size_t delta) { - PR_ASSERT(offset_ + delta <= len()); - offset_ = std::min(len(), offset_ + delta); - } - - size_t offset() const { return offset_; } - size_t remaining() const { return len() - offset_; } - - private: - size_t offset_; -}; - -// Implementation of NSPR methods -static PRStatus DummyClose(PRFileDesc *f) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - f->secret = nullptr; - f->dtor(f); - delete io; - return PR_SUCCESS; -} - -static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - return io->Read(buf, length); -} - -static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - return io->Write(buf, length); -} - -static int32_t DummyAvailable(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -int64_t DummyAvailable64(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummySync(PRFileDesc *f) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return nullptr; -} - -static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyListen(PRFileDesc *f, int32_t depth) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - io->Reset(); - return PR_SUCCESS; -} - -// This function does not support peek. -static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen, - int32_t flags, PRIntervalTime to) { - PR_ASSERT(flags == 0); - if (flags != 0) { - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); - return -1; - } - - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - - if (io->mode() == DGRAM) { - return io->Recv(buf, buflen); - } else { - return io->Read(buf, buflen); - } -} - -// Note: this is always nonblocking and assumes a zero timeout. -static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, PRIntervalTime to) { - int32_t written = DummyWrite(f, buf, amount); - return written; -} - -static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount, - int32_t flags, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd, - PRNetAddr **raddr, void *buf, int32_t amount, - PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f, - const void *headers, int32_t hlen, - PRTransmitFileFlags flags, PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) { - // TODO: Modify to return unique names for each channel - // somehow, as opposed to always the same static address. The current - // implementation messes up the session cache, which is why it's off - // elsewhere - addr->inet.family = PR_AF_INET; - addr->inet.port = 0; - addr->inet.ip = 0; - - return PR_SUCCESS; -} - -static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - opt->value.non_blocking = PR_TRUE; - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -// Imitate setting socket options. These are mostly noops. -static PRStatus DummySetsockoption(PRFileDesc *f, - const PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - return PR_SUCCESS; - case PR_SockOpt_NoDelay: - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in, - PRTransmitFileFlags flags, PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyReserved(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -DummyPrSocket::~DummyPrSocket() { Reset(); } - -void DummyPrSocket::SetPacketFilter(PacketFilter *filter) { - if (filter_) { - delete filter_; - } +void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) { filter_ = filter; } -void DummyPrSocket::Reset() { - delete filter_; - if (peer_) { - peer_->SetPeer(nullptr); - peer_ = nullptr; - } - while (!input_.empty()) { - Packet *front = input_.front(); - input_.pop(); - delete front; - } -} - -static const struct PRIOMethods DummyMethods = { - PR_DESC_LAYERED, DummyClose, - DummyRead, DummyWrite, - DummyAvailable, DummyAvailable64, - DummySync, DummySeek, - DummySeek64, DummyFileInfo, - DummyFileInfo64, DummyWritev, - DummyConnect, DummyAccept, - DummyBind, DummyListen, - DummyShutdown, DummyRecv, - DummySend, DummyRecvfrom, - DummySendto, DummyPoll, - DummyAcceptRead, DummyTransmitFile, - DummyGetsockname, DummyGetpeername, - DummyReserved, DummyReserved, - DummyGetsockoption, DummySetsockoption, - DummySendfile, DummyConnectContinue, - DummyReserved, DummyReserved, - DummyReserved, DummyReserved}; - -PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) { - if (test_fd_identity == PR_INVALID_IO_LAYER) { - test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); - } - - PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods)); - fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode)); - - return fd; -} - -DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { - return reinterpret_cast<DummyPrSocket *>(fd->secret); +ScopedPRFileDesc DummyPrSocket::CreateFD() { + static PRDescIdentity test_fd_identity = + PR_GetUniqueIdentity("testtransportadapter"); + return DummyIOLayerMethods::CreateFD(test_fd_identity, this); } void DummyPrSocket::PacketReceived(const DataBuffer &packet) { - input_.push(new Packet(packet)); + input_.push(Packet(packet)); } -int32_t DummyPrSocket::Read(void *data, int32_t len) { - PR_ASSERT(mode_ == STREAM); - - if (mode_ != STREAM) { +int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) { + PR_ASSERT(variant_ == ssl_variant_stream); + if (variant_ != ssl_variant_stream) { PR_SetError(PR_INVALID_METHOD_ERROR, 0); return -1; } @@ -333,45 +52,54 @@ int32_t DummyPrSocket::Read(void *data, int32_t len) { return -1; } - Packet *front = input_.front(); + auto &front = input_.front(); size_t to_read = - std::min(static_cast<size_t>(len), front->len() - front->offset()); - memcpy(data, static_cast<const void *>(front->data() + front->offset()), + std::min(static_cast<size_t>(len), front.len() - front.offset()); + memcpy(data, static_cast<const void *>(front.data() + front.offset()), to_read); - front->Advance(to_read); + front.Advance(to_read); - if (!front->remaining()) { + if (!front.remaining()) { input_.pop(); - delete front; } return static_cast<int32_t>(to_read); } -int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { +int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen, + int32_t flags, PRIntervalTime to) { + PR_ASSERT(flags == 0); + if (flags != 0) { + PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); + return -1; + } + + if (variant() != ssl_variant_datagram) { + return Read(f, buf, buflen); + } + if (input_.empty()) { PR_SetError(PR_WOULD_BLOCK_ERROR, 0); return -1; } - Packet *front = input_.front(); - if (static_cast<size_t>(buflen) < front->len()) { + auto &front = input_.front(); + if (static_cast<size_t>(buflen) < front.len()) { PR_ASSERT(false); PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); return -1; } - size_t count = front->len(); - memcpy(buf, front->data(), count); + size_t count = front.len(); + memcpy(buf, front.data(), count); input_.pop(); - delete front; - return static_cast<int32_t>(count); } -int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - if (!peer_ || !writeable_) { +int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) { + auto peer = peer_.lock(); + if (!peer || !writeable_) { PR_SetError(PR_IO_ERROR, 0); return -1; } @@ -387,14 +115,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) { case PacketFilter::CHANGE: LOG("Original packet: " << packet); LOG("Filtered packet: " << filtered); - peer_->PacketReceived(filtered); + peer->PacketReceived(filtered); break; case PacketFilter::DROP: LOG("Droppped packet: " << packet); break; case PacketFilter::KEEP: LOGV("Packet: " << packet); - peer_->PacketReceived(packet); + peer->PacketReceived(packet); break; } // libssl can't handle it if this reports something other than the length @@ -415,43 +143,31 @@ void Poller::Shutdown() { instance = nullptr; } -Poller::~Poller() { - while (!timers_.empty()) { - Timer *timer = timers_.top(); - timers_.pop(); - delete timer; - } -} +void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter, + PollTarget *target, PollCallback cb) { + assert(event < TIMER_EVENT); + if (event >= TIMER_EVENT) return; -void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target, - PollCallback cb) { + std::unique_ptr<Waiter> waiter; auto it = waiters_.find(adapter); - Waiter *waiter; - if (it == waiters_.end()) { - waiter = new Waiter(adapter); + waiter.reset(new Waiter(adapter)); } else { - waiter = it->second; + waiter = std::move(it->second); } - assert(event < TIMER_EVENT); - if (event >= TIMER_EVENT) return; - waiter->targets_[event] = target; waiter->callbacks_[event] = cb; - waiters_[adapter] = waiter; + waiters_[adapter] = std::move(waiter); } -void Poller::Cancel(Event event, DummyPrSocket *adapter) { +void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) { auto it = waiters_.find(adapter); - Waiter *waiter; - if (it == waiters_.end()) { return; } - waiter = it->second; - + auto &waiter = it->second; waiter->targets_[event] = nullptr; waiter->callbacks_[event] = nullptr; @@ -460,13 +176,12 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) { if (waiter->callbacks_[i]) return; } - delete waiter; waiters_.erase(adapter); } void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, - Timer **timer) { - Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb); + std::shared_ptr<Timer> *timer) { + auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb); timers_.push(t); if (timer) *timer = t; } @@ -482,7 +197,7 @@ bool Poller::Poll() { // Figure out the timer for the select. if (!timers_.empty()) { - Timer *first_timer = timers_.top(); + auto first_timer = timers_.top(); if (now >= first_timer->deadline_) { // Timer expired. timeout = PR_INTERVAL_NO_WAIT; @@ -493,7 +208,7 @@ bool Poller::Poll() { } for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { - Waiter *waiter = it->second; + auto &waiter = it->second; if (waiter->callbacks_[READABLE_EVENT]) { if (waiter->io_->readable()) { @@ -522,12 +237,11 @@ bool Poller::Poll() { while (!timers_.empty()) { if (now < timers_.top()->deadline_) break; - Timer *timer = timers_.top(); + auto timer = timers_.top(); timers_.pop(); if (timer->callback_) { timer->callback_(timer->target_, TIMER_EVENT); } - delete timer; } return true; |