From b61ea57b3d85c4c624a5538d139e340d09fc68ba Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 29 Jan 2022 16:11:30 +0100 Subject: [PATCH] TestClient timeout simulates http.disconnect Co-authored-by: Fantix King --- starlette/testclient.py | 28 +++++++++++++-- tests/test_testclient.py | 77 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index c951767b4..60a4a326c 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,6 +12,7 @@ from concurrent.futures import Future from urllib.parse import unquote, urljoin, urlsplit +import anyio import anyio.abc import requests from anyio.streams.stapled import StapledObjectStream @@ -190,10 +191,18 @@ def send( request_complete = False response_started = False response_complete: anyio.Event + timeout_called = False raw_kwargs: typing.Dict[str, typing.Any] = {"body": io.BytesIO()} template = None context = None + def do_timeout() -> None: + nonlocal timeout_called, response_complete, response_started + timeout_called = True + if request_complete: + response_started = True + response_complete.set() + async def receive() -> Message: nonlocal request_complete @@ -215,17 +224,23 @@ async def receive() -> Message: return {"type": "http.request", "body": chunk, "more_body": True} except StopIteration: request_complete = True + if timeout_called: + do_timeout() return {"type": "http.request", "body": b""} else: body_bytes = body request_complete = True + if timeout_called: + do_timeout() return {"type": "http.request", "body": body_bytes} async def send(message: Message) -> None: nonlocal raw_kwargs, response_started, template, context - if message["type"] == "http.response.start": + if timeout_called: + pass + elif message["type"] == "http.response.start": assert ( not response_started ), 'Received multiple "http.response.start" messages.' @@ -259,15 +274,24 @@ async def send(message: Message) -> None: template = message["template"] context = message["context"] + async def timeout_task(delay: float) -> None: + await anyio.sleep(delay) + do_timeout() + + timeout: typing.Optional[float] = kwargs.get("timeout") try: with self.portal_factory() as portal: response_complete = portal.call(anyio.Event) + if timeout: + portal.start_task_soon(timeout_task, timeout) portal.call(self.app, scope, receive, send) except BaseException as exc: if self.raise_server_exceptions: raise exc - if self.raise_server_exceptions: + if timeout_called: + raise requests.exceptions.ReadTimeout() + elif self.raise_server_exceptions: assert response_started, "TestClient did not receive any response." elif not response_started: raw_kwargs = { diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 8c0666789..b48d1ceb8 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -4,6 +4,7 @@ import anyio import pytest +import requests import sniffio import trio.lowlevel @@ -229,3 +230,79 @@ async def asgi(receive, send): with client.websocket_connect("/") as websocket: data = websocket.receive_json() assert data == {"message": "test"} + + +def test_timeout(test_client_factory): + done = False + + async def app(scope, receive, send): + nonlocal done + assert (await receive())["type"] == "http.request" + assert (await receive())["type"] == "http.disconnect" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + done = True + + client = test_client_factory(app) + with pytest.raises(requests.ReadTimeout): + client.get("/", timeout=0.001) + assert done + + +def test_timeout_generator(test_client_factory): + done = False + + async def app(scope, receive, send): + nonlocal done + assert (await receive())["type"] == "http.request" + await anyio.sleep(0.01) + assert (await receive())["type"] == "http.request" + assert (await receive())["type"] == "http.disconnect" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + done = True + + client = test_client_factory(app) + + def gen(): + yield "hello" + + with pytest.raises(requests.ReadTimeout): + client.post("/", data=gen(), timeout=0.001) + assert done + + +def test_timeout_early_done(test_client_factory): + done = False + + async def app(scope, receive, send): + nonlocal done + await anyio.sleep(0.01) + assert (await receive())["type"] == "http.request" + assert (await receive())["type"] == "http.disconnect" + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + done = True + + client = test_client_factory(app) + with pytest.raises(requests.ReadTimeout): + client.get("/", timeout=0.001) + assert done