diff --git a/docs/templates.md b/docs/templates.md index ba9c4255b..01f343238 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles templates = Jinja2Templates(directory='templates') async def homepage(request): - return templates.TemplateResponse('index.html', {'request': request}) + return templates.TemplateResponse(request, 'index.html') routes = [ Route('/', endpoint=homepage), diff --git a/starlette/templating.py b/starlette/templating.py index ec9ca193d..abc845fed 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -140,19 +140,87 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL: def get_template(self, name: str) -> "jinja2.Template": return self.env.get_template(name) + @typing.overload def TemplateResponse( self, + request: Request, name: str, - context: dict, + context: typing.Optional[dict] = None, + status_code: int = 200, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, + ) -> _TemplateResponse: + ... + + @typing.overload + def TemplateResponse( + self, + name: str, + context: typing.Optional[dict] = None, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: - if "request" not in context: - raise ValueError('context must include a "request" key') + # Deprecated usage + ... - request = typing.cast(Request, context["request"]) + def TemplateResponse( + self, *args: typing.Any, **kwargs: typing.Any + ) -> _TemplateResponse: + if args: + if isinstance( + args[0], str + ): # the first argument is template name (old style) + warnings.warn( + "The `name` is not the first parameter anymore. " + "The first parameter should be the `Request` instance.\n" + 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501 + DeprecationWarning, + ) + + name = args[0] + context = args[1] if len(args) > 1 else kwargs.get("context", {}) + status_code = ( + args[2] if len(args) > 2 else kwargs.get("status_code", 200) + ) + headers = args[2] if len(args) > 2 else kwargs.get("headers") + media_type = args[3] if len(args) > 3 else kwargs.get("media_type") + background = args[4] if len(args) > 4 else kwargs.get("background") + + if "request" not in context: + raise ValueError('context must include a "request" key') + request = context["request"] + else: # the first argument is a request instance (new style) + request = args[0] + name = args[1] if len(args) > 1 else kwargs["name"] + context = args[2] if len(args) > 2 else kwargs.get("context", {}) + status_code = ( + args[3] if len(args) > 3 else kwargs.get("status_code", 200) + ) + headers = args[4] if len(args) > 4 else kwargs.get("headers") + media_type = args[5] if len(args) > 5 else kwargs.get("media_type") + background = args[6] if len(args) > 6 else kwargs.get("background") + else: # all arguments are kwargs + if "request" not in kwargs: + warnings.warn( + "The `TemplateResponse` now requires the `request` argument.\n" + 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501 + DeprecationWarning, + ) + if "request" not in kwargs.get("context", {}): + raise ValueError('context must include a "request" key') + + context = kwargs.get("context", {}) + request = kwargs.get("request", context.get("request")) + name = typing.cast(str, kwargs["name"]) + status_code = kwargs.get("status_code", 200) + headers = kwargs.get("headers") + media_type = kwargs.get("media_type") + background = kwargs.get("background") + + context.setdefault("request", request) for context_processor in self.context_processors: context.update(context_processor(request)) diff --git a/tests/test_templates.py b/tests/test_templates.py index 1f1909f4b..102f0bfcc 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,10 +1,12 @@ import os from pathlib import Path +from unittest import mock import jinja2 import pytest from starlette.applications import Starlette +from starlette.background import BackgroundTask from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Route @@ -17,7 +19,7 @@ def test_templates(tmpdir, test_client_factory): file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") app = Starlette( debug=True, @@ -32,18 +34,12 @@ async def homepage(request): assert set(response.context.keys()) == {"request"} -def test_template_response_requires_request(tmpdir): - templates = Jinja2Templates(str(tmpdir)) - with pytest.raises(ValueError): - templates.TemplateResponse("", {}) - - def test_calls_context_processors(tmp_path, test_client_factory): path = tmp_path / "index.html" path.write_text("<html>Hello {{ username }}</html>") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") def hello_world_processor(request): return {"username": "World"} @@ -72,7 +68,7 @@ def test_template_with_middleware(tmpdir, test_client_factory): file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -99,7 +95,7 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory): template_a.write_text("<html><a href='{{ url_for('page_a') }}'></a> a</html>") async def page_a(request): - return templates.TemplateResponse("template_a.html", {"request": request}) + return templates.TemplateResponse(request, "template_a.html") dir_b = tmp_path.resolve() / "b" dir_b.mkdir() @@ -107,7 +103,7 @@ async def page_a(request): template_b.write_text("<html><a href='{{ url_for('page_b') }}'></a> b</html>") async def page_b(request): - return templates.TemplateResponse("template_b.html", {"request": request}) + return templates.TemplateResponse(request, "template_b.html") app = Starlette( debug=True, @@ -158,3 +154,142 @@ def test_templates_with_environment(tmpdir): def test_templates_with_environment_options_emit_warning(tmpdir): with pytest.warns(DeprecationWarning): Jinja2Templates(str(tmpdir), autoescape=True) + + +def test_templates_with_kwargs_only(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.MagicMock() + + def page(request): + return templates.TemplateResponse( + request=request, + name="index.html", + context={"a": "b"}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called() + + +def test_templates_with_kwargs_only_requires_request_in_context(tmpdir): + # MAINTAINERS: remove after 1.0 + + templates = Jinja2Templates(directory=str(tmpdir)) + with pytest.warns( + DeprecationWarning, + match="requires the `request` argument", + ): + with pytest.raises(ValueError): + templates.TemplateResponse(name="index.html", context={"a": "b"}) + + +def test_templates_with_kwargs_only_warns_when_no_request_keyword( + tmpdir, test_client_factory +): + # MAINTAINERS: remove after 1.0 + + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("Hello") + + templates = Jinja2Templates(directory=str(tmpdir)) + + def page(request): + return templates.TemplateResponse( + name="index.html", context={"request": request} + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + + with pytest.warns( + DeprecationWarning, + match="requires the `request` argument", + ): + client.get("/") + + +def test_templates_with_requires_request_in_context(tmpdir): + # MAINTAINERS: remove after 1.0 + templates = Jinja2Templates(directory=str(tmpdir)) + with pytest.warns(DeprecationWarning): + with pytest.raises(ValueError): + templates.TemplateResponse("index.html", context={}) + + +def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.MagicMock() + + def page(request): + return templates.TemplateResponse( + "index.html", + {"a": "b", "request": request}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + with pytest.warns(DeprecationWarning): + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called() + + +def test_templates_when_first_argument_is_request(tmpdir, test_client_factory): + # MAINTAINERS: remove after 1.0 + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("value: {{ a }}") + templates = Jinja2Templates(directory=str(tmpdir)) + + spy = mock.MagicMock() + + def page(request): + return templates.TemplateResponse( + request, + "index.html", + {"a": "b"}, + status_code=201, + headers={"x-key": "value"}, + media_type="text/plain", + background=BackgroundTask(func=spy), + ) + + app = Starlette(routes=[Route("/", page)]) + client = test_client_factory(app) + response = client.get("/") + + assert response.text == "value: b" # context was rendered + assert response.status_code == 201 + assert response.headers["x-key"] == "value" + assert response.headers["content-type"] == "text/plain; charset=utf-8" + spy.assert_called()