From 08c518db6c6dd12d81890cd6239113cfd84e9eec Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Sun, 14 Feb 2016 17:35:46 -0800 Subject: [PATCH] Dispose of disconnected sockets --- engineio/server.py | 5 ++++- engineio/socket.py | 8 +++++--- tests/test_server.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 4 deletions(-) mode change 100644 => 100755 engineio/socket.py diff --git a/engineio/server.py b/engineio/server.py index 2f78ba76..e02c27a8 100644 --- a/engineio/server.py +++ b/engineio/server.py @@ -229,8 +229,11 @@ def handle_request(self, environ, start_response): else: r = packets except IOError: - del self.sockets[sid] + if sid in self.sockets: # pragma: no cover + del self.sockets[sid] r = self._bad_request() + if sid in self.sockets and self.sockets[sid].closed: + del self.sockets[sid] elif method == 'POST': if sid is None or sid not in self.sockets: self.logger.warning('Invalid session %s', sid) diff --git a/engineio/socket.py b/engineio/socket.py old mode 100644 new mode 100755 index cf4a9030..e759dbcb --- a/engineio/socket.py +++ b/engineio/socket.py @@ -26,6 +26,8 @@ def poll(self): self.queue.task_done() except self.server.async['queue'].Empty: raise IOError() + if packets == [None]: + return [] try: packets.append(self.queue.get(block=False)) self.queue.task_done() @@ -157,7 +159,7 @@ def writer(): except: break - writer_task = self.server.start_background_task(writer) + self.server.start_background_task(writer) self.server.logger.info( '%s: Upgrade to websocket successful', self.sid) @@ -176,5 +178,5 @@ def writer(): self.receive(pkt) except ValueError: pass - self.close(wait=False, abort=True) - writer_task.join() + self.close(wait=True, abort=True) + self.queue.put(None) # unlock the writer task so that it can exit diff --git a/tests/test_server.py b/tests/test_server.py index 077f59ec..1bc3d29d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -456,6 +456,21 @@ def test_get_request_custom_response(self): start_response = mock.MagicMock() self.assertEqual(s.handle_request(environ, start_response), 'resp') + def test_get_request_closes_socket(self): + s = server.Server() + mock_socket = self._get_mock_socket() + + def mock_get_request(*args, **kwargs): + mock_socket.closed = True + return 'resp' + + mock_socket.handle_get_request = mock_get_request + s.sockets['foo'] = mock_socket + environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'} + start_response = mock.MagicMock() + self.assertEqual(s.handle_request(environ, start_response), 'resp') + self.assertNotIn('foo', s.sockets) + def test_get_request_error(self): s = server.Server() mock_socket = self._get_mock_socket()