diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 49a5e3e2d..2f9292c5b 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -24,9 +24,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + call_next_response = None + send_stream, recv_stream = anyio.create_memory_object_stream() + async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None - send_stream, recv_stream = anyio.create_memory_object_stream() async def coro() -> None: nonlocal app_exc @@ -61,17 +63,22 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: if app_exc is not None: raise app_exc - response = StreamingResponse( + nonlocal call_next_response + + call_next_response = StreamingResponse( status_code=message["status"], content=body_stream() ) - response.raw_headers = message["headers"] - return response + call_next_response.raw_headers = message["headers"] + return call_next_response async with anyio.create_task_group() as task_group: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) + if call_next_response and response is not call_next_response: + async with recv_stream: + async for _ in recv_stream: + ... # pragma: no cover await response(scope, receive, send) - task_group.cancel_scope.cancel() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint