From 280aa0f00c0ca3d099c2a693f6c2ce7919d2dc86 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 24 Nov 2019 09:00:22 +0000 Subject: [PATCH] Use aiohttp's WebSocket client (Fixes https://github.com/miguelgrinberg/python-socketio/issues/324) --- engineio/asyncio_client.py | 53 +++---- examples/client/asyncio/latency_client.py | 4 +- examples/client/asyncio/simple_client.py | 4 +- setup.py | 1 - tests/asyncio/test_asyncio_client.py | 166 +++++++++++----------- tox.ini | 1 - 6 files changed, 111 insertions(+), 118 deletions(-) diff --git a/engineio/asyncio_client.py b/engineio/asyncio_client.py index b187bf9c..5a65fb8d 100644 --- a/engineio/asyncio_client.py +++ b/engineio/asyncio_client.py @@ -6,10 +6,6 @@ except ImportError: # pragma: no cover aiohttp = None import six -try: - import websockets -except ImportError: # pragma: no cover - websockets = None from . import client from . import exceptions @@ -227,8 +223,8 @@ async def _connect_polling(self, url, headers, engineio_path): async def _connect_websocket(self, url, headers, engineio_path): """Establish or upgrade to a WebSocket connection with the server.""" - if websockets is None: # pragma: no cover - self.logger.error('websockets package not installed') + if aiohttp is None: # pragma: no cover + self.logger.error('aiohttp package not installed') return False websocket_url = self._get_engineio_url(url, engineio_path, 'websocket') @@ -243,30 +239,20 @@ async def _connect_websocket(self, url, headers, engineio_path): self.logger.info( 'Attempting WebSocket connection to ' + websocket_url) - # get the cookies from the long-polling connection so that they can - # also be sent the the WebSocket route - cookies = None - if self.http: - cookies = '; '.join(["{}={}".format(cookie.key, cookie.value) - for cookie in self.http._cookie_jar]) - headers = headers.copy() - headers['Cookie'] = cookies - try: if not self.ssl_verify: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE - ws = await websockets.connect( + ws = await self.http.ws_connect( websocket_url + self._get_url_timestamp(), - extra_headers=headers, ssl=ssl_context) + headers=headers, ssl=ssl_context) else: - ws = await websockets.connect( + ws = await self.http.ws_connect( websocket_url + self._get_url_timestamp(), - extra_headers=headers) - except (websockets.exceptions.InvalidURI, - websockets.exceptions.InvalidHandshake, - OSError): + headers=headers) + except (aiohttp.client_exceptions.WSServerHandshakeError, + aiohttp.client_exceptions.ServerConnectionError): if upgrade: self.logger.warning( 'WebSocket upgrade failed: connection error') @@ -277,14 +263,14 @@ async def _connect_websocket(self, url, headers, engineio_path): p = packet.Packet(packet.PING, data='probe').encode( always_bytes=False) try: - await ws.send(p) + await ws.send_str(p) except Exception as e: # pragma: no cover self.logger.warning( 'WebSocket upgrade failed: unexpected send exception: %s', str(e)) return False try: - p = await ws.recv() + p = (await ws.receive()).data except Exception as e: # pragma: no cover self.logger.warning( 'WebSocket upgrade failed: unexpected recv exception: %s', @@ -297,19 +283,17 @@ async def _connect_websocket(self, url, headers, engineio_path): return False p = packet.Packet(packet.UPGRADE).encode(always_bytes=False) try: - await ws.send(p) + await ws.send_str(p) except Exception as e: # pragma: no cover self.logger.warning( 'WebSocket upgrade failed: unexpected send exception: %s', str(e)) return False self.current_transport = 'websocket' - if self.http: # pragma: no cover - await self.http.close() self.logger.info('WebSocket upgrade was successful') else: try: - p = await ws.recv() + p = (await ws.receive()).data except Exception as e: # pragma: no cover raise exceptions.ConnectionError( 'Unexpected recv exception: ' + str(e)) @@ -495,8 +479,8 @@ async def _read_loop_websocket(self): while self.state == 'connected': p = None try: - p = await self.ws.recv() - except websockets.exceptions.ConnectionClosed: + p = (await self.ws.receive()).data + except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.info( 'Read loop: WebSocket connection was closed, aborting') await self.queue.put(None) @@ -579,9 +563,14 @@ async def _write_loop(self): # websocket try: for pkt in packets: - await self.ws.send(pkt.encode(always_bytes=False)) + if pkt.binary: + await self.ws.send_bytes(pkt.encode( + always_bytes=False)) + else: + await self.ws.send_str(pkt.encode( + always_bytes=False)) self.queue.task_done() - except websockets.exceptions.ConnectionClosed: + except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.info( 'Write loop: WebSocket connection was closed, ' 'aborting') diff --git a/examples/client/asyncio/latency_client.py b/examples/client/asyncio/latency_client.py index 4eb8c83d..1a28e4d8 100644 --- a/examples/client/asyncio/latency_client.py +++ b/examples/client/asyncio/latency_client.py @@ -28,10 +28,10 @@ async def on_message(data): await send_ping() -async def start_server(): +async def start_client(): await eio.connect('http://localhost:5000') await eio.wait() if __name__ == '__main__': - loop.run_until_complete(start_server()) + loop.run_until_complete(start_client()) diff --git a/examples/client/asyncio/simple_client.py b/examples/client/asyncio/simple_client.py index 13149a96..4197f8d1 100644 --- a/examples/client/asyncio/simple_client.py +++ b/examples/client/asyncio/simple_client.py @@ -35,11 +35,11 @@ def signal_handler(sig, frame): print('exiting') -async def start_server(): +async def start_client(): await eio.connect('http://localhost:5000') await eio.wait() if __name__ == '__main__': signal.signal(signal.SIGINT, signal_handler) - loop.run_until_complete(start_server()) + loop.run_until_complete(start_client()) diff --git a/setup.py b/setup.py index b0885d18..c77399c2 100755 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ ], 'asyncio_client': [ 'aiohttp>=3.4', - 'websockets>=7.0', ] }, tests_require=[ diff --git a/tests/asyncio/test_asyncio_client.py b/tests/asyncio/test_asyncio_client.py index 4be3914e..236e90d4 100644 --- a/tests/asyncio/test_asyncio_client.py +++ b/tests/asyncio/test_asyncio_client.py @@ -10,14 +10,9 @@ else: import mock try: - import websockets + import aiohttp except ImportError: - # weirdness to avoid errors in PY2 test run - class _dummy(): - pass - websockets = _dummy() - websockets.exceptions = _dummy() - websockets.exceptions.InvalidURI = _dummy() + aiohttp = None from engineio import asyncio_client from engineio import client @@ -458,49 +453,55 @@ def test_polling_connection_not_upgraded(self): self.assertIn(c, client.connected_clients) @mock.patch('engineio.client.time.time', return_value=123.456) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock( - side_effect=[websockets.exceptions.InvalidURI('foo')])) def test_websocket_connection_failed(self, _time): c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock(side_effect=[ + aiohttp.client_exceptions.ServerConnectionError()]) self.assertRaises( exceptions.ConnectionError, _run, c.connect('http://foo', transports=['websocket'], headers={'Foo': 'Bar'})) - asyncio_client.websockets.connect.mock.assert_called_once_with( + c.http.ws_connect.mock.assert_called_once_with( 'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456', - extra_headers={'Foo': 'Bar'}) + headers={'Foo': 'Bar'}) @mock.patch('engineio.client.time.time', return_value=123.456) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock( - side_effect=[websockets.exceptions.InvalidURI('foo')])) def test_websocket_upgrade_failed(self, _time): c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock(side_effect=[ + aiohttp.client_exceptions.ServerConnectionError()]) c.sid = '123' self.assertFalse(_run(c.connect( 'http://foo', transports=['websocket']))) - asyncio_client.websockets.connect.mock.assert_called_once_with( + c.http.ws_connect.mock.assert_called_once_with( 'ws://foo/engine.io/?transport=websocket&EIO=3&sid=123&t=123.456', - extra_headers={}) + headers={}) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_connection_no_open_packet(self): - asyncio_client.websockets.connect.mock.return_value.recv = AsyncMock( - return_value=packet.Packet(packet.CLOSE).encode()) c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet( + packet.CLOSE).encode() self.assertRaises( exceptions.ConnectionError, _run, c.connect('http://foo', transports=['websocket'])) @mock.patch('engineio.client.time.time', return_value=123.456) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_connection_successful(self, _time): - ws = asyncio_client.websockets.connect.mock.return_value - ws.recv = AsyncMock(return_value=packet.Packet( - packet.OPEN, { - 'sid': '123', 'upgrades': [], 'pingInterval': 1000, - 'pingTimeout': 2000 - }).encode()) c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet(packet.OPEN, { + 'sid': '123', 'upgrades': [], 'pingInterval': 1000, + 'pingTimeout': 2000 + }).encode() c._ping_loop = AsyncMock() c._read_loop_polling = AsyncMock() c._read_loop_websocket = AsyncMock() @@ -525,20 +526,21 @@ def test_websocket_connection_successful(self, _time): self.assertEqual(c.upgrades, []) self.assertEqual(c.transport(), 'websocket') self.assertEqual(c.ws, ws) - asyncio_client.websockets.connect.mock.assert_called_once_with( + c.http.ws_connect.mock.assert_called_once_with( 'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456', - extra_headers={}) + headers={}) @mock.patch('engineio.client.time.time', return_value=123.456) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_https_noverify_connection_successful(self, _time): - ws = asyncio_client.websockets.connect.mock.return_value - ws.recv = AsyncMock(return_value=packet.Packet( - packet.OPEN, { - 'sid': '123', 'upgrades': [], 'pingInterval': 1000, - 'pingTimeout': 2000 - }).encode()) c = asyncio_client.AsyncClient(ssl_verify=False) + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet(packet.OPEN, { + 'sid': '123', 'upgrades': [], 'pingInterval': 1000, + 'pingTimeout': 2000 + }).encode() c._ping_loop = AsyncMock() c._read_loop_polling = AsyncMock() c._read_loop_websocket = AsyncMock() @@ -563,22 +565,22 @@ def test_websocket_https_noverify_connection_successful(self, _time): self.assertEqual(c.upgrades, []) self.assertEqual(c.transport(), 'websocket') self.assertEqual(c.ws, ws) - _, kwargs = asyncio_client.websockets.connect.mock.call_args + _, kwargs = c.http.ws_connect.mock.call_args self.assertTrue('ssl' in kwargs) self.assertTrue(isinstance(kwargs['ssl'], ssl.SSLContext)) self.assertEqual(kwargs['ssl'].verify_mode, ssl.CERT_NONE) @mock.patch('engineio.client.time.time', return_value=123.456) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_connection_with_cookies(self, _time): - ws = asyncio_client.websockets.connect.mock.return_value - ws.recv = AsyncMock(return_value=packet.Packet( - packet.OPEN, { - 'sid': '123', 'upgrades': [], 'pingInterval': 1000, - 'pingTimeout': 2000 - }).encode()) - c = asyncio_client.AsyncClient() - c.http = mock.MagicMock() + c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet(packet.OPEN, { + 'sid': '123', 'upgrades': [], 'pingInterval': 1000, + 'pingTimeout': 2000 + }).encode() c.http._cookie_jar = [mock.MagicMock(), mock.MagicMock()] c.http._cookie_jar[0].key = 'key' c.http._cookie_jar[0].value = 'value' @@ -592,20 +594,21 @@ def test_websocket_connection_with_cookies(self, _time): c.on('connect', on_connect) _run(c.connect('ws://foo', transports=['websocket'])) time.sleep(0.1) - asyncio_client.websockets.connect.mock.assert_called_once_with( + c.http.ws_connect.mock.assert_called_once_with( 'ws://foo/engine.io/?transport=websocket&EIO=3&t=123.456', - extra_headers={'Cookie': 'key=value; key2=value2'}) + headers={}) - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_upgrade_no_pong(self): - ws = asyncio_client.websockets.connect.mock.return_value - ws.recv = AsyncMock(return_value=packet.Packet( - packet.OPEN, { - 'sid': '123', 'upgrades': [], 'pingInterval': 1000, - 'pingTimeout': 2000 - }).encode()) - ws.send = AsyncMock() c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet(packet.OPEN, { + 'sid': '123', 'upgrades': [], 'pingInterval': 1000, + 'pingTimeout': 2000 + }).encode() + ws.send_str = AsyncMock() c.sid = '123' c.current_transport = 'polling' c._ping_loop = AsyncMock() @@ -623,15 +626,17 @@ def test_websocket_upgrade_no_pong(self): c._write_loop.mock.assert_not_called() on_connect.assert_not_called() self.assertEqual(c.transport(), 'polling') - ws.send.mock.assert_called_once_with('2probe') + ws.send_str.mock.assert_called_once_with('2probe') - @mock.patch('engineio.asyncio_client.websockets.connect', new=AsyncMock()) def test_websocket_upgrade_successful(self): - ws = asyncio_client.websockets.connect.mock.return_value - ws.recv = AsyncMock(return_value=packet.Packet( - packet.PONG, 'probe').encode()) - ws.send = AsyncMock() c = asyncio_client.AsyncClient() + c.http = mock.MagicMock(closed=False) + c.http.ws_connect = AsyncMock() + ws = c.http.ws_connect.mock.return_value + ws.receive = AsyncMock() + ws.receive.mock.return_value.data = packet.Packet( + packet.PONG, 'probe').encode() + ws.send_str = AsyncMock() c.sid = '123' c.base_url = 'http://foo' c.current_transport = 'polling' @@ -656,10 +661,10 @@ def test_websocket_upgrade_successful(self): self.assertEqual(c.transport(), 'websocket') self.assertEqual(c.ws, ws) self.assertEqual( - ws.send.mock.call_args_list[0], + ws.send_str.mock.call_args_list[0], (('2probe',),)) # ping self.assertEqual( - ws.send.mock.call_args_list[1], + ws.send_str.mock.call_args_list[1], (('5',),)) # upgrade def test_receive_unknown_packet(self): @@ -1004,8 +1009,8 @@ def test_read_loop_websocket_no_response(self): c.queue = mock.MagicMock() c.queue.put = AsyncMock() c.ws = mock.MagicMock() - c.ws.recv = AsyncMock( - side_effect=websockets.exceptions.ConnectionClosed(1, 'foo')) + c.ws.receive = AsyncMock( + side_effect=aiohttp.client_exceptions.ServerDisconnectedError()) c.write_loop_task = AsyncMock()() c.ping_loop_task = AsyncMock()() _run(c._read_loop_websocket()) @@ -1019,7 +1024,7 @@ def test_read_loop_websocket_unexpected_error(self): c.queue = mock.MagicMock() c.queue.put = AsyncMock() c.ws = mock.MagicMock() - c.ws.recv = AsyncMock(side_effect=ValueError) + c.ws.receive = AsyncMock(side_effect=ValueError) c.write_loop_task = AsyncMock()() c.ping_loop_task = AsyncMock()() _run(c._read_loop_websocket()) @@ -1033,8 +1038,9 @@ def test_read_loop_websocket(self): c.queue = mock.MagicMock() c.queue.put = AsyncMock() c.ws = mock.MagicMock() - c.ws.recv = AsyncMock(side_effect=[ - packet.Packet(packet.PING).encode(), ValueError]) + c.ws.receive = AsyncMock(side_effect=[ + mock.MagicMock(data=packet.Packet(packet.PING).encode()), + ValueError]) c.write_loop_task = AsyncMock()() c.ping_loop_task = AsyncMock()() c._receive_packet = AsyncMock() @@ -1224,11 +1230,11 @@ def test_write_loop_websocket_one_packet(self): RuntimeError ]) c.ws = mock.MagicMock() - c.ws.send = AsyncMock() + c.ws.send_str = AsyncMock() _run(c._write_loop()) self.assertEqual(c.queue.task_done.call_count, 1) - self.assertEqual(c.ws.send.mock.call_count, 1) - c.ws.send.mock.assert_called_once_with('4{"foo":"bar"}') + self.assertEqual(c.ws.send_str.mock.call_count, 1) + c.ws.send_str.mock.assert_called_once_with('4{"foo":"bar"}') def test_write_loop_websocket_three_packets(self): c = asyncio_client.AsyncClient() @@ -1248,14 +1254,14 @@ def test_write_loop_websocket_three_packets(self): RuntimeError ]) c.ws = mock.MagicMock() - c.ws.send = AsyncMock() + c.ws.send_str = AsyncMock() _run(c._write_loop()) self.assertEqual(c.queue.task_done.call_count, 3) - self.assertEqual(c.ws.send.mock.call_count, 3) - self.assertEqual(c.ws.send.mock.call_args_list[0][0][0], + self.assertEqual(c.ws.send_str.mock.call_count, 3) + self.assertEqual(c.ws.send_str.mock.call_args_list[0][0][0], '4{"foo":"bar"}') - self.assertEqual(c.ws.send.mock.call_args_list[1][0][0], '2') - self.assertEqual(c.ws.send.mock.call_args_list[2][0][0], '6') + self.assertEqual(c.ws.send_str.mock.call_args_list[1][0][0], '2') + self.assertEqual(c.ws.send_str.mock.call_args_list[2][0][0], '6') def test_write_loop_websocket_one_packet_binary(self): c = asyncio_client.AsyncClient() @@ -1273,11 +1279,11 @@ def test_write_loop_websocket_one_packet_binary(self): RuntimeError ]) c.ws = mock.MagicMock() - c.ws.send = AsyncMock() + c.ws.send_bytes = AsyncMock() _run(c._write_loop()) self.assertEqual(c.queue.task_done.call_count, 1) - self.assertEqual(c.ws.send.mock.call_count, 1) - c.ws.send.mock.assert_called_once_with(b'\x04foo') + self.assertEqual(c.ws.send_bytes.mock.call_count, 1) + c.ws.send_bytes.mock.assert_called_once_with(b'\x04foo') def test_write_loop_websocket_bad_connection(self): c = asyncio_client.AsyncClient() @@ -1295,7 +1301,7 @@ def test_write_loop_websocket_bad_connection(self): RuntimeError ]) c.ws = mock.MagicMock() - c.ws.send = AsyncMock( - side_effect=websockets.exceptions.ConnectionClosed(1, 'foo')) + c.ws.send_str = AsyncMock( + side_effect=aiohttp.client_exceptions.ServerDisconnectedError()) _run(c._write_loop()) self.assertEqual(c.state, 'connected') diff --git a/tox.ini b/tox.ini index dc5320b5..74eb04f6 100644 --- a/tox.ini +++ b/tox.ini @@ -15,7 +15,6 @@ deps= tornado requests websocket-client - websockets basepython = flake8: python3.7 py27: python2.7