From 8051fc49f484585b786031a08617208efdc97f5a Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sat, 29 Jul 2023 17:20:32 +0100 Subject: [PATCH] Improvements in the connection rejected flow --- src/engineio/asyncio_socket.py | 2 ++ src/engineio/socket.py | 2 ++ tests/asyncio/test_asyncio_socket.py | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/src/engineio/asyncio_socket.py b/src/engineio/asyncio_socket.py index 84479666..1fbd0ef3 100644 --- a/src/engineio/asyncio_socket.py +++ b/src/engineio/asyncio_socket.py @@ -204,6 +204,8 @@ async def writer(): await ws.send(pkt.encode()) except: break + await ws.close() + writer_task = asyncio.ensure_future(writer()) self.server.logger.info( diff --git a/src/engineio/socket.py b/src/engineio/socket.py index 2cf8b0aa..dfa414b0 100644 --- a/src/engineio/socket.py +++ b/src/engineio/socket.py @@ -220,6 +220,8 @@ def writer(): ws.send(pkt.encode()) except: break + ws.close() + writer_task = self.server.start_background_task(writer) self.server.logger.info( diff --git a/tests/asyncio/test_asyncio_socket.py b/tests/asyncio/test_asyncio_socket.py index 30e8f7a4..78c003e7 100644 --- a/tests/asyncio/test_asyncio_socket.py +++ b/tests/asyncio/test_asyncio_socket.py @@ -339,6 +339,7 @@ def test_websocket_read_write(self): packet.Packet(packet.MESSAGE, data=foo).encode(), None, ] + ws.close = AsyncMock() _run(s._websocket_handler(ws)) assert s.connected assert s.upgraded @@ -350,6 +351,7 @@ def test_websocket_read_write(self): ] ) ws.send.mock.assert_called_with('4bar') + ws.close.mock.assert_called() def test_websocket_upgrade_read_write(self): mock_server = self._get_mock_server() @@ -374,6 +376,7 @@ def test_websocket_upgrade_read_write(self): packet.Packet(packet.MESSAGE, data=foo).encode(), None, ] + ws.close = AsyncMock() _run(s._websocket_handler(ws)) assert s.upgraded assert mock_server._trigger_event.mock.call_count == 2 @@ -384,6 +387,7 @@ def test_websocket_upgrade_read_write(self): ] ) ws.send.mock.assert_called_with('4bar') + ws.close.mock.assert_called() def test_websocket_upgrade_with_payload(self): mock_server = self._get_mock_server() @@ -398,6 +402,7 @@ def test_websocket_upgrade_with_payload(self): packet.Packet(packet.PING, data=probe).encode(), packet.Packet(packet.UPGRADE, data='2').encode(), ] + ws.close = AsyncMock() _run(s._websocket_handler(ws)) assert s.upgraded @@ -415,6 +420,7 @@ def test_websocket_upgrade_with_backlog(self): packet.Packet(packet.PING, data=probe).encode(), packet.Packet(packet.UPGRADE, data='2').encode(), ] + ws.close = AsyncMock() s.upgrading = True _run(s.send(packet.Packet(packet.MESSAGE, data=foo))) environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=sid'} @@ -454,6 +460,7 @@ def test_websocket_read_write_wait_fail(self): RuntimeError, ] ws.send.mock.side_effect = [None, RuntimeError] + ws.close = AsyncMock() _run(s._websocket_handler(ws)) assert s.closed @@ -495,6 +502,7 @@ def test_websocket_ignore_invalid_packet(self): packet.Packet(packet.MESSAGE, data=foo).encode(), None, ] + ws.close = AsyncMock() _run(s._websocket_handler(ws)) assert s.connected assert mock_server._trigger_event.mock.call_count == 2 @@ -505,6 +513,7 @@ def test_websocket_ignore_invalid_packet(self): ] ) ws.send.mock.assert_called_with('4bar') + ws.close.mock.assert_called() def test_send_after_close(self): mock_server = self._get_mock_server()