diff --git a/tests/benchmarks/http.py b/tests/benchmarks/http.py index fb41547ba..39e81fa7a 100644 --- a/tests/benchmarks/http.py +++ b/tests/benchmarks/http.py @@ -131,7 +131,7 @@ def __init__(self) -> None: self._tasks: list[asyncio.Task[Any]] = [] self._later: list[MockTimerHandle] = [] - def create_task(self, coroutine: Any, **kwargs: Any) -> Any: + def create_task(self, coroutine: Any) -> Any: self._tasks.insert(0, coroutine) return MockTask() diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 1ae45f40d..d80fad8b2 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -226,7 +226,7 @@ def __init__(self): self._tasks: list[asyncio.Task[Any]] = [] self._later: list[MockTimerHandle] = [] - def create_task(self, coroutine: Any, **kwargs: Any) -> Any: + def create_task(self, coroutine: Any) -> Any: self._tasks.insert(0, coroutine) return MockTask() diff --git a/tests/test_server.py b/tests/test_server.py index 0492cc33e..46a6e2a34 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,8 +2,6 @@ import asyncio import contextlib -import contextvars -import json import logging import signal import sys @@ -13,14 +11,12 @@ import httpx import pytest -from tests.protocols.test_http import SIMPLE_GET_REQUEST from tests.utils import run_server -from uvicorn import Server -from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope +from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config -from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol +from uvicorn.server import Server pytestmark = pytest.mark.anyio @@ -154,81 +150,3 @@ async def test_limit_max_requests_jitter( await client.get(f"http://127.0.0.1:{unused_tcp_port}") await task assert f"Maximum request limit of {limit} exceeded. Terminating process." in caplog.text - - -@contextlib.asynccontextmanager -async def server(*, app: ASGIApplication, port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol]): - config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls) - server = Server(config=config) - task = asyncio.create_task(server.serve()) - - while not server.started: - await asyncio.sleep(0.01) - - reader, writer = await asyncio.open_connection("127.0.0.1", port) - - async def extract_json_body(request: bytes): - writer.write(request) - await writer.drain() - - status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2] - assert status == b"HTTP/1.1 200 OK" - - content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:")) - return json.loads(await reader.readexactly(content_length)) - - try: - yield extract_json_body - finally: - writer.close() - await writer.wait_closed() - server.should_exit = True - await task - - -async def test_no_contextvars_pollution_asyncio( - http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int -): - """Non-regression test for https://github.com/encode/uvicorn/issues/2167.""" - default_contextvars = {c.name for c in contextvars.copy_context().keys()} - ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx") - - async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): - assert scope["type"] == "http" - - # initial context should be empty - initial_context = { - n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars - } - # set any contextvar before the body is read - ctx.set(scope["path"]) - - while True: - message = await receive() - assert message["type"] == "http.request" - if not message["more_body"]: - break - - # return the initial context for empty assertion - body = json.dumps(initial_context).encode("utf-8") - headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))] - await send({"type": "http.response.start", "status": 200, "headers": headers}) - await send({"type": "http.response.body", "body": body}) - - # body has to be larger than HIGH_WATER_LIMIT to trigger a reading pause on the main thread - # and a resumption inside the ASGI task - large_body = b"a" * (HIGH_WATER_LIMIT + 1) - large_request = b"\r\n".join( - [ - b"POST /large-body HTTP/1.1", - b"Host: example.org", - b"Content-Type: application/octet-stream", - f"Content-Length: {len(large_body)}".encode(), - b"", - large_body, - ] - ) - - async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body: - assert await extract_json_body(large_request) == {} - assert await extract_json_body(SIMPLE_GET_REQUEST) == {} diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 2ad140c8b..9c78e0280 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -1,10 +1,8 @@ from __future__ import annotations import asyncio -import contextvars import http import logging -import sys from collections.abc import Callable from typing import Any, Literal from urllib.parse import unquote @@ -250,13 +248,7 @@ def handle_events(self) -> None: message_event=asyncio.Event(), on_response=self.on_response_complete, ) - # For the asyncio loop, we need to explicitly start with an empty context - # as it can be polluted from previous ASGI runs. - # See https://github.com/python/cpython/issues/140947 for details. - if sys.version_info >= (3, 11): # pragma: py-lt-311 - task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context()) - else: # pragma: py-gte-311 - task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app)) + task = self.loop.create_task(self.cycle.run_asgi(app)) task.add_done_callback(self.tasks.discard) self.tasks.add(task) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index a23c2fb63..99c575b74 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -1,11 +1,9 @@ from __future__ import annotations import asyncio -import contextvars import http import logging import re -import sys import urllib from asyncio.events import TimerHandle from collections import deque @@ -289,13 +287,7 @@ def on_headers_complete(self) -> None: ) if existing_cycle is None or existing_cycle.response_complete: # Standard case - start processing the request. - # For the asyncio loop, we need to explicitly start with an empty context - # as it can be polluted from previous ASGI runs. - # See https://github.com/python/cpython/issues/140947 for details. - if sys.version_info >= (3, 11): # pragma: py-lt-311 - task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context()) - else: # pragma: py-gte-311 - task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app)) + task = self.loop.create_task(self.cycle.run_asgi(app)) task.add_done_callback(self.tasks.discard) self.tasks.add(task) else: