From 474536683f7dd38d8ddc20ddefb9392fb4d3a812 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 10 Jun 2024 17:04:54 +0200 Subject: [PATCH 1/3] fix: pathsend handling in middlewares (#2613) --- starlette/middleware/base.py | 4 ++++ starlette/middleware/gzip.py | 5 ++++- starlette/responses.py | 9 ++++++++- 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 4e5054d7a..94bff3034 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -170,6 +170,10 @@ async def coro() -> None: async def body_stream() -> typing.AsyncGenerator[bytes, None]: async with recv_stream: async for message in recv_stream: + if message["type"] == "http.response.pathsend": + # with pathsend we don't need to stream anything + yield message + break assert message["type"] == "http.response.body" body = message.get("body", b"") if body: diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index cbb0f4a5b..131541dc7 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -92,7 +92,6 @@ async def send_with_gzip(self, message: Message) -> None: await self.send(self.initial_message) await self.send(message) - elif message_type == "http.response.body": # Remaining body in streaming GZip response. body = message.get("body", b"") @@ -107,6 +106,10 @@ async def send_with_gzip(self, message: Message) -> None: self.gzip_buffer.truncate() await self.send(message) + elif message_type == "http.response.pathsend": + # Don't apply GZip to pathsend responses + await self.send(self.initial_message) + await self.send(message) async def unattached_send(message: Message) -> typing.NoReturn: diff --git a/starlette/responses.py b/starlette/responses.py index a6975747b..297a68b1d 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -247,12 +247,19 @@ async def stream_response(self, send: Send) -> None: "headers": self.raw_headers, } ) + should_close_body = True async for chunk in self.body_iterator: + if isinstance(chunk, dict): + # We got an ASGI message which is not response body (eg: pathsend) + should_close_body = False + await send(chunk) + break if not isinstance(chunk, (bytes, memoryview)): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) - await send({"type": "http.response.body", "body": b"", "more_body": False}) + if should_close_body: + await send({"type": "http.response.body", "body": b"", "more_body": False}) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async with anyio.create_task_group() as task_group: From a6014c6bb029763315784beb357e5f395ba17766 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 10 Jun 2024 17:13:47 +0200 Subject: [PATCH 2/3] chore: make linter happy --- starlette/middleware/base.py | 5 ++++- starlette/responses.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 94bff3034..9c7dd0cc8 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -15,6 +15,9 @@ DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] +BodyStreamGenerator = typing.AsyncGenerator[ + typing.Union[bytes, typing.MutableMapping[str, typing.Any]], None +] T = typing.TypeVar("T") @@ -167,7 +170,7 @@ async def coro() -> None: assert message["type"] == "http.response.start" - async def body_stream() -> typing.AsyncGenerator[bytes, None]: + async def body_stream() -> BodyStreamGenerator: async with recv_stream: async for message in recv_stream: if message["type"] == "http.response.pathsend": diff --git a/starlette/responses.py b/starlette/responses.py index 297a68b1d..bbd205b13 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -207,7 +207,7 @@ def __init__( self.headers["location"] = quote(str(url), safe=":/%#?=@[]!$&'()*+,;") -Content = typing.Union[str, bytes, memoryview] +Content = typing.Union[str, bytes, memoryview, typing.MutableMapping[str, typing.Any]] SyncContentStream = typing.Iterable[Content] AsyncContentStream = typing.AsyncIterable[Content] ContentStream = typing.Union[AsyncContentStream, SyncContentStream] @@ -254,7 +254,7 @@ async def stream_response(self, send: Send) -> None: should_close_body = False await send(chunk) break - if not isinstance(chunk, (bytes, memoryview)): + if isinstance(chunk, str): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) From 4f01bf6f04db293b4da96a1098c2fb7b9ae52788 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Wed, 12 Jun 2024 18:08:17 +0200 Subject: [PATCH 3/3] test: pathsend with middlewares --- starlette/responses.py | 2 +- tests/middleware/test_base.py | 59 ++++++++++++++++++++++++++++++++++- tests/middleware/test_gzip.py | 53 +++++++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 4 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index bbd205b13..dae0a9dd1 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -253,7 +253,7 @@ async def stream_response(self, send: Send) -> None: # We got an ASGI message which is not response body (eg: pathsend) should_close_body = False await send(chunk) - break + continue if isinstance(chunk, str): chunk = chunk.encode(self.charset) await send({"type": "http.response.body", "body": chunk, "more_body": True}) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 2176404d8..34fc27c79 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -2,6 +2,7 @@ import contextvars from contextlib import AsyncExitStack +from pathlib import Path from typing import ( Any, AsyncGenerator, @@ -18,7 +19,12 @@ from starlette.middleware import Middleware, _MiddlewareClass from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response, StreamingResponse +from starlette.responses import ( + FileResponse, + PlainTextResponse, + Response, + StreamingResponse, +) from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send @@ -1035,3 +1041,54 @@ async def endpoint(request: Request) -> Response: resp.raise_for_status() assert bodies == [b"Hello, World!-foo"] + + +@pytest.mark.anyio +async def test_asgi_pathsend_events(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + request_body_sent = False + response_complete = anyio.Event() + events: list[Message] = [] + + async def endpoint_with_pathsend(_: Request) -> FileResponse: + return FileResponse(path) + + async def passthrough( + request: Request, call_next: RequestResponseEndpoint + ) -> Response: + return await call_next(request) + + app = Starlette( + middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)], + routes=[Route("/", endpoint_with_pathsend)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + nonlocal request_body_sent + if not request_body_sent: + request_body_sent = True + return {"type": "http.request", "body": b"", "more_body": False} + await response_complete.wait() + return {"type": "http.disconnect"} + + async def send(message: Message) -> None: + events.append(message) + if message["type"] == "http.response.pathsend": + response_complete.set() + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend" diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index 5bfecadb7..7c99a0ee6 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,13 +1,23 @@ +from __future__ import annotations + +from pathlib import Path from typing import Callable +import pytest + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request -from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse +from starlette.responses import ( + ContentStream, + FileResponse, + PlainTextResponse, + StreamingResponse, +) from starlette.routing import Route from starlette.testclient import TestClient -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Message TestClientFactory = Callable[[ASGIApp], TestClient] @@ -111,3 +121,42 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +@pytest.mark.anyio +async def test_gzip_ignored_for_pathsend_responses(tmpdir: Path) -> None: + path = tmpdir / "example.txt" + with path.open("w") as file: + file.write("") + + events: list[Message] = [] + + async def endpoint_with_pathsend(request: Request) -> FileResponse: + _ = await request.body() + return FileResponse(path) + + app = Starlette( + routes=[Route("/", endpoint=endpoint_with_pathsend)], + middleware=[Middleware(GZipMiddleware)], + ) + + scope = { + "type": "http", + "version": "3", + "method": "GET", + "path": "/", + "headers": [(b"accept-encoding", b"gzip, text")], + "extensions": {"http.response.pathsend": {}}, + } + + async def receive() -> Message: + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message: Message) -> None: + events.append(message) + + await app(scope, receive, send) + + assert len(events) == 2 + assert events[0]["type"] == "http.response.start" + assert events[1]["type"] == "http.response.pathsend"