diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py index 16b7b35cb..95e2119de 100644 --- a/starlette/middleware/http.py +++ b/starlette/middleware/http.py @@ -36,7 +36,7 @@ def __init__( self.app = app self._dispatch_func = dispatch - def dispatch(self, conn: HTTPConnection) -> _DispatchFlow: + def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow: raise NotImplementedError # pragma: no cover async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: @@ -52,6 +52,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: maybe_early_response = await flow.__anext__() if maybe_early_response is not None: + try: + await flow.__anext__() + except StopAsyncIteration: + pass + else: + raise RuntimeError("dispatch() should yield exactly once") + await maybe_early_response(scope, receive, send) return @@ -98,6 +105,3 @@ async def wrapped_send(message: Message) -> None: await response(scope, receive, send) return - - if not response_started: - raise RuntimeError("No response returned.") diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py index 5f83d313f..ba58acea7 100644 --- a/tests/middleware/test_http.py +++ b/tests/middleware/test_http.py @@ -1,32 +1,27 @@ -import contextvars -from typing import AsyncGenerator, Optional +from typing import Any, AsyncGenerator, Callable, Optional import pytest from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.http import HTTPMiddleware -from starlette.requests import HTTPConnection +from starlette.requests import HTTPConnection, Request from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Route, WebSocketRoute +from starlette.testclient import TestClient from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket -class CustomMiddleware(HTTPMiddleware): - async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: - response = yield - response.headers["Custom-Header"] = "Example" - - -def homepage(request): +def homepage(request: Request) -> Response: return PlainTextResponse("Homepage") -def exc(request): +def exc(request: Request) -> Response: raise Exception("Exc") -def exc_stream(request): +def exc_stream(request: Request) -> Response: return StreamingResponse(_generate_faulty_stream()) @@ -36,22 +31,28 @@ def _generate_faulty_stream(): class NoResponse: - def __init__(self, scope, receive, send): + def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: pass - def __await__(self): + def __await__(self) -> Any: return self.dispatch().__await__() - async def dispatch(self): + async def dispatch(self) -> None: pass -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket): await session.accept() await session.send_text("Hello, world!") await session.close() +class CustomMiddleware(HTTPMiddleware): + async def dispatch(self, request: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["Custom-Header"] = "Example" + + app = Starlette( routes=[ Route("/", endpoint=homepage), @@ -64,7 +65,9 @@ async def websocket_endpoint(session): ) -def test_custom_middleware(test_client_factory): +def test_custom_middleware( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: client = test_client_factory(app) response = client.get("/") assert response.headers["Custom-Header"] == "Example" @@ -77,7 +80,7 @@ def test_custom_middleware(test_client_factory): response = client.get("/exc-stream") assert str(ctx.value) == "Faulty Stream" - with pytest.raises(RuntimeError): + with pytest.raises(AssertionError): # from TestClient response = client.get("/no-response") with client.websocket_connect("/ws") as session: @@ -85,41 +88,34 @@ def test_custom_middleware(test_client_factory): assert text == "Hello, world!" -def test_state_data_across_multiple_middlewares(test_client_factory): +def test_state_data_across_multiple_middlewares( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: expected_value1 = "foo" expected_value2 = "bar" - class aMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - conn.state.foo = expected_value1 - yield + async def middleware_a(request: HTTPConnection) -> AsyncGenerator[None, None]: + request.state.foo = expected_value1 + yield - class bMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - conn.state.bar = expected_value2 - response = yield - response.headers["X-State-Foo"] = conn.state.foo + async def middleware_b(request: HTTPConnection) -> AsyncGenerator[None, Response]: + request.state.bar = expected_value2 + response = yield + response.headers["X-State-Foo"] = request.state.foo - class cMiddleware(HTTPMiddleware): - async def dispatch( - self, conn: HTTPConnection - ) -> AsyncGenerator[None, Response]: - response = yield - response.headers["X-State-Bar"] = conn.state.bar + async def middleware_c(request: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["X-State-Bar"] = request.state.bar - def homepage(request): + def homepage(request: Request) -> Response: return PlainTextResponse("OK") app = Starlette( routes=[Route("/", homepage)], middleware=[ - Middleware(aMiddleware), - Middleware(bMiddleware), - Middleware(cMiddleware), + Middleware(HTTPMiddleware, dispatch=middleware_a), + Middleware(HTTPMiddleware, dispatch=middleware_b), + Middleware(HTTPMiddleware, dispatch=middleware_c), ], ) @@ -130,8 +126,10 @@ def homepage(request): assert response.headers["X-State-Bar"] == expected_value2 -def test_dispatch_argument(test_client_factory): - def homepage(request): +def test_dispatch_argument( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + def homepage(request: Request): return PlainTextResponse("Homepage") async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: @@ -148,13 +146,8 @@ async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: assert response.headers["Custom-Header"] == "Example" -def test_middleware_repr(): - middleware = Middleware(CustomMiddleware) - assert repr(middleware) == "Middleware(CustomMiddleware)" - - -def test_early_response(test_client_factory): - async def index(request): +def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + async def index(request: Request): return PlainTextResponse("Hello, world!") class CustomMiddleware(HTTPMiddleware): @@ -179,12 +172,12 @@ async def dispatch( assert response.status_code == 401 -def test_too_many_yields(test_client_factory) -> None: +def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: class CustomMiddleware(HTTPMiddleware): async def dispatch( self, conn: HTTPConnection ) -> AsyncGenerator[None, Response]: - _ = yield + yield yield app = Starlette(middleware=[Middleware(CustomMiddleware)]) @@ -194,11 +187,28 @@ async def dispatch( client.get("/") -def test_error_response(test_client_factory) -> None: +def test_too_many_yields_early_response( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + yield Response() + yield None + + app = Starlette(middleware=[Middleware(CustomMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + +def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: class Failed(Exception): pass - async def failure(request): + async def failure(request: Request): raise Failed() class CustomMiddleware(HTTPMiddleware): @@ -221,11 +231,13 @@ async def dispatch( assert response.status_code == 500 -def test_no_error_response(test_client_factory) -> None: +def test_no_error_response( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: class Failed(Exception): pass - async def index(request): + async def index(request: Request): raise Failed() class CustomMiddleware(HTTPMiddleware): @@ -247,43 +259,24 @@ async def dispatch( client.get("/") -ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") - - -class PureASGICustomMiddleware: - def __init__(self, app: ASGIApp) -> None: - self.app = app +def test_modify_content_type( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def dispatch(request: HTTPConnection) -> AsyncGenerator[None, Response]: + resp = yield + resp.media_type = "text/csv" - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - ctxvar.set("set by middleware") - await self.app(scope, receive, send) - assert ctxvar.get() == "set by endpoint" - - -class HTTPCustomMiddleware(HTTPMiddleware): - async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: - ctxvar.set("set by middleware") - yield - assert ctxvar.get() == "set by endpoint" - - -@pytest.mark.parametrize( - "middleware_cls", - [ - PureASGICustomMiddleware, - HTTPCustomMiddleware, - ], -) -def test_contextvars(test_client_factory, middleware_cls: type): - async def homepage(request): - assert ctxvar.get() == "set by middleware" - ctxvar.set("set by endpoint") - return PlainTextResponse("Homepage") + def homepage(request: Request) -> Response: + return PlainTextResponse("OK") app = Starlette( - middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] + routes=[Route("/", homepage)], + middleware=[ + Middleware(HTTPMiddleware, dispatch=dispatch), + ], ) client = test_client_factory(app) response = client.get("/") - assert response.status_code == 200, response.content + assert response.text == "OK" + assert response.headers["Content-Type"] == "text/csv; charset=utf-8"