diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 4e5054d7a..9c7dd0cc8 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -15,6 +15,9 @@ DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] +BodyStreamGenerator = typing.AsyncGenerator[ + typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None +] T = typing.TypeVar("T") @@ -167,9 +170,13 @@ async def coro() -> None: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> BodyStreamGenerator: async with recv_stream: async for message in recv_stream: + if message["type"] == "http.response.pathsend": + # with pathsend we don't need to stream anything + yield message + break assert message["type"] == "http.response.body" body = message.get("body", b"") if body: diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index cbb0f4a5b..131541dc7 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -92,7 +92,6 @@ async def send_with_gzip(self, message: Message) -> None: await self.send(self.initial_message) await self.send(message) - elif message_type == "http.response.body": # Remaining body in streaming GZip response. body = message.get("body", b"") @@ -107,6 +106,10 @@ async def send_with_gzip(self, message: Message) -> None: self.gzip_buffer.truncate() await self.send(message) + elif message_type == "http.response.pathsend": + # Don't apply GZip to pathsend responses + await self.send(self.initial_message) + await self.send(message) async def unattached_send(message: Message) -> typing.NoReturn: diff --git a/starlette/responses.py b/starlette/responses.py index a6975747b..dae0a9dd1 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -207,7 +207,7 @@ def __init__( self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") -Content = typing.Union[str, bytes, memoryview] +Content = typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]] SyncContentStream = typing.Iterable[Content] AsyncContentStream = typing.AsyncIterable[Content] ContentStream = typing.Union[AsyncContentStream, SyncContentStream] @@ -247,12 +247,19 @@ async def stream_response(self, send: Send) -> None: "headers": self.raw_headers, } ) + should_close_body = True async for chunk in self.body_iterator: - if not isinstance(chunk, (bytes, memoryview)): + if isinstance(chunk, dict): + # We got an ASGI message which is not response body (eg: pathsend) + should_close_body = False + await send(chunk) + continue + if isinstance(chunk, str): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + if should_close_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..34fc27c79 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,6 +2,7 @@ import contextvars from contextlib import AsyncExitStack +from pathlib import Path from typing import ( Any, AsyncGenerator, @@ -18,7 +19,12 @@ from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import ( + FileResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1035,3 +1041,54 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb7..7c99a0ee6 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,13 +1,23 @@ +from __future__ import annotations + +from pathlib import Path from typing import Callable +import pytest + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse +from starlette.responses import ( + ContentStream, + FileResponse, + PlainTextResponse, + StreamingResponse, +) from starlette.routing import Route from starlette.testclient import TestClient -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Message TestClientFactory = Callable[[ASGIApp], TestClient] @@ -111,3 +121,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +@pytest.mark.anyio +async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + events: list[Message] = [] + + async def endpoint_with_pathsend(request: Request) -> FileResponse: + _ = await request.body() + return FileResponse(path) + + app = Starlette( + routes=[Route("/", endpoint=endpoint_with_pathsend)], + middleware=[Middleware(GZipMiddleware)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [(b"accept-encoding", b"gzip, text")], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: Message) -> None: + events.append(message) + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend"