From 3f583c88449f88200fa5f484954248bfad517aa8 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 23 Sep 2018 15:26:53 +0100 Subject: [PATCH] Actively monitor clients for disconnections --- engineio/asyncio_server.py | 27 +++++++++++++++++++++++++++ engineio/asyncio_socket.py | 26 +++++++++++++++++++++----- engineio/server.py | 36 +++++++++++++++++++++++++++++++++++- engineio/socket.py | 13 +++++++++++-- tests/test_asyncio_server.py | 18 ++++++++++++++++++ tests/test_server.py | 16 ++++++++++++++++ 6 files changed, 128 insertions(+), 8 deletions(-) diff --git a/engineio/asyncio_server.py b/engineio/asyncio_server.py index e1afe8a4..2bf36fca 100644 --- a/engineio/asyncio_server.py +++ b/engineio/asyncio_server.py @@ -225,6 +225,11 @@ async def sleep(self, seconds=0): async def _handle_connect(self, environ, transport, b64=False): """Handle a client connection request.""" + if self.start_service_task: + # start the service task to monitor connected clients + self.start_service_task = False + self.start_background_task(self._service_task) + sid = self._generate_id() s = asyncio_socket.AsyncSocket(self, sid) self.sockets[sid] = s @@ -295,3 +300,25 @@ async def async_handler(): # connection return False return ret + + async def _service_task(self): # pragma: no cover + """Monitor connected clients and clean up those that time out.""" + while True: + if len(self.sockets) == 0: + # nothing to do + await self.sleep(self.ping_timeout) + continue + + # go through the entire client list in a ping interval cycle + sleep_interval = self.ping_timeout / len(self.sockets) + + try: + # iterate over the current clients + for socket in self.sockets.copy().values(): + if socket.closed: + continue + await socket.check_ping_timeout() + await self.sleep(sleep_interval) + except: + # an unexpected exception has occurred, log it and continue + self.logger.exception('service task exception') diff --git a/engineio/asyncio_socket.py b/engineio/asyncio_socket.py index acef5135..0076f6d2 100644 --- a/engineio/asyncio_socket.py +++ b/engineio/asyncio_socket.py @@ -50,14 +50,24 @@ async def receive(self, pkt): else: raise exceptions.UnknownPacketError() - async def send(self, pkt): - """Send a packet to the client.""" + async def check_ping_timeout(self): + """Make sure the client is still sending pings. + + This helps detect disconnections for long-polling clients. + """ if self.closed: raise exceptions.SocketIsClosedError() if time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) - return await self.close(wait=False, abort=True) + await self.close(wait=False, abort=True) + return False + return True + + async def send(self, pkt): + """Send a packet to the client.""" + if not await self.check_ping_timeout(): + return self.server.logger.info('%s: Sending packet %s data %s', self.sid, packet.packet_names[pkt.packet_type], pkt.data if not isinstance(pkt.data, bytes) @@ -123,7 +133,10 @@ async def _websocket_handler(self, ws): # the socket was already connected, so this is an upgrade await self.queue.join() # flush the queue first - pkt = await ws.wait() + try: + pkt = await ws.wait() + except IOError: # pragma: no cover + return if pkt != packet.Packet(packet.PING, data=six.text_type('probe')).encode( always_bytes=False): @@ -135,7 +148,10 @@ async def _websocket_handler(self, ws): data=six.text_type('probe')).encode(always_bytes=False)) await self.send(packet.Packet(packet.NOOP)) - pkt = await ws.wait() + try: + pkt = await ws.wait() + except IOError: # pragma: no cover + return decoded_pkt = packet.Packet(encoded_packet=pkt) if decoded_pkt.packet_type != packet.UPGRADE: self.upgraded = False diff --git a/engineio/server.py b/engineio/server.py index f10e390d..451b3091 100644 --- a/engineio/server.py +++ b/engineio/server.py @@ -64,18 +64,23 @@ class Server(object): :param async_handlers: If set to ``True``, run message event handlers in non-blocking threads. To run handlers synchronously, set to ``False``. The default is ``True``. + :param monitor_clients: If set to ``True``, a background task will ensure + inactive clients are closed. Set to ``False`` to + disable the monitoring task (not recommended). The + default is ``True``. :param kwargs: Reserved for future extensions, any additional parameters given as keyword arguments will be silently ignored. """ compression_methods = ['gzip', 'deflate'] event_names = ['connect', 'disconnect', 'message'] + _default_monitor_clients = True def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25, max_http_buffer_size=100000000, allow_upgrades=True, http_compression=True, compression_threshold=1024, cookie='io', cors_allowed_origins=None, cors_credentials=True, logger=False, json=None, - async_handlers=True, **kwargs): + async_handlers=True, monitor_clients=None, **kwargs): self.ping_timeout = ping_timeout self.ping_interval = ping_interval self.max_http_buffer_size = max_http_buffer_size @@ -88,6 +93,8 @@ def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25, self.async_handlers = async_handlers self.sockets = {} self.handlers = {} + self.start_service_task = monitor_clients \ + if monitor_clients is not None else self._default_monitor_clients if json is not None: packet.Packet.json = json if not isinstance(logger, bool): @@ -359,6 +366,11 @@ def _generate_id(self): def _handle_connect(self, environ, start_response, transport, b64=False): """Handle a client connection request.""" + if self.start_service_task: + # start the service task to monitor connected clients + self.start_service_task = False + self.start_background_task(self._service_task) + sid = self._generate_id() s = socket.Socket(self, sid) self.sockets[sid] = s @@ -497,3 +509,25 @@ def _gzip(self, response): def _deflate(self, response): """Apply deflate compression to a response.""" return zlib.compress(response) + + def _service_task(self): # pragma: no cover + """Monitor connected clients and clean up those that time out.""" + while True: + if len(self.sockets) == 0: + # nothing to do + self.sleep(self.ping_timeout) + continue + + # go through the entire client list in a ping interval cycle + sleep_interval = self.ping_timeout / len(self.sockets) + + try: + # iterate over the current clients + for s in self.sockets.copy().values(): + if s.closed: + continue + s.check_ping_timeout() + self.sleep(sleep_interval) + except: + # an unexpected exception has occurred, log it and continue + self.logger.exception('service task exception') diff --git a/engineio/socket.py b/engineio/socket.py index 65cff1c7..82fa2da3 100644 --- a/engineio/socket.py +++ b/engineio/socket.py @@ -63,14 +63,23 @@ def receive(self, pkt): else: raise exceptions.UnknownPacketError() - def send(self, pkt): - """Send a packet to the client.""" + def check_ping_timeout(self): + """Make sure the client is still sending pings. + + This helps detect disconnections for long-polling clients. + """ if self.closed: raise exceptions.SocketIsClosedError() if time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) self.close(wait=False, abort=True) + return False + return True + + def send(self, pkt): + """Send a packet to the client.""" + if not self.check_ping_timeout(): return self.queue.put(pkt) self.server.logger.info('%s: Sending packet %s data %s', diff --git a/tests/test_asyncio_server.py b/tests/test_asyncio_server.py index d0bd0185..e999f95d 100644 --- a/tests/test_asyncio_server.py +++ b/tests/test_asyncio_server.py @@ -68,9 +68,18 @@ def _get_mock_socket(self): mock_socket.send = AsyncMock() mock_socket.handle_get_request = AsyncMock() mock_socket.handle_post_request = AsyncMock() + mock_socket.check_ping_timeout = AsyncMock() mock_socket.close = AsyncMock() return mock_socket + @classmethod + def setUpClass(cls): + asyncio_server.AsyncServer._default_monitor_clients = False + + @classmethod + def tearDownClass(cls): + asyncio_server.AsyncServer._default_monitor_clients = True + def setUp(self): logging.getLogger('engineio').setLevel(logging.NOTSET) @@ -839,3 +848,12 @@ def foo_handler(arg): ZeroDivisionError, asyncio.get_event_loop().run_until_complete, fut) self.assertEqual(result, ['bar']) + + @mock.patch('importlib.import_module') + def test_service_task_started(self, import_module): + a = self.get_async_mock() + import_module.side_effect = [a] + s = asyncio_server.AsyncServer(monitor_clients=True) + s._service_task = AsyncMock() + _run(s.handle_request('request')) + s._service_task.mock.assert_called_once_with() diff --git a/tests/test_server.py b/tests/test_server.py index 6c295421..e970cac9 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -46,6 +46,14 @@ def _get_mock_socket(self): mock_socket.upgraded = False return mock_socket + @classmethod + def setUpClass(cls): + server.Server._default_monitor_clients = False + + @classmethod + def tearDownClass(cls): + server.Server._default_monitor_clients = True + def setUp(self): logging.getLogger('engineio').setLevel(logging.NOTSET) @@ -863,3 +871,11 @@ def test_sleep(self): t = time.time() s.sleep(0.1) self.assertTrue(time.time() - t > 0.1) + + def test_service_task_started(self): + s = server.Server(monitor_clients=True) + s._service_task = mock.MagicMock() + environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''} + start_response = mock.MagicMock() + s.handle_request(environ, start_response) + s._service_task.assert_called_once_with()