Mini Shell
import asyncio
import functools
import traceback
import unittest
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.locks import Event
from tornado.log import gen_log, app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
try:
import tornado.websocket # noqa: F401
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import (
WebSocketHandler,
websocket_connect,
WebSocketError,
WebSocketClosedError,
)
try:
from tornado import speedups
except ImportError:
speedups = None # type: ignore
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for tests to see the close code and reason on the
server side.
"""
def initialize(self, close_future=None, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self):
if self.close_future is not None:
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
@gen.coroutine
def on_message(self, message):
try:
yield self.write_message(message, isinstance(message, bytes))
except asyncio.CancelledError:
pass
except WebSocketClosedError:
pass
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1 / 0
class HeaderHandler(TestWebSocketHandler):
def open(self):
methods_to_test = [
functools.partial(self.write, "This should not work"),
functools.partial(self.redirect, "http://localhost/elsewhere"),
functools.partial(self.set_header, "X-Test", ""),
functools.partial(self.set_cookie, "Chocolate", "Chip"),
functools.partial(self.set_status, 503),
self.flush,
self.finish,
]
for method in methods_to_test:
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
method()
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get("X-Test", ""))
class HeaderEchoHandler(TestWebSocketHandler):
def set_default_headers(self):
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
def prepare(self):
for k, v in self.request.headers.get_all():
if k.lower().startswith("x-test"):
self.set_header(k, v)
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write("ok")
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.on_close_called = False
self.close(1001, "goodbye")
class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class PathArgsHandler(TestWebSocketHandler):
def open(self, arg):
self.write_message(arg)
class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(CoroutineOnMessageHandler, self).initialize(**kwargs)
self.sleeping = 0
@gen.coroutine
def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
yield gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class RenderMessageHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(self.render_string("message.html", message=message))
class SubprotocolHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super(SubprotocolHandler, self).initialize(**kwargs)
self.select_subprotocol_called = False
def select_subprotocol(self, subprotocols):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
if "goodproto" in subprotocols:
return "goodproto"
return None
def open(self):
if not self.select_subprotocol_called:
raise Exception("select_subprotocol not called")
self.write_message("subprotocol=%s" % self.selected_subprotocol)
class OpenCoroutineHandler(TestWebSocketHandler):
def initialize(self, test, **kwargs):
super(OpenCoroutineHandler, self).initialize(**kwargs)
self.test = test
self.open_finished = False
@gen.coroutine
def open(self):
yield self.test.message_sent.wait()
yield gen.sleep(0.010)
self.open_finished = True
def on_message(self, message):
if not self.open_finished:
raise Exception("on_message called before open finished")
self.write_message("ok")
class ErrorInOpenHandler(TestWebSocketHandler):
def open(self):
raise Exception("boom")
class ErrorInAsyncOpenHandler(TestWebSocketHandler):
async def open(self):
await asyncio.sleep(0)
raise Exception("boom")
class NoDelayHandler(TestWebSocketHandler):
def open(self):
self.set_nodelay(True)
self.write_message("hello")
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
)
raise gen.Return(ws)
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
return Application(
[
("/echo", EchoHandler, dict(close_future=self.close_future)),
("/non_ws", NonWebSocketHandler),
("/header", HeaderHandler, dict(close_future=self.close_future)),
(
"/header_echo",
HeaderEchoHandler,
dict(close_future=self.close_future),
),
(
"/close_reason",
CloseReasonHandler,
dict(close_future=self.close_future),
),
(
"/error_in_on_message",
ErrorInOnMessageHandler,
dict(close_future=self.close_future),
),
(
"/async_prepare",
AsyncPrepareHandler,
dict(close_future=self.close_future),
),
(
"/path_args/(.*)",
PathArgsHandler,
dict(close_future=self.close_future),
),
(
"/coroutine",
CoroutineOnMessageHandler,
dict(close_future=self.close_future),
),
("/render", RenderMessageHandler, dict(close_future=self.close_future)),
(
"/subprotocol",
SubprotocolHandler,
dict(close_future=self.close_future),
),
(
"/open_coroutine",
OpenCoroutineHandler,
dict(close_future=self.close_future, test=self),
),
("/error_in_open", ErrorInOpenHandler),
("/error_in_async_open", ErrorInAsyncOpenHandler),
("/nodelay", NoDelayHandler),
],
template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
)
def get_http_client(self):
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient()
def tearDown(self):
super(WebSocketTest, self).tearDown()
RequestHandler._template_loaders.clear()
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch("/echo")
self.assertEqual(response.code, 400)
def test_missing_websocket_key(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "13",
},
)
self.assertEqual(response.code, 400)
def test_bad_websocket_version(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "12",
},
)
self.assertEqual(response.code, 426)
@gen_test
def test_websocket_gen(self):
ws = yield self.ws_connect("/echo")
yield ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
def test_websocket_callbacks(self):
websocket_connect(
"ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
)
ws = self.wait().result()
ws.write_message("hello")
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, "hello")
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(b"hello \xe9", binary=True)
response = yield ws.read_message()
self.assertEqual(response, b"hello \xe9")
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(u"hello \u00e9")
response = yield ws.read_message()
self.assertEqual(response, u"hello \u00e9")
@gen_test
def test_render_message(self):
ws = yield self.ws_connect("/render")
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "<b>hello</b>")
@gen_test
def test_error_in_on_message(self):
ws = yield self.ws_connect("/error_in_on_message")
ws.write_message("hello")
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield self.ws_connect("/notfound")
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield self.ws_connect("/non_ws")
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
"ws://127.0.0.1:%d/" % port, connect_timeout=3600
)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
ws.write_message("hello")
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header" % self.get_http_port(),
headers={"X-Test": "hello"},
)
)
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
headers={"X-Test-Hello": "hello"},
)
)
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
@gen_test
def test_server_close_reason(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
# The on_close callback is called no matter which side closed.
code, reason = yield self.close_future
# The client echoed the close code it received to the server,
# so the server's close code (returned via close_future) is
# the same.
self.assertEqual(code, 1001)
@gen_test
def test_client_close_reason(self):
ws = yield self.ws_connect("/echo")
ws.close(1001, "goodbye")
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, "goodbye")
@gen_test
def test_write_after_close(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
self.assertIs(msg, None)
with self.assertRaises(WebSocketClosedError):
ws.write_message("hello")
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect("/async_prepare")
ws.write_message("hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_path_args(self):
ws = yield self.ws_connect("/path_args/hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_coroutine(self):
ws = yield self.ws_connect("/coroutine")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d/something" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "127.0.0.1:%d" % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {"Origin": "http://somewhereelse.com"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = "ws://localhost:%d/echo" % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {"Origin": "http://subtenant.localhost"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_subprotocols(self):
ws = yield self.ws_connect(
"/subprotocol", subprotocols=["badproto", "goodproto"]
)
self.assertEqual(ws.selected_subprotocol, "goodproto")
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=goodproto")
@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect("/subprotocol")
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=None")
@gen_test
def test_open_coroutine(self):
self.message_sent = Event()
ws = yield self.ws_connect("/open_coroutine")
yield ws.write_message("hello")
self.message_sent.set()
res = yield ws.read_message()
self.assertEqual(res, "ok")
@gen_test
def test_error_in_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_error_in_async_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_async_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_nodelay(self):
ws = yield self.ws_connect("/nodelay")
res = yield ws.read_message()
self.assertEqual(res, "hello")
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.sleeping = 0
async def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
await gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/native", NativeCoroutineOnMessageHandler)])
@gen_test
def test_native_coroutine(self):
ws = yield self.ws_connect("/native")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
class CompressionTestMixin(object):
MESSAGE = "Hello world. Testing 123 123"
def get_app(self):
class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024
def on_message(self, message):
self.write_message(str(len(message)))
return Application(
[
(
"/echo",
EchoHandler,
dict(compression_options=self.get_server_compression_options()),
),
(
"/limited",
LimitedHandler,
dict(compression_options=self.get_server_compression_options()),
),
]
)
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
@gen_test
def test_message_sizes(self):
ws = yield self.ws_connect(
"/echo", compression_options=self.get_client_compression_options()
)
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
@gen_test
def test_size_limit(self):
ws = yield self.ws_connect(
"/limited", compression_options=self.get_client_compression_options()
)
# Small messages pass through.
ws.write_message("a" * 128)
response = yield ws.read_message()
self.assertEqual(response, "128")
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message("a" * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def test_mask(self):
self.assertEqual(self.mask(b"abcd", b""), b"")
self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(
self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
b"\xff\xfa\xff\xff\xfe\xfb",
)
self.assertEqual(
self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
b"\xff\xfa\xff\xff\xfb\xfe",
)
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)
class ServerPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
@gen_test
def test_server_ping(self):
ws = yield self.ws_connect("/")
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
# TODO: test that the connection gets closed if ping responses stop.
class ClientPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message("got ping")
return Application([("/", PingHandler)])
@gen_test
def test_client_ping(self):
ws = yield self.ws_connect("/", ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
# TODO: test that the connection gets closed if ping responses stop.
class ManualPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))
return Application([("/", PingHandler)])
@gen_test
def test_manual_ping(self):
ws = yield self.ws_connect("/")
self.assertRaises(ValueError, ws.ping, "a" * 126)
ws.ping("hello")
resp = yield ws.read_message()
# on_ping always sees bytes.
self.assertEqual(resp, b"hello")
ws.ping(b"binary hello")
resp = yield ws.read_message()
self.assertEqual(resp, b"binary hello")
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/", EchoHandler)], websocket_max_message_size=1024)
@gen_test
def test_large_message(self):
ws = yield self.ws_connect("/")
# Write a message that is allowed.
msg = "a" * 1024
ws.write_message(msg)
resp = yield ws.read_message()
self.assertEqual(resp, msg)
# Write a message that is too large.
ws.write_message(msg + "b")
resp = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(resp, None)
self.assertEqual(ws.close_code, 1009)
self.assertEqual(ws.close_reason, "message too big")
# TODO: Needs tests of messages split over multiple
# continuation frames.
Zerion Mini Shell 1.0