summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py')
-rw-r--r--testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py2089
1 files changed, 2089 insertions, 0 deletions
diff --git a/testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py b/testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py
new file mode 100644
index 000000000..d4598944e
--- /dev/null
+++ b/testing/web-platform/tests/tools/pywebsocket/src/test/test_mux.py
@@ -0,0 +1,2089 @@
+#!/usr/bin/env python
+#
+# Copyright 2012, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+"""Tests for mux module."""
+
+import Queue
+import copy
+import logging
+import optparse
+import struct
+import sys
+import unittest
+import time
+import zlib
+
+import set_sys_path # Update sys.path to locate mod_pywebsocket module.
+
+from mod_pywebsocket import common
+from mod_pywebsocket import mux
+from mod_pywebsocket._stream_base import ConnectionTerminatedException
+from mod_pywebsocket._stream_base import UnsupportedFrameException
+from mod_pywebsocket._stream_hybi import Frame
+from mod_pywebsocket._stream_hybi import Stream
+from mod_pywebsocket._stream_hybi import StreamOptions
+from mod_pywebsocket._stream_hybi import create_binary_frame
+from mod_pywebsocket._stream_hybi import create_close_frame
+from mod_pywebsocket._stream_hybi import create_closing_handshake_body
+from mod_pywebsocket._stream_hybi import parse_frame
+from mod_pywebsocket.extensions import MuxExtensionProcessor
+
+
+import mock
+
+
+_TEST_HEADERS = {'Host': 'server.example.com',
+ 'Upgrade': 'websocket',
+ 'Connection': 'Upgrade',
+ 'Sec-WebSocket-Key': 'dGhlIHNhbXBsZSBub25jZQ==',
+ 'Sec-WebSocket-Version': '13',
+ 'Origin': 'http://example.com'}
+
+
+class _OutgoingChannelData(object):
+ def __init__(self):
+ self.messages = []
+ self.control_messages = []
+
+ self.builder = mux._InnerMessageBuilder()
+
+class _MockMuxConnection(mock.MockBlockingConn):
+ """Mock class of mod_python connection for mux."""
+
+ def __init__(self):
+ mock.MockBlockingConn.__init__(self)
+ self._control_blocks = []
+ self._channel_data = {}
+
+ self._current_opcode = None
+ self._pending_fragments = []
+
+ self.server_close_code = None
+
+ def write(self, data):
+ """Override MockBlockingConn.write."""
+
+ self._current_data = data
+ self._position = 0
+
+ def _receive_bytes(length):
+ if self._position + length > len(self._current_data):
+ raise ConnectionTerminatedException(
+ 'Failed to receive %d bytes from encapsulated '
+ 'frame' % length)
+ data = self._current_data[self._position:self._position+length]
+ self._position += length
+ return data
+
+ # Parse physical frames and assemble a message if the message is
+ # fragmented.
+ opcode, payload, fin, rsv1, rsv2, rsv3 = (
+ parse_frame(_receive_bytes, unmask_receive=False))
+
+ self._pending_fragments.append(payload)
+
+ if self._current_opcode is None:
+ if opcode == common.OPCODE_CONTINUATION:
+ raise Exception('Sending invalid continuation opcode')
+ self._current_opcode = opcode
+ else:
+ if opcode != common.OPCODE_CONTINUATION:
+ raise Exception('Sending invalid opcode %d' % opcode)
+ if not fin:
+ return
+
+ inner_frame_data = ''.join(self._pending_fragments)
+ self._pending_fragments = []
+ self._current_opcode = None
+
+ # Handle a control message on the physical channel.
+ # TODO(bashi): Support other opcodes if needed.
+ if opcode == common.OPCODE_CLOSE:
+ if len(payload) >= 2:
+ self.server_close_code = struct.unpack('!H', payload[:2])[0]
+ close_body = create_closing_handshake_body(
+ common.STATUS_NORMAL_CLOSURE, '')
+ close_frame = create_close_frame(close_body, mask=True)
+ self.put_bytes(close_frame)
+ return
+
+ # Parse the payload of the message on physical channel.
+ parser = mux._MuxFramePayloadParser(inner_frame_data)
+ channel_id = parser.read_channel_id()
+ if channel_id == mux._CONTROL_CHANNEL_ID:
+ self._control_blocks.extend(list(parser.read_control_blocks()))
+ return
+
+ if not channel_id in self._channel_data:
+ self._channel_data[channel_id] = _OutgoingChannelData()
+ channel_data = self._channel_data[channel_id]
+
+ # Parse logical frames and assemble an inner (logical) message.
+ (inner_fin, inner_rsv1, inner_rsv2, inner_rsv3, inner_opcode,
+ inner_payload) = parser.read_inner_frame()
+ inner_frame = Frame(inner_fin, inner_rsv1, inner_rsv2, inner_rsv3,
+ inner_opcode, inner_payload)
+ message = channel_data.builder.build(inner_frame)
+ if message is None:
+ return
+
+ if (message.opcode == common.OPCODE_TEXT or
+ message.opcode == common.OPCODE_BINARY):
+ channel_data.messages.append(message.payload)
+
+ self.on_data_message(message.payload)
+ else:
+ channel_data.control_messages.append(
+ {'opcode': message.opcode,
+ 'message': message.payload})
+
+ def on_data_message(self, message):
+ pass
+
+ def get_written_control_blocks(self):
+ return self._control_blocks
+
+ def get_written_messages(self, channel_id):
+ return self._channel_data[channel_id].messages
+
+ def get_written_control_messages(self, channel_id):
+ return self._channel_data[channel_id].control_messages
+
+
+class _FailOnWriteConnection(_MockMuxConnection):
+ """Specicialized version of _MockMuxConnection. Its write() method raises
+ an exception for testing when a data message is written.
+ """
+
+ def on_data_message(self, message):
+ """Override to raise an exception."""
+
+ raise Exception('Intentional failure')
+
+
+class _ChannelEvent(object):
+ """A structure that records channel events."""
+
+ def __init__(self):
+ self.request = None
+ self.messages = []
+ self.exception = None
+ self.client_initiated_closing = False
+
+
+class _MuxMockDispatcher(object):
+ """Mock class of dispatch.Dispatcher for mux."""
+
+ def __init__(self):
+ self.channel_events = {}
+
+ def do_extra_handshake(self, request):
+ if request.ws_requested_protocols is not None:
+ request.ws_protocol = request.ws_requested_protocols[0]
+
+ def _do_echo(self, request, channel_events):
+ while True:
+ message = request.ws_stream.receive_message()
+ if message == None:
+ channel_events.client_initiated_closing = True
+ return
+ if message == 'Goodbye':
+ return
+ channel_events.messages.append(message)
+ # echo back
+ request.ws_stream.send_message(message)
+
+ def _do_ping(self, request, channel_events):
+ request.ws_stream.send_ping('Ping!')
+
+ def _do_ping_while_hello_world(self, request, channel_events):
+ request.ws_stream.send_message('Hello ', end=False)
+ request.ws_stream.send_ping('Ping!')
+ request.ws_stream.send_message('World!', end=True)
+
+ def _do_two_ping_while_hello_world(self, request, channel_events):
+ request.ws_stream.send_message('Hello ', end=False)
+ request.ws_stream.send_ping('Ping!')
+ request.ws_stream.send_ping('Pong!')
+ request.ws_stream.send_message('World!', end=True)
+
+ def transfer_data(self, request):
+ self.channel_events[request.channel_id] = _ChannelEvent()
+ self.channel_events[request.channel_id].request = request
+
+ try:
+ # Note: more handler will be added.
+ if request.uri.endswith('echo'):
+ self._do_echo(request,
+ self.channel_events[request.channel_id])
+ elif request.uri.endswith('ping'):
+ self._do_ping(request,
+ self.channel_events[request.channel_id])
+ elif request.uri.endswith('two_ping_while_hello_world'):
+ self._do_two_ping_while_hello_world(
+ request, self.channel_events[request.channel_id])
+ elif request.uri.endswith('ping_while_hello_world'):
+ self._do_ping_while_hello_world(
+ request, self.channel_events[request.channel_id])
+ else:
+ raise ValueError('Cannot handle path %r' % request.path)
+ if not request.server_terminated:
+ request.ws_stream.close_connection()
+ except ConnectionTerminatedException, e:
+ self.channel_events[request.channel_id].exception = e
+ except Exception, e:
+ self.channel_events[request.channel_id].exception = e
+ raise
+
+
+def _create_mock_request(connection=None, logical_channel_extensions=None):
+ if connection is None:
+ connection = _MockMuxConnection()
+
+ request = mock.MockRequest(uri='/echo',
+ headers_in=_TEST_HEADERS,
+ connection=connection)
+ request.ws_stream = Stream(request, options=StreamOptions())
+ request.mux_processor = MuxExtensionProcessor(
+ common.ExtensionParameter(common.MUX_EXTENSION))
+ if logical_channel_extensions is not None:
+ request.mux_processor.set_extensions(logical_channel_extensions)
+ request.mux_processor.set_quota(8 * 1024)
+ return request
+
+
+def _create_add_channel_request_frame(channel_id, encoding, encoded_handshake):
+ # Allow invalid encoding for testing.
+ first_byte = ((mux._MUX_OPCODE_ADD_CHANNEL_REQUEST << 5) | encoding)
+ payload = (chr(first_byte) +
+ mux._encode_channel_id(channel_id) +
+ mux._encode_number(len(encoded_handshake)) +
+ encoded_handshake)
+ return create_binary_frame(
+ (mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
+
+
+def _create_drop_channel_frame(channel_id, code=None, message=''):
+ payload = mux._create_drop_channel(channel_id, code, message)
+ return create_binary_frame(
+ (mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
+
+
+def _create_flow_control_frame(channel_id, replenished_quota):
+ payload = mux._create_flow_control(channel_id, replenished_quota)
+ return create_binary_frame(
+ (mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + payload), mask=True)
+
+
+def _create_logical_frame(channel_id, message, opcode=common.OPCODE_BINARY,
+ fin=True, rsv1=False, rsv2=False, rsv3=False,
+ mask=True):
+ bits = chr((fin << 7) | (rsv1 << 6) | (rsv2 << 5) | (rsv3 << 4) | opcode)
+ payload = mux._encode_channel_id(channel_id) + bits + message
+ return create_binary_frame(payload, mask=True)
+
+
+def _create_request_header(path='/echo', extensions=None):
+ headers = (
+ 'GET %s HTTP/1.1\r\n'
+ 'Host: server.example.com\r\n'
+ 'Connection: Upgrade\r\n'
+ 'Origin: http://example.com\r\n') % path
+ if extensions:
+ headers += '%s: %s' % (
+ common.SEC_WEBSOCKET_EXTENSIONS_HEADER, extensions)
+ return headers
+
+
+class MuxTest(unittest.TestCase):
+ """A unittest for mux module."""
+
+ def test_channel_id_decode(self):
+ data = '\x00\x01\xbf\xff\xdf\xff\xff\xff\xff\xff\xff'
+ parser = mux._MuxFramePayloadParser(data)
+ channel_id = parser.read_channel_id()
+ self.assertEqual(0, channel_id)
+ channel_id = parser.read_channel_id()
+ self.assertEqual(1, channel_id)
+ channel_id = parser.read_channel_id()
+ self.assertEqual(2 ** 14 - 1, channel_id)
+ channel_id = parser.read_channel_id()
+ self.assertEqual(2 ** 21 - 1, channel_id)
+ channel_id = parser.read_channel_id()
+ self.assertEqual(2 ** 29 - 1, channel_id)
+ self.assertEqual(len(data), parser._read_position)
+
+ def test_channel_id_encode(self):
+ encoded = mux._encode_channel_id(0)
+ self.assertEqual('\x00', encoded)
+ encoded = mux._encode_channel_id(2 ** 14 - 1)
+ self.assertEqual('\xbf\xff', encoded)
+ encoded = mux._encode_channel_id(2 ** 14)
+ self.assertEqual('\xc0@\x00', encoded)
+ encoded = mux._encode_channel_id(2 ** 21 - 1)
+ self.assertEqual('\xdf\xff\xff', encoded)
+ encoded = mux._encode_channel_id(2 ** 21)
+ self.assertEqual('\xe0 \x00\x00', encoded)
+ encoded = mux._encode_channel_id(2 ** 29 - 1)
+ self.assertEqual('\xff\xff\xff\xff', encoded)
+ # channel_id is too large
+ self.assertRaises(ValueError,
+ mux._encode_channel_id,
+ 2 ** 29)
+
+ def test_read_multiple_control_blocks(self):
+ # Use AddChannelRequest because it can contain arbitrary length of data
+ data = ('\x00\x01\x01a'
+ '\x00\x02\x7d%s'
+ '\x00\x03\x7e\xff\xff%s'
+ '\x00\x04\x7f\x00\x00\x00\x00\x00\x01\x00\x00%s') % (
+ 'a' * 0x7d, 'b' * 0xffff, 'c' * 0x10000)
+ parser = mux._MuxFramePayloadParser(data)
+ blocks = list(parser.read_control_blocks())
+ self.assertEqual(4, len(blocks))
+
+ self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode)
+ self.assertEqual(1, blocks[0].channel_id)
+ self.assertEqual(1, len(blocks[0].encoded_handshake))
+
+ self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[1].opcode)
+ self.assertEqual(2, blocks[1].channel_id)
+ self.assertEqual(0x7d, len(blocks[1].encoded_handshake))
+
+ self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[2].opcode)
+ self.assertEqual(3, blocks[2].channel_id)
+ self.assertEqual(0xffff, len(blocks[2].encoded_handshake))
+
+ self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[3].opcode)
+ self.assertEqual(4, blocks[3].channel_id)
+ self.assertEqual(0x10000, len(blocks[3].encoded_handshake))
+
+ self.assertEqual(len(data), parser._read_position)
+
+ def test_read_add_channel_request(self):
+ data = '\x00\x01\x01a'
+ parser = mux._MuxFramePayloadParser(data)
+ blocks = list(parser.read_control_blocks())
+ self.assertEqual(mux._MUX_OPCODE_ADD_CHANNEL_REQUEST, blocks[0].opcode)
+ self.assertEqual(1, blocks[0].channel_id)
+ self.assertEqual(1, len(blocks[0].encoded_handshake))
+
+ def test_read_drop_channel(self):
+ data = '\x60\x01\x00'
+ parser = mux._MuxFramePayloadParser(data)
+ blocks = list(parser.read_control_blocks())
+ self.assertEqual(1, len(blocks))
+ self.assertEqual(1, blocks[0].channel_id)
+ self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[0].opcode)
+ self.assertEqual(None, blocks[0].drop_code)
+ self.assertEqual(0, len(blocks[0].drop_message))
+
+ data = '\x60\x02\x09\x03\xe8Success'
+ parser = mux._MuxFramePayloadParser(data)
+ blocks = list(parser.read_control_blocks())
+ self.assertEqual(1, len(blocks))
+ self.assertEqual(2, blocks[0].channel_id)
+ self.assertEqual(mux._MUX_OPCODE_DROP_CHANNEL, blocks[0].opcode)
+ self.assertEqual(1000, blocks[0].drop_code)
+ self.assertEqual('Success', blocks[0].drop_message)
+
+ # Reason is too short.
+ data = '\x60\x01\x01\x00'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(mux.PhysicalConnectionError,
+ lambda: list(parser.read_control_blocks()))
+
+ def test_read_flow_control(self):
+ data = '\x40\x01\x02'
+ parser = mux._MuxFramePayloadParser(data)
+ blocks = list(parser.read_control_blocks())
+ self.assertEqual(1, len(blocks))
+ self.assertEqual(1, blocks[0].channel_id)
+ self.assertEqual(mux._MUX_OPCODE_FLOW_CONTROL, blocks[0].opcode)
+ self.assertEqual(2, blocks[0].send_quota)
+
+ def test_read_new_channel_slot(self):
+ data = '\x80\x01\x02\x02\x03'
+ parser = mux._MuxFramePayloadParser(data)
+ # TODO(bashi): Implement
+ self.assertRaises(mux.PhysicalConnectionError,
+ lambda: list(parser.read_control_blocks()))
+
+ def test_read_invalid_number_field_in_control_block(self):
+ # No number field.
+ data = ''
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # The last two bytes are missing.
+ data = '\x7e'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # Missing the last one byte.
+ data = '\x7f\x00\x00\x00\x00\x00\x01\x00'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # The length of number field is too large.
+ data = '\x7f\xff\xff\xff\xff\xff\xff\xff\xff'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # The msb of the first byte is set.
+ data = '\x80'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # Using 3 bytes encoding for 125.
+ data = '\x7e\x00\x7d'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ # Using 9 bytes encoding for 0xffff
+ data = '\x7f\x00\x00\x00\x00\x00\x00\xff\xff'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(ValueError, parser._read_number)
+
+ def test_read_invalid_size_and_contents(self):
+ # Only contain number field.
+ data = '\x01'
+ parser = mux._MuxFramePayloadParser(data)
+ self.assertRaises(mux.PhysicalConnectionError,
+ parser._read_size_and_contents)
+
+ def test_create_add_channel_response(self):
+ data = mux._create_add_channel_response(channel_id=1,
+ encoded_handshake='FooBar',
+ encoding=0,
+ rejected=False)
+ self.assertEqual('\x20\x01\x06FooBar', data)
+
+ data = mux._create_add_channel_response(channel_id=2,
+ encoded_handshake='Hello',
+ encoding=1,
+ rejected=True)
+ self.assertEqual('\x31\x02\x05Hello', data)
+
+ def test_create_drop_channel(self):
+ data = mux._create_drop_channel(channel_id=1)
+ self.assertEqual('\x60\x01\x00', data)
+
+ data = mux._create_drop_channel(channel_id=1,
+ code=2000,
+ message='error')
+ self.assertEqual('\x60\x01\x07\x07\xd0error', data)
+
+ # reason must be empty if code is None
+ self.assertRaises(ValueError,
+ mux._create_drop_channel,
+ 1, None, 'FooBar')
+
+ def test_parse_request_text(self):
+ request_text = _create_request_header()
+ command, path, version, headers = mux._parse_request_text(request_text)
+ self.assertEqual('GET', command)
+ self.assertEqual('/echo', path)
+ self.assertEqual('HTTP/1.1', version)
+ self.assertEqual(3, len(headers))
+ self.assertEqual('server.example.com', headers['Host'])
+ self.assertEqual('http://example.com', headers['Origin'])
+
+
+class MuxHandlerTest(unittest.TestCase):
+
+ def test_add_channel(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=3, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=3,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='World'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual([], dispatcher.channel_events[1].messages)
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ self.assertEqual(['World'], dispatcher.channel_events[3].messages)
+ # Channel 2
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('Hello', messages[0])
+ # Channel 3
+ messages = request.connection.get_written_messages(3)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('World', messages[0])
+ control_blocks = request.connection.get_written_control_blocks()
+ # There should be 8 control blocks:
+ # - 1 NewChannelSlot
+ # - 2 AddChannelResponses for channel id 2 and 3
+ # - 6 FlowControls for channel id 1 (initialize), 'Hello', 'World',
+ # and 3 'Goodbye's
+ self.assertEqual(9, len(control_blocks))
+
+ def test_physical_connection_write_failure(self):
+ # Use _FailOnWriteConnection.
+ request = _create_mock_request(connection=_FailOnWriteConnection())
+
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # Let the worker echo back 'Hello'. It causes _FailOnWriteConnection
+ # raising an exception.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Hello'))
+
+ # Let the worker exit. This will be unnecessary when
+ # _LogicalConnection.write() is changed to throw an exception if
+ # woke up by on_writer_done.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ # All threads should be done.
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ def test_send_blocked(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # On receiving this 'Hello', the server tries to echo back 'Hello',
+ # but it will be blocked since there's no send quota available for the
+ # channel 2.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+
+ # Wait until the worker is blocked due to send quota shortage.
+ time.sleep(1)
+
+ # Close the channel 2. The worker should be notified of the end of
+ # writer thread and stop waiting for send quota to be replenished.
+ drop_channel = _create_drop_channel_frame(channel_id=2)
+
+ request.connection.put_bytes(drop_channel)
+
+ # Make sure the channel 1 is also closed.
+ drop_channel = _create_drop_channel_frame(channel_id=1)
+ request.connection.put_bytes(drop_channel)
+
+ # All threads should be done.
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ def test_add_channel_delta_encoding(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ delta = 'GET /echo HTTP/1.1\r\n\r\n'
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('Hello', messages[0])
+
+ def test_add_channel_delta_encoding_override(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ # Override Sec-WebSocket-Protocol.
+ delta = ('GET /echo HTTP/1.1\r\n'
+ 'Sec-WebSocket-Protocol: x-foo\r\n'
+ '\r\n')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('Hello', messages[0])
+ self.assertEqual('x-foo',
+ dispatcher.channel_events[2].request.ws_protocol)
+
+ def test_add_channel_delta_after_identity(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+ # Sec-WebSocket-Protocol is different from client's opening handshake
+ # of the physical connection.
+ # TODO(bashi): Remove Upgrade, Connection, Sec-WebSocket-Key and
+ # Sec-WebSocket-Version.
+ encoded_handshake = (
+ 'GET /echo HTTP/1.1\r\n'
+ 'Host: server.example.com\r\n'
+ 'Sec-WebSocket-Protocol: x-foo\r\n'
+ 'Connection: Upgrade\r\n'
+ 'Origin: http://example.com\r\n'
+ '\r\n')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ delta = 'GET /echo HTTP/1.1\r\n\r\n'
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=3, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=3,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='World'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual([], dispatcher.channel_events[1].messages)
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ self.assertEqual(['World'], dispatcher.channel_events[3].messages)
+ # Channel 2
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('Hello', messages[0])
+ # Channel 3
+ messages = request.connection.get_written_messages(3)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('World', messages[0])
+ # Handshake base should be updated.
+ self.assertEqual(
+ 'x-foo',
+ mux_handler._handshake_base._headers['Sec-WebSocket-Protocol'])
+
+ def test_add_channel_delta_remove_header(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+ # Override handshake delta base.
+ encoded_handshake = (
+ 'GET /echo HTTP/1.1\r\n'
+ 'Host: server.example.com\r\n'
+ 'Sec-WebSocket-Protocol: x-foo\r\n'
+ 'Connection: Upgrade\r\n'
+ 'Origin: http://example.com\r\n'
+ '\r\n')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ # Remove Sec-WebSocket-Protocol header.
+ delta = ('GET /echo HTTP/1.1\r\n'
+ 'Sec-WebSocket-Protocol:'
+ '\r\n')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=3, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=3,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='World'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual([], dispatcher.channel_events[1].messages)
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ self.assertEqual(['World'], dispatcher.channel_events[3].messages)
+ # Channel 2
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('Hello', messages[0])
+ # Channel 3
+ messages = request.connection.get_written_messages(3)
+ self.assertEqual(1, len(messages))
+ self.assertEqual('World', messages[0])
+ self.assertEqual(
+ 'x-foo',
+ dispatcher.channel_events[2].request.ws_protocol)
+ self.assertEqual(
+ None,
+ dispatcher.channel_events[3].request.ws_protocol)
+
+ def test_add_channel_delta_encoding_permessage_compress(self):
+ # Enable permessage compress extension on the implicitly opened channel.
+ extensions = common.parse_extensions(
+ '%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
+ request = _create_mock_request(
+ logical_channel_extensions=extensions)
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ delta = 'GET /echo HTTP/1.1\r\n\r\n'
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=20)
+ request.connection.put_bytes(flow_control)
+
+ # Send compressed 'Hello' on logical channel 1 and 2.
+ compress = zlib.compressobj(
+ zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed_hello = compress.compress('Hello')
+ compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
+ compressed_hello = compressed_hello[:-4]
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message=compressed_hello,
+ rsv1=True))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message=compressed_hello,
+ rsv1=True))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(['Hello'], dispatcher.channel_events[1].messages)
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ # Written 'Hello's should be compressed.
+ messages = request.connection.get_written_messages(1)
+ self.assertEqual(1, len(messages))
+ self.assertEqual(compressed_hello, messages[0])
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(1, len(messages))
+ self.assertEqual(compressed_hello, messages[0])
+
+ def test_add_channel_delta_encoding_remove_extensions(self):
+ # Enable permessage compress extension on the implicitly opened channel.
+ extensions = common.parse_extensions(
+ '%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
+ request = _create_mock_request(
+ logical_channel_extensions=extensions)
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ # Remove permessage compress extension.
+ delta = ('GET /echo HTTP/1.1\r\n'
+ 'Sec-WebSocket-Extensions:\r\n'
+ '\r\n')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=1, encoded_handshake=delta)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=20)
+ request.connection.put_bytes(flow_control)
+
+ # Send compressed message on logical channel 2. The message should
+ # be rejected (since rsv1 is set).
+ compress = zlib.compressobj(
+ zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed_hello = compress.compress('Hello')
+ compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
+ compressed_hello = compressed_hello[:-4]
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message=compressed_hello,
+ rsv1=True))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_NORMAL_CLOSURE, drop_channel.drop_code)
+ self.assertEqual(2, drop_channel.channel_id)
+ # UnsupportedFrameException should be raised on logical channel 2.
+ self.assertTrue(isinstance(dispatcher.channel_events[2].exception,
+ UnsupportedFrameException))
+
+ def test_add_channel_invalid_encoding(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=3,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_UNKNOWN_REQUEST_ENCODING,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_add_channel_incomplete_handshake(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ incomplete_encoded_handshake = 'GET /echo HTTP/1.1'
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=incomplete_encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertTrue(1 in dispatcher.channel_events)
+ self.assertTrue(not 2 in dispatcher.channel_events)
+
+ def test_add_channel_duplicate_channel_id(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_CHANNEL_ALREADY_EXISTS,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_receive_drop_channel(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ drop_channel = _create_drop_channel_frame(channel_id=2)
+ request.connection.put_bytes(drop_channel)
+
+ # Terminate implicitly opened channel.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ exception = dispatcher.channel_events[2].exception
+ self.assertTrue(exception.__class__ == ConnectionTerminatedException)
+
+ def test_receive_ping_frame(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=13)
+ request.connection.put_bytes(flow_control)
+
+ ping_frame = _create_logical_frame(channel_id=2,
+ message='Hello World!',
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping_frame)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PONG, messages[0]['opcode'])
+ self.assertEqual('Hello World!', messages[0]['message'])
+
+ def test_receive_fragmented_ping(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=13)
+ request.connection.put_bytes(flow_control)
+
+ # Send a ping with message 'Hello world!' in two fragmented frames.
+ ping_frame1 = _create_logical_frame(channel_id=2,
+ message='Hello ',
+ fin=False,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping_frame1)
+ ping_frame2 = _create_logical_frame(channel_id=2,
+ message='World!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(ping_frame2)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PONG, messages[0]['opcode'])
+ self.assertEqual('Hello World!', messages[0]['message'])
+
+ def test_receive_fragmented_ping_while_receiving_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=19)
+ request.connection.put_bytes(flow_control)
+
+ # Send a fragmented frame of message 'Hello '.
+ hello = _create_logical_frame(channel_id=2,
+ message='Hello ',
+ fin=False)
+ request.connection.put_bytes(hello)
+
+ # Before sending the last fragmented frame of the message, send a
+ # fragmented ping.
+ ping1 = _create_logical_frame(channel_id=2,
+ message='Pi',
+ fin=False,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping1)
+ ping2 = _create_logical_frame(channel_id=2,
+ message='ng!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(ping2)
+
+ # Send the last fragmented frame of the message.
+ world = _create_logical_frame(channel_id=2,
+ message='World!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(world)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello World!'], messages)
+ control_messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PONG, control_messages[0]['opcode'])
+ self.assertEqual('Ping!', control_messages[0]['message'])
+
+ def test_receive_two_ping_while_receiving_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=25)
+ request.connection.put_bytes(flow_control)
+
+ # Send a fragmented frame of message 'Hello '.
+ hello = _create_logical_frame(channel_id=2,
+ message='Hello ',
+ fin=False)
+ request.connection.put_bytes(hello)
+
+ # Before sending the last fragmented frame of the message, send a
+ # fragmented ping and a non-fragmented ping.
+ ping1 = _create_logical_frame(channel_id=2,
+ message='Pi',
+ fin=False,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping1)
+ ping2 = _create_logical_frame(channel_id=2,
+ message='ng!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(ping2)
+ ping3 = _create_logical_frame(channel_id=2,
+ message='Pong!',
+ fin=True,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping3)
+
+ # Send the last fragmented frame of the message.
+ world = _create_logical_frame(channel_id=2,
+ message='World!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(world)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello World!'], messages)
+ control_messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PONG, control_messages[0]['opcode'])
+ self.assertEqual('Ping!', control_messages[0]['message'])
+ self.assertEqual(common.OPCODE_PONG, control_messages[1]['opcode'])
+ self.assertEqual('Pong!', control_messages[1]['message'])
+
+ def test_receive_message_while_receiving_fragmented_ping(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=19)
+ request.connection.put_bytes(flow_control)
+
+ # Send a fragmented ping.
+ ping1 = _create_logical_frame(channel_id=2,
+ message='Pi',
+ fin=False,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(ping1)
+
+ # Before sending the last fragmented ping, send a message.
+ # The logical channel (2) should be dropped.
+ message = _create_logical_frame(channel_id=2,
+ message='Hello world!',
+ fin=True)
+ request.connection.put_bytes(message)
+
+ # Send the last fragmented frame of the message.
+ ping2 = _create_logical_frame(channel_id=2,
+ message='ng!',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(ping2)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(2, drop_channel.channel_id)
+ # No message should be sent on channel 2.
+ self.assertRaises(KeyError,
+ request.connection.get_written_messages,
+ 2)
+ self.assertRaises(KeyError,
+ request.connection.get_written_control_messages,
+ 2)
+
+ def test_send_ping(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/ping')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
+ self.assertEqual('Ping!', messages[0]['message'])
+
+ def test_send_fragmented_ping(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/ping')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Replenish 3 bytes. This isn't enough to send the whole ping frame
+ # because the frame will have 5 bytes message('Ping!'). The frame
+ # should be fragmented.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=3)
+ request.connection.put_bytes(flow_control)
+
+ # Wait until the worker is blocked due to send quota shortage.
+ time.sleep(1)
+
+ # Replenish remaining 2 + 1 bytes (including extra cost).
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=3)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
+ self.assertEqual('Ping!', messages[0]['message'])
+
+ def test_send_fragmented_ping_while_sending_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(
+ path='/ping_while_hello_world')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Application will send:
+ # - text message 'Hello ' with fin=0
+ # - ping with 'Ping!' message
+ # - text message 'World!' with fin=1
+ # Replenish (6 + 1) + (2 + 1) bytes so that the ping will be
+ # fragmented on the logical channel.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=10)
+ request.connection.put_bytes(flow_control)
+
+ time.sleep(1)
+
+ # Replenish remaining 3 + 6 bytes.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=9)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello World!'], messages)
+ control_messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PING, control_messages[0]['opcode'])
+ self.assertEqual('Ping!', control_messages[0]['message'])
+
+ def test_send_fragmented_two_ping_while_sending_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(
+ path='/two_ping_while_hello_world')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Application will send:
+ # - text message 'Hello ' with fin=0
+ # - ping with 'Ping!' message
+ # - ping with 'Pong!' message
+ # - text message 'World!' with fin=1
+ # Replenish (6 + 1) + (2 + 1) bytes so that the first ping will be
+ # fragmented on the logical channel.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=10)
+ request.connection.put_bytes(flow_control)
+
+ time.sleep(1)
+
+ # Replenish remaining 3 + (5 + 1) + 6 bytes. The second ping won't
+ # be fragmented on the logical channel.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=15)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello World!'], messages)
+ control_messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PING, control_messages[0]['opcode'])
+ self.assertEqual('Ping!', control_messages[0]['message'])
+ self.assertEqual(common.OPCODE_PING, control_messages[1]['opcode'])
+ self.assertEqual('Pong!', control_messages[1]['message'])
+
+ def test_send_drop_channel(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # DropChannel for channel id 1 which doesn't have reason.
+ frame = create_binary_frame('\x00\x60\x01\x00', mask=True)
+ request.connection.put_bytes(frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_ACKNOWLEDGED,
+ drop_channel.drop_code)
+ self.assertEqual(1, drop_channel.channel_id)
+
+ def test_two_flow_control(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Replenish 5 bytes.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=5)
+ request.connection.put_bytes(flow_control)
+
+ # Send 10 bytes. The server will try echo back 10 bytes.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='HelloWorld'))
+
+ # Replenish 5 + 1 (per-message extra cost) bytes.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['HelloWorld'], messages)
+ received_flow_controls = [
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_FLOW_CONTROL and b.channel_id == 2]
+ # Replenishment for 'HelloWorld' + 1
+ self.assertEqual(11, received_flow_controls[0].send_quota)
+ # Replenishment for 'Goodbye' + 1
+ self.assertEqual(8, received_flow_controls[1].send_quota)
+
+ def test_no_send_quota_on_server(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='HelloWorld'))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ # Just wait for 1 sec so that the server attempts to echo back
+ # 'HelloWorld'.
+ self.assertFalse(mux_handler.wait_until_done(timeout=1))
+
+ # No message should be sent on channel 2.
+ self.assertRaises(KeyError,
+ request.connection.get_written_messages,
+ 2)
+
+ def test_no_send_quota_on_server_for_permessage_extra_cost(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+ # Replenish only len('World') bytes.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=5)
+ request.connection.put_bytes(flow_control)
+ # Server should not callback for this message.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='World'))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ # Just wait for 1 sec so that the server attempts to echo back
+ # 'World'.
+ self.assertFalse(mux_handler.wait_until_done(timeout=1))
+
+ # Only one message should be sent on channel 2.
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello'], messages)
+
+ def test_quota_violation_by_client(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 0)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='HelloWorld'))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ control_blocks = request.connection.get_written_control_blocks()
+ self.assertEqual(5, len(control_blocks))
+ drop_channel = next(
+ b for b in control_blocks
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_SEND_QUOTA_VIOLATION,
+ drop_channel.drop_code)
+
+ def test_consume_quota_empty_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ # Client has 1 byte quota.
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 1)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=2)
+ request.connection.put_bytes(flow_control)
+ # Send an empty message. Pywebsocket always replenishes 1 byte quota
+ # for empty message
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message=''))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ # This message violates quota on channel id 2.
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(1, len(dispatcher.channel_events[2].messages))
+ self.assertEqual('', dispatcher.channel_events[2].messages[0])
+
+ received_flow_controls = [
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_FLOW_CONTROL and b.channel_id == 2]
+ self.assertEqual(1, len(received_flow_controls))
+ self.assertEqual(1, received_flow_controls[0].send_quota)
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(2, drop_channel.channel_id)
+ self.assertEqual(mux._DROP_CODE_SEND_QUOTA_VIOLATION,
+ drop_channel.drop_code)
+
+ def test_consume_quota_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ # Client has len('Hello') + len('Goodbye') + 2 bytes quota.
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS, 14)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='He', fin=False,
+ opcode=common.OPCODE_TEXT))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='llo', fin=True,
+ opcode=common.OPCODE_CONTINUATION))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(['Hello'], messages)
+
+ def test_fragmented_control_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/ping')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Replenish total 6 bytes in 3 FlowControls.
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=1)
+ request.connection.put_bytes(flow_control)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=2)
+ request.connection.put_bytes(flow_control)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=3)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ messages = request.connection.get_written_control_messages(2)
+ self.assertEqual(common.OPCODE_PING, messages[0]['opcode'])
+ self.assertEqual('Ping!', messages[0]['message'])
+
+ def test_channel_slot_violation_by_client(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(slots=1,
+ send_quota=mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Hello'))
+
+ # This request should be rejected.
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=3, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+ flow_control = _create_flow_control_frame(channel_id=3,
+ replenished_quota=6)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=3, message='Hello'))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual([], dispatcher.channel_events[1].messages)
+ self.assertEqual(['Hello'], dispatcher.channel_events[2].messages)
+ self.assertFalse(dispatcher.channel_events.has_key(3))
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(3, drop_channel.channel_id)
+ self.assertEqual(mux._DROP_CODE_NEW_CHANNEL_SLOT_VIOLATION,
+ drop_channel.drop_code)
+
+ def test_quota_overflow_by_client(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(slots=1,
+ send_quota=mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+ # Replenish 0x7FFFFFFFFFFFFFFF bytes twice.
+ flow_control = _create_flow_control_frame(
+ channel_id=2,
+ replenished_quota=0x7FFFFFFFFFFFFFFF)
+ request.connection.put_bytes(flow_control)
+ request.connection.put_bytes(flow_control)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(2, drop_channel.channel_id)
+ self.assertEqual(mux._DROP_CODE_SEND_QUOTA_OVERFLOW,
+ drop_channel.drop_code)
+
+ def test_invalid_encapsulated_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ first_byte = (mux._MUX_OPCODE_ADD_CHANNEL_REQUEST << 5)
+ block = (chr(first_byte) +
+ mux._encode_channel_id(1) +
+ mux._encode_number(0))
+ payload = mux._encode_channel_id(mux._CONTROL_CHANNEL_ID) + block
+ text_frame = create_binary_frame(payload, opcode=common.OPCODE_TEXT,
+ mask=True)
+ request.connection.put_bytes(text_frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_INVALID_ENCAPSULATING_MESSAGE,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_channel_id_truncated(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # The last byte of the channel id is missing.
+ frame = create_binary_frame('\x80', mask=True)
+ request.connection.put_bytes(frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_CHANNEL_ID_TRUNCATED,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_inner_frame_truncated(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # Just contain channel id 1.
+ frame = create_binary_frame('\x01', mask=True)
+ request.connection.put_bytes(frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_ENCAPSULATED_FRAME_IS_TRUNCATED,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_unknown_mux_opcode(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # Undefined opcode 5
+ frame = create_binary_frame('\x00\xa0', mask=True)
+ request.connection.put_bytes(frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_UNKNOWN_MUX_OPCODE,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_invalid_mux_control_block(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+
+ # DropChannel contains 1 byte reason
+ frame = create_binary_frame('\x00\x60\x00\x01\x00', mask=True)
+ request.connection.put_bytes(frame)
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channel = next(
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL)
+ self.assertEqual(mux._DROP_CODE_INVALID_MUX_CONTROL_BLOCK,
+ drop_channel.drop_code)
+ self.assertEqual(common.STATUS_INTERNAL_ENDPOINT_ERROR,
+ request.connection.server_close_code)
+
+ def test_permessage_compress(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ # Enable permessage compress extension on logical channel 2.
+ extensions = '%s; method=deflate' % (
+ common.PERMESSAGE_COMPRESSION_EXTENSION)
+ encoded_handshake = _create_request_header(path='/echo',
+ extensions=extensions)
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ flow_control = _create_flow_control_frame(channel_id=2,
+ replenished_quota=20)
+ request.connection.put_bytes(flow_control)
+
+ # Send compressed 'Hello' twice.
+ compress = zlib.compressobj(
+ zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed_hello1 = compress.compress('Hello')
+ compressed_hello1 += compress.flush(zlib.Z_SYNC_FLUSH)
+ compressed_hello1 = compressed_hello1[:-4]
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message=compressed_hello1,
+ rsv1=True))
+ compressed_hello2 = compress.compress('Hello')
+ compressed_hello2 += compress.flush(zlib.Z_SYNC_FLUSH)
+ compressed_hello2 = compressed_hello2[:-4]
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message=compressed_hello2,
+ rsv1=True))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=2, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(['Hello', 'Hello'],
+ dispatcher.channel_events[2].messages)
+ # Written 'Hello's should be compressed.
+ messages = request.connection.get_written_messages(2)
+ self.assertEqual(2, len(messages))
+ self.assertEqual(compressed_hello1, messages[0])
+ self.assertEqual(compressed_hello2, messages[1])
+
+
+ def test_permessage_compress_fragmented_message(self):
+ extensions = common.parse_extensions(
+ '%s; method=deflate' % common.PERMESSAGE_COMPRESSION_EXTENSION)
+ request = _create_mock_request(
+ logical_channel_extensions=extensions)
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ # Send compressed 'HelloHelloHello' as fragmented message.
+ compress = zlib.compressobj(
+ zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed_hello = compress.compress('HelloHelloHello')
+ compressed_hello += compress.flush(zlib.Z_SYNC_FLUSH)
+ compressed_hello = compressed_hello[:-4]
+
+ m = len(compressed_hello) / 2
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1,
+ message=compressed_hello[:m],
+ fin=False, rsv1=True,
+ opcode=common.OPCODE_TEXT))
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1,
+ message=compressed_hello[m:],
+ fin=True, rsv1=False,
+ opcode=common.OPCODE_CONTINUATION))
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ self.assertEqual(['HelloHelloHello'],
+ dispatcher.channel_events[1].messages)
+ messages = request.connection.get_written_messages(1)
+ self.assertEqual(1, len(messages))
+ self.assertEqual(compressed_hello, messages[0])
+
+ def test_receive_bad_fragmented_message(self):
+ request = _create_mock_request()
+ dispatcher = _MuxMockDispatcher()
+ mux_handler = mux._MuxHandler(request, dispatcher)
+ mux_handler.start()
+ mux_handler.add_channel_slots(mux._INITIAL_NUMBER_OF_CHANNEL_SLOTS,
+ mux._INITIAL_QUOTA_FOR_CLIENT)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=2, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Send a frame with fin=False, and then send a frame with
+ # opcode=TEXT (not CONTINUATION). Logical channel 2 should be dropped.
+ frame1 = _create_logical_frame(channel_id=2,
+ message='Hello ',
+ fin=False,
+ opcode=common.OPCODE_TEXT)
+ request.connection.put_bytes(frame1)
+ frame2 = _create_logical_frame(channel_id=2,
+ message='World!',
+ fin=True,
+ opcode=common.OPCODE_TEXT)
+ request.connection.put_bytes(frame2)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=3, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Send a frame with opcode=CONTINUATION without a preceding frame
+ # the fin of which is not set. Logical channel 3 should be dropped.
+ frame3 = _create_logical_frame(channel_id=3,
+ message='Hello',
+ fin=True,
+ opcode=common.OPCODE_CONTINUATION)
+ request.connection.put_bytes(frame3)
+
+ encoded_handshake = _create_request_header(path='/echo')
+ add_channel_request = _create_add_channel_request_frame(
+ channel_id=4, encoding=0,
+ encoded_handshake=encoded_handshake)
+ request.connection.put_bytes(add_channel_request)
+
+ # Send a frame with opcode=PING and fin=False, and then send a frame
+ # with opcode=TEXT (not CONTINUATION). Logical channel 4 should be
+ # dropped.
+ frame4 = _create_logical_frame(channel_id=4,
+ message='Ping',
+ fin=False,
+ opcode=common.OPCODE_PING)
+ request.connection.put_bytes(frame4)
+ frame5 = _create_logical_frame(channel_id=4,
+ message='Hello',
+ fin=True,
+ opcode=common.OPCODE_TEXT)
+ request.connection.put_bytes(frame5)
+
+ request.connection.put_bytes(
+ _create_logical_frame(channel_id=1, message='Goodbye'))
+
+ self.assertTrue(mux_handler.wait_until_done(timeout=2))
+
+ drop_channels = [
+ b for b in request.connection.get_written_control_blocks()
+ if b.opcode == mux._MUX_OPCODE_DROP_CHANNEL]
+ self.assertEqual(3, len(drop_channels))
+ for d in drop_channels:
+ self.assertEqual(mux._DROP_CODE_BAD_FRAGMENTATION,
+ d.drop_code)
+
+
+if __name__ == '__main__':
+ unittest.main()
+
+
+# vi:sts=4 sw=4 et