diff --git a/docs/settings.md b/docs/settings.md index 8c9ab8e88..40dd66033 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -39,6 +39,7 @@ uvicorn itself. * `APP` - The ASGI application to run, in the format `":"`. * `--factory` - Treat `APP` as an application factory, i.e. a `() -> ` callable. * `--app-dir ` - Look for APP in the specified directory by adding it to the PYTHONPATH. **Default:** *Current working directory*. +* `--reset-contextvars` - Run each ASGI request in a fresh `contextvars.Context`. Workaround for a [context leak in asyncio](https://github.com/python/cpython/issues/140947); only relevant when using the `asyncio` event loop (uvloop is not affected). Enabling this hides any context set in the lifespan or by external instrumentation from ASGI handlers. **Default:** *False*. ## Socket Binding diff --git a/tests/test_server.py b/tests/test_server.py index 46a6e2a34..f57c1f236 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -2,6 +2,8 @@ import asyncio import contextlib +import contextvars +import json import logging import signal import sys @@ -11,9 +13,11 @@ import httpx import pytest +from tests.protocols.test_http import SIMPLE_GET_REQUEST from tests.utils import run_server -from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope +from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config +from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.server import Server @@ -150,3 +154,114 @@ async def test_limit_max_requests_jitter( await client.get(f"http://127.0.0.1:{unused_tcp_port}") await task assert f"Maximum request limit of {limit} exceeded. Terminating process." in caplog.text + + +@contextlib.asynccontextmanager +async def _raw_server( + *, + app: ASGIApplication, + port: int, + http_protocol_cls: type[H11Protocol | HttpToolsProtocol], + reset_contextvars: bool = False, +): + config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls, reset_contextvars=reset_contextvars) + server = Server(config=config) + task = asyncio.create_task(server.serve()) + + while not server.started: + await asyncio.sleep(0.01) + + reader, writer = await asyncio.open_connection("127.0.0.1", port) + + async def extract_json_body(request: bytes): + writer.write(request) + await writer.drain() + + status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2] + assert status == b"HTTP/1.1 200 OK" + + content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:")) + return json.loads(await reader.readexactly(content_length)) + + try: + yield extract_json_body + finally: + writer.close() + await writer.wait_closed() + server.should_exit = True + await task + + +async def test_contextvars_preserved_by_default( + http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int +): + """By default, context set outside the ASGI task is visible inside it.""" + ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx") + ctx.set("outer-value") + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" + while True: + message = await receive() + assert message["type"] == "http.request" + if not message["more_body"]: + break + body = json.dumps({"ctx": ctx.get("MISSING")}).encode("utf-8") + headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))] + await send({"type": "http.response.start", "status": 200, "headers": headers}) + await send({"type": "http.response.body", "body": body}) + + async with _raw_server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body: + assert await extract_json_body(SIMPLE_GET_REQUEST) == {"ctx": "outer-value"} + + +async def test_reset_contextvars_asyncio( + http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int +): + """With reset_contextvars=True, each ASGI run starts with a fresh context. + + Non-regression test for https://github.com/encode/uvicorn/issues/2167. + """ + default_contextvars = {c.name for c in contextvars.copy_context().keys()} + ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx") + + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): + assert scope["type"] == "http" + + # initial context should be empty + initial_context = { + n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars + } + # set any contextvar before the body is read + ctx.set(scope["path"]) + + while True: + message = await receive() + assert message["type"] == "http.request" + if not message["more_body"]: + break + + body = json.dumps(initial_context).encode("utf-8") + headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))] + await send({"type": "http.response.start", "status": 200, "headers": headers}) + await send({"type": "http.response.body", "body": body}) + + # body larger than HIGH_WATER_LIMIT forces a reading pause on the main thread + # and a resumption inside the ASGI task, which is where the original pollution showed up. + large_body = b"a" * (HIGH_WATER_LIMIT + 1) + large_request = b"\r\n".join( + [ + b"POST /large-body HTTP/1.1", + b"Host: example.org", + b"Content-Type: application/octet-stream", + f"Content-Length: {len(large_body)}".encode(), + b"", + large_body, + ] + ) + + async with _raw_server( + app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port, reset_contextvars=True + ) as extract_json_body: + assert await extract_json_body(large_request) == {} + assert await extract_json_body(SIMPLE_GET_REQUEST) == {} diff --git a/uvicorn/config.py b/uvicorn/config.py index 2a5b089d1..a87db7b4c 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -228,6 +228,7 @@ def __init__( headers: list[tuple[str, str]] | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + reset_contextvars: bool = False, ): self.app = app self.host = host @@ -275,6 +276,7 @@ def __init__( self.encoded_headers: list[tuple[bytes, bytes]] = [] self.factory = factory self.h11_max_incomplete_event_size = h11_max_incomplete_event_size + self.reset_contextvars = reset_contextvars self.loaded = False self.configure_logging() diff --git a/uvicorn/main.py b/uvicorn/main.py index 951f97af0..303727344 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -372,6 +372,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No default=None, help="For h11, the maximum number of bytes to buffer of an incomplete event.", ) +@click.option( + "--reset-contextvars", + is_flag=True, + default=False, + help="Run each ASGI request in a fresh contextvars.Context. Hides context set in the lifespan.", + show_default=True, +) @click.option( "--factory", is_flag=True, @@ -428,6 +435,7 @@ def main( use_colors: bool, app_dir: str, h11_max_incomplete_event_size: int | None, + reset_contextvars: bool, factory: bool, ) -> None: run( @@ -480,6 +488,7 @@ def main( factory=factory, app_dir=app_dir, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + reset_contextvars=reset_contextvars, ) @@ -534,6 +543,7 @@ def run( app_dir: str | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + reset_contextvars: bool = False, ) -> None: if app_dir is not None: sys.path.insert(0, app_dir) @@ -587,6 +597,7 @@ def run( use_colors=use_colors, factory=factory, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + reset_contextvars=reset_contextvars, ) server = Server(config=config) diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index 9c78e0280..3f660cae5 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import contextvars import http import logging +import sys from collections.abc import Callable from typing import Any, Literal from urllib.parse import unquote @@ -248,7 +250,16 @@ def handle_events(self) -> None: message_event=asyncio.Event(), on_response=self.on_response_complete, ) - task = self.loop.create_task(self.cycle.run_asgi(app)) + if self.config.reset_contextvars: + # Opt-in workaround for https://github.com/python/cpython/issues/140947: + # asyncio can leak context vars between tasks. Hides context set in the + # lifespan or by external instrumentation. + if sys.version_info >= (3, 11): # pragma: py-lt-311 + task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context()) + else: # pragma: py-gte-311 + task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app)) + else: + task = self.loop.create_task(self.cycle.run_asgi(app)) task.add_done_callback(self.tasks.discard) self.tasks.add(task) diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 99c575b74..b8eb93958 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -1,9 +1,11 @@ from __future__ import annotations import asyncio +import contextvars import http import logging import re +import sys import urllib from asyncio.events import TimerHandle from collections import deque @@ -287,14 +289,26 @@ def on_headers_complete(self) -> None: ) if existing_cycle is None or existing_cycle.response_complete: # Standard case - start processing the request. - task = self.loop.create_task(self.cycle.run_asgi(app)) - task.add_done_callback(self.tasks.discard) - self.tasks.add(task) + self._start_asgi_task(self.cycle, app) else: # Pipelined HTTP requests need to be queued up. self.flow.pause_reading() self.pipeline.appendleft((self.cycle, app)) + def _start_asgi_task(self, cycle: RequestResponseCycle, app: ASGI3Application) -> None: + if self.config.reset_contextvars: + # Opt-in workaround for https://github.com/python/cpython/issues/140947: + # asyncio can leak context vars between tasks. Hides context set in the + # lifespan or by external instrumentation. + if sys.version_info >= (3, 11): # pragma: py-lt-311 + task = self.loop.create_task(cycle.run_asgi(app), context=contextvars.Context()) + else: # pragma: py-gte-311 + task = contextvars.Context().run(self.loop.create_task, cycle.run_asgi(app)) + else: + task = self.loop.create_task(cycle.run_asgi(app)) + task.add_done_callback(self.tasks.discard) + self.tasks.add(task) + def on_body(self, body: bytes) -> None: if (self.parser.should_upgrade() and self._should_upgrade()) or self.cycle.response_complete: return @@ -325,9 +339,7 @@ def on_response_complete(self) -> None: # Keep-Alive timeout instead. if self.pipeline: cycle, app = self.pipeline.pop() - task = self.loop.create_task(cycle.run_asgi(app)) - task.add_done_callback(self.tasks.discard) - self.tasks.add(task) + self._start_asgi_task(cycle, app) else: self.timeout_keep_alive_task = self.loop.call_later( self.timeout_keep_alive, self.timeout_keep_alive_handler