Skip to content

Commit

Permalink
Add a new test to validate the lookahead race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
digitalresistor committed Oct 27, 2024
1 parent 6943dcf commit 7e7f11e
Showing 1 changed file with 54 additions and 1 deletion.
55 changes: 54 additions & 1 deletion tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,11 +805,12 @@ def app_check_disconnect(self, environ, start_response):
)
return [body]

def _make_app_with_lookahead(self):
def _make_app_with_lookahead(self, recv_bytes=8192):
"""
Setup a channel with lookahead and store it and the socket in self
"""
adj = DummyAdjustments()
adj.recv_bytes = recv_bytes
adj.channel_request_lookahead = 5
channel, sock, map = self._makeOneWithMap(adj=adj)
channel.server.application = self.app_check_disconnect
Expand Down Expand Up @@ -901,6 +902,58 @@ def test_lookahead_continue(self):
self.assertEqual(data.split("\r\n")[-1], "finished")
self.assertEqual(self.request_body, b"x")

def test_lookahead_bad_request_drop_extra_data(self):
"""
Send two requests, the first one being bad, split on the recv_bytes
limit, then emulate a race that could happen whereby we read data from
the socket while the service thread is cleaning up due to an error
processing the request.
"""

invalid_request = [
"GET / HTTP/1.1",
"Host: localhost:8080",
"Content-length: -1",
"",
]

invalid_request_len = len("".join([x + "\r\n" for x in invalid_request]))

second_request = [
"POST / HTTP/1.1",
"Host: localhost:8080",
"Content-Length: 1",
"",
"x",
]

full_request = invalid_request + second_request

self._make_app_with_lookahead(recv_bytes=invalid_request_len)
self._send(*full_request)
self.channel.handle_read()
self.assertEqual(len(self.channel.requests), 1)
self.channel.server.tasks[0].service()
self.assertTrue(self.channel.close_when_flushed)
# Read all of the next request
self.channel.handle_read()
self.channel.handle_read()
# Validate that there is no more data to be read
self.assertEqual(self.sock.remote.local_sent, b"")
# Validate that we dropped the data from the second read, and did not
# create a new request
self.assertEqual(len(self.channel.requests), 0)
data = self.sock.recv(256).decode("ascii")
self.assertFalse(self.channel.readable())
self.assertTrue(self.channel.writable())

# Handle the write, which will close the socket
self.channel.handle_write()
self.assertTrue(self.sock.closed)

data = self.sock.recv(256)
self.assertEqual(len(data), 0)


class DummySock:
blocking = False
Expand Down

0 comments on commit 7e7f11e

Please sign in to comment.