summaryrefslogtreecommitdiffstats
path: root/security/nss/gtests/ssl_gtest/test_io.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/test_io.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/test_io.cc536
1 files changed, 536 insertions, 0 deletions
diff --git a/security/nss/gtests/ssl_gtest/test_io.cc b/security/nss/gtests/ssl_gtest/test_io.cc
new file mode 100644
index 000000000..f3fd0b24c
--- /dev/null
+++ b/security/nss/gtests/ssl_gtest/test_io.cc
@@ -0,0 +1,536 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "test_io.h"
+
+#include <algorithm>
+#include <cassert>
+#include <iostream>
+#include <memory>
+
+#include "prerror.h"
+#include "prlog.h"
+#include "prthread.h"
+
+#include "databuffer.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER;
+
+#define UNIMPLEMENTED() \
+ std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \
+ PR_ASSERT(PR_FALSE); \
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
+
+#define LOG(a) std::cerr << name_ << ": " << a << std::endl
+#define LOGV(a) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(a); \
+ } while (false)
+
+class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {}
+
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
+ }
+
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
+};
+
+// Implementation of NSPR methods
+static PRStatus DummyClose(PRFileDesc *f) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ f->secret = nullptr;
+ f->dtor(f);
+ delete io;
+ return PR_SUCCESS;
+}
+
+static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ return io->Read(buf, length);
+}
+
+static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ return io->Write(buf, length);
+}
+
+static int32_t DummyAvailable(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+int64_t DummyAvailable64(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummySync(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return nullptr;
+}
+
+static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyListen(PRFileDesc *f, int32_t depth) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ io->Reset();
+ return PR_SUCCESS;
+}
+
+// This function does not support peek.
+static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen,
+ int32_t flags, PRIntervalTime to) {
+ PR_ASSERT(flags == 0);
+ if (flags != 0) {
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
+ return -1;
+ }
+
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+
+ if (io->mode() == DGRAM) {
+ return io->Recv(buf, buflen);
+ } else {
+ return io->Read(buf, buflen);
+ }
+}
+
+// Note: this is always nonblocking and assumes a zero timeout.
+static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount,
+ int32_t flags, PRIntervalTime to) {
+ int32_t written = DummyWrite(f, buf, amount);
+ return written;
+}
+
+static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount,
+ int32_t flags, PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount,
+ int32_t flags, const PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
+ PRNetAddr **raddr, void *buf, int32_t amount,
+ PRIntervalTime t) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f,
+ const void *headers, int32_t hlen,
+ PRTransmitFileFlags flags, PRIntervalTime t) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) {
+ // TODO: Modify to return unique names for each channel
+ // somehow, as opposed to always the same static address. The current
+ // implementation messes up the session cache, which is why it's off
+ // elsewhere
+ addr->inet.family = PR_AF_INET;
+ addr->inet.port = 0;
+ addr->inet.ip = 0;
+
+ return PR_SUCCESS;
+}
+
+static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) {
+ switch (opt->option) {
+ case PR_SockOpt_Nonblocking:
+ opt->value.non_blocking = PR_TRUE;
+ return PR_SUCCESS;
+ default:
+ UNIMPLEMENTED();
+ break;
+ }
+
+ return PR_FAILURE;
+}
+
+// Imitate setting socket options. These are mostly noops.
+static PRStatus DummySetsockoption(PRFileDesc *f,
+ const PRSocketOptionData *opt) {
+ switch (opt->option) {
+ case PR_SockOpt_Nonblocking:
+ return PR_SUCCESS;
+ case PR_SockOpt_NoDelay:
+ return PR_SUCCESS;
+ default:
+ UNIMPLEMENTED();
+ break;
+ }
+
+ return PR_FAILURE;
+}
+
+static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in,
+ PRTransmitFileFlags flags, PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummyReserved(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+DummyPrSocket::~DummyPrSocket() { Reset(); }
+
+void DummyPrSocket::SetPacketFilter(PacketFilter *filter) {
+ if (filter_) {
+ delete filter_;
+ }
+ filter_ = filter;
+}
+
+void DummyPrSocket::Reset() {
+ delete filter_;
+ if (peer_) {
+ peer_->SetPeer(nullptr);
+ peer_ = nullptr;
+ }
+ while (!input_.empty()) {
+ Packet *front = input_.front();
+ input_.pop();
+ delete front;
+ }
+}
+
+static const struct PRIOMethods DummyMethods = {
+ PR_DESC_LAYERED, DummyClose,
+ DummyRead, DummyWrite,
+ DummyAvailable, DummyAvailable64,
+ DummySync, DummySeek,
+ DummySeek64, DummyFileInfo,
+ DummyFileInfo64, DummyWritev,
+ DummyConnect, DummyAccept,
+ DummyBind, DummyListen,
+ DummyShutdown, DummyRecv,
+ DummySend, DummyRecvfrom,
+ DummySendto, DummyPoll,
+ DummyAcceptRead, DummyTransmitFile,
+ DummyGetsockname, DummyGetpeername,
+ DummyReserved, DummyReserved,
+ DummyGetsockoption, DummySetsockoption,
+ DummySendfile, DummyConnectContinue,
+ DummyReserved, DummyReserved,
+ DummyReserved, DummyReserved};
+
+PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
+ if (test_fd_identity == PR_INVALID_IO_LAYER) {
+ test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
+ }
+
+ PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
+ fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));
+
+ return fd;
+}
+
+DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
+ return reinterpret_cast<DummyPrSocket *>(fd->secret);
+}
+
+void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
+ input_.push(new Packet(packet));
+}
+
+int32_t DummyPrSocket::Read(void *data, int32_t len) {
+ PR_ASSERT(mode_ == STREAM);
+
+ if (mode_ != STREAM) {
+ PR_SetError(PR_INVALID_METHOD_ERROR, 0);
+ return -1;
+ }
+
+ if (input_.empty()) {
+ LOGV("Read --> wouldblock " << len);
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ Packet *front = input_.front();
+ size_t to_read =
+ std::min(static_cast<size_t>(len), front->len() - front->offset());
+ memcpy(data, static_cast<const void *>(front->data() + front->offset()),
+ to_read);
+ front->Advance(to_read);
+
+ if (!front->remaining()) {
+ input_.pop();
+ delete front;
+ }
+
+ return static_cast<int32_t>(to_read);
+}
+
+int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
+ if (input_.empty()) {
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ Packet *front = input_.front();
+ if (static_cast<size_t>(buflen) < front->len()) {
+ PR_ASSERT(false);
+ PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
+ return -1;
+ }
+
+ size_t count = front->len();
+ memcpy(buf, front->data(), count);
+
+ input_.pop();
+ delete front;
+
+ return static_cast<int32_t>(count);
+}
+
+int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
+ if (!peer_ || !writeable_) {
+ PR_SetError(PR_IO_ERROR, 0);
+ return -1;
+ }
+
+ DataBuffer packet(static_cast<const uint8_t *>(buf),
+ static_cast<size_t>(length));
+ DataBuffer filtered;
+ PacketFilter::Action action = PacketFilter::KEEP;
+ if (filter_) {
+ action = filter_->Filter(packet, &filtered);
+ }
+ switch (action) {
+ case PacketFilter::CHANGE:
+ LOG("Original packet: " << packet);
+ LOG("Filtered packet: " << filtered);
+ peer_->PacketReceived(filtered);
+ break;
+ case PacketFilter::DROP:
+ LOG("Droppped packet: " << packet);
+ break;
+ case PacketFilter::KEEP:
+ LOGV("Packet: " << packet);
+ peer_->PacketReceived(packet);
+ break;
+ }
+ // libssl can't handle it if this reports something other than the length
+ // of what was passed in (or less, but we're not doing partial writes).
+ return static_cast<int32_t>(packet.len());
+}
+
+Poller *Poller::instance;
+
+Poller *Poller::Instance() {
+ if (!instance) instance = new Poller();
+
+ return instance;
+}
+
+void Poller::Shutdown() {
+ delete instance;
+ instance = nullptr;
+}
+
+Poller::~Poller() {
+ while (!timers_.empty()) {
+ Timer *timer = timers_.top();
+ timers_.pop();
+ delete timer;
+ }
+}
+
+void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
+ PollCallback cb) {
+ auto it = waiters_.find(adapter);
+ Waiter *waiter;
+
+ if (it == waiters_.end()) {
+ waiter = new Waiter(adapter);
+ } else {
+ waiter = it->second;
+ }
+
+ assert(event < TIMER_EVENT);
+ if (event >= TIMER_EVENT) return;
+
+ waiter->targets_[event] = target;
+ waiter->callbacks_[event] = cb;
+ waiters_[adapter] = waiter;
+}
+
+void Poller::Cancel(Event event, DummyPrSocket *adapter) {
+ auto it = waiters_.find(adapter);
+ Waiter *waiter;
+
+ if (it == waiters_.end()) {
+ return;
+ }
+
+ waiter = it->second;
+
+ waiter->targets_[event] = nullptr;
+ waiter->callbacks_[event] = nullptr;
+
+ // Clean up if there are no callbacks.
+ for (size_t i = 0; i < TIMER_EVENT; ++i) {
+ if (waiter->callbacks_[i]) return;
+ }
+
+ delete waiter;
+ waiters_.erase(adapter);
+}
+
+void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
+ Timer **timer) {
+ Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb);
+ timers_.push(t);
+ if (timer) *timer = t;
+}
+
+bool Poller::Poll() {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Poll() waiters = " << waiters_.size()
+ << " timers = " << timers_.size() << std::endl;
+ }
+ PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT;
+ PRTime now = PR_Now();
+ bool fired = false;
+
+ // Figure out the timer for the select.
+ if (!timers_.empty()) {
+ Timer *first_timer = timers_.top();
+ if (now >= first_timer->deadline_) {
+ // Timer expired.
+ timeout = PR_INTERVAL_NO_WAIT;
+ } else {
+ timeout =
+ PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000);
+ }
+ }
+
+ for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
+ Waiter *waiter = it->second;
+
+ if (waiter->callbacks_[READABLE_EVENT]) {
+ if (waiter->io_->readable()) {
+ PollCallback callback = waiter->callbacks_[READABLE_EVENT];
+ PollTarget *target = waiter->targets_[READABLE_EVENT];
+ waiter->callbacks_[READABLE_EVENT] = nullptr;
+ waiter->targets_[READABLE_EVENT] = nullptr;
+ callback(target, READABLE_EVENT);
+ fired = true;
+ }
+ }
+ }
+
+ if (fired) timeout = PR_INTERVAL_NO_WAIT;
+
+ // Can't wait forever and also have nothing readable now.
+ if (timeout == PR_INTERVAL_NO_TIMEOUT) return false;
+
+ // Sleep.
+ if (timeout != PR_INTERVAL_NO_WAIT) {
+ PR_Sleep(timeout);
+ }
+
+ // Now process anything that timed out.
+ now = PR_Now();
+ while (!timers_.empty()) {
+ if (now < timers_.top()->deadline_) break;
+
+ Timer *timer = timers_.top();
+ timers_.pop();
+ if (timer->callback_) {
+ timer->callback_(timer->target_, TIMER_EVENT);
+ }
+ delete timer;
+ }
+
+ return true;
+}
+
+} // namespace nss_test