Skip to content
86 changes: 84 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import asyncio
import contextlib
import contextvars
import json
import logging
import signal
import sys
Expand All @@ -12,12 +14,14 @@
import httpx
import pytest

from tests.protocols.test_http import SIMPLE_GET_REQUEST
from tests.utils import run_server
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn import Server
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.server import Server

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -95,3 +99,81 @@ async def test_request_than_limit_max_requests_warn_log(
responses = await asyncio.gather(*tasks)
assert len(responses) == 2
assert "Maximum request limit of 1 exceeded. Terminating process." in caplog.text


@contextlib.asynccontextmanager
async def server(*, app: ASGIApplication, port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol]):
config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls)
server = Server(config=config)
task = asyncio.create_task(server.serve())

while not server.started:
await asyncio.sleep(0.01)

reader, writer = await asyncio.open_connection("127.0.0.1", port)

async def extract_json_body(request: bytes):
writer.write(request)
await writer.drain()

status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2]
assert status == b"HTTP/1.1 200 OK"

content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:"))
return json.loads(await reader.readexactly(content_length))

try:
yield extract_json_body
finally:
writer.close()
await writer.wait_closed()
server.should_exit = True
await task


async def test_no_contextvars_pollution_asyncio(
http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int
):
"""Non-regression test for https://github.com/encode/uvicorn/issues/2167."""
default_contextvars = {c.name for c in contextvars.copy_context().keys()}
ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx")

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"

# initial context should be empty
initial_context = {
n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars
}
# set any contextvar before the body is read
ctx.set(scope["path"])

while True:
message = await receive()
assert message["type"] == "http.request"
if not message["more_body"]:
break

# return the initial context for empty assertion
body = json.dumps(initial_context).encode("utf-8")
headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))]
await send({"type": "http.response.start", "status": 200, "headers": headers})
await send({"type": "http.response.body", "body": body})

# body has to be larger than HIGH_WATER_LIMIT to trigger a reading pause on the main thread
# and a resumption inside the ASGI task
large_body = b"a" * (HIGH_WATER_LIMIT + 1)
large_request = b"\r\n".join(
[
b"POST /large-body HTTP/1.1",
b"Host: example.org",
b"Content-Type: application/octet-stream",
f"Content-Length: {len(large_body)}".encode(),
b"",
large_body,
]
)

async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body:
assert await extract_json_body(large_request) == {}
assert await extract_json_body(SIMPLE_GET_REQUEST) == {}
8 changes: 7 additions & 1 deletion uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextvars
import http
import logging
from typing import Any, Callable, Literal, cast
Expand Down Expand Up @@ -247,7 +248,12 @@ def handle_events(self) -> None:
message_event=asyncio.Event(),
on_response=self.on_response_complete,
)
task = self.loop.create_task(self.cycle.run_asgi(app))
# For the asyncio loop, we need to explicitly start with an empty context
# as it can be polluted from previous ASGI runs.
# See https://github.com/python/cpython/issues/140947 for details.
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
# TODO: Replace the line above with the line below for Python >= 3.11
# task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
Comment on lines +251 to +256
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the lifespan is a sibling task, but that's not the case in Hypercorn. It may be the case in the future that we refactor the server to make the lifespan task a parent of the whole process instead of a sibling task. Which means that the context would need to come from there. We still don't want the context to leak between sibling tasks as it's currently happening in asyncio...

Since this is going to be fixed in Python 3.15, can we have a note about it? I would like to revert the context= when it lands, or at least remember that we can do it if we decide to make the lifespan task a parent from this one.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of python/cpython#141158, but since this is not merged or ready to be merged, do we already want to assume that 3.15 will fix this?

task.add_done_callback(self.tasks.discard)
self.tasks.add(task)

Expand Down
8 changes: 7 additions & 1 deletion uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextvars
import http
import logging
import re
Expand Down Expand Up @@ -287,7 +288,12 @@ def on_headers_complete(self) -> None:
)
if existing_cycle is None or existing_cycle.response_complete:
# Standard case - start processing the request.
task = self.loop.create_task(self.cycle.run_asgi(app))
# For the asyncio loop, we need to explicitly start with an empty context
# as it can be polluted from previous ASGI runs.
# See https://github.com/python/cpython/issues/140947 for details.
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
# TODO: Replace the line above with the line below for Python >= 3.11
# task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
else:
Expand Down