/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=8 sts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "mozilla/dom/HttpServer.h" #include "nsISocketTransport.h" #include "nsWhitespaceTokenizer.h" #include "nsNetUtil.h" #include "nsIStreamTransportService.h" #include "nsIAsyncStreamCopier2.h" #include "nsIPipe.h" #include "nsIOService.h" #include "nsIHttpChannelInternal.h" #include "Base64.h" #include "WebSocketChannel.h" #include "nsCharSeparatedTokenizer.h" #include "nsIX509Cert.h" static NS_DEFINE_CID(kStreamTransportServiceCID, NS_STREAMTRANSPORTSERVICE_CID); namespace mozilla { namespace dom { static LazyLogModule gHttpServerLog("HttpServer"); #undef LOG_I #define LOG_I(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Debug, (__VA_ARGS__)) #undef LOG_V #define LOG_V(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Verbose, (__VA_ARGS__)) #undef LOG_E #define LOG_E(...) MOZ_LOG(gHttpServerLog, mozilla::LogLevel::Error, (__VA_ARGS__)) NS_IMPL_ISUPPORTS(HttpServer, nsIServerSocketListener, nsILocalCertGetCallback) HttpServer::HttpServer() : mPort() , mHttps() { } HttpServer::~HttpServer() { } void HttpServer::Init(int32_t aPort, bool aHttps, HttpServerListener* aListener) { mPort = aPort; mHttps = aHttps; mListener = aListener; if (mHttps) { nsCOMPtr<nsILocalCertService> lcs = do_CreateInstance("@mozilla.org/security/local-cert-service;1"); nsresult rv = lcs->GetOrCreateCert(NS_LITERAL_CSTRING("flyweb"), this); if (NS_FAILED(rv)) { NotifyStarted(rv); } } else { // Make sure to always have an async step before notifying callbacks HandleCert(nullptr, NS_OK); } } NS_IMETHODIMP HttpServer::HandleCert(nsIX509Cert* aCert, nsresult aResult) { nsresult rv = aResult; if (NS_SUCCEEDED(rv)) { rv = StartServerSocket(aCert); } if (NS_FAILED(rv) && mServerSocket) { mServerSocket->Close(); mServerSocket = nullptr; } NotifyStarted(rv); return NS_OK; } void HttpServer::NotifyStarted(nsresult aStatus) { RefPtr<HttpServerListener> listener = mListener; nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([listener, aStatus] () { listener->OnServerStarted(aStatus); }); NS_DispatchToCurrentThread(event); } nsresult HttpServer::StartServerSocket(nsIX509Cert* aCert) { nsresult rv; mServerSocket = do_CreateInstance(aCert ? "@mozilla.org/network/tls-server-socket;1" : "@mozilla.org/network/server-socket;1", &rv); NS_ENSURE_SUCCESS(rv, rv); rv = mServerSocket->Init(mPort, false, -1); NS_ENSURE_SUCCESS(rv, rv); if (aCert) { nsCOMPtr<nsITLSServerSocket> tls = do_QueryInterface(mServerSocket); rv = tls->SetServerCert(aCert); NS_ENSURE_SUCCESS(rv, rv); rv = tls->SetSessionTickets(false); NS_ENSURE_SUCCESS(rv, rv); mCert = aCert; } rv = mServerSocket->AsyncListen(this); NS_ENSURE_SUCCESS(rv, rv); rv = mServerSocket->GetPort(&mPort); NS_ENSURE_SUCCESS(rv, rv); LOG_I("HttpServer::StartServerSocket(%p)", this); return NS_OK; } NS_IMETHODIMP HttpServer::OnSocketAccepted(nsIServerSocket* aServ, nsISocketTransport* aTransport) { MOZ_ASSERT(SameCOMIdentity(aServ, mServerSocket)); nsresult rv; RefPtr<Connection> conn = new Connection(aTransport, this, rv); NS_ENSURE_SUCCESS(rv, rv); LOG_I("HttpServer::OnSocketAccepted(%p) - Socket %p", this, conn.get()); mConnections.AppendElement(conn.forget()); return NS_OK; } NS_IMETHODIMP HttpServer::OnStopListening(nsIServerSocket* aServ, nsresult aStatus) { MOZ_ASSERT(aServ == mServerSocket || !mServerSocket); LOG_I("HttpServer::OnStopListening(%p) - status 0x%lx", this, aStatus); Close(); return NS_OK; } void HttpServer::SendResponse(InternalRequest* aRequest, InternalResponse* aResponse) { for (Connection* conn : mConnections) { if (conn->TryHandleResponse(aRequest, aResponse)) { return; } } MOZ_ASSERT(false, "Unknown request"); } already_AddRefed<nsITransportProvider> HttpServer::AcceptWebSocket(InternalRequest* aConnectRequest, const Optional<nsAString>& aProtocol, ErrorResult& aRv) { for (Connection* conn : mConnections) { if (!conn->HasPendingWebSocketRequest(aConnectRequest)) { continue; } nsCOMPtr<nsITransportProvider> provider = conn->HandleAcceptWebSocket(aProtocol, aRv); if (aRv.Failed()) { conn->Close(); } // This connection is now owned by the websocket, or we just closed it mConnections.RemoveElement(conn); return provider.forget(); } aRv.Throw(NS_ERROR_UNEXPECTED); MOZ_ASSERT(false, "Unknown request"); return nullptr; } void HttpServer::SendWebSocketResponse(InternalRequest* aConnectRequest, InternalResponse* aResponse) { for (Connection* conn : mConnections) { if (conn->HasPendingWebSocketRequest(aConnectRequest)) { conn->HandleWebSocketResponse(aResponse); return; } } MOZ_ASSERT(false, "Unknown request"); } void HttpServer::Close() { if (mServerSocket) { mServerSocket->Close(); mServerSocket = nullptr; } if (mListener) { RefPtr<HttpServerListener> listener = mListener.forget(); listener->OnServerClose(); } for (Connection* conn : mConnections) { conn->Close(); } mConnections.Clear(); } void HttpServer::GetCertKey(nsACString& aKey) { nsAutoString tmp; if (mCert) { mCert->GetSha256Fingerprint(tmp); } LossyCopyUTF16toASCII(tmp, aKey); } NS_IMPL_ISUPPORTS(HttpServer::TransportProvider, nsITransportProvider) HttpServer::TransportProvider::~TransportProvider() { } NS_IMETHODIMP HttpServer::TransportProvider::SetListener(nsIHttpUpgradeListener* aListener) { MOZ_ASSERT(!mListener); MOZ_ASSERT(aListener); mListener = aListener; MaybeNotify(); return NS_OK; } NS_IMETHODIMP HttpServer::TransportProvider::GetIPCChild(PTransportProviderChild** aChild) { MOZ_CRASH("Don't call this in parent process"); *aChild = nullptr; return NS_OK; } void HttpServer::TransportProvider::SetTransport(nsISocketTransport* aTransport, nsIAsyncInputStream* aInput, nsIAsyncOutputStream* aOutput) { MOZ_ASSERT(!mTransport); MOZ_ASSERT(aTransport && aInput && aOutput); mTransport = aTransport; mInput = aInput; mOutput = aOutput; MaybeNotify(); } void HttpServer::TransportProvider::MaybeNotify() { if (mTransport && mListener) { RefPtr<TransportProvider> self = this; nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([self, this] () { mListener->OnTransportAvailable(mTransport, mInput, mOutput); }); NS_DispatchToCurrentThread(event); } } NS_IMPL_ISUPPORTS(HttpServer::Connection, nsIInputStreamCallback, nsIOutputStreamCallback) HttpServer::Connection::Connection(nsISocketTransport* aTransport, HttpServer* aServer, nsresult& rv) : mServer(aServer) , mTransport(aTransport) , mState(eRequestLine) , mPendingReqVersion() , mRemainingBodySize() , mCloseAfterRequest(false) { nsCOMPtr<nsIInputStream> input; rv = mTransport->OpenInputStream(0, 0, 0, getter_AddRefs(input)); NS_ENSURE_SUCCESS_VOID(rv); mInput = do_QueryInterface(input); nsCOMPtr<nsIOutputStream> output; rv = mTransport->OpenOutputStream(0, 0, 0, getter_AddRefs(output)); NS_ENSURE_SUCCESS_VOID(rv); mOutput = do_QueryInterface(output); if (mServer->mHttps) { SetSecurityObserver(true); } else { mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread()); } } NS_IMETHODIMP HttpServer::Connection::OnHandshakeDone(nsITLSServerSocket* aServer, nsITLSClientStatus* aStatus) { LOG_I("HttpServer::Connection::OnHandshakeDone(%p)", this); // XXX Verify connection security SetSecurityObserver(false); mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread()); return NS_OK; } void HttpServer::Connection::SetSecurityObserver(bool aListen) { LOG_I("HttpServer::Connection::SetSecurityObserver(%p) - %s", this, aListen ? "On" : "Off"); nsCOMPtr<nsISupports> secInfo; mTransport->GetSecurityInfo(getter_AddRefs(secInfo)); nsCOMPtr<nsITLSServerConnectionInfo> tlsConnInfo = do_QueryInterface(secInfo); MOZ_ASSERT(tlsConnInfo); tlsConnInfo->SetSecurityObserver(aListen ? this : nullptr); } HttpServer::Connection::~Connection() { } NS_IMETHODIMP HttpServer::Connection::OnInputStreamReady(nsIAsyncInputStream* aStream) { MOZ_ASSERT(!mInput || aStream == mInput); LOG_I("HttpServer::Connection::OnInputStreamReady(%p)", this); if (!mInput || mState == ePause) { return NS_OK; } uint64_t avail; nsresult rv = mInput->Available(&avail); if (NS_FAILED(rv)) { LOG_I("HttpServer::Connection::OnInputStreamReady(%p) - Connection closed", this); mServer->mConnections.RemoveElement(this); // Connection closed. Handle errors here. return NS_OK; } uint32_t numRead; rv = mInput->ReadSegments(ReadSegmentsFunc, this, UINT32_MAX, &numRead); NS_ENSURE_SUCCESS(rv, rv); rv = mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread()); NS_ENSURE_SUCCESS(rv, rv); return NS_OK; } nsresult HttpServer::Connection::ReadSegmentsFunc(nsIInputStream* aIn, void* aClosure, const char* aBuffer, uint32_t aToOffset, uint32_t aCount, uint32_t* aWriteCount) { const char* buffer = aBuffer; nsresult rv = static_cast<HttpServer::Connection*>(aClosure)-> ConsumeInput(buffer, buffer + aCount); *aWriteCount = buffer - aBuffer; MOZ_ASSERT(*aWriteCount <= aCount); return rv; } static const char* findCRLF(const char* aBuffer, const char* aEnd) { if (aBuffer + 1 >= aEnd) { return nullptr; } const char* pos; while ((pos = static_cast<const char*>(memchr(aBuffer, '\r', aEnd - aBuffer - 1)))) { if (*(pos + 1) == '\n') { return pos; } aBuffer = pos + 1; } return nullptr; } nsresult HttpServer::Connection::ConsumeInput(const char*& aBuffer, const char* aEnd) { nsresult rv; while (mState == eRequestLine || mState == eHeaders) { // Consume line-by-line // Check if buffer boundry ended up right between the CR and LF if (!mInputBuffer.IsEmpty() && mInputBuffer.Last() == '\r' && *aBuffer == '\n') { aBuffer++; rv = ConsumeLine(mInputBuffer.BeginReading(), mInputBuffer.Length() - 1); NS_ENSURE_SUCCESS(rv, rv); mInputBuffer.Truncate(); } // Look for a CRLF const char* pos = findCRLF(aBuffer, aEnd); if (!pos) { mInputBuffer.Append(aBuffer, aEnd - aBuffer); aBuffer = aEnd; return NS_OK; } if (!mInputBuffer.IsEmpty()) { mInputBuffer.Append(aBuffer, pos - aBuffer); aBuffer = pos + 2; rv = ConsumeLine(mInputBuffer.BeginReading(), mInputBuffer.Length() - 1); NS_ENSURE_SUCCESS(rv, rv); mInputBuffer.Truncate(); } else { rv = ConsumeLine(aBuffer, pos - aBuffer); NS_ENSURE_SUCCESS(rv, rv); aBuffer = pos + 2; } } if (mState == eBody) { uint32_t size = std::min(mRemainingBodySize, static_cast<uint32_t>(aEnd - aBuffer)); uint32_t written = size; if (mCurrentRequestBody) { rv = mCurrentRequestBody->Write(aBuffer, size, &written); // Since we've given the pipe unlimited size, we should never // end up needing to block. MOZ_ASSERT(rv != NS_BASE_STREAM_WOULD_BLOCK); if (NS_FAILED(rv)) { written = size; mCurrentRequestBody = nullptr; } } aBuffer += written; mRemainingBodySize -= written; if (!mRemainingBodySize) { mCurrentRequestBody->Close(); mCurrentRequestBody = nullptr; mState = eRequestLine; } } return NS_OK; } bool ContainsToken(const nsCString& aList, const nsCString& aToken) { nsCCharSeparatedTokenizer tokens(aList, ','); bool found = false; while (!found && tokens.hasMoreTokens()) { found = tokens.nextToken().Equals(aToken); } return found; } static bool IsWebSocketRequest(InternalRequest* aRequest, uint32_t aHttpVersion) { if (aHttpVersion < 1) { return false; } nsAutoCString str; aRequest->GetMethod(str); if (!str.EqualsLiteral("GET")) { return false; } InternalHeaders* headers = aRequest->Headers(); ErrorResult res; headers->GetFirst(NS_LITERAL_CSTRING("upgrade"), str, res); MOZ_ASSERT(!res.Failed()); if (!str.EqualsLiteral("websocket")) { return false; } headers->GetFirst(NS_LITERAL_CSTRING("connection"), str, res); MOZ_ASSERT(!res.Failed()); if (!ContainsToken(str, NS_LITERAL_CSTRING("Upgrade"))) { return false; } headers->GetFirst(NS_LITERAL_CSTRING("sec-websocket-key"), str, res); MOZ_ASSERT(!res.Failed()); nsAutoCString binary; if (NS_FAILED(Base64Decode(str, binary)) || binary.Length() != 16) { return false; } nsresult rv; headers->GetFirst(NS_LITERAL_CSTRING("sec-websocket-version"), str, res); MOZ_ASSERT(!res.Failed()); if (str.ToInteger(&rv) != 13 || NS_FAILED(rv)) { return false; } return true; } nsresult HttpServer::Connection::ConsumeLine(const char* aBuffer, size_t aLength) { MOZ_ASSERT(mState == eRequestLine || mState == eHeaders); if (MOZ_LOG_TEST(gHttpServerLog, mozilla::LogLevel::Verbose)) { nsCString line(aBuffer, aLength); LOG_V("HttpServer::Connection::ConsumeLine(%p) - \"%s\"", this, line.get()); } if (mState == eRequestLine) { LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsing request line", this); NS_ENSURE_FALSE(mCloseAfterRequest, NS_ERROR_UNEXPECTED); if (aLength == 0) { // Ignore empty lines before the request line return NS_OK; } MOZ_ASSERT(!mPendingReq); // Process request line nsCWhitespaceTokenizer tokens(Substring(aBuffer, aLength)); NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED); nsDependentCSubstring method = tokens.nextToken(); NS_ENSURE_TRUE(NS_IsValidHTTPToken(method), NS_ERROR_UNEXPECTED); NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED); nsDependentCSubstring url = tokens.nextToken(); // Seems like it's also allowed to pass full urls with scheme+host+port. // May need to support that. NS_ENSURE_TRUE(url.First() == '/', NS_ERROR_UNEXPECTED); mPendingReq = new InternalRequest(url, /* aURLFragment */ EmptyCString()); mPendingReq->SetMethod(method); NS_ENSURE_TRUE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED); nsDependentCSubstring version = tokens.nextToken(); NS_ENSURE_TRUE(StringBeginsWith(version, NS_LITERAL_CSTRING("HTTP/1.")), NS_ERROR_UNEXPECTED); nsresult rv; // This integer parsing is likely not strict enough. nsCString reqVersion; reqVersion = Substring(version, MOZ_ARRAY_LENGTH("HTTP/1.") - 1); mPendingReqVersion = reqVersion.ToInteger(&rv); NS_ENSURE_SUCCESS(rv, NS_ERROR_UNEXPECTED); NS_ENSURE_FALSE(tokens.hasMoreTokens(), NS_ERROR_UNEXPECTED); LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsed request line", this); mState = eHeaders; return NS_OK; } if (aLength == 0) { LOG_V("HttpServer::Connection::ConsumeLine(%p) - Found end of headers", this); MaybeAddPendingHeader(); ErrorResult res; mPendingReq->Headers()->SetGuard(HeadersGuardEnum::Immutable, res); // Check for WebSocket if (IsWebSocketRequest(mPendingReq, mPendingReqVersion)) { LOG_V("HttpServer::Connection::ConsumeLine(%p) - Fire OnWebSocket", this); mState = ePause; mPendingWebSocketRequest = mPendingReq.forget(); mPendingReqVersion = 0; RefPtr<HttpServerListener> listener = mServer->mListener; RefPtr<InternalRequest> request = mPendingWebSocketRequest; nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([listener, request] () { listener->OnWebSocket(request); }); NS_DispatchToCurrentThread(event); return NS_OK; } nsAutoCString header; mPendingReq->Headers()->GetFirst(NS_LITERAL_CSTRING("connection"), header, res); MOZ_ASSERT(!res.Failed()); // 1.0 defaults to closing connections. // 1.1 and higher defaults to keep-alive. if (ContainsToken(header, NS_LITERAL_CSTRING("close")) || (mPendingReqVersion == 0 && !ContainsToken(header, NS_LITERAL_CSTRING("keep-alive")))) { mCloseAfterRequest = true; } mPendingReq->Headers()->GetFirst(NS_LITERAL_CSTRING("content-length"), header, res); MOZ_ASSERT(!res.Failed()); LOG_V("HttpServer::Connection::ConsumeLine(%p) - content-length is \"%s\"", this, header.get()); if (!header.IsEmpty()) { nsresult rv; mRemainingBodySize = header.ToInteger(&rv); NS_ENSURE_SUCCESS(rv, rv); } else { mRemainingBodySize = 0; } if (mRemainingBodySize) { LOG_V("HttpServer::Connection::ConsumeLine(%p) - Starting consume body", this); mState = eBody; // We use an unlimited buffer size here to ensure // that we get to the next request even if the webpage hangs on // to the request indefinitely without consuming the body. nsCOMPtr<nsIInputStream> input; nsCOMPtr<nsIOutputStream> output; nsresult rv = NS_NewPipe(getter_AddRefs(input), getter_AddRefs(output), 0, // Segment size UINT32_MAX, // Unlimited buffer size false, // not nonBlockingInput true); // nonBlockingOutput NS_ENSURE_SUCCESS(rv, rv); mCurrentRequestBody = do_QueryInterface(output); mPendingReq->SetBody(input); } else { LOG_V("HttpServer::Connection::ConsumeLine(%p) - No body", this); mState = eRequestLine; } mPendingRequests.AppendElement(PendingRequest(mPendingReq, nullptr)); LOG_V("HttpServer::Connection::ConsumeLine(%p) - Fire OnRequest", this); RefPtr<HttpServerListener> listener = mServer->mListener; RefPtr<InternalRequest> request = mPendingReq.forget(); nsCOMPtr<nsIRunnable> event = NS_NewRunnableFunction([listener, request] () { listener->OnRequest(request); }); NS_DispatchToCurrentThread(event); mPendingReqVersion = 0; return NS_OK; } // Parse header line if (aBuffer[0] == ' ' || aBuffer[0] == '\t') { LOG_V("HttpServer::Connection::ConsumeLine(%p) - Add to header %s", this, mPendingHeaderName.get()); NS_ENSURE_FALSE(mPendingHeaderName.IsEmpty(), NS_ERROR_UNEXPECTED); // We might need to do whitespace trimming/compression here. mPendingHeaderValue.Append(aBuffer, aLength); return NS_OK; } MaybeAddPendingHeader(); const char* colon = static_cast<const char*>(memchr(aBuffer, ':', aLength)); NS_ENSURE_TRUE(colon, NS_ERROR_UNEXPECTED); ToLowerCase(Substring(aBuffer, colon - aBuffer), mPendingHeaderName); mPendingHeaderValue.Assign(colon + 1, aLength - (colon - aBuffer) - 1); NS_ENSURE_TRUE(NS_IsValidHTTPToken(mPendingHeaderName), NS_ERROR_UNEXPECTED); LOG_V("HttpServer::Connection::ConsumeLine(%p) - Parsed header %s", this, mPendingHeaderName.get()); return NS_OK; } void HttpServer::Connection::MaybeAddPendingHeader() { if (mPendingHeaderName.IsEmpty()) { return; } // We might need to do more whitespace trimming/compression here. mPendingHeaderValue.Trim(" \t"); ErrorResult rv; mPendingReq->Headers()->Append(mPendingHeaderName, mPendingHeaderValue, rv); mPendingHeaderName.Truncate(); } bool HttpServer::Connection::TryHandleResponse(InternalRequest* aRequest, InternalResponse* aResponse) { bool handledResponse = false; for (uint32_t i = 0; i < mPendingRequests.Length(); ++i) { PendingRequest& pending = mPendingRequests[i]; if (pending.first() == aRequest) { MOZ_ASSERT(!handledResponse); MOZ_ASSERT(!pending.second()); pending.second() = aResponse; if (i != 0) { return true; } handledResponse = true; } if (handledResponse && !pending.second()) { // Shortcut if we've handled the response, and // we don't have more responses to send return true; } if (i == 0 && pending.second()) { RefPtr<InternalResponse> resp = pending.second().forget(); mPendingRequests.RemoveElementAt(0); QueueResponse(resp); --i; } } return handledResponse; } already_AddRefed<nsITransportProvider> HttpServer::Connection::HandleAcceptWebSocket(const Optional<nsAString>& aProtocol, ErrorResult& aRv) { MOZ_ASSERT(mPendingWebSocketRequest); RefPtr<InternalResponse> response = new InternalResponse(101, NS_LITERAL_CSTRING("Switching Protocols")); InternalHeaders* headers = response->Headers(); headers->Set(NS_LITERAL_CSTRING("Upgrade"), NS_LITERAL_CSTRING("websocket"), aRv); headers->Set(NS_LITERAL_CSTRING("Connection"), NS_LITERAL_CSTRING("Upgrade"), aRv); if (aProtocol.WasPassed()) { NS_ConvertUTF16toUTF8 protocol(aProtocol.Value()); nsAutoCString reqProtocols; mPendingWebSocketRequest->Headers()-> GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Protocol"), reqProtocols, aRv); if (!ContainsToken(reqProtocols, protocol)) { // Should throw a better error here aRv.Throw(NS_ERROR_FAILURE); return nullptr; } headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Protocol"), protocol, aRv); } nsAutoCString key, hash; mPendingWebSocketRequest->Headers()-> GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Key"), key, aRv); nsresult rv = mozilla::net::CalculateWebSocketHashedSecret(key, hash); if (NS_FAILED(rv)) { aRv.Throw(rv); return nullptr; } headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Accept"), hash, aRv); nsAutoCString extensions, negotiatedExtensions; mPendingWebSocketRequest->Headers()-> GetFirst(NS_LITERAL_CSTRING("Sec-WebSocket-Extensions"), extensions, aRv); mozilla::net::ProcessServerWebSocketExtensions(extensions, negotiatedExtensions); if (!negotiatedExtensions.IsEmpty()) { headers->Set(NS_LITERAL_CSTRING("Sec-WebSocket-Extensions"), negotiatedExtensions, aRv); } RefPtr<TransportProvider> result = new TransportProvider(); mWebSocketTransportProvider = result; QueueResponse(response); return result.forget(); } void HttpServer::Connection::HandleWebSocketResponse(InternalResponse* aResponse) { MOZ_ASSERT(mPendingWebSocketRequest); mState = eRequestLine; mPendingWebSocketRequest = nullptr; mInput->AsyncWait(this, 0, 0, NS_GetCurrentThread()); QueueResponse(aResponse); } void HttpServer::Connection::QueueResponse(InternalResponse* aResponse) { bool chunked = false; RefPtr<InternalHeaders> headers = new InternalHeaders(*aResponse->Headers()); { ErrorResult res; headers->SetGuard(HeadersGuardEnum::None, res); } nsCOMPtr<nsIInputStream> body; int64_t bodySize; aResponse->GetBody(getter_AddRefs(body), &bodySize); if (body && bodySize >= 0) { nsCString sizeStr; sizeStr.AppendInt(bodySize); LOG_V("HttpServer::Connection::QueueResponse(%p) - " "Setting content-length to %s", this, sizeStr.get()); ErrorResult res; headers->Set(NS_LITERAL_CSTRING("content-length"), sizeStr, res); } else if (body) { // Use chunked transfer encoding LOG_V("HttpServer::Connection::QueueResponse(%p) - Chunked transfer-encoding", this); ErrorResult res; headers->Set(NS_LITERAL_CSTRING("transfer-encoding"), NS_LITERAL_CSTRING("chunked"), res); headers->Delete(NS_LITERAL_CSTRING("content-length"), res); chunked = true; } else { LOG_V("HttpServer::Connection::QueueResponse(%p) - " "No body - setting content-length to 0", this); ErrorResult res; headers->Set(NS_LITERAL_CSTRING("content-length"), NS_LITERAL_CSTRING("0"), res); } nsCString head(NS_LITERAL_CSTRING("HTTP/1.1 ")); head.AppendInt(aResponse->GetStatus()); // XXX is the statustext security checked? head.Append(NS_LITERAL_CSTRING(" ") + aResponse->GetStatusText() + NS_LITERAL_CSTRING("\r\n")); AutoTArray<InternalHeaders::Entry, 16> entries; headers->GetEntries(entries); for (auto header : entries) { head.Append(header.mName + NS_LITERAL_CSTRING(": ") + header.mValue + NS_LITERAL_CSTRING("\r\n")); } head.Append(NS_LITERAL_CSTRING("\r\n")); mOutputBuffers.AppendElement()->mString = head; if (body) { OutputBuffer* bodyBuffer = mOutputBuffers.AppendElement(); bodyBuffer->mStream = body; bodyBuffer->mChunked = chunked; } OnOutputStreamReady(mOutput); } namespace { typedef MozPromise<nsresult, bool, false> StreamCopyPromise; class StreamCopier final : public nsIOutputStreamCallback , public nsIInputStreamCallback , public nsIRunnable { public: static RefPtr<StreamCopyPromise> Copy(nsIInputStream* aSource, nsIAsyncOutputStream* aSink, bool aChunked) { RefPtr<StreamCopier> copier = new StreamCopier(aSource, aSink, aChunked); RefPtr<StreamCopyPromise> p = copier->mPromise.Ensure(__func__); nsresult rv = copier->mTarget->Dispatch(copier, NS_DISPATCH_NORMAL); if (NS_FAILED(rv)) { copier->mPromise.Resolve(rv, __func__); } return p; } NS_DECL_THREADSAFE_ISUPPORTS NS_DECL_NSIINPUTSTREAMCALLBACK NS_DECL_NSIOUTPUTSTREAMCALLBACK NS_DECL_NSIRUNNABLE private: StreamCopier(nsIInputStream* aSource, nsIAsyncOutputStream* aSink, bool aChunked) : mSource(aSource) , mAsyncSource(do_QueryInterface(aSource)) , mSink(aSink) , mTarget(do_GetService(NS_STREAMTRANSPORTSERVICE_CONTRACTID)) , mChunkRemaining(0) , mChunked(aChunked) , mAddedFinalSeparator(false) , mFirstChunk(aChunked) { } ~StreamCopier() {} static nsresult FillOutputBufferHelper(nsIOutputStream* aOutStr, void* aClosure, char* aBuffer, uint32_t aOffset, uint32_t aCount, uint32_t* aCountRead); nsresult FillOutputBuffer(char* aBuffer, uint32_t aCount, uint32_t* aCountRead); nsCOMPtr<nsIInputStream> mSource; nsCOMPtr<nsIAsyncInputStream> mAsyncSource; nsCOMPtr<nsIAsyncOutputStream> mSink; MozPromiseHolder<StreamCopyPromise> mPromise; nsCOMPtr<nsIEventTarget> mTarget; // XXX we should cache this somewhere uint32_t mChunkRemaining; nsCString mSeparator; bool mChunked; bool mAddedFinalSeparator; bool mFirstChunk; }; NS_IMPL_ISUPPORTS(StreamCopier, nsIOutputStreamCallback, nsIInputStreamCallback, nsIRunnable) struct WriteState { StreamCopier* copier; nsresult sourceRv; }; // This function only exists to enable FillOutputBuffer to be a non-static // function where we can use member variables more easily. nsresult StreamCopier::FillOutputBufferHelper(nsIOutputStream* aOutStr, void* aClosure, char* aBuffer, uint32_t aOffset, uint32_t aCount, uint32_t* aCountRead) { WriteState* ws = static_cast<WriteState*>(aClosure); ws->sourceRv = ws->copier->FillOutputBuffer(aBuffer, aCount, aCountRead); return ws->sourceRv; } nsresult CheckForEOF(nsIInputStream* aIn, void* aClosure, const char* aBuffer, uint32_t aToOffset, uint32_t aCount, uint32_t* aWriteCount) { *static_cast<bool*>(aClosure) = true; *aWriteCount = 0; return NS_BINDING_ABORTED; } nsresult StreamCopier::FillOutputBuffer(char* aBuffer, uint32_t aCount, uint32_t* aCountRead) { nsresult rv = NS_OK; while (mChunked && mSeparator.IsEmpty() && !mChunkRemaining && !mAddedFinalSeparator) { uint64_t avail; rv = mSource->Available(&avail); if (rv == NS_BASE_STREAM_CLOSED) { avail = 0; rv = NS_OK; } NS_ENSURE_SUCCESS(rv, rv); mChunkRemaining = avail > UINT32_MAX ? UINT32_MAX : static_cast<uint32_t>(avail); if (!mChunkRemaining) { // Either it's an non-blocking stream without any data // currently available, or we're at EOF. Sadly there's no way // to tell other than to read from the stream. bool hadData = false; uint32_t numRead; rv = mSource->ReadSegments(CheckForEOF, &hadData, 1, &numRead); if (rv == NS_BASE_STREAM_CLOSED) { avail = 0; rv = NS_OK; } NS_ENSURE_SUCCESS(rv, rv); MOZ_ASSERT(numRead == 0); if (hadData) { // The source received data between the call to Available and the // call to ReadSegments. Restart with a new call to Available continue; } // We're at EOF, write a separator with 0 mAddedFinalSeparator = true; } if (mFirstChunk) { mFirstChunk = false; MOZ_ASSERT(mSeparator.IsEmpty()); } else { // For all chunks except the first, add the newline at the end // of the previous chunk of data mSeparator.AssignLiteral("\r\n"); } mSeparator.AppendInt(mChunkRemaining, 16); mSeparator.AppendLiteral("\r\n"); if (mAddedFinalSeparator) { mSeparator.AppendLiteral("\r\n"); } break; } // If we're doing chunked encoding, we should either have a chunk size, // or we should have reached the end of the input stream. MOZ_ASSERT_IF(mChunked, mChunkRemaining || mAddedFinalSeparator); // We should only have a separator if we're doing chunked encoding MOZ_ASSERT_IF(!mSeparator.IsEmpty(), mChunked); if (!mSeparator.IsEmpty()) { *aCountRead = std::min(mSeparator.Length(), aCount); memcpy(aBuffer, mSeparator.BeginReading(), *aCountRead); mSeparator.Cut(0, *aCountRead); rv = NS_OK; } else if (mChunked) { *aCountRead = 0; if (mChunkRemaining) { rv = mSource->Read(aBuffer, std::min(aCount, mChunkRemaining), aCountRead); mChunkRemaining -= *aCountRead; } } else { rv = mSource->Read(aBuffer, aCount, aCountRead); } if (NS_SUCCEEDED(rv) && *aCountRead == 0) { rv = NS_BASE_STREAM_CLOSED; } return rv; } NS_IMETHODIMP StreamCopier::Run() { nsresult rv; while (1) { WriteState state = { this, NS_OK }; uint32_t written; rv = mSink->WriteSegments(FillOutputBufferHelper, &state, mozilla::net::nsIOService::gDefaultSegmentSize, &written); MOZ_ASSERT(NS_SUCCEEDED(rv) || NS_SUCCEEDED(state.sourceRv)); if (rv == NS_BASE_STREAM_WOULD_BLOCK) { mSink->AsyncWait(this, 0, 0, mTarget); return NS_OK; } if (NS_FAILED(rv)) { mPromise.Resolve(rv, __func__); return NS_OK; } if (state.sourceRv == NS_BASE_STREAM_WOULD_BLOCK) { MOZ_ASSERT(mAsyncSource); mAsyncSource->AsyncWait(this, 0, 0, mTarget); mSink->AsyncWait(this, nsIAsyncInputStream::WAIT_CLOSURE_ONLY, 0, mTarget); return NS_OK; } if (state.sourceRv == NS_BASE_STREAM_CLOSED) { // We're done! // No longer interested in callbacks about either stream closing mSink->AsyncWait(nullptr, 0, 0, nullptr); if (mAsyncSource) { mAsyncSource->AsyncWait(nullptr, 0, 0, nullptr); } mSource->Close(); mSource = nullptr; mAsyncSource = nullptr; mSink = nullptr; mPromise.Resolve(NS_OK, __func__); return NS_OK; } if (NS_FAILED(state.sourceRv)) { mPromise.Resolve(state.sourceRv, __func__); return NS_OK; } } MOZ_ASSUME_UNREACHABLE_MARKER(); } NS_IMETHODIMP StreamCopier::OnInputStreamReady(nsIAsyncInputStream* aStream) { MOZ_ASSERT(aStream == mAsyncSource || (!mSource && !mAsyncSource && !mSink)); return mSource ? Run() : NS_OK; } NS_IMETHODIMP StreamCopier::OnOutputStreamReady(nsIAsyncOutputStream* aStream) { MOZ_ASSERT(aStream == mSink || (!mSource && !mAsyncSource && !mSink)); return mSource ? Run() : NS_OK; } } // namespace NS_IMETHODIMP HttpServer::Connection::OnOutputStreamReady(nsIAsyncOutputStream* aStream) { MOZ_ASSERT(aStream == mOutput || !mOutput); if (!mOutput) { return NS_OK; } nsresult rv; while (!mOutputBuffers.IsEmpty()) { if (!mOutputBuffers[0].mStream) { nsCString& buffer = mOutputBuffers[0].mString; while (!buffer.IsEmpty()) { uint32_t written = 0; rv = mOutput->Write(buffer.BeginReading(), buffer.Length(), &written); buffer.Cut(0, written); if (rv == NS_BASE_STREAM_WOULD_BLOCK) { return mOutput->AsyncWait(this, 0, 0, NS_GetCurrentThread()); } if (NS_FAILED(rv)) { Close(); return NS_OK; } } mOutputBuffers.RemoveElementAt(0); } else { if (mOutputCopy) { // we're already copying the stream return NS_OK; } mOutputCopy = StreamCopier::Copy(mOutputBuffers[0].mStream, mOutput, mOutputBuffers[0].mChunked); RefPtr<Connection> self = this; mOutputCopy-> Then(AbstractThread::MainThread(), __func__, [self, this] (nsresult aStatus) { MOZ_ASSERT(mOutputBuffers[0].mStream); LOG_V("HttpServer::Connection::OnOutputStreamReady(%p) - " "Sent body. Status 0x%lx", this, aStatus); mOutputBuffers.RemoveElementAt(0); mOutputCopy = nullptr; OnOutputStreamReady(mOutput); }, [] (bool) { MOZ_ASSERT_UNREACHABLE("Reject unexpected"); }); } } if (mPendingRequests.IsEmpty()) { if (mCloseAfterRequest) { LOG_V("HttpServer::Connection::OnOutputStreamReady(%p) - Closing channel", this); Close(); } else if (mWebSocketTransportProvider) { mInput->AsyncWait(nullptr, 0, 0, nullptr); mOutput->AsyncWait(nullptr, 0, 0, nullptr); mWebSocketTransportProvider->SetTransport(mTransport, mInput, mOutput); mTransport = nullptr; mInput = nullptr; mOutput = nullptr; mWebSocketTransportProvider = nullptr; } } return NS_OK; } void HttpServer::Connection::Close() { if (!mTransport) { MOZ_ASSERT(!mOutput && !mInput); return; } mTransport->Close(NS_BINDING_ABORTED); if (mInput) { mInput->Close(); mInput = nullptr; } if (mOutput) { mOutput->Close(); mOutput = nullptr; } mTransport = nullptr; mInputBuffer.Truncate(); mOutputBuffers.Clear(); mPendingRequests.Clear(); } } // namespace net } // namespace mozilla