From bf34abb88aab714d3d0dd4bfa8a6f0dff0e7bc5a Mon Sep 17 00:00:00 2001 From: lealre Date: Mon, 16 Dec 2024 20:51:39 +0000 Subject: [PATCH 1/3] tests: branches coverage in `responses.py`, `staticfiles.py`, `templating.py`, `middleware/wsgi.py` and `endpoints.py` --- starlette/endpoints.py | 2 +- starlette/middleware/wsgi.py | 2 +- starlette/templating.py | 2 +- tests/test_responses.py | 12 ++++++++++++ tests/test_staticfiles.py | 14 ++++++++++++++ 5 files changed, 29 insertions(+), 3 deletions(-) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index eb1dace42..107690266 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -74,7 +74,7 @@ async def dispatch(self) -> None: if message["type"] == "websocket.receive": data = await self.decode(websocket, message) await self.on_receive(websocket, data) - elif message["type"] == "websocket.disconnect": + elif message["type"] == "websocket.disconnect": # pragma: no branch close_code = int(message.get("code") or status.WS_1000_NORMAL_CLOSURE) break except Exception as exc: diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 71f4ab5de..6e0a3fae6 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -121,7 +121,7 @@ def start_response( exc_info: typing.Any = None, ) -> None: self.exc_info = exc_info - if not self.response_started: + if not self.response_started: # pragma: no branch self.response_started = True status_code_string, _ = status.split(" ", 1) status_code = int(status_code_string) diff --git a/starlette/templating.py b/starlette/templating.py index 78bfb8c26..6b01aac92 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -43,7 +43,7 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) - if "http.response.debug" in extensions: + if "http.response.debug" in extensions: # pragma: no branch await send( { "type": "http.response.debug", diff --git a/tests/test_responses.py b/tests/test_responses.py index 3c2d346d3..26e44048b 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -393,6 +393,18 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: assert response.headers["set-cookie"] == "mycookie=myvalue; SameSite=lax" +def test_set_cookie_samesite_none(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + response = Response("Hello, world!", media_type="text/plain") + response.set_cookie("mycookie", "myvalue", samesite=None) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "Hello, world!" + assert response.headers["set-cookie"] == "mycookie=myvalue; Path=/" + + @pytest.mark.parametrize( "expires", [ diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 8f7423593..696332fee 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -216,6 +216,20 @@ def test_staticfiles_304_with_etag_match(tmpdir: Path, test_client_factory: Test assert second_resp.content == b"" +def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: TestClientFactory) -> None: + path = os.path.join(tmpdir, "example.txt") + with open(path, "w") as file: + file.write("") + + app = StaticFiles(directory=tmpdir) + client = test_client_factory(app) + first_resp = client.get("/example.txt") + assert first_resp.status_code == 200 + second_resp = client.get("/example.txt", headers={"if-none-match": '"123"'}) + assert second_resp.status_code == 200 + assert second_resp.content == b"" + + def test_staticfiles_304_with_last_modified_compare_last_req( tmpdir: Path, test_client_factory: TestClientFactory ) -> None: From 846eb141d31fba311457db5114fef4b34d6d7841 Mon Sep 17 00:00:00 2001 From: lealre Date: Wed, 18 Dec 2024 15:25:40 +0000 Subject: [PATCH 2/3] remove `pragma: no branch` in `templating.py` and add a test for the case --- starlette/templating.py | 2 +- tests/test_templates.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/starlette/templating.py b/starlette/templating.py index 6b01aac92..78bfb8c26 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -43,7 +43,7 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) - if "http.response.debug" in extensions: # pragma: no branch + if "http.response.debug" in extensions: await send( { "type": "http.response.debug", diff --git a/tests/test_templates.py b/tests/test_templates.py index 6b2080c17..ad30b8ccd 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -15,6 +15,7 @@ from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates +from starlette.types import Message, Receive, Scope, Send from tests.types import TestClientFactory @@ -308,3 +309,36 @@ def page(request: Request) -> Response: assert response.headers["x-key"] == "value" assert response.headers["content-type"] == "text/plain; charset=utf-8" spy.assert_called() + + +@pytest.mark.anyio +async def test_branch_coverage_http_response_debug_not_in_message(tmpdir: Path) -> None: + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("Hello") + + templates = Jinja2Templates(directory=str(tmpdir)) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + request = Request(scope, receive) + response = templates.TemplateResponse(request, "index.html") + await response(scope, receive, send) + + async def receive() -> Message: + raise NotImplementedError("Should not be called!") + + async def send(message: Message) -> None: + if message["type"] == "http.response.start": + assert message["status"] == 200 + if message["type"] == "http.response.body": + assert message["body"] == b"Hello" + assert "http.response.debug" not in message["type"] + + scope = { + "type": "http", + "method": "GET", + "path": "/", + } + + await app(scope, receive, send) From 877ec15775b9629ec0ed82c68be4ab5812401bd3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 25 Dec 2024 09:53:17 +0100 Subject: [PATCH 3/3] Add some opinionated ideas --- starlette/templating.py | 2 +- tests/test_staticfiles.py | 1 + tests/test_templates.py | 34 ---------------------------------- 3 files changed, 2 insertions(+), 35 deletions(-) diff --git a/starlette/templating.py b/starlette/templating.py index 78bfb8c26..6b01aac92 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -43,7 +43,7 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) - if "http.response.debug" in extensions: + if "http.response.debug" in extensions: # pragma: no branch await send( { "type": "http.response.debug", diff --git a/tests/test_staticfiles.py b/tests/test_staticfiles.py index 696332fee..b4f131719 100644 --- a/tests/test_staticfiles.py +++ b/tests/test_staticfiles.py @@ -225,6 +225,7 @@ def test_staticfiles_200_with_etag_mismatch(tmpdir: Path, test_client_factory: T client = test_client_factory(app) first_resp = client.get("/example.txt") assert first_resp.status_code == 200 + assert first_resp.headers["etag"] != '"123"' second_resp = client.get("/example.txt", headers={"if-none-match": '"123"'}) assert second_resp.status_code == 200 assert second_resp.content == b"" diff --git a/tests/test_templates.py b/tests/test_templates.py index ad30b8ccd..6b2080c17 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -15,7 +15,6 @@ from starlette.responses import Response from starlette.routing import Route from starlette.templating import Jinja2Templates -from starlette.types import Message, Receive, Scope, Send from tests.types import TestClientFactory @@ -309,36 +308,3 @@ def page(request: Request) -> Response: assert response.headers["x-key"] == "value" assert response.headers["content-type"] == "text/plain; charset=utf-8" spy.assert_called() - - -@pytest.mark.anyio -async def test_branch_coverage_http_response_debug_not_in_message(tmpdir: Path) -> None: - path = os.path.join(tmpdir, "index.html") - with open(path, "w") as file: - file.write("Hello") - - templates = Jinja2Templates(directory=str(tmpdir)) - - async def app(scope: Scope, receive: Receive, send: Send) -> None: - assert scope["type"] == "http" - request = Request(scope, receive) - response = templates.TemplateResponse(request, "index.html") - await response(scope, receive, send) - - async def receive() -> Message: - raise NotImplementedError("Should not be called!") - - async def send(message: Message) -> None: - if message["type"] == "http.response.start": - assert message["status"] == 200 - if message["type"] == "http.response.body": - assert message["body"] == b"Hello" - assert "http.response.debug" not in message["type"] - - scope = { - "type": "http", - "method": "GET", - "path": "/", - } - - await app(scope, receive, send)