diff --git a/starlette/_compat.py b/starlette/_compat.py index 116561917..760431745 100644 --- a/starlette/_compat.py +++ b/starlette/_compat.py @@ -1,4 +1,11 @@ import hashlib +import sys +from typing import Any, AsyncContextManager + +__all__ = [ + "md5_hexdigest", + "aclosing", +] # Compat wrapper to always include the `usedforsecurity=...` parameter, # which is only added from Python 3.9 onwards. @@ -27,3 +34,18 @@ def md5_hexdigest( def md5_hexdigest(data: bytes, *, usedforsecurity: bool = True) -> str: return hashlib.md5(data).hexdigest() + + +if sys.version_info >= (3, 10): # pragma: no cover + from contextlib import aclosing +else: # pragma: no cover + + class aclosing(AsyncContextManager): + def __init__(self, thing: Any) -> None: + self.thing = thing + + async def __aenter__(self) -> Any: + return self.thing + + async def __aexit__(self, *exc_info: Any) -> None: + await self.thing.aclose() diff --git a/starlette/middleware/http.py b/starlette/middleware/http.py new file mode 100644 index 000000000..4341458cc --- /dev/null +++ b/starlette/middleware/http.py @@ -0,0 +1,170 @@ +from typing import Any, AsyncGenerator, Callable, Optional, Union + +from .._compat import aclosing +from ..datastructures import Headers +from ..requests import HTTPConnection +from ..responses import Response +from ..types import ASGIApp, Message, Receive, Scope, Send + +# This type hint not exposed, as it exists mostly for our own documentation purposes. +# End users should use one of these type hints explicitly when overriding '.dispatch()'. +_DispatchFlow = Union[ + # Default case: + # response = yield + AsyncGenerator[None, Response], + # Early response and/or error handling: + # if condition: + # yield Response(...) + # return + # try: + # response = yield None + # except Exception: + # yield Response(...) + # else: + # ... + AsyncGenerator[Optional[Response], Response], +] + + +class HTTPMiddleware: + def __init__( + self, + app: ASGIApp, + dispatch: Optional[Callable[[HTTPConnection], _DispatchFlow]] = None, + ) -> None: + self.app = app + self._dispatch_func = self.dispatch if dispatch is None else dispatch + + def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow: + raise NotImplementedError( + "No dispatch implementation was given. " + "Either pass 'dispatch=...' to HTTPMiddleware, " + "or subclass HTTPMiddleware and override the 'dispatch()' method." + ) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + conn = HTTPConnection(scope) + + async with aclosing(self._dispatch_func(conn)) as flow: + # Kick the flow until the first `yield`. + # Might respond early before we call into the app. + maybe_early_response = await flow.__anext__() + + if maybe_early_response is not None: + try: + await flow.__anext__() + except StopAsyncIteration: + pass + else: + raise RuntimeError("dispatch() should yield exactly once") + + await maybe_early_response(scope, receive, send) + return + + response_started = False + + async def wrapped_send(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + + headers = Headers(raw=message["headers"]) + response = _StubResponse( + status_code=message["status"], + media_type=headers.get("content-type"), + ) + response.raw_headers = headers.raw + + try: + await flow.asend(response) + except StopAsyncIteration: + pass + else: + raise RuntimeError("dispatch() should yield exactly once") + + message["headers"] = response.raw_headers + + await send(message) + + try: + await self.app(scope, receive, wrapped_send) + except Exception as exc: + if response_started: + raise + + try: + response = await flow.athrow(exc) + except StopAsyncIteration: + response = None + except Exception: + # Exception was not handled, or they raised another one. + raise + + if response is None: + raise RuntimeError( + f"dispatch() handled exception {exc!r}, " + "but no response was returned" + ) + + await response(scope, receive, send) + return + + +# This customized stub response helps prevent users from shooting themselves +# in the foot, doing things that don't actually have any effect. + + +class _StubResponse(Response): + def __init__(self, status_code: int, media_type: Optional[str] = None) -> None: + self._status_code = status_code + self._media_type = media_type + self.raw_headers = [] + + @property # type: ignore + def status_code(self) -> int: # type: ignore + return self._status_code + + @status_code.setter + def status_code(self, value: Any) -> None: + raise RuntimeError( + "Setting .status_code in HTTPMiddleware is not supported. " + "If you're writing middleware that requires modifying the response " + "status code or sending another response altogether, please consider " + "writing pure ASGI middleware. " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) + + @property # type: ignore + def media_type(self) -> Optional[str]: # type: ignore + return self._media_type + + @media_type.setter + def media_type(self, value: Any) -> None: + raise RuntimeError( + "Setting .media_type in HTTPMiddleware is not supported, as it has " + "no effect. If you do need to tweak the response " + "content type, consider: response.headers['Content-Type'] = ..." + ) + + @property # type: ignore + def body(self) -> bytes: # type: ignore + raise RuntimeError( + "Accessing the response body in HTTPMiddleware is not supported. " + "If you're writing middleware that requires peeking into the response " + "body, please consider writing pure ASGI middleware and wrapping send(). " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) + + @body.setter + def body(self, body: bytes) -> None: + raise RuntimeError( + "Setting the response body in HTTPMiddleware is not supported." + "If you're writing middleware that requires modifying the response " + "body, please consider writing pure ASGI middleware and wrapping send(). " + "See: https://starlette.io/middleware/#pure-asgi-middleware" + ) diff --git a/tests/middleware/test_http.py b/tests/middleware/test_http.py new file mode 100644 index 000000000..081625030 --- /dev/null +++ b/tests/middleware/test_http.py @@ -0,0 +1,288 @@ +from typing import AsyncGenerator, Callable, Iterator, Optional + +import pytest + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.http import HTTPMiddleware +from starlette.requests import HTTPConnection, Request +from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.routing import Route, WebSocketRoute +from starlette.testclient import TestClient +from starlette.types import ASGIApp +from starlette.websockets import WebSocket + + +async def homepage(request: Request) -> Response: + return PlainTextResponse("Homepage") + + +async def exc(request: Request) -> Response: + raise Exception("Exc") + + +async def exc_stream(request: Request) -> Response: + return StreamingResponse(_generate_faulty_stream()) + + +def _generate_faulty_stream() -> Iterator[bytes]: + yield b"Ok" + raise Exception("Faulty Stream") + + +async def websocket_endpoint(session: WebSocket) -> None: + await session.accept() + await session.send_text("Hello, world!") + await session.close() + + +class CustomMiddleware(HTTPMiddleware): + async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["Custom-Header"] = "Example" + + +app = Starlette( + routes=[ + Route("/", endpoint=homepage), + Route("/exc", endpoint=exc), + Route("/exc-stream", endpoint=exc_stream), + WebSocketRoute("/ws", endpoint=websocket_endpoint), + ], + middleware=[Middleware(CustomMiddleware)], +) + + +def test_custom_middleware( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + client = test_client_factory(app) + response = client.get("/") + assert response.headers["Custom-Header"] == "Example" + + with pytest.raises(Exception) as ctx: + response = client.get("/exc") + assert str(ctx.value) == "Exc" + + with pytest.raises(Exception) as ctx: + response = client.get("/exc-stream") + assert str(ctx.value) == "Faulty Stream" + + with client.websocket_connect("/ws") as session: + text = session.receive_text() + assert text == "Hello, world!" + + +def test_state_data_across_multiple_middlewares( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") + + expected_value1 = "foo" + expected_value2 = "bar" + + async def middleware_a(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + conn.state.foo = expected_value1 + yield + + async def middleware_b(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + conn.state.bar = expected_value2 + response = yield + response.headers["X-State-Foo"] = conn.state.foo + + async def middleware_c(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + response.headers["X-State-Bar"] = conn.state.bar + + app = Starlette( + routes=[Route("/", homepage)], + middleware=[ + Middleware(HTTPMiddleware, dispatch=middleware_a), + Middleware(HTTPMiddleware, dispatch=middleware_b), + Middleware(HTTPMiddleware, dispatch=middleware_c), + ], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "OK" + assert response.headers["X-State-Foo"] == expected_value1 + assert response.headers["X-State-Bar"] == expected_value2 + + +def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: + _ = yield + yield + + app = Starlette(middleware=[Middleware(CustomMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + +def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + if conn.headers.get("X-Early") == "true": + yield Response(status_code=401) + else: + yield None + + app = Starlette( + routes=[Route("/", homepage)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/") + assert response.status_code == 200 + assert response.text == "OK" + response = client.get("/", headers={"X-Early": "true"}) + assert response.status_code == 401 + + +def test_early_response_too_many_yields( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + yield Response() + yield None + + app = Starlette(middleware=[Middleware(CustomMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="should yield exactly once"): + client.get("/") + + +def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: + class Failed(Exception): + pass + + async def failure(request: Request) -> Response: + raise Failed() + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[Optional[Response], Response]: + try: + yield None + except Failed: + yield Response("Failed", status_code=500) + + app = Starlette( + routes=[Route("/fail", failure)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/fail") + assert response.text == "Failed" + assert response.status_code == 500 + + +def test_error_handling_must_send_response( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + class Failed(Exception): + pass + + async def failure(request: Request) -> Response: + raise Failed() + + class CustomMiddleware(HTTPMiddleware): + async def dispatch( + self, conn: HTTPConnection + ) -> AsyncGenerator[None, Response]: + try: + yield + except Failed: + pass # `yield ` expected + + app = Starlette( + routes=[Route("/fail", failure)], + middleware=[Middleware(CustomMiddleware)], + ) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="no response was returned"): + client.get("/fail") + + +def test_no_dispatch_given( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + app = Starlette(middleware=[Middleware(HTTPMiddleware)]) + + client = test_client_factory(app) + with pytest.raises(NotImplementedError, match="No dispatch implementation"): + client.get("/") + + +def test_response_stub_attributes( + test_client_factory: Callable[[ASGIApp], TestClient] +) -> None: + async def homepage(request: Request) -> Response: + return PlainTextResponse("OK") + + async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: + response = yield + if conn.url.path == "/status_code": + assert response.status_code == 200 + response.status_code = 401 + if conn.url.path == "/media_type": + assert response.media_type == "text/plain; charset=utf-8" + response.media_type = "text/csv" + if conn.url.path == "/body-get": + response.body + if conn.url.path == "/body-set": + response.body = b"changed" + + app = Starlette( + routes=[ + Route("/status_code", homepage), + Route("/media_type", homepage), + Route("/body-get", homepage), + Route("/body-set", homepage), + ], + middleware=[Middleware(HTTPMiddleware, dispatch=dispatch)], + ) + + client = test_client_factory(app) + + with pytest.raises( + RuntimeError, match="Setting .status_code in HTTPMiddleware is not supported." + ): + client.get("/status_code") + + with pytest.raises( + RuntimeError, match="Setting .media_type in HTTPMiddleware is not supported" + ): + client.get("/media_type") + + with pytest.raises( + RuntimeError, + match="Accessing the response body in HTTPMiddleware is not supported", + ): + client.get("/body-get") + + with pytest.raises( + RuntimeError, + match="Setting the response body in HTTPMiddleware is not supported", + ): + client.get("/body-set")