Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import contextvars

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send


class CustomMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -163,3 +166,58 @@ def test_exception_on_mounted_apps(test_client_factory):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


class CustomMiddlewareWithoutBaseHTTPMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"


class CustomMiddlewareUsingBaseHTTPMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
ctxvar.set("set by middleware")
resp = await call_next(request)
assert ctxvar.get() == "set by endpoint"
return resp # pragma: no cover


@pytest.mark.parametrize(
"middleware_cls",
[
CustomMiddlewareWithoutBaseHTTPMiddleware,
pytest.param(
CustomMiddlewareUsingBaseHTTPMiddleware,
marks=pytest.mark.xfail(
reason=(
"BaseHTTPMiddleware creates a TaskGroup which copies the context"
"and erases any changes to it made within the TaskGroup"
),
raises=AssertionError,
),
),
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
# this has to be an async endpoint because Starlette calls run_in_threadpool
# on sync endpoints which has it's own set of peculiarities w.r.t propagating
# contextvars (it propagates them forwards but not backwards)
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")

app = Starlette(
middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content