From 0701493d1d893fded41cb5a82c193b02f516219e Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 10:41:46 +0200 Subject: [PATCH 1/9] Add async test client --- starlette/testclient.py | 644 ++++++++++++++++++++++++++++++++++ tests/conftest.py | 17 +- tests/test_asynctestclient.py | 435 +++++++++++++++++++++++ tests/types.py | 18 +- 4 files changed, 1111 insertions(+), 3 deletions(-) create mode 100644 tests/test_asynctestclient.py diff --git a/starlette/testclient.py b/starlette/testclient.py index d54025e52..80581e58b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -68,6 +68,11 @@ class _AsyncBackend(typing.TypedDict): backend_options: dict[str, typing.Any] +class _AsyncUpgrade(Exception): + def __init__(self, session: AsyncWebSocketTestSession) -> None: + self.session = session + + class _Upgrade(Exception): def __init__(self, session: WebSocketTestSession) -> None: self.session = session @@ -729,3 +734,642 @@ async def receive() -> typing.Any: ) if message["type"] == "lifespan.shutdown.failed": await receive() + + +class AsyncWebSocketTestSession: + def __init__( + self, + app: ASGI3App, + scope: Scope, + ) -> None: + self.app = app + self.scope = scope + self.accepted_subprotocol = None + self.extra_headers = None + + async def __aenter__(self) -> AsyncWebSocketTestSession: + async with contextlib.AsyncExitStack() as stack: + task_group = await stack.enter_async_context(anyio.create_task_group()) + self.done = anyio.Event() + + async def run(*, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + await self._run(task_status=task_status) + self.done.set() + + await task_group.start(run) + stack.push_async_callback(self.done.wait) + stack.callback(task_group.cancel_scope.cancel) + await self.send({"type": "websocket.connect"}) + message = await self.receive() + await self._raise_on_close(message) + self.accepted_subprotocol = message.get("subprotocol", None) + self.extra_headers = message.get("headers", None) + stack.push_async_callback(self.aclose, 1000) + self.exit_stack = stack.pop_all() + return self + + async def __aexit__(self, *args: typing.Any) -> bool | None: + return await self.exit_stack.__aexit__(*args) + + async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: + send: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + send_tx, send_rx = send + receive: anyio.create_memory_object_stream[Message] = anyio.create_memory_object_stream(math.inf) + receive_tx, receive_rx = receive + with send_tx, send_rx, receive_tx, receive_rx, anyio.CancelScope() as cs: + self._receive_tx = receive_tx + self._send_rx = send_rx + task_status.started(cs) + await self.app(self.scope, receive_rx.receive, send_tx.send) + + # wait for cs.cancel to be called before closing streams + await anyio.sleep_forever() + + async def _raise_on_close(self, message: Message) -> None: + if message["type"] == "websocket.close": + raise WebSocketDisconnect(code=message.get("code", 1000), reason=message.get("reason", "")) + elif message["type"] == "websocket.http.response.start": + status_code: int = message["status"] + headers: list[tuple[bytes, bytes]] = message["headers"] + body: list[bytes] = [] + while True: + message = await self.receive() + assert message["type"] == "websocket.http.response.body" + body.append(message["body"]) + if not message.get("more_body", False): + break + raise WebSocketDenialResponse(status_code=status_code, headers=headers, content=b"".join(body)) + + async def send(self, message: Message) -> None: + await self._receive_tx.send(message) + + async def send_text(self, data: str) -> None: + await self.send({"type": "websocket.receive", "text": data}) + + async def send_bytes(self, data: bytes) -> None: + await self.send({"type": "websocket.receive", "bytes": data}) + + async def send_json(self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text") -> None: + text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) + if mode == "text": + await self.send({"type": "websocket.receive", "text": text}) + else: + await self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")}) + + async def aclose(self, code: int = 1000, reason: str | None = None) -> None: + await self.send({"type": "websocket.disconnect", "code": code, "reason": reason}) + + async def receive(self) -> Message: + return await self._send_rx.receive() + + async def receive_text(self) -> str: + message = await self.receive() + await self._raise_on_close(message) + return typing.cast(str, message["text"]) + + async def receive_bytes(self) -> bytes: + message = await self.receive() + await self._raise_on_close(message) + return typing.cast(bytes, message["bytes"]) + + async def receive_json(self, mode: typing.Literal["text", "binary"] = "text") -> typing.Any: + message = await self.receive() + await self._raise_on_close(message) + if mode == "text": + text = message["text"] + else: + text = message["bytes"].decode("utf-8") + return json.loads(text) + + +class _AsyncTestClientTransport(httpx.AsyncBaseTransport): + def __init__( + self, + app: ASGI3App, + raise_server_exceptions: bool = True, + root_path: str = "", + *, + client: tuple[str, int], + app_state: dict[str, typing.Any], + ) -> None: + self.app = app + self.raise_server_exceptions = raise_server_exceptions + self.root_path = root_path + self.app_state = app_state + self.client = client + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + scheme = request.url.scheme + netloc = request.url.netloc.decode(encoding="ascii") + path = request.url.path + raw_path = request.url.raw_path + query = request.url.query.decode(encoding="ascii") + + default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme] + + if ":" in netloc: + host, port_string = netloc.split(":", 1) + port = int(port_string) + else: + host = netloc + port = default_port + + # Include the 'host' header. + if "host" in request.headers: + headers: list[tuple[bytes, bytes]] = [] + elif port == default_port: # pragma: no cover + headers = [(b"host", host.encode())] + else: # pragma: no cover + headers = [(b"host", (f"{host}:{port}").encode())] + + # Include other request headers. + headers += [(key.lower().encode(), value.encode()) for key, value in request.headers.multi_items()] + + scope: dict[str, typing.Any] + + if scheme in {"ws", "wss"}: + subprotocol = request.headers.get("sec-websocket-protocol", None) + if subprotocol is None: + subprotocols: typing.Sequence[str] = [] + else: + subprotocols = [value.strip() for value in subprotocol.split(",")] + scope = { + "type": "websocket", + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "subprotocols": subprotocols, + "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, + } + session = AsyncWebSocketTestSession(self.app, scope) + raise _AsyncUpgrade(session) + + scope = { + "type": "http", + "http_version": "1.1", + "method": request.method, + "path": unquote(path), + "raw_path": raw_path.split(b"?", 1)[0], + "root_path": self.root_path, + "scheme": scheme, + "query_string": query.encode(), + "headers": headers, + "client": self.client, + "server": [host, port], + "extensions": {"http.response.debug": {}}, + "state": self.app_state.copy(), + } + + request_complete = False + response_started = False + response_complete: anyio.Event + raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()} + template = None + context = None + + async def receive() -> Message: + nonlocal request_complete + + if request_complete: + if not response_complete.is_set(): + await response_complete.wait() + return {"type": "http.disconnect"} + + body = request.read() + if isinstance(body, str): + body_bytes: bytes = body.encode("utf-8") # pragma: no cover + elif body is None: + body_bytes = b"" # pragma: no cover + elif isinstance(body, GeneratorType): + try: # pragma: no cover + chunk = body.send(None) + if isinstance(chunk, str): + chunk = chunk.encode("utf-8") + return {"type": "http.request", "body": chunk, "more_body": True} + except StopIteration: # pragma: no cover + request_complete = True + return {"type": "http.request", "body": b""} + else: + body_bytes = body + + request_complete = True + return {"type": "http.request", "body": body_bytes} + + async def send(message: Message) -> None: + nonlocal raw_kwargs, response_started, template, context + + if message["type"] == "http.response.start": + assert not response_started, 'Received multiple "http.response.start" messages.' + raw_kwargs["status_code"] = message["status"] + raw_kwargs["headers"] = [(key.decode(), value.decode()) for key, value in message.get("headers", [])] + response_started = True + elif message["type"] == "http.response.body": + assert response_started, 'Received "http.response.body" without "http.response.start".' + assert not response_complete.is_set(), 'Received "http.response.body" after response completed.' + body = message.get("body", b"") + more_body = message.get("more_body", False) + if request.method != "HEAD": + raw_kwargs["stream"].write(body) + if not more_body: + raw_kwargs["stream"].seek(0) + response_complete.set() + elif message["type"] == "http.response.debug": + template = message["info"]["template"] + context = message["info"]["context"] + + try: + response_complete = anyio.Event() + await self.app(scope, receive, send) + except BaseException as exc: + if self.raise_server_exceptions: + raise exc + + if self.raise_server_exceptions: + assert response_started, "TestClient did not receive any response." + elif not response_started: + raw_kwargs = { + "status_code": 500, + "headers": [], + "stream": io.BytesIO(), + } + + raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read()) + + response = httpx.Response(**raw_kwargs, request=request) + if template is not None: + response.template = template # type: ignore[attr-defined] + response.context = context # type: ignore[attr-defined] + return response + + +class AsyncTestClient(httpx.AsyncClient): + __test__ = False + + def __init__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + backend: typing.Literal["asyncio", "trio"] = "asyncio", + backend_options: dict[str, typing.Any] | None = None, + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> None: + self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) + if _is_asgi3(app): + asgi_app = app + else: + app = typing.cast(ASGI2App, app) # type: ignore[assignment] + asgi_app = _WrapASGI2(app) # type: ignore[arg-type] + self.app = asgi_app + self.app_state: dict[str, typing.Any] = {} + transport = _AsyncTestClientTransport( + self.app, + raise_server_exceptions=raise_server_exceptions, + root_path=root_path, + app_state=self.app_state, + client=client, + ) + if headers is None: + headers = {} + headers.setdefault("user-agent", "testclient") + super().__init__( + base_url=base_url, + headers=headers, + transport=transport, + follow_redirects=follow_redirects, + cookies=cookies, + ) + + async def request( # type: ignore[override] + self, + method: str, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + if timeout is not httpx.USE_CLIENT_DEFAULT: + warnings.warn( + "You should not use the 'timeout' argument with the TestClient. " + "See https://github.com/encode/starlette/issues/1108 for more information.", + DeprecationWarning, + ) + url = self._merge_url(url) + return await super().request( + method, + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def get( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().get( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + def options( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return super().options( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def head( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().head( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def post( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().post( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def put( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().put( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def patch( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + content: httpx._types.RequestContent | None = None, + data: _RequestData | None = None, + files: httpx._types.RequestFiles | None = None, + json: typing.Any = None, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().patch( + url, + content=content, + data=data, + files=files, + json=json, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def delete( # type: ignore[override] + self, + url: httpx._types.URLTypes, + *, + params: httpx._types.QueryParamTypes | None = None, + headers: httpx._types.HeaderTypes | None = None, + cookies: httpx._types.CookieTypes | None = None, + auth: httpx._types.AuthTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + follow_redirects: bool | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, + extensions: dict[str, typing.Any] | None = None, + ) -> httpx.Response: + return await super().delete( + url, + params=params, + headers=headers, + cookies=cookies, + auth=auth, + follow_redirects=follow_redirects, + timeout=timeout, + extensions=extensions, + ) + + async def websocket_connect( + self, + url: str, + subprotocols: typing.Sequence[str] | None = None, + **kwargs: typing.Any, + ) -> AsyncWebSocketTestSession: + url = urljoin("ws://testserver", url) + headers = kwargs.get("headers", {}) + headers.setdefault("connection", "upgrade") + headers.setdefault("sec-websocket-key", "testserver==") + headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + kwargs["headers"] = headers + try: + await super().request("GET", url, **kwargs) + except _AsyncUpgrade as exc: + session = exc.session + else: + raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover + + return session + + async def __aenter__(self) -> AsyncTestClient: + async with contextlib.AsyncExitStack() as stack: + task_group = await stack.enter_async_context(anyio.create_task_group()) + send: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any] | None] = ( + anyio.create_memory_object_stream(math.inf) + ) + receive: anyio.create_memory_object_stream[typing.MutableMapping[str, typing.Any]] = ( + anyio.create_memory_object_stream(math.inf) + ) + for channel in (*send, *receive): + stack.push_async_callback(channel.aclose) + self.stream_send = StapledObjectStream(*send) + self.stream_receive = StapledObjectStream(*receive) + self.task_done = anyio.Event() + + async def lifespan(): + await self.lifespan() + self.task_done.set() + + task_group.start_soon(lifespan) + await self.wait_startup() + + @stack.push_async_callback + async def wait_shutdown() -> None: + await self.wait_shutdown() + + self.exit_stack = stack.pop_all() + + return self + + async def __aexit__(self, *args: typing.Any) -> None: + await self.exit_stack.aclose() + + async def lifespan(self) -> None: + scope = {"type": "lifespan", "state": self.app_state} + try: + await self.app(scope, self.stream_receive.receive, self.stream_send.send) + finally: + try: + await self.stream_send.send(None) + except anyio.ClosedResourceError: + pass + + async def wait_startup(self) -> None: + await self.stream_receive.send({"type": "lifespan.startup"}) + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + await self.task_done.wait() + return message + + message = await receive() + assert message["type"] in ( + "lifespan.startup.complete", + "lifespan.startup.failed", + ) + if message["type"] == "lifespan.startup.failed": + await receive() + + async def wait_shutdown(self) -> None: + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + await self.task_done.wait() + return message + + await self.stream_receive.send({"type": "lifespan.shutdown"}) + message = await receive() + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + await receive() diff --git a/tests/conftest.py b/tests/conftest.py index 4db3ae018..f67990cfe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,10 +5,23 @@ import pytest -from starlette.testclient import TestClient -from tests.types import TestClientFactory +from starlette.testclient import AsyncTestClient, TestClient +from tests.types import AsyncTestClientFactory, TestClientFactory +@pytest.fixture +def async_test_client_factory( + anyio_backend_name: Literal["asyncio", "trio"], + anyio_backend_options: dict[str, Any], +) -> AsyncTestClientFactory: + # anyio_backend_name defined by: + # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on + return functools.partial( + AsyncTestClient, + backend=anyio_backend_name, + backend_options=anyio_backend_options, + ) + @pytest.fixture def test_client_factory( anyio_backend_name: Literal["asyncio", "trio"], diff --git a/tests/test_asynctestclient.py b/tests/test_asynctestclient.py new file mode 100644 index 000000000..7fac538b7 --- /dev/null +++ b/tests/test_asynctestclient.py @@ -0,0 +1,435 @@ +from __future__ import annotations + +import itertools +import sys +from asyncio import Task, current_task as asyncio_current_task +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel +import pytest +import sniffio +import trio.lowlevel + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import JSONResponse, RedirectResponse, Response +from starlette.routing import Route +from starlette.testclient import ASGIInstance, AsyncTestClient +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket, WebSocketDisconnect +from tests.types import AsyncTestClientFactory + + +def mock_service_endpoint(request: Request) -> JSONResponse: + return JSONResponse({"mock": "example"}) + + +mock_service = Starlette(routes=[Route("/", endpoint=mock_service_endpoint)]) + + +def current_task() -> Task[Any] | trio.lowlevel.Task: + # anyio's TaskInfo comparisons are invalid after their associated native + # task object is GC'd https://github.com/agronholm/anyio/issues/324 + asynclib_name = sniffio.current_async_library() + if asynclib_name == "trio": + return trio.lowlevel.current_task() + + if asynclib_name == "asyncio": + task = asyncio_current_task() + if task is None: + raise RuntimeError("must be called from a running task") # pragma: no cover + return task + raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover + + +def startup() -> None: + raise RuntimeError() + + +async def test_use_testclient_in_endpoint(async_test_client_factory: AsyncTestClientFactory) -> None: + """ + We should be able to use the test client within applications. + + This is useful if we need to mock out other services, + during tests or in development. + """ + + async def homepage(request: Request) -> JSONResponse: + client = async_test_client_factory(mock_service) + response = await client.get("/") + return JSONResponse(response.json()) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.json() == {"mock": "example"} + + +def test_testclient_headers_behavior() -> None: + """ + We should be able to use the test client with user defined headers. + + This is useful if we need to set custom headers for authentication + during tests or in development. + """ + + client = AsyncTestClient(mock_service) + assert client.headers.get("user-agent") == "testclient" + + client = AsyncTestClient(mock_service, headers={"user-agent": "non-default-agent"}) + assert client.headers.get("user-agent") == "non-default-agent" + + client = AsyncTestClient(mock_service, headers={"Authentication": "Bearer 123"}) + assert client.headers.get("user-agent") == "testclient" + assert client.headers.get("Authentication") == "Bearer 123" + + +async def test_use_testclient_as_contextmanager(async_test_client_factory: AsyncTestClientFactory, anyio_backend_name: str) -> None: + """ + This test asserts a number of properties that are important for an + app level task_group + """ + counter = itertools.count() + identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar") + + def get_identity() -> int: + try: + return identity_runvar.get() + except LookupError: + token = next(counter) + identity_runvar.set(token) + return token + + startup_task = object() + startup_loop = None + shutdown_task = object() + shutdown_loop = None + + @asynccontextmanager + async def lifespan_context(app: Starlette) -> AsyncGenerator[None, None]: + nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop + + startup_task = current_task() + startup_loop = get_identity() + async with anyio.create_task_group(): + yield + shutdown_task = current_task() + shutdown_loop = get_identity() + + async def loop_id(request: Request) -> JSONResponse: + return JSONResponse(get_identity()) + + app = Starlette( + lifespan=lifespan_context, + routes=[Route("/loop_id", endpoint=loop_id)], + ) + + client = async_test_client_factory(app) + + async with client: + # within a TestClient context every async request runs in the same thread + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + # that thread is also the same as the lifespan thread + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events run in the same task, this is important because a task + # group must be entered and exited in the same task. + assert startup_task is shutdown_task + + # outside the TestClient context, new requests continue to spawn in new + # event loops in new threads + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + first_task = startup_task + + async with client: + # the TestClient context can be re-used, starting a new lifespan task + # in a new thread + assert (await client.get("/loop_id")).json() == 0 + assert (await client.get("/loop_id")).json() == 0 + + assert startup_loop == 0 + assert shutdown_loop == 0 + + # lifespan events still run in the same task, with the context but... + assert startup_task is shutdown_task + + # ... the second TestClient context creates a new lifespan task. + assert first_task is not startup_task + + +async def test_error_on_startup(async_test_client_factory: AsyncTestClientFactory) -> None: + with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): + startup_error_app = Starlette(on_startup=[startup]) + + with pytest.raises(ExceptionGroup) as excinfo: + async with async_test_client_factory(startup_error_app): + pass # pragma: no cover + + assert excinfo.group_contains(RuntimeError) + + +async def test_exception_in_middleware(async_test_client_factory: AsyncTestClientFactory) -> None: + class MiddlewareException(Exception): + pass + + class BrokenMiddleware: + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + raise MiddlewareException() + + broken_middleware = Starlette(middleware=[Middleware(BrokenMiddleware)]) + + with pytest.raises(ExceptionGroup) as excinfo: + async with async_test_client_factory(broken_middleware): + pass # pragma: no cover + + assert excinfo.group_contains(MiddlewareException) + + +async def test_testclient_asgi2(async_test_client_factory: AsyncTestClientFactory) -> None: + def app(scope: Scope) -> ASGIInstance: + async def inner(receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + return inner + + client = async_test_client_factory(app) # type: ignore + response = await client.get("/") + assert response.text == "Hello, world!" + + +async def test_testclient_asgi3(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"text/plain"]], + } + ) + await send({"type": "http.response.body", "body": b"Hello, world!"}) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.text == "Hello, world!" + + +async def test_websocket_blocking_receive(async_test_client_factory: AsyncTestClientFactory) -> None: + def app(scope: Scope) -> ASGIInstance: + async def respond(websocket: WebSocket) -> None: + await websocket.send_json({"message": "test"}) + + async def asgi(receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + async with anyio.create_task_group() as task_group: + task_group.start_soon(respond, websocket) + try: + # this will block as the client does not send us data + # it should not prevent `respond` from executing though + await websocket.receive_json() + except WebSocketDisconnect: + pass + + return asgi + + client = async_test_client_factory(app) # type: ignore + async with await client.websocket_connect("/") as websocket: + data = await websocket.receive_json() + assert data == {"message": "test"} + + +async def test_websocket_not_block_on_close(async_test_client_factory: AsyncTestClientFactory) -> None: + cancelled = False + + def app(scope: Scope) -> ASGIInstance: + async def asgi(receive: Receive, send: Send) -> None: + nonlocal cancelled + try: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + await anyio.sleep_forever() + except anyio.get_cancelled_exc_class(): + cancelled = True + raise + + return asgi + + client = async_test_client_factory(app) # type: ignore + async with await client.websocket_connect("/"): + ... + assert cancelled + + +async def test_client(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + client = scope.get("client") + assert client is not None + host, port = client + response = JSONResponse({"host": host, "port": port}) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/") + assert response.json() == {"host": "testclient", "port": 50000} + + +async def test_client_custom_client(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + client = scope.get("client") + assert client is not None + host, port = client + response = JSONResponse({"host": host, "port": port}) + await response(scope, receive, send) + + client = async_test_client_factory(app, client=("192.168.0.1", 3000)) + response = await client.get("/") + assert response.json() == {"host": "192.168.0.1", "port": 3000} + + +@pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) +async def test_query_params(async_test_client_factory: AsyncTestClientFactory, param: str) -> None: + def homepage(request: Request) -> Response: + return Response(request.query_params["param"]) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + client = async_test_client_factory(app) + response = await client.get("/", params={"param": param}) + assert response.text == param + + +@pytest.mark.parametrize( + "domain, ok", + [ + pytest.param( + "testserver", + True, + marks=[ + pytest.mark.xfail( + sys.version_info < (3, 11), + reason="Fails due to domain handling in http.cookiejar module (see #2152)", + ), + ], + ), + ("testserver.local", True), + ("localhost", False), + ("example.com", False), + ], +) +async def test_domain_restricted_cookies(async_test_client_factory: AsyncTestClientFactory, domain: str, ok: bool) -> None: + """ + Test that test client discards domain restricted cookies which do not match the + base_url of the testclient (`http://testserver` by default). + + The domain `testserver.local` works because the Python http.cookiejar module derives + the "effective domain" by appending `.local` to non-dotted request domains + in accordance with RFC 2965. + """ + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response("Hello, world!", media_type="text/plain") + response.set_cookie( + "mycookie", + "myvalue", + path="/", + domain=domain, + ) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/") + cookie_set = len(response.cookies) == 1 + assert cookie_set == ok + + +async def test_forward_follow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + if "/ok" in scope["path"]: + response = Response("ok") + else: + response = RedirectResponse("/ok") + await response(scope, receive, send) + + client = async_test_client_factory(app, follow_redirects=True) + response = await client.get("/") + assert response.status_code == 200 + + +async def test_forward_nofollow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = RedirectResponse("/ok") + await response(scope, receive, send) + + client = async_test_client_factory(app, follow_redirects=False) + response = await client.get("/") + assert response.status_code == 307 + + +async def test_with_duplicate_headers(async_test_client_factory: AsyncTestClientFactory) -> None: + def homepage(request: Request) -> JSONResponse: + return JSONResponse({"x-token": request.headers.getlist("x-token")}) + + app = Starlette(routes=[Route("/", endpoint=homepage)]) + client = async_test_client_factory(app) + response = await client.get("/", headers=[("x-token", "foo"), ("x-token", "bar")]) + assert response.json() == {"x-token": ["foo", "bar"]} + + +async def test_merge_url(async_test_client_factory: AsyncTestClientFactory) -> None: + def homepage(request: Request) -> Response: + return Response(request.url.path) + + app = Starlette(routes=[Route("/api/v1/bar", endpoint=homepage)]) + client = async_test_client_factory(app, base_url="http://testserver/api/v1/") + response = await client.get("/bar") + assert response.text == "/api/v1/bar" + + +async def test_raw_path_with_querystring(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response(scope.get("raw_path")) + await response(scope, receive, send) + + client = async_test_client_factory(app) + response = await client.get("/hello-world", params={"foo": "bar"}) + assert response.content == b"/hello-world" + + +async def test_websocket_raw_path_without_params(async_test_client_factory: AsyncTestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + websocket = WebSocket(scope, receive=receive, send=send) + await websocket.accept() + raw_path = scope.get("raw_path") + assert raw_path is not None + await websocket.send_bytes(raw_path) + + client = async_test_client_factory(app) + async with await client.websocket_connect("/hello-world", params={"foo": "bar"}) as websocket: + data = await websocket.receive_bytes() + assert data == b"/hello-world" + + +@pytest.mark.anyio +async def test_timeout_deprecation() -> None: + with pytest.deprecated_call(match="You should not use the 'timeout' argument with the TestClient."): + client = AsyncTestClient(mock_service) + await client.get("/", timeout=1) diff --git a/tests/types.py b/tests/types.py index e4769d308..f0a3b44ed 100644 --- a/tests/types.py +++ b/tests/types.py @@ -4,11 +4,24 @@ import httpx -from starlette.testclient import TestClient +from starlette.testclient import AsyncTestClient, TestClient from starlette.types import ASGIApp if TYPE_CHECKING: + class AsyncTestClientFactory(Protocol): # pragma: no cover + def __call__( + self, + app: ASGIApp, + base_url: str = "http://testserver", + raise_server_exceptions: bool = True, + root_path: str = "", + cookies: httpx._types.CookieTypes | None = None, + headers: dict[str, str] | None = None, + follow_redirects: bool = True, + client: tuple[str, int] = ("testclient", 50000), + ) -> AsyncTestClient: ... + class TestClientFactory(Protocol): # pragma: no cover def __call__( self, @@ -23,5 +36,8 @@ def __call__( ) -> TestClient: ... else: # pragma: no cover + class AsyncTestClientFactory: + __test__ = False + class TestClientFactory: __test__ = False From e252ead923a2a930b2c115470715cd7b0c5cd250 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 11:05:05 +0200 Subject: [PATCH 2/9] lint and types --- starlette/testclient.py | 8 ++++---- tests/conftest.py | 1 + tests/test_asynctestclient.py | 8 ++++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 80581e58b..3564e8d7e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -766,7 +766,7 @@ async def run(*, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None: self.extra_headers = message.get("headers", None) stack.push_async_callback(self.aclose, 1000) self.exit_stack = stack.pop_all() - return self + return self async def __aexit__(self, *args: typing.Any) -> bool | None: return await self.exit_stack.__aexit__(*args) @@ -1113,7 +1113,7 @@ async def get( # type: ignore[override] extensions=extensions, ) - def options( # type: ignore[override] + async def options( # type: ignore[override] self, url: httpx._types.URLTypes, *, @@ -1125,7 +1125,7 @@ def options( # type: ignore[override] timeout: httpx._types.TimeoutTypes | httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT, extensions: dict[str, typing.Any] | None = None, ) -> httpx.Response: - return super().options( + return await super().options( url, params=params, headers=headers, @@ -1313,7 +1313,7 @@ async def __aenter__(self) -> AsyncTestClient: self.stream_receive = StapledObjectStream(*receive) self.task_done = anyio.Event() - async def lifespan(): + async def lifespan() -> None: await self.lifespan() self.task_done.set() diff --git a/tests/conftest.py b/tests/conftest.py index f67990cfe..021eabce2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ def async_test_client_factory( backend_options=anyio_backend_options, ) + @pytest.fixture def test_client_factory( anyio_backend_name: Literal["asyncio", "trio"], diff --git a/tests/test_asynctestclient.py b/tests/test_asynctestclient.py index 7fac538b7..6017a629d 100644 --- a/tests/test_asynctestclient.py +++ b/tests/test_asynctestclient.py @@ -89,7 +89,9 @@ def test_testclient_headers_behavior() -> None: assert client.headers.get("Authentication") == "Bearer 123" -async def test_use_testclient_as_contextmanager(async_test_client_factory: AsyncTestClientFactory, anyio_backend_name: str) -> None: +async def test_use_testclient_as_contextmanager( + async_test_client_factory: AsyncTestClientFactory, anyio_backend_name: str +) -> None: """ This test asserts a number of properties that are important for an app level task_group @@ -335,7 +337,9 @@ def homepage(request: Request) -> Response: ("example.com", False), ], ) -async def test_domain_restricted_cookies(async_test_client_factory: AsyncTestClientFactory, domain: str, ok: bool) -> None: +async def test_domain_restricted_cookies( + async_test_client_factory: AsyncTestClientFactory, domain: str, ok: bool +) -> None: """ Test that test client discards domain restricted cookies which do not match the base_url of the testclient (`http://testserver` by default). From 6403b576d285d481bc30729ce41ecfe68a26af51 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 11:26:25 +0200 Subject: [PATCH 3/9] Support ExceptionGroup for Python <3.11 --- requirements.txt | 1 + tests/test_asynctestclient.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/requirements.txt b/requirements.txt index 01c2016c1..a6078076f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ types-PyYAML==6.0.12.20250402 types-dataclasses==0.6.6 pytest==8.3.5 trio==0.30.0 +exceptiongroup; python_version<'3.11' # Documentation black==25.1.0 diff --git a/tests/test_asynctestclient.py b/tests/test_asynctestclient.py index 6017a629d..cba73b449 100644 --- a/tests/test_asynctestclient.py +++ b/tests/test_asynctestclient.py @@ -23,6 +23,9 @@ from starlette.websockets import WebSocket, WebSocketDisconnect from tests.types import AsyncTestClientFactory +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + def mock_service_endpoint(request: Request) -> JSONResponse: return JSONResponse({"mock": "example"}) From 64710c92380295bdf2221dbc03f990d97f2800d2 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 11:49:32 +0200 Subject: [PATCH 4/9] Ruff requires python >=3.11 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 02a3820fc..2f66ccec5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ path = "starlette/__init__.py" [tool.ruff] line-length = 120 +requires-python = ">=3.11" [tool.ruff.lint] select = [ From 9761e7c43e6a8a31779b1d416a6940e299821111 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 11:54:45 +0200 Subject: [PATCH 5/9] Ruff target-version py311 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2f66ccec5..6a70269ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,7 @@ path = "starlette/__init__.py" [tool.ruff] line-length = 120 -requires-python = ">=3.11" +target-version = "py311" [tool.ruff.lint] select = [ From e7100f43d5f756f30ab8bcfe64cf5553563ccf9a Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 11:58:29 +0200 Subject: [PATCH 6/9] revert --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a70269ad..02a3820fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ path = "starlette/__init__.py" [tool.ruff] line-length = 120 -target-version = "py311" [tool.ruff.lint] select = [ From c623f6ada1277a83f7523b476cf6e3e7e96b8fc7 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 12:05:03 +0200 Subject: [PATCH 7/9] Temporarily don't test python 3.14 --- .github/workflows/test-suite.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 17d6b3e41..4e9f1b70a 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: "actions/checkout@v4" From fc6b132323d4b9ee52ec24e58ffe622d24e869f4 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 14:21:02 +0200 Subject: [PATCH 8/9] Remove backend parameter --- starlette/testclient.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 3564e8d7e..ba10a72e7 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -1017,14 +1017,11 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: typing.Literal["asyncio", "trio"] = "asyncio", - backend_options: dict[str, typing.Any] | None = None, cookies: httpx._types.CookieTypes | None = None, headers: dict[str, str] | None = None, follow_redirects: bool = True, client: tuple[str, int] = ("testclient", 50000), ) -> None: - self.async_backend = _AsyncBackend(backend=backend, backend_options=backend_options or {}) if _is_asgi3(app): asgi_app = app else: From f476e280d452d5c338ce681ae74976b7e3bd5300 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 15 May 2025 14:34:10 +0200 Subject: [PATCH 9/9] Fix tests --- tests/conftest.py | 9 +-------- tests/test_asynctestclient.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 021eabce2..87412e66e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,16 +10,9 @@ @pytest.fixture -def async_test_client_factory( - anyio_backend_name: Literal["asyncio", "trio"], - anyio_backend_options: dict[str, Any], -) -> AsyncTestClientFactory: - # anyio_backend_name defined by: - # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on +def async_test_client_factory() -> AsyncTestClientFactory: return functools.partial( AsyncTestClient, - backend=anyio_backend_name, - backend_options=anyio_backend_options, ) diff --git a/tests/test_asynctestclient.py b/tests/test_asynctestclient.py index cba73b449..e5d7b210b 100644 --- a/tests/test_asynctestclient.py +++ b/tests/test_asynctestclient.py @@ -53,6 +53,7 @@ def startup() -> None: raise RuntimeError() +@pytest.mark.anyio async def test_use_testclient_in_endpoint(async_test_client_factory: AsyncTestClientFactory) -> None: """ We should be able to use the test client within applications. @@ -172,6 +173,7 @@ async def loop_id(request: Request) -> JSONResponse: assert first_task is not startup_task +@pytest.mark.anyio async def test_error_on_startup(async_test_client_factory: AsyncTestClientFactory) -> None: with pytest.deprecated_call(match="The on_startup and on_shutdown parameters are deprecated"): startup_error_app = Starlette(on_startup=[startup]) @@ -183,6 +185,7 @@ async def test_error_on_startup(async_test_client_factory: AsyncTestClientFactor assert excinfo.group_contains(RuntimeError) +@pytest.mark.anyio async def test_exception_in_middleware(async_test_client_factory: AsyncTestClientFactory) -> None: class MiddlewareException(Exception): pass @@ -203,6 +206,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: assert excinfo.group_contains(MiddlewareException) +@pytest.mark.anyio async def test_testclient_asgi2(async_test_client_factory: AsyncTestClientFactory) -> None: def app(scope: Scope) -> ASGIInstance: async def inner(receive: Receive, send: Send) -> None: @@ -222,6 +226,7 @@ async def inner(receive: Receive, send: Send) -> None: assert response.text == "Hello, world!" +@pytest.mark.anyio async def test_testclient_asgi3(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await send( @@ -238,6 +243,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.text == "Hello, world!" +@pytest.mark.anyio async def test_websocket_blocking_receive(async_test_client_factory: AsyncTestClientFactory) -> None: def app(scope: Scope) -> ASGIInstance: async def respond(websocket: WebSocket) -> None: @@ -263,6 +269,7 @@ async def asgi(receive: Receive, send: Send) -> None: assert data == {"message": "test"} +@pytest.mark.anyio async def test_websocket_not_block_on_close(async_test_client_factory: AsyncTestClientFactory) -> None: cancelled = False @@ -285,6 +292,7 @@ async def asgi(receive: Receive, send: Send) -> None: assert cancelled +@pytest.mark.anyio async def test_client(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") @@ -298,6 +306,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"host": "testclient", "port": 50000} +@pytest.mark.anyio async def test_client_custom_client(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: client = scope.get("client") @@ -311,6 +320,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"host": "192.168.0.1", "port": 3000} +@pytest.mark.anyio @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) async def test_query_params(async_test_client_factory: AsyncTestClientFactory, param: str) -> None: def homepage(request: Request) -> Response: @@ -322,6 +332,7 @@ def homepage(request: Request) -> Response: assert response.text == param +@pytest.mark.anyio @pytest.mark.parametrize( "domain, ok", [ @@ -368,6 +379,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert cookie_set == ok +@pytest.mark.anyio async def test_forward_follow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: if "/ok" in scope["path"]: @@ -381,6 +393,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.status_code == 200 +@pytest.mark.anyio async def test_forward_nofollow_redirects(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = RedirectResponse("/ok") @@ -391,6 +404,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.status_code == 307 +@pytest.mark.anyio async def test_with_duplicate_headers(async_test_client_factory: AsyncTestClientFactory) -> None: def homepage(request: Request) -> JSONResponse: return JSONResponse({"x-token": request.headers.getlist("x-token")}) @@ -401,6 +415,7 @@ def homepage(request: Request) -> JSONResponse: assert response.json() == {"x-token": ["foo", "bar"]} +@pytest.mark.anyio async def test_merge_url(async_test_client_factory: AsyncTestClientFactory) -> None: def homepage(request: Request) -> Response: return Response(request.url.path) @@ -411,6 +426,7 @@ def homepage(request: Request) -> Response: assert response.text == "/api/v1/bar" +@pytest.mark.anyio async def test_raw_path_with_querystring(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: response = Response(scope.get("raw_path")) @@ -421,6 +437,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.content == b"/hello-world" +@pytest.mark.anyio async def test_websocket_raw_path_without_params(async_test_client_factory: AsyncTestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send)