diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index bfb4a54a4..c1cdcc8ba 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -4,7 +4,7 @@ from starlette.requests import Request from starlette.responses import Response, StreamingResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ @@ -29,9 +29,18 @@ async def call_next(request: Request) -> Response: async def coro() -> None: nonlocal app_exc + app_receive = request.receive + + if receive_called: + # middleware consumed the request body + async def error_receive() -> Message: + raise RuntimeError("Receive stream already consumed") + + app_receive = error_receive + async with send_stream: try: - await self.app(scope, request.receive, send_stream.send) + await self.app(scope, app_receive, send_stream.send) except Exception as exc: app_exc = exc @@ -61,8 +70,15 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response + receive_called = False + + async def tracked_receive() -> Message: + nonlocal receive_called + receive_called = True + return await receive() + async with anyio.create_task_group() as task_group: - request = Request(scope, receive=receive) + request = Request(scope, receive=tracked_receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) task_group.cancel_scope.cancel() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 04da3a961..5b1d42667 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -3,7 +3,8 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import PlainTextResponse, StreamingResponse +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response, StreamingResponse from starlette.routing import Mount, Route, WebSocketRoute @@ -163,3 +164,23 @@ def test_exception_on_mounted_apps(test_client_factory): with pytest.raises(Exception) as ctx: client.get("/sub/") assert str(ctx.value) == "Exc" + + +def test_stream_consumed_in_middleware(test_client_factory) -> None: + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + await request.body() + response = await call_next(request) + return response # pragma: no cover + + async def endpoint(request: Request) -> Response: + await request.body() + return Response() # pragma: no cover + + app = Starlette( + middleware=[Middleware(CustomMiddleware)], routes=[Route("/", endpoint)] + ) + + client = test_client_factory(app) + with pytest.raises(RuntimeError, match="Receive stream already consumed"): + client.get("/")