Skip to content

Commit

Permalink
Merge pull request #1 from jairhenrique/fix-tests-warning
Browse files Browse the repository at this point in the history
Fix tests warning
  • Loading branch information
tuffnatty authored Aug 23, 2024
2 parents 307a7ec + 4174940 commit e621422
Showing 1 changed file with 49 additions and 59 deletions.
108 changes: 49 additions & 59 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Some of these tests are the same as the ones from starlette.tests.middleware.test_gzip
but using zstd instead.
"""

import functools
import gzip
import io
Expand All @@ -16,6 +17,7 @@
Response,
StreamingResponse,
)
from starlette.routing import Route
from starlette.testclient import TestClient

import zstandard
Expand All @@ -33,14 +35,12 @@ def test_client_factory(anyio_backend_name, anyio_backend_options):


def test_zstd_responses(test_client_factory):
app = Starlette()

app.add_middleware(ZstdMiddleware)

@app.route("/")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(ZstdMiddleware)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
Expand All @@ -52,14 +52,13 @@ def homepage(request):


def test_zstd_not_in_accept_encoding(test_client_factory):
app = Starlette()

app.add_middleware(ZstdMiddleware)

@app.route("/")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(routes=[Route("/", homepage)])

app.add_middleware(ZstdMiddleware)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "identity"})
assert response.status_code == 200
Expand All @@ -69,14 +68,13 @@ def homepage(request):


def test_zstd_ignored_for_small_responses(test_client_factory):
app = Starlette()

app.add_middleware(ZstdMiddleware)

@app.route("/")
def homepage(request):
return PlainTextResponse("OK", status_code=200)

app = Starlette(routes=[Route("/", homepage)])

app.add_middleware(ZstdMiddleware)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
Expand All @@ -86,11 +84,6 @@ def homepage(request):


def test_zstd_streaming_response(test_client_factory):
app = Starlette()

app.add_middleware(ZstdMiddleware)

@app.route("/")
def homepage(request):
async def generator(bytes, count):
for index in range(count):
Expand All @@ -99,6 +92,9 @@ async def generator(bytes, count):
streaming = generator(bytes=b"x" * 400, count=10)
return StreamingResponse(streaming, status_code=200)

app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(ZstdMiddleware)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
Expand All @@ -107,32 +103,30 @@ async def generator(bytes, count):
assert "Content-Length" not in response.headers


def test_zstd_api_options():
"""Tests default values overriding."""
app = Starlette()
def test_zstd_api_options(test_client_factory):
def homepage(request):
return JSONResponse({"data": "a" * 4000}, status_code=200)

app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(
ZstdMiddleware, level=19, write_checksum=True, threads=2,
ZstdMiddleware,
level=19,
write_checksum=True,
threads=2,
)

@app.route("/")
def homepage(request):
return JSONResponse({"data": "a" * 4000}, status_code=200)

client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200


def test_gzip_fallback():
app = Starlette()

app.add_middleware(ZstdMiddleware, gzip_fallback=True)

@app.route("/")
def test_gzip_fallback(test_client_factory):
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(ZstdMiddleware, gzip_fallback=True)

client = TestClient(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
Expand All @@ -141,36 +135,32 @@ def homepage(request):
assert int(response.headers["Content-Length"]) < 4000


def test_gzip_fallback_false():
app = Starlette()

app.add_middleware(ZstdMiddleware, gzip_fallback=False)

@app.route("/")
def test_gzip_fallback_false(test_client_factory):
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

client = TestClient(app)
app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(ZstdMiddleware, gzip_fallback=False)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "gzip"})
assert response.status_code == 200
assert response.text == "x" * 4000
assert "Content-Encoding" not in response.headers
assert int(response.headers["Content-Length"]) == 4000


def test_excluded_handlers():
app = Starlette()
def test_excluded_handlers(test_client_factory):
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

app = Starlette(routes=[Route("/excluded", homepage)])
app.add_middleware(
ZstdMiddleware,
excluded_handlers=["/excluded"],
)

@app.route("/excluded")
def homepage(request):
return PlainTextResponse("x" * 4000, status_code=200)

client = TestClient(app)
client = test_client_factory(app)
response = client.get("/excluded", headers={"accept-encoding": "zstd"})

assert response.status_code == 200
Expand All @@ -179,14 +169,8 @@ def homepage(request):
assert int(response.headers["Content-Length"]) == 4000


def test_zstd_avoids_double_encoding():
def test_zstd_avoids_double_encoding(test_client_factory):
# See https://github.com/encode/starlette/pull/1901

app = Starlette()

app.add_middleware(ZstdMiddleware, minimum_size=1)

@app.route("/")
def homepage(request):
gzip_buffer = io.BytesIO()
gzip_file = gzip.GzipFile(mode="wb", fileobj=gzip_buffer)
Expand All @@ -197,13 +181,19 @@ def homepage(request):
body,
headers={
"content-encoding": "gzip",
"x-gzipped-content-length": str(len(body))
}
"x-gzipped-content-length": str(len(body)),
},
)

client = TestClient(app)
app = Starlette(routes=[Route("/", homepage)])
app.add_middleware(ZstdMiddleware, minimum_size=1)

client = test_client_factory(app)
response = client.get("/", headers={"accept-encoding": "zstd"})
assert response.status_code == 200
assert response.text == "hello world" * 200
assert response.headers["Content-Encoding"] == "gzip"
assert response.headers["Content-Length"] == response.headers["x-gzipped-content-length"]
assert (
response.headers["Content-Length"]
== response.headers["x-gzipped-content-length"]
)

0 comments on commit e621422

Please sign in to comment.