Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/benchmarks/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __init__(self) -> None:
self._tasks: list[asyncio.Task[Any]] = []
self._later: list[MockTimerHandle] = []

def create_task(self, coroutine: Any, **kwargs: Any) -> Any:
def create_task(self, coroutine: Any) -> Any:
self._tasks.insert(0, coroutine)
return MockTask()

Expand Down
2 changes: 1 addition & 1 deletion tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def __init__(self):
self._tasks: list[asyncio.Task[Any]] = []
self._later: list[MockTimerHandle] = []

def create_task(self, coroutine: Any, **kwargs: Any) -> Any:
def create_task(self, coroutine: Any) -> Any:
self._tasks.insert(0, coroutine)
return MockTask()

Expand Down
86 changes: 2 additions & 84 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import asyncio
import contextlib
import contextvars
import json
import logging
import signal
import sys
Expand All @@ -13,14 +11,12 @@
import httpx
import pytest

from tests.protocols.test_http import SIMPLE_GET_REQUEST
from tests.utils import run_server
from uvicorn import Server
from uvicorn._types import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
from uvicorn.config import Config
from uvicorn.protocols.http.flow_control import HIGH_WATER_LIMIT
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
from uvicorn.server import Server

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -154,81 +150,3 @@ async def test_limit_max_requests_jitter(
await client.get(f"http://127.0.0.1:{unused_tcp_port}")
await task
assert f"Maximum request limit of {limit} exceeded. Terminating process." in caplog.text


@contextlib.asynccontextmanager
async def server(*, app: ASGIApplication, port: int, http_protocol_cls: type[H11Protocol | HttpToolsProtocol]):
config = Config(app=app, port=port, loop="asyncio", http=http_protocol_cls)
server = Server(config=config)
task = asyncio.create_task(server.serve())

while not server.started:
await asyncio.sleep(0.01)

reader, writer = await asyncio.open_connection("127.0.0.1", port)

async def extract_json_body(request: bytes):
writer.write(request)
await writer.drain()

status, *headers = (await reader.readuntil(b"\r\n\r\n")).split(b"\r\n")[:-2]
assert status == b"HTTP/1.1 200 OK"

content_length = next(int(h.split(b":", 1)[1]) for h in headers if h.lower().startswith(b"content-length:"))
return json.loads(await reader.readexactly(content_length))

try:
yield extract_json_body
finally:
writer.close()
await writer.wait_closed()
server.should_exit = True
await task


async def test_no_contextvars_pollution_asyncio(
http_protocol_cls: type[H11Protocol | HttpToolsProtocol], unused_tcp_port: int
):
"""Non-regression test for https://github.com/encode/uvicorn/issues/2167."""
default_contextvars = {c.name for c in contextvars.copy_context().keys()}
ctx: contextvars.ContextVar[str] = contextvars.ContextVar("ctx")

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"

# initial context should be empty
initial_context = {
n: v for c, v in contextvars.copy_context().items() if (n := c.name) not in default_contextvars
}
# set any contextvar before the body is read
ctx.set(scope["path"])

while True:
message = await receive()
assert message["type"] == "http.request"
if not message["more_body"]:
break

# return the initial context for empty assertion
body = json.dumps(initial_context).encode("utf-8")
headers = [(b"content-type", b"application/json"), (b"content-length", str(len(body)).encode("utf-8"))]
await send({"type": "http.response.start", "status": 200, "headers": headers})
await send({"type": "http.response.body", "body": body})

# body has to be larger than HIGH_WATER_LIMIT to trigger a reading pause on the main thread
# and a resumption inside the ASGI task
large_body = b"a" * (HIGH_WATER_LIMIT + 1)
large_request = b"\r\n".join(
[
b"POST /large-body HTTP/1.1",
b"Host: example.org",
b"Content-Type: application/octet-stream",
f"Content-Length: {len(large_body)}".encode(),
b"",
large_body,
]
)

async with server(app=app, http_protocol_cls=http_protocol_cls, port=unused_tcp_port) as extract_json_body:
assert await extract_json_body(large_request) == {}
assert await extract_json_body(SIMPLE_GET_REQUEST) == {}
10 changes: 1 addition & 9 deletions uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from __future__ import annotations

import asyncio
import contextvars
import http
import logging
import sys
from collections.abc import Callable
from typing import Any, Literal
from urllib.parse import unquote
Expand Down Expand Up @@ -250,13 +248,7 @@ def handle_events(self) -> None:
message_event=asyncio.Event(),
on_response=self.on_response_complete,
)
# For the asyncio loop, we need to explicitly start with an empty context
# as it can be polluted from previous ASGI runs.
# See https://github.com/python/cpython/issues/140947 for details.
if sys.version_info >= (3, 11): # pragma: py-lt-311
task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
else: # pragma: py-gte-311
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
task = self.loop.create_task(self.cycle.run_asgi(app))
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)

Expand Down
10 changes: 1 addition & 9 deletions uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import asyncio
import contextvars
import http
import logging
import re
import sys
import urllib
from asyncio.events import TimerHandle
from collections import deque
Expand Down Expand Up @@ -289,13 +287,7 @@ def on_headers_complete(self) -> None:
)
if existing_cycle is None or existing_cycle.response_complete:
# Standard case - start processing the request.
# For the asyncio loop, we need to explicitly start with an empty context
# as it can be polluted from previous ASGI runs.
# See https://github.com/python/cpython/issues/140947 for details.
if sys.version_info >= (3, 11): # pragma: py-lt-311
task = self.loop.create_task(self.cycle.run_asgi(app), context=contextvars.Context())
else: # pragma: py-gte-311
task = contextvars.Context().run(self.loop.create_task, self.cycle.run_asgi(app))
task = self.loop.create_task(self.cycle.run_asgi(app))
task.add_done_callback(self.tasks.discard)
self.tasks.add(task)
else:
Expand Down
Loading