summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py
diff options
context:
space:
mode:
Diffstat (limited to 'testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py')
-rw-r--r--testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py1100
1 files changed, 1100 insertions, 0 deletions
diff --git a/testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py b/testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py
new file mode 100644
index 000000000..c7f805ee9
--- /dev/null
+++ b/testing/web-platform/tests/tools/pywebsocket/src/test/client_for_testing.py
@@ -0,0 +1,1100 @@
+#!/usr/bin/env python
+#
+# Copyright 2012, Google Inc.
+# All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+"""WebSocket client utility for testing.
+
+This module contains helper methods for performing handshake, frame
+sending/receiving as a WebSocket client.
+
+This is code for testing mod_pywebsocket. Keep this code independent from
+mod_pywebsocket. Don't import e.g. Stream class for generating frame for
+testing. Using util.hexify, etc. that are not related to protocol processing
+is allowed.
+
+Note:
+This code is far from robust, e.g., we cut corners in handshake.
+"""
+
+
+import base64
+import errno
+import logging
+import os
+import random
+import re
+import socket
+import struct
+import time
+
+from mod_pywebsocket import common
+from mod_pywebsocket import util
+
+
+DEFAULT_PORT = 80
+DEFAULT_SECURE_PORT = 443
+
+# Opcodes introduced in IETF HyBi 01 for the new framing format
+OPCODE_CONTINUATION = 0x0
+OPCODE_CLOSE = 0x8
+OPCODE_PING = 0x9
+OPCODE_PONG = 0xa
+OPCODE_TEXT = 0x1
+OPCODE_BINARY = 0x2
+
+# Strings used for handshake
+_UPGRADE_HEADER = 'Upgrade: websocket\r\n'
+_UPGRADE_HEADER_HIXIE75 = 'Upgrade: WebSocket\r\n'
+_CONNECTION_HEADER = 'Connection: Upgrade\r\n'
+
+WEBSOCKET_ACCEPT_UUID = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
+
+# Status codes
+STATUS_NORMAL_CLOSURE = 1000
+STATUS_GOING_AWAY = 1001
+STATUS_PROTOCOL_ERROR = 1002
+STATUS_UNSUPPORTED_DATA = 1003
+STATUS_NO_STATUS_RECEIVED = 1005
+STATUS_ABNORMAL_CLOSURE = 1006
+STATUS_INVALID_FRAME_PAYLOAD_DATA = 1007
+STATUS_POLICY_VIOLATION = 1008
+STATUS_MESSAGE_TOO_BIG = 1009
+STATUS_MANDATORY_EXT = 1010
+STATUS_INTERNAL_ENDPOINT_ERROR = 1011
+STATUS_TLS_HANDSHAKE = 1015
+
+# Extension tokens
+_DEFLATE_FRAME_EXTENSION = 'deflate-frame'
+# TODO(bashi): Update after mux implementation finished.
+_MUX_EXTENSION = 'mux_DO_NOT_USE'
+_PERMESSAGE_DEFLATE_EXTENSION = 'permessage-deflate'
+
+def _method_line(resource):
+ return 'GET %s HTTP/1.1\r\n' % resource
+
+
+def _sec_origin_header(origin):
+ return 'Sec-WebSocket-Origin: %s\r\n' % origin.lower()
+
+
+def _origin_header(origin):
+ # 4.1 13. concatenation of the string "Origin:", a U+0020 SPACE character,
+ # and the /origin/ value, converted to ASCII lowercase, to /fields/.
+ return 'Origin: %s\r\n' % origin.lower()
+
+
+def _format_host_header(host, port, secure):
+ # 4.1 9. Let /hostport/ be an empty string.
+ # 4.1 10. Append the /host/ value, converted to ASCII lowercase, to
+ # /hostport/
+ hostport = host.lower()
+ # 4.1 11. If /secure/ is false, and /port/ is not 80, or if /secure/
+ # is true, and /port/ is not 443, then append a U+003A COLON character
+ # (:) followed by the value of /port/, expressed as a base-ten integer,
+ # to /hostport/
+ if ((not secure and port != DEFAULT_PORT) or
+ (secure and port != DEFAULT_SECURE_PORT)):
+ hostport += ':' + str(port)
+ # 4.1 12. concatenation of the string "Host:", a U+0020 SPACE
+ # character, and /hostport/, to /fields/.
+ return 'Host: %s\r\n' % hostport
+
+
+# TODO(tyoshino): Define a base class and move these shared methods to that.
+
+
+def receive_bytes(socket, length):
+ received_bytes = []
+ remaining = length
+ while remaining > 0:
+ new_received_bytes = socket.recv(remaining)
+ if not new_received_bytes:
+ raise Exception(
+ 'Connection closed before receiving requested length '
+ '(requested %d bytes but received only %d bytes)' %
+ (length, length - remaining))
+ received_bytes.append(new_received_bytes)
+ remaining -= len(new_received_bytes)
+ return ''.join(received_bytes)
+
+
+# TODO(tyoshino): Now the WebSocketHandshake class diverts these methods. We
+# should move to HTTP parser as specified in RFC 6455. For HyBi 00 and
+# Hixie 75, pack these methods as some parser class.
+
+
+def _read_fields(socket):
+ # 4.1 32. let /fields/ be a list of name-value pairs, initially empty.
+ fields = {}
+ while True:
+ # 4.1 33. let /name/ and /value/ be empty byte arrays
+ name = ''
+ value = ''
+ # 4.1 34. read /name/
+ name = _read_name(socket)
+ if name is None:
+ break
+ # 4.1 35. read spaces
+ # TODO(tyoshino): Skip only one space as described in the spec.
+ ch = _skip_spaces(socket)
+ # 4.1 36. read /value/
+ value = _read_value(socket, ch)
+ # 4.1 37. read a byte from the server
+ ch = receive_bytes(socket, 1)
+ if ch != '\n': # 0x0A
+ raise Exception(
+ 'Expected LF but found %r while reading value %r for header '
+ '%r' % (ch, name, value))
+ # 4.1 38. append an entry to the /fields/ list that has the name
+ # given by the string obtained by interpreting the /name/ byte
+ # array as a UTF-8 stream and the value given by the string
+ # obtained by interpreting the /value/ byte array as a UTF-8 byte
+ # stream.
+ fields.setdefault(name, []).append(value)
+ # 4.1 39. return to the "Field" step above
+ return fields
+
+
+def _read_name(socket):
+ # 4.1 33. let /name/ be empty byte arrays
+ name = ''
+ while True:
+ # 4.1 34. read a byte from the server
+ ch = receive_bytes(socket, 1)
+ if ch == '\r': # 0x0D
+ return None
+ elif ch == '\n': # 0x0A
+ raise Exception(
+ 'Unexpected LF when reading header name %r' % name)
+ elif ch == ':': # 0x3A
+ return name
+ elif ch >= 'A' and ch <= 'Z': # range 0x31 to 0x5A
+ ch = chr(ord(ch) + 0x20)
+ name += ch
+ else:
+ name += ch
+
+
+def _skip_spaces(socket):
+ # 4.1 35. read a byte from the server
+ while True:
+ ch = receive_bytes(socket, 1)
+ if ch == ' ': # 0x20
+ continue
+ return ch
+
+
+def _read_value(socket, ch):
+ # 4.1 33. let /value/ be empty byte arrays
+ value = ''
+ # 4.1 36. read a byte from server.
+ while True:
+ if ch == '\r': # 0x0D
+ return value
+ elif ch == '\n': # 0x0A
+ raise Exception(
+ 'Unexpected LF when reading header value %r' % value)
+ else:
+ value += ch
+ ch = receive_bytes(socket, 1)
+
+
+def read_frame_header(socket):
+ received = receive_bytes(socket, 2)
+
+ first_byte = ord(received[0])
+ fin = (first_byte >> 7) & 1
+ rsv1 = (first_byte >> 6) & 1
+ rsv2 = (first_byte >> 5) & 1
+ rsv3 = (first_byte >> 4) & 1
+ opcode = first_byte & 0xf
+
+ second_byte = ord(received[1])
+ mask = (second_byte >> 7) & 1
+ payload_length = second_byte & 0x7f
+
+ if mask != 0:
+ raise Exception(
+ 'Mask bit must be 0 for frames coming from server')
+
+ if payload_length == 127:
+ extended_payload_length = receive_bytes(socket, 8)
+ payload_length = struct.unpack(
+ '!Q', extended_payload_length)[0]
+ if payload_length > 0x7FFFFFFFFFFFFFFF:
+ raise Exception('Extended payload length >= 2^63')
+ elif payload_length == 126:
+ extended_payload_length = receive_bytes(socket, 2)
+ payload_length = struct.unpack(
+ '!H', extended_payload_length)[0]
+
+ return fin, rsv1, rsv2, rsv3, opcode, payload_length
+
+
+class _TLSSocket(object):
+ """Wrapper for a TLS connection."""
+
+ def __init__(self, raw_socket):
+ self._ssl = socket.ssl(raw_socket)
+
+ def send(self, bytes):
+ return self._ssl.write(bytes)
+
+ def recv(self, size=-1):
+ return self._ssl.read(size)
+
+ def close(self):
+ # Nothing to do.
+ pass
+
+
+class HttpStatusException(Exception):
+ """This exception will be raised when unexpected http status code was
+ received as a result of handshake.
+ """
+
+ def __init__(self, name, status):
+ super(HttpStatusException, self).__init__(name)
+ self.status = status
+
+
+class WebSocketHandshake(object):
+ """Opening handshake processor for the WebSocket protocol (RFC 6455)."""
+
+ def __init__(self, options):
+ self._logger = util.get_class_logger(self)
+
+ self._options = options
+
+ def handshake(self, socket):
+ """Handshake WebSocket.
+
+ Raises:
+ Exception: handshake failed.
+ """
+
+ self._socket = socket
+
+ request_line = _method_line(self._options.resource)
+ self._logger.debug('Opening handshake Request-Line: %r', request_line)
+ self._socket.sendall(request_line)
+
+ fields = []
+ fields.append(_UPGRADE_HEADER)
+ fields.append(_CONNECTION_HEADER)
+
+ fields.append(_format_host_header(
+ self._options.server_host,
+ self._options.server_port,
+ self._options.use_tls))
+
+ if self._options.version is 8:
+ fields.append(_sec_origin_header(self._options.origin))
+ else:
+ fields.append(_origin_header(self._options.origin))
+
+ original_key = os.urandom(16)
+ key = base64.b64encode(original_key)
+ self._logger.debug(
+ 'Sec-WebSocket-Key: %s (%s)', key, util.hexify(original_key))
+ fields.append('Sec-WebSocket-Key: %s\r\n' % key)
+
+ fields.append('Sec-WebSocket-Version: %d\r\n' % self._options.version)
+
+ # Setting up extensions.
+ if len(self._options.extensions) > 0:
+ fields.append('Sec-WebSocket-Extensions: %s\r\n' %
+ ', '.join(self._options.extensions))
+
+ self._logger.debug('Opening handshake request headers: %r', fields)
+
+ for field in fields:
+ self._socket.sendall(field)
+ self._socket.sendall('\r\n')
+
+ self._logger.info('Sent opening handshake request')
+
+ field = ''
+ while True:
+ ch = receive_bytes(self._socket, 1)
+ field += ch
+ if ch == '\n':
+ break
+
+ self._logger.debug('Opening handshake Response-Line: %r', field)
+
+ if len(field) < 7 or not field.endswith('\r\n'):
+ raise Exception('Wrong status line: %r' % field)
+ m = re.match('[^ ]* ([^ ]*) .*', field)
+ if m is None:
+ raise Exception(
+ 'No HTTP status code found in status line: %r' % field)
+ code = m.group(1)
+ if not re.match('[0-9][0-9][0-9]', code):
+ raise Exception(
+ 'HTTP status code %r is not three digit in status line: %r' %
+ (code, field))
+ if code != '101':
+ raise HttpStatusException(
+ 'Expected HTTP status code 101 but found %r in status line: '
+ '%r' % (code, field), int(code))
+ fields = _read_fields(self._socket)
+ ch = receive_bytes(self._socket, 1)
+ if ch != '\n': # 0x0A
+ raise Exception('Expected LF but found: %r' % ch)
+
+ self._logger.debug('Opening handshake response headers: %r', fields)
+
+ # Check /fields/
+ if len(fields['upgrade']) != 1:
+ raise Exception(
+ 'Multiple Upgrade headers found: %s' % fields['upgrade'])
+ if len(fields['connection']) != 1:
+ raise Exception(
+ 'Multiple Connection headers found: %s' % fields['connection'])
+ if fields['upgrade'][0] != 'websocket':
+ raise Exception(
+ 'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
+ if fields['connection'][0].lower() != 'upgrade':
+ raise Exception(
+ 'Unexpected Connection header value: %s' %
+ fields['connection'][0])
+
+ if len(fields['sec-websocket-accept']) != 1:
+ raise Exception(
+ 'Multiple Sec-WebSocket-Accept headers found: %s' %
+ fields['sec-websocket-accept'])
+
+ accept = fields['sec-websocket-accept'][0]
+
+ # Validate
+ try:
+ decoded_accept = base64.b64decode(accept)
+ except TypeError, e:
+ raise HandshakeException(
+ 'Illegal value for header Sec-WebSocket-Accept: ' + accept)
+
+ if len(decoded_accept) != 20:
+ raise HandshakeException(
+ 'Decoded value of Sec-WebSocket-Accept is not 20-byte long')
+
+ self._logger.debug('Actual Sec-WebSocket-Accept: %r (%s)',
+ accept, util.hexify(decoded_accept))
+
+ original_expected_accept = util.sha1_hash(
+ key + WEBSOCKET_ACCEPT_UUID).digest()
+ expected_accept = base64.b64encode(original_expected_accept)
+
+ self._logger.debug('Expected Sec-WebSocket-Accept: %r (%s)',
+ expected_accept,
+ util.hexify(original_expected_accept))
+
+ if accept != expected_accept:
+ raise Exception(
+ 'Invalid Sec-WebSocket-Accept header: %r (expected) != %r '
+ '(actual)' % (accept, expected_accept))
+
+ server_extensions_header = fields.get('sec-websocket-extensions')
+ accepted_extensions = []
+ if server_extensions_header is not None:
+ accepted_extensions = common.parse_extensions(
+ ', '.join(server_extensions_header))
+
+ # Scan accepted extension list to check if there is any unrecognized
+ # extensions or extensions we didn't request in it. Then, for
+ # extensions we request, parse them and store parameters. They will be
+ # used later by each extension.
+ deflate_frame_accepted = False
+ mux_accepted = False
+ for extension in accepted_extensions:
+ if extension.name() == _DEFLATE_FRAME_EXTENSION:
+ if self._options.use_deflate_frame:
+ deflate_frame_accepted = True
+ continue
+ if extension.name() == _MUX_EXTENSION:
+ if self._options.use_mux:
+ mux_accepted = True
+ continue
+ if extension.name() == _PERMESSAGE_DEFLATE_EXTENSION:
+ checker = self._options.check_permessage_deflate
+ if checker:
+ checker(extension)
+ continue
+
+ raise Exception(
+ 'Received unrecognized extension: %s' % extension.name())
+
+ # Let all extensions check the response for extension request.
+
+ if (self._options.use_deflate_frame and
+ not deflate_frame_accepted):
+ raise Exception('%s extension not accepted' %
+ _DEFLATE_FRAME_EXTENSION)
+
+ if self._options.use_mux and not mux_accepted:
+ raise Exception('%s extension not accepted' % _MUX_EXTENSION)
+
+
+class WebSocketHybi00Handshake(object):
+ """Opening handshake processor for the WebSocket protocol version HyBi 00.
+ """
+
+ def __init__(self, options, draft_field):
+ self._logger = util.get_class_logger(self)
+
+ self._options = options
+ self._draft_field = draft_field
+
+ def handshake(self, socket):
+ """Handshake WebSocket.
+
+ Raises:
+ Exception: handshake failed.
+ """
+
+ self._socket = socket
+
+ # 4.1 5. send request line.
+ request_line = _method_line(self._options.resource)
+ self._logger.debug('Opening handshake Request-Line: %r', request_line)
+ self._socket.sendall(request_line)
+ # 4.1 6. Let /fields/ be an empty list of strings.
+ fields = []
+ # 4.1 7. Add the string "Upgrade: WebSocket" to /fields/.
+ fields.append(_UPGRADE_HEADER_HIXIE75)
+ # 4.1 8. Add the string "Connection: Upgrade" to /fields/.
+ fields.append(_CONNECTION_HEADER)
+ # 4.1 9-12. Add Host: field to /fields/.
+ fields.append(_format_host_header(
+ self._options.server_host,
+ self._options.server_port,
+ self._options.use_tls))
+ # 4.1 13. Add Origin: field to /fields/.
+ fields.append(_origin_header(self._options.origin))
+ # TODO: 4.1 14 Add Sec-WebSocket-Protocol: field to /fields/.
+ # TODO: 4.1 15 Add cookie headers to /fields/.
+
+ # 4.1 16-23. Add Sec-WebSocket-Key<n> to /fields/.
+ self._number1, key1 = self._generate_sec_websocket_key()
+ self._logger.debug('Number1: %d', self._number1)
+ fields.append('Sec-WebSocket-Key1: %s\r\n' % key1)
+ self._number2, key2 = self._generate_sec_websocket_key()
+ self._logger.debug('Number2: %d', self._number1)
+ fields.append('Sec-WebSocket-Key2: %s\r\n' % key2)
+
+ fields.append('Sec-WebSocket-Draft: %s\r\n' % self._draft_field)
+
+ # 4.1 24. For each string in /fields/, in a random order: send the
+ # string, encoded as UTF-8, followed by a UTF-8 encoded U+000D CARRIAGE
+ # RETURN U+000A LINE FEED character pair (CRLF).
+ random.shuffle(fields)
+
+ self._logger.debug('Opening handshake request headers: %r', fields)
+ for field in fields:
+ self._socket.sendall(field)
+
+ # 4.1 25. send a UTF-8-encoded U+000D CARRIAGE RETURN U+000A LINE FEED
+ # character pair (CRLF).
+ self._socket.sendall('\r\n')
+ # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
+ # equivalently, a random 64 bit integer encoded in a big-endian order).
+ self._key3 = self._generate_key3()
+ # 4.1 27. send /key3/ to the server.
+ self._socket.sendall(self._key3)
+ self._logger.debug(
+ 'Key3: %r (%s)', self._key3, util.hexify(self._key3))
+
+ self._logger.info('Sent opening handshake request')
+
+ # 4.1 28. Read bytes from the server until either the connection
+ # closes, or a 0x0A byte is read. let /field/ be these bytes, including
+ # the 0x0A bytes.
+ field = ''
+ while True:
+ ch = receive_bytes(self._socket, 1)
+ field += ch
+ if ch == '\n':
+ break
+
+ self._logger.debug('Opening handshake Response-Line: %r', field)
+
+ # if /field/ is not at least seven bytes long, or if the last
+ # two bytes aren't 0x0D and 0x0A respectively, or if it does not
+ # contain at least two 0x20 bytes, then fail the WebSocket connection
+ # and abort these steps.
+ if len(field) < 7 or not field.endswith('\r\n'):
+ raise Exception('Wrong status line: %r' % field)
+ m = re.match('[^ ]* ([^ ]*) .*', field)
+ if m is None:
+ raise Exception('No code found in status line: %r' % field)
+ # 4.1 29. let /code/ be the substring of /field/ that starts from the
+ # byte after the first 0x20 byte, and ends with the byte before the
+ # second 0x20 byte.
+ code = m.group(1)
+ # 4.1 30. if /code/ is not three bytes long, or if any of the bytes in
+ # /code/ are not in the range 0x30 to 0x90, then fail the WebSocket
+ # connection and abort these steps.
+ if not re.match('[0-9][0-9][0-9]', code):
+ raise Exception(
+ 'HTTP status code %r is not three digit in status line: %r' %
+ (code, field))
+ # 4.1 31. if /code/, interpreted as UTF-8, is "101", then move to the
+ # next step.
+ if code != '101':
+ raise HttpStatusException(
+ 'Expected HTTP status code 101 but found %r in status line: '
+ '%r' % (code, field), int(code))
+ # 4.1 32-39. read fields into /fields/
+ fields = _read_fields(self._socket)
+
+ self._logger.debug('Opening handshake response headers: %r', fields)
+
+ # 4.1 40. _Fields processing_
+ # read a byte from server
+ ch = receive_bytes(self._socket, 1)
+ if ch != '\n': # 0x0A
+ raise Exception('Expected LF but found %r' % ch)
+ # 4.1 41. check /fields/
+ if len(fields['upgrade']) != 1:
+ raise Exception(
+ 'Multiple Upgrade headers found: %s' % fields['upgrade'])
+ if len(fields['connection']) != 1:
+ raise Exception(
+ 'Multiple Connection headers found: %s' % fields['connection'])
+ if len(fields['sec-websocket-origin']) != 1:
+ raise Exception(
+ 'Multiple Sec-WebSocket-Origin headers found: %s' %
+ fields['sec-sebsocket-origin'])
+ if len(fields['sec-websocket-location']) != 1:
+ raise Exception(
+ 'Multiple Sec-WebSocket-Location headers found: %s' %
+ fields['sec-sebsocket-location'])
+ # TODO(ukai): protocol
+ # if the entry's name is "upgrade"
+ # if the value is not exactly equal to the string "WebSocket",
+ # then fail the WebSocket connection and abort these steps.
+ if fields['upgrade'][0] != 'WebSocket':
+ raise Exception(
+ 'Unexpected Upgrade header value: %s' % fields['upgrade'][0])
+ # if the entry's name is "connection"
+ # if the value, converted to ASCII lowercase, is not exactly equal
+ # to the string "upgrade", then fail the WebSocket connection and
+ # abort these steps.
+ if fields['connection'][0].lower() != 'upgrade':
+ raise Exception(
+ 'Unexpected Connection header value: %s' %
+ fields['connection'][0])
+ # TODO(ukai): check origin, location, cookie, ..
+
+ # 4.1 42. let /challenge/ be the concatenation of /number_1/,
+ # expressed as a big endian 32 bit integer, /number_2/, expressed
+ # as big endian 32 bit integer, and the eight bytes of /key_3/ in the
+ # order they were sent on the wire.
+ challenge = struct.pack('!I', self._number1)
+ challenge += struct.pack('!I', self._number2)
+ challenge += self._key3
+
+ self._logger.debug(
+ 'Challenge: %r (%s)', challenge, util.hexify(challenge))
+
+ # 4.1 43. let /expected/ be the MD5 fingerprint of /challenge/ as a
+ # big-endian 128 bit string.
+ expected = util.md5_hash(challenge).digest()
+ self._logger.debug(
+ 'Expected challenge response: %r (%s)',
+ expected, util.hexify(expected))
+
+ # 4.1 44. read sixteen bytes from the server.
+ # let /reply/ be those bytes.
+ reply = receive_bytes(self._socket, 16)
+ self._logger.debug(
+ 'Actual challenge response: %r (%s)', reply, util.hexify(reply))
+
+ # 4.1 45. if /reply/ does not exactly equal /expected/, then fail
+ # the WebSocket connection and abort these steps.
+ if expected != reply:
+ raise Exception(
+ 'Bad challenge response: %r (expected) != %r (actual)' %
+ (expected, reply))
+ # 4.1 46. The *WebSocket connection is established*.
+
+ def _generate_sec_websocket_key(self):
+ # 4.1 16. let /spaces_n/ be a random integer from 1 to 12 inclusive.
+ spaces = random.randint(1, 12)
+ # 4.1 17. let /max_n/ be the largest integer not greater than
+ # 4,294,967,295 divided by /spaces_n/.
+ maxnum = 4294967295 / spaces
+ # 4.1 18. let /number_n/ be a random integer from 0 to /max_n/
+ # inclusive.
+ number = random.randint(0, maxnum)
+ # 4.1 19. let /product_n/ be the result of multiplying /number_n/ and
+ # /spaces_n/ together.
+ product = number * spaces
+ # 4.1 20. let /key_n/ be a string consisting of /product_n/, expressed
+ # in base ten using the numerals in the range U+0030 DIGIT ZERO (0) to
+ # U+0039 DIGIT NINE (9).
+ key = str(product)
+ # 4.1 21. insert between one and twelve random characters from the
+ # range U+0021 to U+002F and U+003A to U+007E into /key_n/ at random
+ # positions.
+ available_chars = range(0x21, 0x2f + 1) + range(0x3a, 0x7e + 1)
+ n = random.randint(1, 12)
+ for _ in xrange(n):
+ ch = random.choice(available_chars)
+ pos = random.randint(0, len(key))
+ key = key[0:pos] + chr(ch) + key[pos:]
+ # 4.1 22. insert /spaces_n/ U+0020 SPACE characters into /key_n/ at
+ # random positions other than start or end of the string.
+ for _ in xrange(spaces):
+ pos = random.randint(1, len(key) - 1)
+ key = key[0:pos] + ' ' + key[pos:]
+ return number, key
+
+ def _generate_key3(self):
+ # 4.1 26. let /key3/ be a string consisting of eight random bytes (or
+ # equivalently, a random 64 bit integer encoded in a big-endian order).
+ return ''.join([chr(random.randint(0, 255)) for _ in xrange(8)])
+
+
+class WebSocketHixie75Handshake(object):
+ """WebSocket handshake processor for IETF Hixie 75."""
+
+ _EXPECTED_RESPONSE = (
+ 'HTTP/1.1 101 Web Socket Protocol Handshake\r\n' +
+ _UPGRADE_HEADER_HIXIE75 +
+ _CONNECTION_HEADER)
+
+ def __init__(self, options):
+ self._logger = util.get_class_logger(self)
+
+ self._options = options
+
+ def _skip_headers(self):
+ terminator = '\r\n\r\n'
+ pos = 0
+ while pos < len(terminator):
+ received = receive_bytes(self._socket, 1)
+ if received == terminator[pos]:
+ pos += 1
+ elif received == terminator[0]:
+ pos = 1
+ else:
+ pos = 0
+
+ def handshake(self, socket):
+ self._socket = socket
+
+ request_line = _method_line(self._options.resource)
+ self._logger.debug('Opening handshake Request-Line: %r', request_line)
+ self._socket.sendall(request_line)
+
+ headers = _UPGRADE_HEADER_HIXIE75 + _CONNECTION_HEADER
+ headers += _format_host_header(
+ self._options.server_host,
+ self._options.server_port,
+ self._options.use_tls)
+ headers += _origin_header(self._options.origin)
+ self._logger.debug('Opening handshake request headers: %r', headers)
+ self._socket.sendall(headers)
+
+ self._socket.sendall('\r\n')
+
+ self._logger.info('Sent opening handshake request')
+
+ for expected_char in WebSocketHixie75Handshake._EXPECTED_RESPONSE:
+ received = receive_bytes(self._socket, 1)
+ if expected_char != received:
+ raise Exception('Handshake failure')
+ # We cut corners and skip other headers.
+ self._skip_headers()
+
+
+class WebSocketStream(object):
+ """Frame processor for the WebSocket protocol (RFC 6455)."""
+
+ def __init__(self, socket, handshake):
+ self._handshake = handshake
+ self._socket = socket
+
+ # Filters applied to application data part of data frames.
+ self._outgoing_frame_filter = None
+ self._incoming_frame_filter = None
+
+ if self._handshake._options.use_deflate_frame:
+ self._outgoing_frame_filter = (
+ util._RFC1979Deflater(None, False))
+ self._incoming_frame_filter = util._RFC1979Inflater()
+
+ self._fragmented = False
+
+ def _mask_hybi(self, s):
+ # TODO(tyoshino): os.urandom does open/read/close for every call. If
+ # performance matters, change this to some library call that generates
+ # cryptographically secure pseudo random number sequence.
+ masking_nonce = os.urandom(4)
+ result = [masking_nonce]
+ count = 0
+ for c in s:
+ result.append(chr(ord(c) ^ ord(masking_nonce[count])))
+ count = (count + 1) % len(masking_nonce)
+ return ''.join(result)
+
+ def send_frame_of_arbitrary_bytes(self, header, body):
+ self._socket.sendall(header + self._mask_hybi(body))
+
+ def send_data(self, payload, frame_type, end=True, mask=True,
+ rsv1=0, rsv2=0, rsv3=0):
+ if self._outgoing_frame_filter is not None:
+ payload = self._outgoing_frame_filter.filter(payload)
+
+ if self._fragmented:
+ opcode = OPCODE_CONTINUATION
+ else:
+ opcode = frame_type
+
+ if end:
+ self._fragmented = False
+ fin = 1
+ else:
+ self._fragmented = True
+ fin = 0
+
+ if self._handshake._options.use_deflate_frame:
+ rsv1 = 1
+
+ if mask:
+ mask_bit = 1 << 7
+ else:
+ mask_bit = 0
+
+ header = chr(fin << 7 | rsv1 << 6 | rsv2 << 5 | rsv3 << 4 | opcode)
+ payload_length = len(payload)
+ if payload_length <= 125:
+ header += chr(mask_bit | payload_length)
+ elif payload_length < 1 << 16:
+ header += chr(mask_bit | 126) + struct.pack('!H', payload_length)
+ elif payload_length < 1 << 63:
+ header += chr(mask_bit | 127) + struct.pack('!Q', payload_length)
+ else:
+ raise Exception('Too long payload (%d byte)' % payload_length)
+ if mask:
+ payload = self._mask_hybi(payload)
+ self._socket.sendall(header + payload)
+
+ def send_binary(self, payload, end=True, mask=True):
+ self.send_data(payload, OPCODE_BINARY, end, mask)
+
+ def send_text(self, payload, end=True, mask=True):
+ self.send_data(payload.encode('utf-8'), OPCODE_TEXT, end, mask)
+
+ def _assert_receive_data(self, payload, opcode, fin, rsv1, rsv2, rsv3):
+ (actual_fin, actual_rsv1, actual_rsv2, actual_rsv3, actual_opcode,
+ payload_length) = read_frame_header(self._socket)
+
+ if actual_opcode != opcode:
+ raise Exception(
+ 'Unexpected opcode: %d (expected) vs %d (actual)' %
+ (opcode, actual_opcode))
+
+ if actual_fin != fin:
+ raise Exception(
+ 'Unexpected fin: %d (expected) vs %d (actual)' %
+ (fin, actual_fin))
+
+ if rsv1 is None:
+ rsv1 = 0
+ if self._handshake._options.use_deflate_frame:
+ rsv1 = 1
+
+ if rsv2 is None:
+ rsv2 = 0
+
+ if rsv3 is None:
+ rsv3 = 0
+
+ if actual_rsv1 != rsv1:
+ raise Exception(
+ 'Unexpected rsv1: %r (expected) vs %r (actual)' %
+ (rsv1, actual_rsv1))
+
+ if actual_rsv2 != rsv2:
+ raise Exception(
+ 'Unexpected rsv2: %r (expected) vs %r (actual)' %
+ (rsv2, actual_rsv2))
+
+ if actual_rsv3 != rsv3:
+ raise Exception(
+ 'Unexpected rsv3: %r (expected) vs %r (actual)' %
+ (rsv3, actual_rsv3))
+
+ received = receive_bytes(self._socket, payload_length)
+
+ if self._incoming_frame_filter is not None:
+ received = self._incoming_frame_filter.filter(received)
+
+ if len(received) != len(payload):
+ raise Exception(
+ 'Unexpected payload length: %d (expected) vs %d (actual)' %
+ (len(payload), len(received)))
+
+ if payload != received:
+ raise Exception(
+ 'Unexpected payload: %r (expected) vs %r (actual)' %
+ (payload, received))
+
+ def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
+ rsv1=None, rsv2=None, rsv3=None):
+ self._assert_receive_data(payload, opcode, fin, rsv1, rsv2, rsv3)
+
+ def assert_receive_text(self, payload, opcode=OPCODE_TEXT, fin=1,
+ rsv1=None, rsv2=None, rsv3=None):
+ self._assert_receive_data(payload.encode('utf-8'), opcode, fin, rsv1,
+ rsv2, rsv3)
+
+ def _build_close_frame(self, code, reason, mask):
+ frame = chr(1 << 7 | OPCODE_CLOSE)
+
+ if code is not None:
+ body = struct.pack('!H', code) + reason.encode('utf-8')
+ else:
+ body = ''
+ if mask:
+ frame += chr(1 << 7 | len(body)) + self._mask_hybi(body)
+ else:
+ frame += chr(len(body)) + body
+ return frame
+
+ def send_close(self, code, reason):
+ self._socket.sendall(
+ self._build_close_frame(code, reason, True))
+
+ def assert_receive_close(self, code, reason):
+ expected_frame = self._build_close_frame(code, reason, False)
+ actual_frame = receive_bytes(self._socket, len(expected_frame))
+ if actual_frame != expected_frame:
+ raise Exception(
+ 'Unexpected close frame: %r (expected) vs %r (actual)' %
+ (expected_frame, actual_frame))
+
+
+class WebSocketStreamHixie75(object):
+ """Frame processor for the WebSocket protocol version Hixie 75 and HyBi 00.
+ """
+
+ _CLOSE_FRAME = '\xff\x00'
+
+ def __init__(self, socket, unused_handshake):
+ self._socket = socket
+
+ def send_frame_of_arbitrary_bytes(self, header, body):
+ self._socket.sendall(header + body)
+
+ def send_data(self, payload, unused_frame_typem, unused_end, unused_mask):
+ frame = ''.join(['\x00', payload, '\xff'])
+ self._socket.sendall(frame)
+
+ def send_binary(self, unused_payload, unused_end, unused_mask):
+ pass
+
+ def send_text(self, payload, unused_end, unused_mask):
+ encoded_payload = payload.encode('utf-8')
+ frame = ''.join(['\x00', encoded_payload, '\xff'])
+ self._socket.sendall(frame)
+
+ def assert_receive_binary(self, payload, opcode=OPCODE_BINARY, fin=1,
+ rsv1=0, rsv2=0, rsv3=0):
+ raise Exception('Binary frame is not supported in hixie75')
+
+ def assert_receive_text(self, payload):
+ received = receive_bytes(self._socket, 1)
+
+ if received != '\x00':
+ raise Exception(
+ 'Unexpected frame type: %d (expected) vs %d (actual)' %
+ (0, ord(received)))
+
+ received = receive_bytes(self._socket, len(payload) + 1)
+ if received[-1] != '\xff':
+ raise Exception(
+ 'Termination expected: 0xff (expected) vs %r (actual)' %
+ received)
+
+ if received[0:-1] != payload:
+ raise Exception(
+ 'Unexpected payload: %r (expected) vs %r (actual)' %
+ (payload, received[0:-1]))
+
+ def send_close(self, code, reason):
+ self._socket.sendall(self._CLOSE_FRAME)
+
+ def assert_receive_close(self, unused_code, unused_reason):
+ closing = receive_bytes(self._socket, len(self._CLOSE_FRAME))
+ if closing != self._CLOSE_FRAME:
+ raise Exception('Didn\'t receive closing handshake')
+
+
+class ClientOptions(object):
+ """Holds option values to configure the Client object."""
+
+ def __init__(self):
+ self.version = 13
+ self.server_host = ''
+ self.origin = ''
+ self.resource = ''
+ self.server_port = -1
+ self.socket_timeout = 1000
+ self.use_tls = False
+ self.extensions = []
+ # Enable deflate-application-data.
+ self.use_deflate_frame = False
+ # Enable mux
+ self.use_mux = False
+
+ def enable_deflate_frame(self):
+ self.use_deflate_frame = True
+ self.extensions.append(_DEFLATE_FRAME_EXTENSION)
+
+ def enable_mux(self):
+ self.use_mux = True
+ self.extensions.append(_MUX_EXTENSION)
+
+
+def connect_socket_with_retry(host, port, timeout, use_tls,
+ retry=10, sleep_sec=0.1):
+ retry_count = 0
+ while retry_count < retry:
+ try:
+ s = socket.socket()
+ s.settimeout(timeout)
+ s.connect((host, port))
+ if use_tls:
+ return _TLSSocket(s)
+ return s
+ except socket.error, e:
+ if e.errno != errno.ECONNREFUSED:
+ raise
+ else:
+ retry_count = retry_count + 1
+ time.sleep(sleep_sec)
+
+ return None
+
+
+class Client(object):
+ """WebSocket client."""
+
+ def __init__(self, options, handshake, stream_class):
+ self._logger = util.get_class_logger(self)
+
+ self._options = options
+ self._socket = None
+
+ self._handshake = handshake
+ self._stream_class = stream_class
+
+ def connect(self):
+ self._socket = connect_socket_with_retry(
+ self._options.server_host,
+ self._options.server_port,
+ self._options.socket_timeout,
+ self._options.use_tls)
+
+ self._handshake.handshake(self._socket)
+
+ self._stream = self._stream_class(self._socket, self._handshake)
+
+ self._logger.info('Connection established')
+
+ def send_frame_of_arbitrary_bytes(self, header, body):
+ self._stream.send_frame_of_arbitrary_bytes(header, body)
+
+ def send_message(self, message, end=True, binary=False, raw=False,
+ mask=True):
+ if binary:
+ self._stream.send_binary(message, end, mask)
+ elif raw:
+ self._stream.send_data(message, OPCODE_TEXT, end, mask)
+ else:
+ self._stream.send_text(message, end, mask)
+
+ def assert_receive(self, payload, binary=False):
+ if binary:
+ self._stream.assert_receive_binary(payload)
+ else:
+ self._stream.assert_receive_text(payload)
+
+ def send_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
+ self._stream.send_close(code, reason)
+
+ def assert_receive_close(self, code=STATUS_NORMAL_CLOSURE, reason=''):
+ self._stream.assert_receive_close(code, reason)
+
+ def close_socket(self):
+ self._socket.close()
+
+ def assert_connection_closed(self):
+ try:
+ read_data = receive_bytes(self._socket, 1)
+ except Exception, e:
+ if str(e).find(
+ 'Connection closed before receiving requested length ') == 0:
+ return
+ try:
+ error_number, message = e
+ for error_name in ['ECONNRESET', 'WSAECONNRESET']:
+ if (error_name in dir(errno) and
+ error_number == getattr(errno, error_name)):
+ return
+ except:
+ raise e
+ raise e
+
+ raise Exception('Connection is not closed (Read: %r)' % read_data)
+
+
+def create_client(options):
+ return Client(
+ options, WebSocketHandshake(options), WebSocketStream)
+
+
+def create_client_hybi00(options):
+ return Client(
+ options,
+ WebSocketHybi00Handshake(options, '0'),
+ WebSocketStreamHixie75)
+
+
+def create_client_hixie75(options):
+ return Client(
+ options, WebSocketHixie75Handshake(options), WebSocketStreamHixie75)
+
+
+# vi:sts=4 sw=4 et