diff --git a/tests/test_server.py b/tests/test_server.py index ad82cf139..e0e5b8b35 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,6 +2,8 @@ import asyncio import contextlib +import contextvars +import json import logging import signal import sys @@ -12,12 +14,14 @@ import httpx import pytest +from tests.protocols.test_http import SIMPLE_GET_REQUEST from tests.utils import run_server -from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope +from uvicorn import Server +from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config +from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol -from uvicorn.server import Server pytestmark = pytest.mark.anyio @@ -95,3 +99,81 @@ async def test_request_than_limit_max_requests_warn_log( responses = await asyncio.gather(*tasks) assert len(responses) == 2 assert "Maximum request limit of 1 exceeded. Terminating process." in caplog.text + + +@contextlib.asynccontextmanager +async def server(*, app: ASGIApplication, port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol]): + config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls) + server = Server(config=config) + task = asyncio.create_task(server.serve()) + + while not server.started: + await asyncio.sleep(0.01) + + reader, writer = await asyncio.open_connection("127.0.0.1", port) + + async def extract_json_body(request: bytes): + writer.write(request) + await writer.drain() + + status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2] + assert status == b"HTTP/1.1 200 OK" + + content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:")) + return json.loads(await reader.readexactly(content_length)) + + try: + yield extract_json_body + finally: + writer.close() + await writer.wait_closed() + server.should_exit = True + await task + + +async def test_no_contextvars_pollution_asyncio( + http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int +): + """Non-regression test for https://github.com/encode/uvicorn/issues/2167.""" + default_contextvars = {c.name for c in contextvars.copy_context().keys()} + ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx") + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" + + # initial context should be empty + initial_context = { + n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars + } + # set any contextvar before the body is read + ctx.set(scope["path"]) + + while True: + message = await receive() + assert message["type"] == "http.request" + if not message["more_body"]: + break + + # return the initial context for empty assertion + body = json.dumps(initial_context).encode("utf-8") + headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))] + await send({"type": "http.response.start", "status": 200, "headers": headers}) + await send({"type": "http.response.body", "body": body}) + + # body has to be larger than HIGH_WATER_LIMIT to trigger a reading pause on the main thread + # and a resumption inside the ASGI task + large_body = b"a" * (HIGH_WATER_LIMIT + 1) + large_request = b"\r\n".join( + [ + b"POST /large-body HTTP/1.1", + b"Host: example.org", + b"Content-Type: application/octet-stream", + f"Content-Length: {len(large_body)}".encode(), + b"", + large_body, + ] + ) + + async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body: + assert await extract_json_body(large_request) == {} + assert await extract_json_body(SIMPLE_GET_REQUEST) == {} diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index b8cdde3ab..0c3ef1fa5 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import http import logging from typing import Any, Callable, Literal, cast @@ -247,7 +248,12 @@ def handle_events(self) -> None: message_event=asyncio.Event(), on_response=self.on_response_complete, ) - task = self.loop.create_task(self.cycle.run_asgi(app)) + # For the asyncio loop, we need to explicitly start with an empty context + # as it can be polluted from previous ASGI runs. + # See https://github.com/python/cpython/issues/140947 for details. + task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app)) + # TODO: Replace the line above with the line below for Python >= 3.11 + # task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context()) task.add_done_callback(self.tasks.discard) self.tasks.add(task) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index e8795ed35..dba5eedce 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextvars import http import logging import re @@ -287,7 +288,12 @@ def on_headers_complete(self) -> None: ) if existing_cycle is None or existing_cycle.response_complete: # Standard case - start processing the request. - task = self.loop.create_task(self.cycle.run_asgi(app)) + # For the asyncio loop, we need to explicitly start with an empty context + # as it can be polluted from previous ASGI runs. + # See https://github.com/python/cpython/issues/140947 for details. + task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app)) + # TODO: Replace the line above with the line below for Python >= 3.11 + # task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context()) task.add_done_callback(self.tasks.discard) self.tasks.add(task) else: