diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index b677063da..875198e52 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -81,6 +81,7 @@ async def send_with_gzip(self, message: Message) -> None: del headers["Content-Length"] self.gzip_file.write(body) + self.gzip_file.flush() message["body"] = self.gzip_buffer.getvalue() self.gzip_buffer.seek(0) self.gzip_buffer.truncate() @@ -94,7 +95,9 @@ async def send_with_gzip(self, message: Message) -> None: more_body = message.get("more_body", False) self.gzip_file.write(body) - if not more_body: + if more_body: + self.gzip_file.flush() + else: self.gzip_file.close() message["body"] = self.gzip_buffer.getvalue() diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84..6c5f60247 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -1,9 +1,12 @@ +from __future__ import annotations + from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.gzip import GZipMiddleware from starlette.requests import Request from starlette.responses import ContentStream, PlainTextResponse, StreamingResponse from starlette.routing import Route +from starlette.types import ASGIApp, Message, Receive, Scope, Send from tests.types import TestClientFactory @@ -61,6 +64,24 @@ def homepage(request: Request) -> PlainTextResponse: def test_gzip_streaming_response(test_client_factory: TestClientFactory) -> None: + class VerifyingMiddleware: + def __init__(self, app: ASGIApp) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + self.received_chunks: list[bytes] = [] + await self.app(scope, receive, self.sent_from_gzip) + assert all(chunk != b"" for chunk in self.received_chunks) + assert len(self.received_chunks) == 11 + + async def sent_from_gzip(self, message: Message) -> None: + message_type = message["type"] + if message_type == "http.response.body": + body = message.get("body", b"") + self.received_chunks.append(body) + await self.send(message) + def homepage(request: Request) -> StreamingResponse: async def generator(bytes: bytes, count: int) -> ContentStream: for index in range(count): @@ -71,7 +92,10 @@ async def generator(bytes: bytes, count: int) -> ContentStream: app = Starlette( routes=[Route("/", endpoint=homepage)], - middleware=[Middleware(GZipMiddleware)], + middleware=[ + Middleware(VerifyingMiddleware), + Middleware(GZipMiddleware), + ], ) client = test_client_factory(app)