diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index b677063da..fc63e91b6 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -5,6 +5,8 @@ from starlette.datastructures import Headers, MutableHeaders from starlette.types import ASGIApp, Message, Receive, Scope, Send +DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",) + class GZipMiddleware: def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None: @@ -30,6 +32,7 @@ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> N self.initial_message: Message = {} self.started = False self.content_encoding_set = False + self.content_type_is_excluded = False self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel) @@ -46,7 +49,8 @@ async def send_with_gzip(self, message: Message) -> None: self.initial_message = message headers = Headers(raw=self.initial_message["headers"]) self.content_encoding_set = "content-encoding" in headers - elif message_type == "http.response.body" and self.content_encoding_set: + self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES) + elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded): if not self.started: self.started = True await self.send(self.initial_message) diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index b20a7cb84..38a4e1e35 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -104,3 +104,25 @@ async def generator(bytes: bytes, count: int) -> ContentStream: assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "text" assert "Content-Length" not in response.headers + + +def test_gzip_ignored_on_server_sent_events(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> StreamingResponse: + async def generator(bytes: bytes, count: int) -> ContentStream: + for _ in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse(streaming, status_code=200, media_type="text/event-stream") + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(GZipMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "gzip"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert "Content-Length" not in response.headers