Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import collapse_excgroups
from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.responses import AsyncContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -192,33 +193,21 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(Response):
class _StreamingResponse(StreamingResponse):
def __init__(
self,
content: AsyncContentStream,
status_code: int = 200,
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
info: typing.Mapping[str, typing.Any] | None = None,
) -> None:
self.info = info
self.body_iterator = content
self.status_code = status_code
self.media_type = media_type
self.init_headers(headers)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.info is not None:
await send({"type": "http.response.debug", "info": self.info})
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)

async for chunk in self.body_iterator:
await send({"type": "http.response.body", "body": chunk, "more_body": True})

await send({"type": "http.response.body", "body": b"", "more_body": False})
self._info = info
# Disabling early disconnect to allow stacked middleware gracefull termination
super().__init__(content, status_code, headers, media_type, background, early_disconnect=False)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
17 changes: 11 additions & 6 deletions starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
headers: typing.Mapping[str, str] | None = None,
media_type: str | None = None,
background: BackgroundTask | None = None,
early_disconnect: bool = True,
) -> None:
if isinstance(content, typing.AsyncIterable):
self.body_iterator = content
Expand All @@ -223,6 +224,7 @@ def __init__(
self.status_code = status_code
self.media_type = self.media_type if media_type is None else media_type
self.background = background
self.early_disconnect = early_disconnect
self.init_headers(headers)

async def listen_for_disconnect(self, receive: Receive) -> None:
Expand All @@ -247,14 +249,17 @@ async def stream_response(self, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async with anyio.create_task_group() as task_group:
if self.early_disconnect:
async with anyio.create_task_group() as task_group:

async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()
async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
task_group.cancel_scope.cancel()

task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
task_group.start_soon(wrap, partial(self.stream_response, send))
await wrap(partial(self.listen_for_disconnect, receive))
else:
await self.stream_response(send)

if self.background is not None:
await self.background()
Expand Down
36 changes: 28 additions & 8 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,16 +1006,23 @@ async def endpoint(request: Request) -> Response:

@pytest.mark.anyio
async def test_multiple_middlewares_stacked_client_disconnected() -> None:
ordered_events: list[str] = []
unordered_events: list[str] = []

class MyMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, version: int, events: list[str]) -> None:
def __init__(self, app: ASGIApp, version: int) -> None:
self.version = version
self.events = events
super().__init__(app)

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
self.events.append(f"{self.version}:STARTED")
ordered_events.append(f"{self.version}:STARTED")
res = await call_next(request)
self.events.append(f"{self.version}:COMPLETED")
ordered_events.append(f"{self.version}:COMPLETED")

def background() -> None:
unordered_events.append(f"{self.version}:BACKGROUND")

res.background = BackgroundTask(background)
return res

async def sleepy(request: Request) -> Response:
Expand All @@ -1027,11 +1034,9 @@ async def sleepy(request: Request) -> Response:
raise AssertionError("Should have raised ClientDisconnect")
return Response(b"")

events: list[str] = []

app = Starlette(
routes=[Route("/", sleepy)],
middleware=[Middleware(MyMiddleware, version=_ + 1, events=events) for _ in range(10)],
middleware=[Middleware(MyMiddleware, version=_ + 1) for _ in range(10)],
)

scope = {
Expand All @@ -1051,7 +1056,7 @@ async def send(message: Message) -> None:

await app(scope, receive().__anext__, send)

assert events == [
assert ordered_events == [
"1:STARTED",
"2:STARTED",
"3:STARTED",
Expand All @@ -1074,6 +1079,21 @@ async def send(message: Message) -> None:
"1:COMPLETED",
]

assert sorted(unordered_events) == sorted(
[
"1:BACKGROUND",
"2:BACKGROUND",
"3:BACKGROUND",
"4:BACKGROUND",
"5:BACKGROUND",
"6:BACKGROUND",
"7:BACKGROUND",
"8:BACKGROUND",
"9:BACKGROUND",
"10:BACKGROUND",
]
)

assert sent == [
{
"type": "http.response.start",
Expand Down