import os
import asyncio
import pytest
import txaio
# because py.test tries to collect it as a test-case
from unittest.mock import Mock
from autobahn.asyncio.websocket import WebSocketServerFactory
[docs]async def echo_async(what, when):
await asyncio.sleep(when)
return what
[docs]@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
@pytest.mark.asyncio
async def test_echo_async():
assert 'Hello!' == await echo_async('Hello!', 0)
# @pytest.mark.asyncio(forbid_global_loop=True)
[docs]@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
def test_websocket_custom_loop(event_loop):
factory = WebSocketServerFactory(loop=event_loop)
server = factory()
transport = Mock()
server.connection_made(transport)
[docs]@pytest.mark.skipif(not os.environ.get('USE_ASYNCIO', False), reason='test runs on asyncio only')
@pytest.mark.asyncio
async def test_async_on_connect_server(event_loop):
num = 42
done = txaio.create_future()
values = []
async def foo(x):
await asyncio.sleep(1)
return x * x
async def on_connect(req):
v = await foo(num)
values.append(v)
txaio.resolve(done, req)
factory = WebSocketServerFactory()
server = factory()
server.onConnect = on_connect
transport = Mock()
server.connection_made(transport)
server.data = b'\r\n'.join([
b'GET /ws HTTP/1.1',
b'Host: www.example.com',
b'Sec-WebSocket-Version: 13',
b'Origin: http://www.example.com.malicious.com',
b'Sec-WebSocket-Extensions: permessage-deflate',
b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
b'Connection: keep-alive, Upgrade',
b'Upgrade: websocket',
b'\r\n', # last string doesn't get a \r\n from join()
])
server.processHandshake()
await done
assert len(values) == 1
assert values[0] == num * num