diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 2ac6f7f7f..619f3e2e0 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -6,8 +6,9 @@ from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette._utils import collapse_excgroups +from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request -from starlette.responses import AsyncContentStream, Response +from starlette.responses import AsyncContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] @@ -192,33 +193,21 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) - raise NotImplementedError() # pragma: no cover -class _StreamingResponse(Response): +class _StreamingResponse(StreamingResponse): def __init__( self, content: AsyncContentStream, status_code: int = 200, headers: typing.Mapping[str, str] | None = None, media_type: str | None = None, + background: BackgroundTask | None = None, info: typing.Mapping[str, typing.Any] | None = None, ) -> None: - self.info = info - self.body_iterator = content - self.status_code = status_code - self.media_type = media_type - self.init_headers(headers) - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if self.info is not None: - await send({"type": "http.response.debug", "info": self.info}) - await send( - { - "type": "http.response.start", - "status": self.status_code, - "headers": self.raw_headers, - } - ) - - async for chunk in self.body_iterator: - await send({"type": "http.response.body", "body": chunk, "more_body": True}) - - await send({"type": "http.response.body", "body": b"", "more_body": False}) + self._info = info + # Disabling early disconnect to allow stacked middleware gracefull termination + super().__init__(content, status_code, headers, media_type, background, early_disconnect=False) + + async def stream_response(self, send: Send) -> None: + if self._info: + await send({"type": "http.response.debug", "info": self._info}) + return await super().stream_response(send) diff --git a/starlette/responses.py b/starlette/responses.py index 06d6ce5ca..654260150 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -215,6 +215,7 @@ def __init__( headers: typing.Mapping[str, str] | None = None, media_type: str | None = None, background: BackgroundTask | None = None, + early_disconnect: bool = True, ) -> None: if isinstance(content, typing.AsyncIterable): self.body_iterator = content @@ -223,6 +224,7 @@ def __init__( self.status_code = status_code self.media_type = self.media_type if media_type is None else media_type self.background = background + self.early_disconnect = early_disconnect self.init_headers(headers) async def listen_for_disconnect(self, receive: Receive) -> None: @@ -247,14 +249,17 @@ async def stream_response(self, send: Send) -> None: await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - async with anyio.create_task_group() as task_group: + if self.early_disconnect: + async with anyio.create_task_group() as task_group: - async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: - await func() - task_group.cancel_scope.cancel() + async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None: + await func() + task_group.cancel_scope.cancel() - task_group.start_soon(wrap, partial(self.stream_response, send)) - await wrap(partial(self.listen_for_disconnect, receive)) + task_group.start_soon(wrap, partial(self.stream_response, send)) + await wrap(partial(self.listen_for_disconnect, receive)) + else: + await self.stream_response(send) if self.background is not None: await self.background() diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 225038650..f38693be9 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1006,16 +1006,23 @@ async def endpoint(request: Request) -> Response: @pytest.mark.anyio async def test_multiple_middlewares_stacked_client_disconnected() -> None: + ordered_events: list[str] = [] + unordered_events: list[str] = [] + class MyMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None: + def __init__(self, app: ASGIApp, version: int) -> None: self.version = version - self.events = events super().__init__(app) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: - self.events.append(f"{self.version}:STARTED") + ordered_events.append(f"{self.version}:STARTED") res = await call_next(request) - self.events.append(f"{self.version}:COMPLETED") + ordered_events.append(f"{self.version}:COMPLETED") + + def background() -> None: + unordered_events.append(f"{self.version}:BACKGROUND") + + res.background = BackgroundTask(background) return res async def sleepy(request: Request) -> Response: @@ -1027,11 +1034,9 @@ async def sleepy(request: Request) -> Response: raise AssertionError("Should have raised ClientDisconnect") return Response(b"") - events: list[str] = [] - app = Starlette( routes=[Route("/", sleepy)], - middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)], + middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)], ) scope = { @@ -1051,7 +1056,7 @@ async def send(message: Message) -> None: await app(scope, receive().__anext__, send) - assert events == [ + assert ordered_events == [ "1:STARTED", "2:STARTED", "3:STARTED", @@ -1074,6 +1079,21 @@ async def send(message: Message) -> None: "1:COMPLETED", ] + assert sorted(unordered_events) == sorted( + [ + "1:BACKGROUND", + "2:BACKGROUND", + "3:BACKGROUND", + "4:BACKGROUND", + "5:BACKGROUND", + "6:BACKGROUND", + "7:BACKGROUND", + "8:BACKGROUND", + "9:BACKGROUND", + "10:BACKGROUND", + ] + ) + assert sent == [ { "type": "http.response.start",