Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,24 @@ def send( # type: ignore
request_complete = False
response_started = False
response_complete = False
timeout_called = False
raw_kwargs = {"body": io.BytesIO()} # type: typing.Dict[str, typing.Any]
template = None
context = None

def do_timeout() -> None:
nonlocal timeout_called, response_started, response_complete
timeout_called = True
if request_complete:
response_started = True
response_complete = True
response_complete_set.set()

async def receive() -> Message:
nonlocal request_complete, response_complete

if request_complete:
while not response_complete:
await asyncio.sleep(0.0001)
await response_complete_set.wait()
return {"type": "http.disconnect"}

body = request.body
Expand All @@ -187,17 +195,23 @@ async def receive() -> Message:
return {"type": "http.request", "body": chunk, "more_body": True}
except StopIteration:
request_complete = True
if timeout_called:
do_timeout()
return {"type": "http.request", "body": b""}
else:
body_bytes = body

request_complete = True
if timeout_called:
do_timeout()
return {"type": "http.request", "body": body_bytes}

async def send(message: Message) -> None:
nonlocal raw_kwargs, response_started, response_complete, template, context

if message["type"] == "http.response.start":
if timeout_called:
pass
elif message["type"] == "http.response.start":
assert (
not response_started
), 'Received multiple "http.response.start" messages.'
Expand Down Expand Up @@ -226,6 +240,7 @@ async def send(message: Message) -> None:
if not more_body:
raw_kwargs["body"].seek(0)
response_complete = True
response_complete_set.set()
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
Expand All @@ -236,13 +251,20 @@ async def send(message: Message) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

response_complete_set = asyncio.Event()
timeout = kwargs.get("timeout")
if timeout:
loop.call_later(timeout, do_timeout)

try:
loop.run_until_complete(self.app(scope, receive, send))
except BaseException as exc:
if self.raise_server_exceptions:
raise exc from None

if self.raise_server_exceptions:
if timeout_called:
raise requests.exceptions.ReadTimeout()
elif self.raise_server_exceptions:
assert response_started, "TestClient did not receive any response."
elif not response_started:
raw_kwargs = {
Expand Down
77 changes: 77 additions & 0 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

import pytest
import requests

from starlette.applications import Starlette
from starlette.responses import JSONResponse
Expand Down Expand Up @@ -113,3 +114,79 @@ async def asgi(receive, send):
with client.websocket_connect("/") as websocket:
data = websocket.receive_json()
assert data == {"message": "test"}


def test_timeout():
done = False

async def app(scope, receive, send):
nonlocal done
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = TestClient(app)
with pytest.raises(requests.ReadTimeout):
client.get("/", timeout=0.001)
assert done


def test_timeout_generator():
done = False

async def app(scope, receive, send):
nonlocal done
assert (await receive())["type"] == "http.request"
await asyncio.sleep(0.01)
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = TestClient(app)

def gen():
yield "hello"

with pytest.raises(requests.ReadTimeout):
client.post("/", data=gen(), timeout=0.001)
assert done


def test_timeout_early_done():
done = False

async def app(scope, receive, send):
nonlocal done
await asyncio.sleep(0.01)
assert (await receive())["type"] == "http.request"
assert (await receive())["type"] == "http.disconnect"
await send(
{
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"text/plain"]],
}
)
await send({"type": "http.response.body", "body": b"Hello, world!"})
done = True

client = TestClient(app)
with pytest.raises(requests.ReadTimeout):
client.get("/", timeout=0.001)
assert done