From 9b70582d6afba5c20dfe121aa105a32137a5348d Mon Sep 17 00:00:00 2001 From: Phil Krylov Date: Sat, 16 Dec 2023 17:30:20 +0100 Subject: [PATCH] Added `excluded_handlers` parameter for selective compression, source: https://github.com/fullonic/brotli-asgi/pull/21 --- README.md | 4 ++- tests.py | 21 +++++++++++++ zstd_asgi/__init__.py | 73 ++++++++++++++++++++++++++++++++----------- 3 files changed, 79 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index c89fabf..71e5b14 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ app.add_middleware( threads=0, write_checksum=True, write_content_size=False, - gzip_fallback=True + gzip_fallback=True, + excluded_handlers=None, ) ``` @@ -72,6 +73,7 @@ app.add_middleware( - `write_checksum`: If True, a 4 byte content checksum will be written with the compressed data, allowing the decompressor to perform content verification. - `write_content_size`: If True (the default), the decompressed content size will be included in the header of the compressed data. This data will only be written if the compressor knows the size of the input data. - `gzip_fallback`: If `True`, uses gzip encoding if `zstd` is not in the Accept-Encoding header. +- `excluded_handlers`: List of handlers to be excluded from being compressed. ## Performance diff --git a/tests.py b/tests.py index 158a07e..594ba1a 100644 --- a/tests.py +++ b/tests.py @@ -158,6 +158,27 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 +def test_excluded_handlers(): + app = Starlette() + + app.add_middleware( + ZstdMiddleware, + excluded_handlers=["/excluded"], + ) + + @app.route("/excluded") + def homepage(request): + return PlainTextResponse("x" * 4000, status_code=200) + + client = TestClient(app) + response = client.get("/excluded", headers={"accept-encoding": "zstd"}) + + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert "Content-Encoding" not in response.headers + assert int(response.headers["Content-Length"]) == 4000 + + def test_zstd_avoids_double_encoding(): # See https://github.com/encode/starlette/pull/1901 diff --git a/zstd_asgi/__init__.py b/zstd_asgi/__init__.py index 56caeef..c69a958 100644 --- a/zstd_asgi/__init__.py +++ b/zstd_asgi/__init__.py @@ -1,4 +1,6 @@ import io +import re +from typing import List, Union, NoReturn from starlette.datastructures import Headers, MutableHeaders from starlette.middleware.gzip import GZipResponder @@ -11,6 +13,8 @@ class ZstdMiddleware: + """Zstd middleware public interface.""" + def __init__( self, app: ASGIApp, @@ -20,7 +24,30 @@ def __init__( write_checksum: bool = False, write_content_size: bool = True, gzip_fallback: bool = True, + excluded_handlers: Union[List, None] = None, ) -> None: + """ + Arguments. + + level: Integer compression level. + Valid values are all negative integers through 22. + Negative levels effectively engage --fast mode from the zstd CLI. + minimum_size: Only compress responses that are bigger than this value in bytes. + threads: Number of threads to use to compress data concurrently. + When set, compression operations are performed on multiple threads. + The default value (0) disables multi-threaded compression. + A value of -1 means to set the number of threads to the number + of detected logical CPUs. + write_checksum: If True, a 4 byte content checksum will be written with + the compressed data, allowing the decompressor to perform content + verification. + write_content_size: If True (the default), the decompressed content size + will be included in the header of the compressed data. This data + will only be written if the compressor knows the size of the input + data. + gzip_fallback: If True, uses gzip encoding if br is not in the Accept-Encoding header. + excluded_handlers: List of handlers to be excluded from being compressed. + """ self.app = app self.level = level self.minimum_size = minimum_size @@ -28,30 +55,40 @@ def __init__( self.write_checksum = write_checksum self.write_content_size = write_content_size self.gzip_fallback = gzip_fallback + if excluded_handlers: + self.excluded_handlers = [re.compile(path) for path in excluded_handlers] + else: + self.excluded_handlers = [] async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] == "http": - accept_encoding = Headers(scope=scope).get("Accept-Encoding", "") - if "zstd" in accept_encoding: - responder = ZstdResponder( - self.app, - self.level, - self.threads, - self.write_checksum, - self.write_content_size, - self.minimum_size, - ) - await responder(scope, receive, send) - return - if self.gzip_fallback and "gzip" in accept_encoding: - responder = GZipResponder(self.app, self.minimum_size) - await responder(scope, receive, send) - return + if self._is_handler_excluded(scope) or scope["type"] != "http": + return await self.app(scope, receive, send) + accept_encoding = Headers(scope=scope).get("Accept-Encoding", "") + if "zstd" in accept_encoding: + responder = ZstdResponder( + self.app, + self.level, + self.threads, + self.write_checksum, + self.write_content_size, + self.minimum_size, + ) + await responder(scope, receive, send) + return + if self.gzip_fallback and "gzip" in accept_encoding: + responder = GZipResponder(self.app, self.minimum_size) + await responder(scope, receive, send) + return await self.app(scope, receive, send) + def _is_handler_excluded(self, scope: Scope) -> bool: + handler = scope.get("path", "") + + return any(pattern.search(handler) for pattern in self.excluded_handlers) + class ZstdResponder: def __init__( @@ -155,5 +192,5 @@ async def send_with_zstd(self, message: Message) -> None: await self.send(message) -async def unattached_send(message: Message) -> None: +async def unattached_send(message: Message) -> NoReturn: raise RuntimeError("send awaitable not set") # pragma: no cover