Skip to content

Commit

Permalink
Dispose of disconnected sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Feb 15, 2016
1 parent e3badf3 commit 08c518d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 4 deletions.
5 changes: 4 additions & 1 deletion engineio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,11 @@ def handle_request(self, environ, start_response):
else:
r = packets
except IOError:
del self.sockets[sid]
if sid in self.sockets: # pragma: no cover
del self.sockets[sid]
r = self._bad_request()
if sid in self.sockets and self.sockets[sid].closed:
del self.sockets[sid]
elif method == 'POST':
if sid is None or sid not in self.sockets:
self.logger.warning('Invalid session %s', sid)
Expand Down
8 changes: 5 additions & 3 deletions engineio/socket.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def poll(self):
self.queue.task_done()
except self.server.async['queue'].Empty:
raise IOError()
if packets == [None]:
return []
try:
packets.append(self.queue.get(block=False))
self.queue.task_done()
Expand Down Expand Up @@ -157,7 +159,7 @@ def writer():
except:
break

writer_task = self.server.start_background_task(writer)
self.server.start_background_task(writer)

self.server.logger.info(
'%s: Upgrade to websocket successful', self.sid)
Expand All @@ -176,5 +178,5 @@ def writer():
self.receive(pkt)
except ValueError:
pass
self.close(wait=False, abort=True)
writer_task.join()
self.close(wait=True, abort=True)
self.queue.put(None) # unlock the writer task so that it can exit
15 changes: 15 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,21 @@ def test_get_request_custom_response(self):
start_response = mock.MagicMock()
self.assertEqual(s.handle_request(environ, start_response), 'resp')

def test_get_request_closes_socket(self):
s = server.Server()
mock_socket = self._get_mock_socket()

def mock_get_request(*args, **kwargs):
mock_socket.closed = True
return 'resp'

mock_socket.handle_get_request = mock_get_request
s.sockets['foo'] = mock_socket
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
start_response = mock.MagicMock()
self.assertEqual(s.handle_request(environ, start_response), 'resp')
self.assertNotIn('foo', s.sockets)

def test_get_request_error(self):
s = server.Server()
mock_socket = self._get_mock_socket()
Expand Down

0 comments on commit 08c518d

Please sign in to comment.