Skip to content

Commit

Permalink
Added excluded_handlers parameter for selective compression, source:
Browse files Browse the repository at this point in the history
  • Loading branch information
tuffnatty committed Dec 16, 2023
1 parent 238c7ac commit 9b70582
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 19 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ app.add_middleware(
threads=0,
write_checksum=True,
write_content_size=False,
gzip_fallback=True
gzip_fallback=True,
excluded_handlers=None,
)
```

Expand All @@ -72,6 +73,7 @@ app.add_middleware(
- `write_checksum`: If True, a 4 byte content checksum will be written with the compressed data, allowing the decompressor to perform content verification.
- `write_content_size`: If True (the default), the decompressed content size will be included in the header of the compressed data. This data will only be written if the compressor knows the size of the input data.
- `gzip_fallback`: If `True`, uses gzip encoding if `zstd` is not in the Accept-Encoding header.
- `excluded_handlers`: List of handlers to be excluded from being compressed.

## Performance

Expand Down
21 changes: 21 additions & 0 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ def homepage(request):
assert int(response.headers["Content-Length"]) == 4000


def test_excluded_handlers():
app = Starlette()

app.add_middleware(
ZstdMiddleware,
excluded_handlers=["/excluded"],
)

@app.route("/excluded")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

client = TestClient(app)
response = client.get("/excluded", headers={"accept-encoding": "zstd"})

assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 4000


def test_zstd_avoids_double_encoding():
# See https://github.com/encode/starlette/pull/1901

Expand Down
73 changes: 55 additions & 18 deletions zstd_asgi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import re
from typing import List, Union, NoReturn

from starlette.datastructures import Headers, MutableHeaders
from starlette.middleware.gzip import GZipResponder
Expand All @@ -11,6 +13,8 @@


class ZstdMiddleware:
"""Zstd middleware public interface."""

def __init__(
self,
app: ASGIApp,
Expand All @@ -20,38 +24,71 @@ def __init__(
write_checksum: bool = False,
write_content_size: bool = True,
gzip_fallback: bool = True,
excluded_handlers: Union[List, None] = None,
) -> None:
"""
Arguments.
level: Integer compression level.
Valid values are all negative integers through 22.
Negative levels effectively engage --fast mode from the zstd CLI.
minimum_size: Only compress responses that are bigger than this value in bytes.
threads: Number of threads to use to compress data concurrently.
When set, compression operations are performed on multiple threads.
The default value (0) disables multi-threaded compression.
A value of -1 means to set the number of threads to the number
of detected logical CPUs.
write_checksum: If True, a 4 byte content checksum will be written with
the compressed data, allowing the decompressor to perform content
verification.
write_content_size: If True (the default), the decompressed content size
will be included in the header of the compressed data. This data
will only be written if the compressor knows the size of the input
data.
gzip_fallback: If True, uses gzip encoding if br is not in the Accept-Encoding header.
excluded_handlers: List of handlers to be excluded from being compressed.
"""
self.app = app
self.level = level
self.minimum_size = minimum_size
self.threads = threads
self.write_checksum = write_checksum
self.write_content_size = write_content_size
self.gzip_fallback = gzip_fallback
if excluded_handlers:
self.excluded_handlers = [re.compile(path) for path in excluded_handlers]
else:
self.excluded_handlers = []

async def __call__(self,
scope: Scope,
receive: Receive,
send: Send) -> None:
if scope["type"] == "http":
accept_encoding = Headers(scope=scope).get("Accept-Encoding", "")
if "zstd" in accept_encoding:
responder = ZstdResponder(
self.app,
self.level,
self.threads,
self.write_checksum,
self.write_content_size,
self.minimum_size,
)
await responder(scope, receive, send)
return
if self.gzip_fallback and "gzip" in accept_encoding:
responder = GZipResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
return
if self._is_handler_excluded(scope) or scope["type"] != "http":
return await self.app(scope, receive, send)
accept_encoding = Headers(scope=scope).get("Accept-Encoding", "")
if "zstd" in accept_encoding:
responder = ZstdResponder(
self.app,
self.level,
self.threads,
self.write_checksum,
self.write_content_size,
self.minimum_size,
)
await responder(scope, receive, send)
return
if self.gzip_fallback and "gzip" in accept_encoding:
responder = GZipResponder(self.app, self.minimum_size)
await responder(scope, receive, send)
return
await self.app(scope, receive, send)

def _is_handler_excluded(self, scope: Scope) -> bool:
handler = scope.get("path", "")

return any(pattern.search(handler) for pattern in self.excluded_handlers)


class ZstdResponder:
def __init__(
Expand Down Expand Up @@ -155,5 +192,5 @@ async def send_with_zstd(self, message: Message) -> None:
await self.send(message)


async def unattached_send(message: Message) -> None:
async def unattached_send(message: Message) -> NoReturn:
raise RuntimeError("send awaitable not set") # pragma: no cover

0 comments on commit 9b70582

Please sign in to comment.