From 2af4d8225114e0fcf806376bd928e918e95c3f68 Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Thu, 22 Jun 2023 17:05:22 +0200 Subject: [PATCH 1/5] make request a required argument of TemplateResponse --- docs/templates.md | 2 +- starlette/templating.py | 9 ++++----- tests/test_templates.py | 16 +++++----------- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/docs/templates.md b/docs/templates.md index ba9c4255b..01f343238 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles templates = Jinja2Templates(directory='templates') async def homepage(request): - return templates.TemplateResponse('index.html', {'request': request}) + return templates.TemplateResponse(request, 'index.html') routes = [ Route('/', endpoint=homepage), diff --git a/starlette/templating.py b/starlette/templating.py index ec9ca193d..a96506608 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -142,17 +142,16 @@ def get_template(self, name: str) -> "jinja2.Template": def TemplateResponse( self, + request: Request, name: str, - context: dict, + context: typing.Optional[dict] = None, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: - if "request" not in context: - raise ValueError('context must include a "request" key') - - request = typing.cast(Request, context["request"]) + context = context or {} + context.setdefault("request", request) for context_processor in self.context_processors: context.update(context_processor(request)) diff --git a/tests/test_templates.py b/tests/test_templates.py index 1f1909f4b..566718d18 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -17,7 +17,7 @@ def test_templates(tmpdir, test_client_factory): file.write("Hello, world") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") app = Starlette( debug=True, @@ -32,18 +32,12 @@ async def homepage(request): assert set(response.context.keys()) == {"request"} -def test_template_response_requires_request(tmpdir): - templates = Jinja2Templates(str(tmpdir)) - with pytest.raises(ValueError): - templates.TemplateResponse("", {}) - - def test_calls_context_processors(tmp_path, test_client_factory): path = tmp_path / "index.html" path.write_text("Hello {{ username }}") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") def hello_world_processor(request): return {"username": "World"} @@ -72,7 +66,7 @@ def test_template_with_middleware(tmpdir, test_client_factory): file.write("Hello, world") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -99,7 +93,7 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory): template_a.write_text(" a") async def page_a(request): - return templates.TemplateResponse("template_a.html", {"request": request}) + return templates.TemplateResponse(request, "template_a.html") dir_b = tmp_path.resolve() / "b" dir_b.mkdir() @@ -107,7 +101,7 @@ async def page_a(request): template_b.write_text(" b") async def page_b(request): - return templates.TemplateResponse("template_b.html", {"request": request}) + return templates.TemplateResponse(request, "template_b.html") app = Starlette( debug=True, From 6847bc5b5bea174be5d764cfb69709f92aebdf58 Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Wed, 28 Jun 2023 20:44:34 +0200 Subject: [PATCH 2/5] add compatibility between old and new argument set of TemplateResponse --- starlette/templating.py | 67 +++++++++++++++++++ tests/test_templates.py | 141 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+) diff --git a/starlette/templating.py b/starlette/templating.py index a96506608..4c409fc39 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -140,6 +140,7 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL: def get_template(self, name: str) -> "jinja2.Template": return self.env.get_template(name) + @typing.overload def TemplateResponse( self, request: Request, @@ -150,6 +151,72 @@ def TemplateResponse( media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: + ... + + @typing.overload + def TemplateResponse( + self, + name: str, + context: typing.Optional[dict] = None, + status_code: int = 200, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, + ) -> _TemplateResponse: + # Deprecated usage + ... + + def TemplateResponse( + self, *args: typing.Any, **kwargs: typing.Any + ) -> _TemplateResponse: + if args: + if isinstance( + args[0], str + ): # the first argument is template name (old style) + warnings.warn( + "Argument 1 of TemplateResponse must be a Request instance.", + DeprecationWarning, + ) + + name = args[0] + context = args[1] if len(args) > 1 else kwargs.get("context", {}) + status_code = ( + args[2] if len(args) > 2 else kwargs.get("status_code", 200) + ) + headers = args[2] if len(args) > 2 else kwargs.get("headers") + media_type = args[3] if len(args) > 3 else kwargs.get("media_type") + background = args[4] if len(args) > 4 else kwargs.get("background") + + if "request" not in context: + raise ValueError('context must include a "request" key') + request = context["request"] + else: # the first argument is a request instance (new style) + request = args[0] + name = args[1] if len(args) > 1 else kwargs["name"] + context = args[2] if len(args) > 2 else kwargs.get("context") + status_code = ( + args[3] if len(args) > 3 else kwargs.get("status_code", 200) + ) + headers = args[4] if len(args) > 4 else kwargs.get("headers") + media_type = args[5] if len(args) > 5 else kwargs.get("media_type") + background = args[6] if len(args) > 6 else kwargs.get("background") + else: # all arguments are kwargs + if "request" not in kwargs: + warnings.warn( + "TemplateResponse requires `request` keyword argument.", + DeprecationWarning, + ) + if "request" not in kwargs.get("context", {}): + raise ValueError('context must include a "request" key') + + context = kwargs.get("context", {}) + request = kwargs.get("request", context.get("request")) + name = typing.cast(str, kwargs["name"]) + status_code = kwargs.get("status_code", 200) + headers = kwargs.get("headers") + media_type = kwargs.get("media_type") + background = kwargs.get("background") + context = context or {} context.setdefault("request", request) for context_processor in self.context_processors: diff --git a/tests/test_templates.py b/tests/test_templates.py index 566718d18..5e29ccdd6 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,10 +1,12 @@ import os from pathlib import Path +from unittest import mock import jinja2 import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Route @@ -152,3 +154,142 @@ def test_templates_with_environment(tmpdir): def test_templates_with_environment_options_emit_warning(tmpdir): with pytest.warns(DeprecationWarning): Jinja2Templates(str(tmpdir), autoescape=True) + + +def test_templates_with_kwargs_only(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.AsyncMock() + + def page(request): + return templates.TemplateResponse( + request=request, + name="index.html", + context={"a": "b"}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called() + + +def test_templates_with_kwargs_only_requires_request_in_context(tmpdir): + # MAINTAINERS: remove after 1.0 + + templates = Jinja2Templates(directory=str(tmpdir)) + with pytest.warns( + DeprecationWarning, + match="TemplateResponse requires `request` keyword argument.", + ): + with pytest.raises(ValueError): + templates.TemplateResponse(name="index.html", context={"a": "b"}) + + +def test_templates_with_kwargs_only_warns_when_no_request_keyword( + tmpdir, test_client_factory +): + # MAINTAINERS: remove after 1.0 + + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("Hello") + + templates = Jinja2Templates(directory=str(tmpdir)) + + def page(request): + return templates.TemplateResponse( + name="index.html", context={"request": request} + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + + with pytest.warns( + DeprecationWarning, + match="TemplateResponse requires `request` keyword argument.", + ): + client.get("/") + + +def test_templates_with_requires_request_in_context(tmpdir): + # MAINTAINERS: remove after 1.0 + templates = Jinja2Templates(directory=str(tmpdir)) + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + templates.TemplateResponse("index.html", context={}) + + +def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.AsyncMock() + + def page(request): + return templates.TemplateResponse( + "index.html", + {"a": "b", "request": request}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + with pytest.warns(DeprecationWarning): + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called() + + +def test_templates_warns_when_first_argument_is_request(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.AsyncMock() + + def page(request): + return templates.TemplateResponse( + request, + "index.html", + {"a": "b"}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called() From dda3bc07d25eb85cd7e9ea2a63c6bdd1526df164 Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Wed, 28 Jun 2023 20:51:52 +0200 Subject: [PATCH 3/5] rename test --- tests/test_templates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_templates.py b/tests/test_templates.py index 5e29ccdd6..0aff294f2 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -264,7 +264,7 @@ def page(request): spy.assert_called() -def test_templates_warns_when_first_argument_is_request(tmpdir, test_client_factory): +def test_templates_when_first_argument_is_request(tmpdir, test_client_factory): # MAINTAINERS: remove after 1.0 path = os.path.join(tmpdir, "index.html") with open(path, "w") as file: From 527f4c3aac1929197fe406867bc0f53f8bc49d0c Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Wed, 28 Jun 2023 20:52:46 +0200 Subject: [PATCH 4/5] replace AsyncMock with MagicMock --- tests/test_templates.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_templates.py b/tests/test_templates.py index 0aff294f2..db0b2bb63 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -163,7 +163,7 @@ def test_templates_with_kwargs_only(tmpdir, test_client_factory): file.write("value: {{ a }}") templates = Jinja2Templates(directory=str(tmpdir)) - spy = mock.AsyncMock() + spy = mock.MagicMock() def page(request): return templates.TemplateResponse( @@ -240,7 +240,7 @@ def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_f file.write("value: {{ a }}") templates = Jinja2Templates(directory=str(tmpdir)) - spy = mock.AsyncMock() + spy = mock.MagicMock() def page(request): return templates.TemplateResponse( @@ -271,7 +271,7 @@ def test_templates_when_first_argument_is_request(tmpdir, test_client_factory): file.write("value: {{ a }}") templates = Jinja2Templates(directory=str(tmpdir)) - spy = mock.AsyncMock() + spy = mock.MagicMock() def page(request): return templates.TemplateResponse( From f403d89dac006ffaa1ee8c4a30c13aa6500f3b0f Mon Sep 17 00:00:00 2001 From: "alex.oleshkevich" Date: Fri, 30 Jun 2023 11:55:06 +0200 Subject: [PATCH 5/5] remove intemediate code --- starlette/templating.py | 67 ------------------- tests/test_templates.py | 141 ---------------------------------------- 2 files changed, 208 deletions(-) diff --git a/starlette/templating.py b/starlette/templating.py index 4c409fc39..a96506608 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -140,7 +140,6 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL: def get_template(self, name: str) -> "jinja2.Template": return self.env.get_template(name) - @typing.overload def TemplateResponse( self, request: Request, @@ -151,72 +150,6 @@ def TemplateResponse( media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: - ... - - @typing.overload - def TemplateResponse( - self, - name: str, - context: typing.Optional[dict] = None, - status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, - ) -> _TemplateResponse: - # Deprecated usage - ... - - def TemplateResponse( - self, *args: typing.Any, **kwargs: typing.Any - ) -> _TemplateResponse: - if args: - if isinstance( - args[0], str - ): # the first argument is template name (old style) - warnings.warn( - "Argument 1 of TemplateResponse must be a Request instance.", - DeprecationWarning, - ) - - name = args[0] - context = args[1] if len(args) > 1 else kwargs.get("context", {}) - status_code = ( - args[2] if len(args) > 2 else kwargs.get("status_code", 200) - ) - headers = args[2] if len(args) > 2 else kwargs.get("headers") - media_type = args[3] if len(args) > 3 else kwargs.get("media_type") - background = args[4] if len(args) > 4 else kwargs.get("background") - - if "request" not in context: - raise ValueError('context must include a "request" key') - request = context["request"] - else: # the first argument is a request instance (new style) - request = args[0] - name = args[1] if len(args) > 1 else kwargs["name"] - context = args[2] if len(args) > 2 else kwargs.get("context") - status_code = ( - args[3] if len(args) > 3 else kwargs.get("status_code", 200) - ) - headers = args[4] if len(args) > 4 else kwargs.get("headers") - media_type = args[5] if len(args) > 5 else kwargs.get("media_type") - background = args[6] if len(args) > 6 else kwargs.get("background") - else: # all arguments are kwargs - if "request" not in kwargs: - warnings.warn( - "TemplateResponse requires `request` keyword argument.", - DeprecationWarning, - ) - if "request" not in kwargs.get("context", {}): - raise ValueError('context must include a "request" key') - - context = kwargs.get("context", {}) - request = kwargs.get("request", context.get("request")) - name = typing.cast(str, kwargs["name"]) - status_code = kwargs.get("status_code", 200) - headers = kwargs.get("headers") - media_type = kwargs.get("media_type") - background = kwargs.get("background") - context = context or {} context.setdefault("request", request) for context_processor in self.context_processors: diff --git a/tests/test_templates.py b/tests/test_templates.py index db0b2bb63..566718d18 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,12 +1,10 @@ import os from pathlib import Path -from unittest import mock import jinja2 import pytest from starlette.applications import Starlette -from starlette.background import BackgroundTask from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Route @@ -154,142 +152,3 @@ def test_templates_with_environment(tmpdir): def test_templates_with_environment_options_emit_warning(tmpdir): with pytest.warns(DeprecationWarning): Jinja2Templates(str(tmpdir), autoescape=True) - - -def test_templates_with_kwargs_only(tmpdir, test_client_factory): - # MAINTAINERS: remove after 1.0 - path = os.path.join(tmpdir, "index.html") - with open(path, "w") as file: - file.write("value: {{ a }}") - templates = Jinja2Templates(directory=str(tmpdir)) - - spy = mock.MagicMock() - - def page(request): - return templates.TemplateResponse( - request=request, - name="index.html", - context={"a": "b"}, - status_code=201, - headers={"x-key": "value"}, - media_type="text/plain", - background=BackgroundTask(func=spy), - ) - - app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) - response = client.get("/") - - assert response.text == "value: b" # context was rendered - assert response.status_code == 201 - assert response.headers["x-key"] == "value" - assert response.headers["content-type"] == "text/plain; charset=utf-8" - spy.assert_called() - - -def test_templates_with_kwargs_only_requires_request_in_context(tmpdir): - # MAINTAINERS: remove after 1.0 - - templates = Jinja2Templates(directory=str(tmpdir)) - with pytest.warns( - DeprecationWarning, - match="TemplateResponse requires `request` keyword argument.", - ): - with pytest.raises(ValueError): - templates.TemplateResponse(name="index.html", context={"a": "b"}) - - -def test_templates_with_kwargs_only_warns_when_no_request_keyword( - tmpdir, test_client_factory -): - # MAINTAINERS: remove after 1.0 - - path = os.path.join(tmpdir, "index.html") - with open(path, "w") as file: - file.write("Hello") - - templates = Jinja2Templates(directory=str(tmpdir)) - - def page(request): - return templates.TemplateResponse( - name="index.html", context={"request": request} - ) - - app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) - - with pytest.warns( - DeprecationWarning, - match="TemplateResponse requires `request` keyword argument.", - ): - client.get("/") - - -def test_templates_with_requires_request_in_context(tmpdir): - # MAINTAINERS: remove after 1.0 - templates = Jinja2Templates(directory=str(tmpdir)) - with pytest.warns(DeprecationWarning): - with pytest.raises(ValueError): - templates.TemplateResponse("index.html", context={}) - - -def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory): - # MAINTAINERS: remove after 1.0 - path = os.path.join(tmpdir, "index.html") - with open(path, "w") as file: - file.write("value: {{ a }}") - templates = Jinja2Templates(directory=str(tmpdir)) - - spy = mock.MagicMock() - - def page(request): - return templates.TemplateResponse( - "index.html", - {"a": "b", "request": request}, - status_code=201, - headers={"x-key": "value"}, - media_type="text/plain", - background=BackgroundTask(func=spy), - ) - - app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) - with pytest.warns(DeprecationWarning): - response = client.get("/") - - assert response.text == "value: b" # context was rendered - assert response.status_code == 201 - assert response.headers["x-key"] == "value" - assert response.headers["content-type"] == "text/plain; charset=utf-8" - spy.assert_called() - - -def test_templates_when_first_argument_is_request(tmpdir, test_client_factory): - # MAINTAINERS: remove after 1.0 - path = os.path.join(tmpdir, "index.html") - with open(path, "w") as file: - file.write("value: {{ a }}") - templates = Jinja2Templates(directory=str(tmpdir)) - - spy = mock.MagicMock() - - def page(request): - return templates.TemplateResponse( - request, - "index.html", - {"a": "b"}, - status_code=201, - headers={"x-key": "value"}, - media_type="text/plain", - background=BackgroundTask(func=spy), - ) - - app = Starlette(routes=[Route("/", page)]) - client = test_client_factory(app) - response = client.get("/") - - assert response.text == "value: b" # context was rendered - assert response.status_code == 201 - assert response.headers["x-key"] == "value" - assert response.headers["content-type"] == "text/plain; charset=utf-8" - spy.assert_called()