diff --git a/tests.py b/tests.py index 594ba1a..31deb64 100644 --- a/tests.py +++ b/tests.py @@ -3,6 +3,7 @@ Some of these tests are the same as the ones from starlette.tests.middleware.test_gzip but using zstd instead. """ + import functools import gzip import io @@ -16,6 +17,7 @@ Response, StreamingResponse, ) +from starlette.routing import Route from starlette.testclient import TestClient import zstandard @@ -33,14 +35,12 @@ def test_client_factory(anyio_backend_name, anyio_backend_options): def test_zstd_responses(test_client_factory): - app = Starlette() - - app.add_middleware(ZstdMiddleware) - - @app.route("/") def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) + app = Starlette(routes=[Route("/", homepage)]) + app.add_middleware(ZstdMiddleware) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 @@ -52,14 +52,13 @@ def homepage(request): def test_zstd_not_in_accept_encoding(test_client_factory): - app = Starlette() - - app.add_middleware(ZstdMiddleware) - - @app.route("/") def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) + app = Starlette(routes=[Route("/", homepage)]) + + app.add_middleware(ZstdMiddleware) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "identity"}) assert response.status_code == 200 @@ -69,14 +68,13 @@ def homepage(request): def test_zstd_ignored_for_small_responses(test_client_factory): - app = Starlette() - - app.add_middleware(ZstdMiddleware) - - @app.route("/") def homepage(request): return PlainTextResponse("OK", status_code=200) + app = Starlette(routes=[Route("/", homepage)]) + + app.add_middleware(ZstdMiddleware) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 @@ -86,11 +84,6 @@ def homepage(request): def test_zstd_streaming_response(test_client_factory): - app = Starlette() - - app.add_middleware(ZstdMiddleware) - - @app.route("/") def homepage(request): async def generator(bytes, count): for index in range(count): @@ -99,6 +92,9 @@ async def generator(bytes, count): streaming = generator(bytes=b"x" * 400, count=10) return StreamingResponse(streaming, status_code=200) + app = Starlette(routes=[Route("/", homepage)]) + app.add_middleware(ZstdMiddleware) + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 @@ -107,32 +103,30 @@ async def generator(bytes, count): assert "Content-Length" not in response.headers -def test_zstd_api_options(): - """Tests default values overriding.""" - app = Starlette() +def test_zstd_api_options(test_client_factory): + def homepage(request): + return JSONResponse({"data": "a" * 4000}, status_code=200) + app = Starlette(routes=[Route("/", homepage)]) app.add_middleware( - ZstdMiddleware, level=19, write_checksum=True, threads=2, + ZstdMiddleware, + level=19, + write_checksum=True, + threads=2, ) - @app.route("/") - def homepage(request): - return JSONResponse({"data": "a" * 4000}, status_code=200) - client = TestClient(app) response = client.get("/", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 -def test_gzip_fallback(): - app = Starlette() - - app.add_middleware(ZstdMiddleware, gzip_fallback=True) - - @app.route("/") +def test_gzip_fallback(test_client_factory): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) + app = Starlette(routes=[Route("/", homepage)]) + app.add_middleware(ZstdMiddleware, gzip_fallback=True) + client = TestClient(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 @@ -141,16 +135,14 @@ def homepage(request): assert int(response.headers["Content-Length"]) < 4000 -def test_gzip_fallback_false(): - app = Starlette() - - app.add_middleware(ZstdMiddleware, gzip_fallback=False) - - @app.route("/") +def test_gzip_fallback_false(test_client_factory): def homepage(request): return PlainTextResponse("x" * 4000, status_code=200) - client = TestClient(app) + app = Starlette(routes=[Route("/", homepage)]) + app.add_middleware(ZstdMiddleware, gzip_fallback=False) + + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "gzip"}) assert response.status_code == 200 assert response.text == "x" * 4000 @@ -158,19 +150,17 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 -def test_excluded_handlers(): - app = Starlette() +def test_excluded_handlers(test_client_factory): + def homepage(request): + return PlainTextResponse("x" * 4000, status_code=200) + app = Starlette(routes=[Route("/excluded", homepage)]) app.add_middleware( ZstdMiddleware, excluded_handlers=["/excluded"], ) - @app.route("/excluded") - def homepage(request): - return PlainTextResponse("x" * 4000, status_code=200) - - client = TestClient(app) + client = test_client_factory(app) response = client.get("/excluded", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 @@ -179,14 +169,8 @@ def homepage(request): assert int(response.headers["Content-Length"]) == 4000 -def test_zstd_avoids_double_encoding(): +def test_zstd_avoids_double_encoding(test_client_factory): # See https://github.com/encode/starlette/pull/1901 - - app = Starlette() - - app.add_middleware(ZstdMiddleware, minimum_size=1) - - @app.route("/") def homepage(request): gzip_buffer = io.BytesIO() gzip_file = gzip.GzipFile(mode="wb", fileobj=gzip_buffer) @@ -197,13 +181,19 @@ def homepage(request): body, headers={ "content-encoding": "gzip", - "x-gzipped-content-length": str(len(body)) - } + "x-gzipped-content-length": str(len(body)), + }, ) - client = TestClient(app) + app = Starlette(routes=[Route("/", homepage)]) + app.add_middleware(ZstdMiddleware, minimum_size=1) + + client = test_client_factory(app) response = client.get("/", headers={"accept-encoding": "zstd"}) assert response.status_code == 200 assert response.text == "hello world" * 200 assert response.headers["Content-Encoding"] == "gzip" - assert response.headers["Content-Length"] == response.headers["x-gzipped-content-length"] + assert ( + response.headers["Content-Length"] + == response.headers["x-gzipped-content-length"] + )