diff --git a/README.md b/README.md index 1357f9c0e..bae9846bc 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,8 @@ It is production-ready, and gives you the following: * In-process background tasks. * Startup and shutdown events. * Test client built on `httpx`. -* CORS, GZip, Static Files, Streaming responses. +* CORS, Static Files, Streaming responses. +* ZStd, Brotli, GZip response compression. * Session and Cookie support. * 100% test coverage. * 100% type annotated codebase. @@ -88,6 +89,8 @@ Starlette only requires `anyio`, and the following are optional: * [`python-multipart`][python-multipart] - Required if you want to support form parsing, with `request.form()`. * [`itsdangerous`][itsdangerous] - Required for `SessionMiddleware` support. * [`pyyaml`][pyyaml] - Required for `SchemaGenerator` support. +* [`brotli`][brotli] or [`brotlicffi`][brotlicffi] - Required for Brotli response compression. +* [`zstandard`][zstandard] - Required for ZStd response compression. You can install all of these with `pip3 install starlette[full]`. @@ -134,6 +137,8 @@ in isolation. [jinja2]: https://jinja.palletsprojects.com/ [python-multipart]: https://andrew-d.github.io/python-multipart/ [itsdangerous]: https://itsdangerous.palletsprojects.com/ -[sqlalchemy]: https://www.sqlalchemy.org [pyyaml]: https://pyyaml.org/wiki/PyYAMLDocumentation [techempower]: https://www.techempower.com/benchmarks/#hw=ph&test=fortune&l=zijzen-sf +[brotli]: https://pypi.org/project/Brotli/ +[brotlicffi]: https://pypi.org/project/brotlicffi/ +[zstandard]: https://python-zstandard.readthedocs.io/ diff --git a/pyproject.toml b/pyproject.toml index 679deaade..45c28cfb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ full = [ "python-multipart>=0.0.7", "pyyaml", "httpx>=0.22.0", + "brotli>=1; platform_python_implementation == 'CPython'", + "brotlicffi>=1; platform_python_implementation != 'CPython'", + "zstandard>=0.15", ] [project.urls] diff --git a/starlette/middleware/compress.py b/starlette/middleware/compress.py new file mode 100644 index 000000000..1e9b19621 --- /dev/null +++ b/starlette/middleware/compress.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import gzip +import platform +import re +from functools import lru_cache +from io import BytesIO +from typing import TYPE_CHECKING, NoReturn + +from starlette.datastructures import Headers, MutableHeaders +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +_missing_packages: list[str] = [] + +if platform.python_implementation() == "CPython": + try: + try: + import brotli + except ModuleNotFoundError: # pragma: nocover + import brotlicffi as brotli + except ModuleNotFoundError: # pragma: nocover + _missing_packages.append("brotli") +else: # pragma: nocover + try: + try: + import brotlicffi as brotli + except ModuleNotFoundError: + import brotli + except ModuleNotFoundError: + _missing_packages.append("brotlicffi") + +try: + from zstandard import ZstdCompressor + + if TYPE_CHECKING: # pragma: nocover + from zstandard import ZstdCompressionChunker +except ModuleNotFoundError: # pragma: nocover + _missing_packages.append("zstandard") + +if _missing_packages: # pragma: nocover + missing_packages_and = " and ".join(_missing_packages) + missing_packages_space = " ".join(_missing_packages) + raise RuntimeError( + "The starlette.middleware.compress module requires " + f"the {missing_packages_and} package to be installed.\n" + "You can install this with:\n" + f" $ pip install {missing_packages_space}\n" + ) + + +class CompressMiddleware: + """ + Response compressing middleware. + """ + + __slots__ = ( + "app", + "minimum_size", + "gzip", + "gzip_level", + "brotli", + "brotli_quality", + "zstd", + "zstd_compressor", + ) + + def __init__( + self, + app: ASGIApp, + *, + minimum_size: int = 500, + gzip: bool = True, + gzip_level: int = 4, + brotli: bool = True, + brotli_quality: int = 4, + zstd: bool = True, + zstd_level: int = 4, + ) -> None: + self.app = app + self.minimum_size = minimum_size + self.gzip = gzip + self.gzip_level = gzip_level + self.brotli = brotli + self.brotli_quality = brotli_quality + self.zstd = zstd + self.zstd_compressor = ZstdCompressor(level=zstd_level) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + accept_encoding = Headers(scope=scope).get("Accept-Encoding") + + if not accept_encoding: + await self.app(scope, receive, send) + return + + accept_encodings = parse_accept_encoding(accept_encoding) + + if self.zstd and "zstd" in accept_encodings: + await _ZstdResponder(self.app, self.minimum_size, self.zstd_compressor)( + scope, receive, send + ) + return + elif self.brotli and "br" in accept_encodings: + await _BrotliResponder( + self.app, self.minimum_size, self.brotli_quality + )(scope, receive, send) + return + elif self.gzip and "gzip" in accept_encodings: + await _GZipResponder(self.app, self.minimum_size, self.gzip_level)( + scope, receive, send + ) + return + + await self.app(scope, receive, send) + + +class _ZstdResponder: + __slots__ = ( + "app", + "minimum_size", + "compressor", + "chunker", + "send", + "start_message", + ) + + def __init__( + self, app: ASGIApp, minimum_size: int, compressor: ZstdCompressor + ) -> None: + self.app = app + self.minimum_size = minimum_size + self.compressor = compressor + self.chunker: ZstdCompressionChunker | None = None + self.send: Send = _unattached_send + self.start_message: Message | None = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + await self.app(scope, receive, self.wrapper) + + async def wrapper(self, message: Message) -> None: + message_type: str = message["type"] + + # handle start message + if message_type == "http.response.start": + if self.start_message is not None: # pragma: nocover + raise AssertionError("Unexpected repeated http.response.start message") + + if _is_start_message_satisfied(message): + # capture start message and wait for response body + self.start_message = message + return + else: + await self.send(message) + return + + # skip if start message is not satisfied or unknown message type + if self.start_message is None or message_type != "http.response.body": + await self.send(message) + return + + body: bytes = message.get("body", b"") + more_body: bool = message.get("more_body", False) + + if self.chunker is None: + # skip compression for small responses + if not more_body and len(body) < self.minimum_size: + await self.send(self.start_message) + await self.send(message) + return + + headers = MutableHeaders(raw=self.start_message["headers"]) + headers["Content-Encoding"] = "zstd" + headers.add_vary_header("Accept-Encoding") + + if not more_body: + # one-shot + compressed_body = self.compressor.compress(body) + headers["Content-Length"] = str(len(compressed_body)) + message["body"] = compressed_body + await self.send(self.start_message) + await self.send(message) + return + + # begin streaming + content_length: int = int(headers.get("Content-Length", -1)) + del headers["Content-Length"] + await self.send(self.start_message) + self.chunker = self.compressor.chunker(content_length) + + # streaming + for chunk in self.chunker.compress(body): + await self.send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + if more_body: + return + for chunk in self.chunker.finish(): # type: ignore + await self.send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + + await self.send({"type": "http.response.body"}) + + +class _BrotliResponder: + __slots__ = ( + "app", + "minimum_size", + "quality", + "compressor", + "send", + "start_message", + ) + + def __init__(self, app: ASGIApp, minimum_size: int, quality: int) -> None: + self.app = app + self.minimum_size = minimum_size + self.quality = quality + self.compressor: brotli.Compressor | None = None + self.send: Send = _unattached_send + self.start_message: Message | None = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + await self.app(scope, receive, self.wrapper) + + async def wrapper(self, message: Message) -> None: + message_type: str = message["type"] + + # handle start message + if message_type == "http.response.start": + if self.start_message is not None: # pragma: nocover + raise AssertionError("Unexpected repeated http.response.start message") + + if _is_start_message_satisfied(message): + # capture start message and wait for response body + self.start_message = message + return + else: + await self.send(message) + return + + # skip if start message is not satisfied or unknown message type + if self.start_message is None or message_type != "http.response.body": + await self.send(message) + return + + body: bytes = message.get("body", b"") + more_body: bool = message.get("more_body", False) + + if self.compressor is None: + # skip compression for small responses + if not more_body and len(body) < self.minimum_size: + await self.send(self.start_message) + await self.send(message) + return + + headers = MutableHeaders(raw=self.start_message["headers"]) + headers["Content-Encoding"] = "br" + headers.add_vary_header("Accept-Encoding") + + if not more_body: + # one-shot + compressed_body = brotli.compress(body, quality=self.quality) + headers["Content-Length"] = str(len(compressed_body)) + message["body"] = compressed_body + await self.send(self.start_message) + await self.send(message) + return + + # begin streaming + del headers["Content-Length"] + await self.send(self.start_message) + self.compressor = brotli.Compressor(quality=self.quality) + + # streaming + chunk = self.compressor.process(body) + if chunk: + await self.send( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + if more_body: + return + chunk = self.compressor.finish() + await self.send({"type": "http.response.body", "body": chunk}) + + +class _GZipResponder: + __slots__ = ( + "app", + "minimum_size", + "level", + "compressor", + "buffer", + "send", + "start_message", + ) + + def __init__(self, app: ASGIApp, minimum_size: int, level: int) -> None: + self.app = app + self.minimum_size = minimum_size + self.level = level + self.compressor: gzip.GzipFile | None = None + self.buffer: BytesIO | None = None + self.send: Send = _unattached_send + self.start_message: Message | None = None + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + self.send = send + await self.app(scope, receive, self.wrapper) + + async def wrapper(self, message: Message) -> None: + message_type: str = message["type"] + + # handle start message + if message_type == "http.response.start": + if self.start_message is not None: # pragma: nocover + raise AssertionError("Unexpected repeated http.response.start message") + + if _is_start_message_satisfied(message): + # capture start message and wait for response body + self.start_message = message + return + else: + await self.send(message) + return + + # skip if start message is not satisfied or unknown message type + if self.start_message is None or message_type != "http.response.body": + await self.send(message) + return + + body: bytes = message.get("body", b"") + more_body: bool = message.get("more_body", False) + + if self.compressor is None: + # skip compression for small responses + if not more_body and len(body) < self.minimum_size: + await self.send(self.start_message) + await self.send(message) + return + + headers = MutableHeaders(raw=self.start_message["headers"]) + headers["Content-Encoding"] = "gzip" + headers.add_vary_header("Accept-Encoding") + + if not more_body: + # one-shot + compressed_body = gzip.compress(body, compresslevel=self.level) + headers["Content-Length"] = str(len(compressed_body)) + message["body"] = compressed_body + await self.send(self.start_message) + await self.send(message) + return + + # begin streaming + del headers["Content-Length"] + await self.send(self.start_message) + self.buffer = BytesIO() + self.compressor = gzip.GzipFile( + mode="wb", compresslevel=self.level, fileobj=self.buffer + ) + + if self.buffer is None: # pragma: nocover + raise AssertionError("Compressor is set but buffer is not") + + # streaming + self.compressor.write(body) + if not more_body: + self.compressor.close() + compressed_body = self.buffer.getvalue() + if more_body: + if compressed_body: + self.buffer.seek(0) + self.buffer.truncate() + else: + return + await self.send( + { + "type": "http.response.body", + "body": compressed_body, + "more_body": more_body, + } + ) + + +_accept_encoding_re = re.compile(r"[a-z]{2,8}") + + +@lru_cache(maxsize=128) +def parse_accept_encoding(accept_encoding: str) -> frozenset[str]: + """ + Parse the accept encoding header and return a set of supported encodings. + + >>> parse_accept_encoding('br;q=1.0, gzip;q=0.8, *;q=0.1') + {'br', 'gzip'} + """ + return frozenset(_accept_encoding_re.findall(accept_encoding)) + + +# Primarily based on: +# https://github.com/h5bp/server-configs-nginx/blob/main/h5bp/web_performance/compression.conf#L38 +_compress_content_types: set[str] = { + "application/atom+xml", + "application/geo+json", + "application/gpx+xml", + "application/javascript", + "application/x-javascript", + "application/json", + "application/ld+json", + "application/manifest+json", + "application/rdf+xml", + "application/rss+xml", + "application/vnd.mapbox-vector-tile", + "application/vnd.ms-fontobject", + "application/wasm", + "application/x-web-app-manifest+json", + "application/xhtml+xml", + "application/xml", + "font/eot", + "font/otf", + "font/ttf", + "image/bmp", + "image/svg+xml", + "image/vnd.microsoft.icon", + "image/x-icon", + "text/cache-manifest", + "text/calendar", + "text/css", + "text/html", + "text/javascript", + "text/markdown", + "text/plain", + "text/xml", + "text/vcard", + "text/vnd.rim.location.xloc", + "text/vtt", + "text/x-component", + "text/x-cross-domain-policy", +} + + +def register_compress_content_type(content_type: str) -> None: + """ + Register a new content type to be compressed. + """ + _compress_content_types.add(content_type) + + +def deregister_compress_content_type(content_type: str) -> None: + """ + Deregister a content type from being compressed. + """ + _compress_content_types.discard(content_type) + + +def _is_start_message_satisfied(message: Message) -> bool: + """ + Check if response should be compressed based on the start message. + """ + headers = Headers(raw=message["headers"]) + + # must not already be compressed + if "Content-Encoding" in headers: + return False + + # content type header must be present + content_type = headers.get("Content-Type") + if not content_type: + return False + + # must be a compressible content type + basic_content_type = content_type.partition(";")[0].strip() + return basic_content_type in _compress_content_types + + +async def _unattached_send(message: Message) -> NoReturn: # pragma: nocover + raise RuntimeError("send awaitable not set") diff --git a/tests/middleware/test_compress.py b/tests/middleware/test_compress.py new file mode 100644 index 000000000..ca2132f43 --- /dev/null +++ b/tests/middleware/test_compress.py @@ -0,0 +1,248 @@ +import random +from typing import Callable + +import pytest +import zstandard + +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.compress import ( + CompressMiddleware, + deregister_compress_content_type, + parse_accept_encoding, + register_compress_content_type, +) +from starlette.requests import Request +from starlette.responses import ( + ContentStream, + PlainTextResponse, + Response, + StreamingResponse, +) +from starlette.routing import Route +from starlette.testclient import TestClient +from starlette.types import ASGIApp + +TestClientFactory = Callable[[ASGIApp], TestClient] + + +def test_compress_responses(test_client_factory: TestClientFactory) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("x" * 4000, status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + + # httpx does not support zstd yet + # https://github.com/encode/httpx/pull/3139 + if encoding == "zstd": + response._text = zstandard.decompress(response.content).decode() + + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == encoding + assert int(response.headers["Content-Length"]) < 4000 + + +def test_compress_not_in_accept_encoding( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("x" * 4000, status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "identity"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + +def test_compress_ignored_for_small_responses( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("OK", status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + assert response.text == "OK" + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 2 + + +@pytest.mark.parametrize( + "chunk_size", + [ + 1, + zstandard.COMPRESSION_RECOMMENDED_OUTPUT_SIZE, # currently 128KB + ], +) +def test_compress_streaming_response( + test_client_factory: TestClientFactory, chunk_size: int +) -> None: + random.seed(42) + chunk_count = 70 + + def homepage(request: Request) -> StreamingResponse: + async def generator(count: int) -> ContentStream: + for _ in range(count): + # enough entropy is required for successful chunks + yield random.getrandbits(8 * chunk_size).to_bytes(chunk_size, "big") + + streaming = generator(chunk_count) + return StreamingResponse(streaming, status_code=200, media_type="text/plain") + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + + # httpx does not support zstd yet + # https://github.com/encode/httpx/pull/3139 + if encoding == "zstd": + response._content = ( + zstandard.ZstdDecompressor() + .decompressobj() + .decompress(response.content) + ) + + assert len(response.content) == chunk_count * chunk_size + assert response.headers["Content-Encoding"] == encoding + assert "Content-Length" not in response.headers + + +def test_compress_ignored_for_responses_with_encoding_set( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> StreamingResponse: + async def generator(bytes: bytes, count: int) -> ContentStream: + for _ in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse( + streaming, status_code=200, headers={"Content-Encoding": "test"} + ) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": f"{encoding}, test"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == "test" + assert "Content-Length" not in response.headers + + +def test_compress_ignored_for_missing_accept_encoding( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> PlainTextResponse: + return PlainTextResponse("x" * 4000, status_code=200) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": ""}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + +def test_compress_ignored_for_missing_content_type( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> Response: + return Response("x" * 4000, status_code=200, media_type=None) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + +def test_compress_registered_content_type( + test_client_factory: TestClientFactory, +) -> None: + def homepage(request: Request) -> Response: + return Response("x" * 4000, status_code=200, media_type="test/test") + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CompressMiddleware)], + ) + + client = test_client_factory(app) + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + register_compress_content_type("test/test") + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + assert response.headers["Content-Encoding"] == encoding + assert int(response.headers["Content-Length"]) < 4000 + + deregister_compress_content_type("test/test") + + for encoding in ("gzip", "br", "zstd"): + response = client.get("/", headers={"accept-encoding": encoding}) + assert response.status_code == 200 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + +def test_parse_accept_encoding() -> None: + assert parse_accept_encoding("") == frozenset() + assert parse_accept_encoding("gzip, deflate") == {"gzip", "deflate"} + assert parse_accept_encoding("br;q=1.0,gzip;q=0.8, *;q=0.1") == {"br", "gzip"} diff --git a/tests/test_requests.py b/tests/test_requests.py index d8e2e9477..30b8613d8 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys from typing import Any, Callable, Iterator import anyio @@ -42,10 +41,6 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.json() == {"params": {"a": "123", "b": "456"}} -@pytest.mark.skipif( - any(module in sys.modules for module in ("brotli", "brotlicffi")), - reason='urllib3 includes "br" to the "accept-encoding" headers.', -) def test_request_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) @@ -59,7 +54,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "headers": { "host": "example.org", "user-agent": "testclient", - "accept-encoding": "gzip, deflate", + "accept-encoding": "gzip, deflate, br", "accept": "*/*", "connection": "keep-alive", } diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 854c26914..a370beab1 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,4 +1,3 @@ -import sys from typing import Any, Callable, MutableMapping import anyio @@ -74,10 +73,6 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert data == {"params": {"a": "abc", "b": "456"}} -@pytest.mark.skipif( - any(module in sys.modules for module in ("brotli", "brotlicffi")), - reason='urllib3 includes "br" to the "accept-encoding" headers.', -) def test_websocket_headers(test_client_factory: TestClientFactory) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: websocket = WebSocket(scope, receive=receive, send=send) @@ -90,7 +85,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: with client.websocket_connect("/") as websocket: expected_headers = { "accept": "*/*", - "accept-encoding": "gzip, deflate", + "accept-encoding": "gzip, deflate, br", "connection": "upgrade", "host": "testserver", "user-agent": "testclient",