diff --git a/docs/templates.md b/docs/templates.md
index ba9c4255b..01f343238 100644
--- a/docs/templates.md
+++ b/docs/templates.md
@@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles
templates = Jinja2Templates(directory='templates')
async def homepage(request):
- return templates.TemplateResponse('index.html', {'request': request})
+ return templates.TemplateResponse(request, 'index.html')
routes = [
Route('/', endpoint=homepage),
diff --git a/starlette/templating.py b/starlette/templating.py
index ec9ca193d..abc845fed 100644
--- a/starlette/templating.py
+++ b/starlette/templating.py
@@ -140,19 +140,87 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL:
def get_template(self, name: str) -> "jinja2.Template":
return self.env.get_template(name)
+ @typing.overload
def TemplateResponse(
self,
+ request: Request,
name: str,
- context: dict,
+ context: typing.Optional[dict] = None,
+ status_code: int = 200,
+ headers: typing.Optional[typing.Mapping[str, str]] = None,
+ media_type: typing.Optional[str] = None,
+ background: typing.Optional[BackgroundTask] = None,
+ ) -> _TemplateResponse:
+ ...
+
+ @typing.overload
+ def TemplateResponse(
+ self,
+ name: str,
+ context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
- if "request" not in context:
- raise ValueError('context must include a "request" key')
+ # Deprecated usage
+ ...
- request = typing.cast(Request, context["request"])
+ def TemplateResponse(
+ self, *args: typing.Any, **kwargs: typing.Any
+ ) -> _TemplateResponse:
+ if args:
+ if isinstance(
+ args[0], str
+ ): # the first argument is template name (old style)
+ warnings.warn(
+ "The `name` is not the first parameter anymore. "
+ "The first parameter should be the `Request` instance.\n"
+ 'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
+ DeprecationWarning,
+ )
+
+ name = args[0]
+ context = args[1] if len(args) > 1 else kwargs.get("context", {})
+ status_code = (
+ args[2] if len(args) > 2 else kwargs.get("status_code", 200)
+ )
+ headers = args[2] if len(args) > 2 else kwargs.get("headers")
+ media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
+ background = args[4] if len(args) > 4 else kwargs.get("background")
+
+ if "request" not in context:
+ raise ValueError('context must include a "request" key')
+ request = context["request"]
+ else: # the first argument is a request instance (new style)
+ request = args[0]
+ name = args[1] if len(args) > 1 else kwargs["name"]
+ context = args[2] if len(args) > 2 else kwargs.get("context", {})
+ status_code = (
+ args[3] if len(args) > 3 else kwargs.get("status_code", 200)
+ )
+ headers = args[4] if len(args) > 4 else kwargs.get("headers")
+ media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
+ background = args[6] if len(args) > 6 else kwargs.get("background")
+ else: # all arguments are kwargs
+ if "request" not in kwargs:
+ warnings.warn(
+ "The `TemplateResponse` now requires the `request` argument.\n"
+ 'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
+ DeprecationWarning,
+ )
+ if "request" not in kwargs.get("context", {}):
+ raise ValueError('context must include a "request" key')
+
+ context = kwargs.get("context", {})
+ request = kwargs.get("request", context.get("request"))
+ name = typing.cast(str, kwargs["name"])
+ status_code = kwargs.get("status_code", 200)
+ headers = kwargs.get("headers")
+ media_type = kwargs.get("media_type")
+ background = kwargs.get("background")
+
+ context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))
diff --git a/tests/test_templates.py b/tests/test_templates.py
index 1f1909f4b..102f0bfcc 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -1,10 +1,12 @@
import os
from pathlib import Path
+from unittest import mock
import jinja2
import pytest
from starlette.applications import Starlette
+from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
@@ -17,7 +19,7 @@ def test_templates(tmpdir, test_client_factory):
file.write("Hello, world")
async def homepage(request):
- return templates.TemplateResponse("index.html", {"request": request})
+ return templates.TemplateResponse(request, "index.html")
app = Starlette(
debug=True,
@@ -32,18 +34,12 @@ async def homepage(request):
assert set(response.context.keys()) == {"request"}
-def test_template_response_requires_request(tmpdir):
- templates = Jinja2Templates(str(tmpdir))
- with pytest.raises(ValueError):
- templates.TemplateResponse("", {})
-
-
def test_calls_context_processors(tmp_path, test_client_factory):
path = tmp_path / "index.html"
path.write_text("Hello {{ username }}")
async def homepage(request):
- return templates.TemplateResponse("index.html", {"request": request})
+ return templates.TemplateResponse(request, "index.html")
def hello_world_processor(request):
return {"username": "World"}
@@ -72,7 +68,7 @@ def test_template_with_middleware(tmpdir, test_client_factory):
file.write("Hello, world")
async def homepage(request):
- return templates.TemplateResponse("index.html", {"request": request})
+ return templates.TemplateResponse(request, "index.html")
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
@@ -99,7 +95,7 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory):
template_a.write_text(" a")
async def page_a(request):
- return templates.TemplateResponse("template_a.html", {"request": request})
+ return templates.TemplateResponse(request, "template_a.html")
dir_b = tmp_path.resolve() / "b"
dir_b.mkdir()
@@ -107,7 +103,7 @@ async def page_a(request):
template_b.write_text(" b")
async def page_b(request):
- return templates.TemplateResponse("template_b.html", {"request": request})
+ return templates.TemplateResponse(request, "template_b.html")
app = Starlette(
debug=True,
@@ -158,3 +154,142 @@ def test_templates_with_environment(tmpdir):
def test_templates_with_environment_options_emit_warning(tmpdir):
with pytest.warns(DeprecationWarning):
Jinja2Templates(str(tmpdir), autoescape=True)
+
+
+def test_templates_with_kwargs_only(tmpdir, test_client_factory):
+ # MAINTAINERS: remove after 1.0
+ path = os.path.join(tmpdir, "index.html")
+ with open(path, "w") as file:
+ file.write("value: {{ a }}")
+ templates = Jinja2Templates(directory=str(tmpdir))
+
+ spy = mock.MagicMock()
+
+ def page(request):
+ return templates.TemplateResponse(
+ request=request,
+ name="index.html",
+ context={"a": "b"},
+ status_code=201,
+ headers={"x-key": "value"},
+ media_type="text/plain",
+ background=BackgroundTask(func=spy),
+ )
+
+ app = Starlette(routes=[Route("/", page)])
+ client = test_client_factory(app)
+ response = client.get("/")
+
+ assert response.text == "value: b" # context was rendered
+ assert response.status_code == 201
+ assert response.headers["x-key"] == "value"
+ assert response.headers["content-type"] == "text/plain; charset=utf-8"
+ spy.assert_called()
+
+
+def test_templates_with_kwargs_only_requires_request_in_context(tmpdir):
+ # MAINTAINERS: remove after 1.0
+
+ templates = Jinja2Templates(directory=str(tmpdir))
+ with pytest.warns(
+ DeprecationWarning,
+ match="requires the `request` argument",
+ ):
+ with pytest.raises(ValueError):
+ templates.TemplateResponse(name="index.html", context={"a": "b"})
+
+
+def test_templates_with_kwargs_only_warns_when_no_request_keyword(
+ tmpdir, test_client_factory
+):
+ # MAINTAINERS: remove after 1.0
+
+ path = os.path.join(tmpdir, "index.html")
+ with open(path, "w") as file:
+ file.write("Hello")
+
+ templates = Jinja2Templates(directory=str(tmpdir))
+
+ def page(request):
+ return templates.TemplateResponse(
+ name="index.html", context={"request": request}
+ )
+
+ app = Starlette(routes=[Route("/", page)])
+ client = test_client_factory(app)
+
+ with pytest.warns(
+ DeprecationWarning,
+ match="requires the `request` argument",
+ ):
+ client.get("/")
+
+
+def test_templates_with_requires_request_in_context(tmpdir):
+ # MAINTAINERS: remove after 1.0
+ templates = Jinja2Templates(directory=str(tmpdir))
+ with pytest.warns(DeprecationWarning):
+ with pytest.raises(ValueError):
+ templates.TemplateResponse("index.html", context={})
+
+
+def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory):
+ # MAINTAINERS: remove after 1.0
+ path = os.path.join(tmpdir, "index.html")
+ with open(path, "w") as file:
+ file.write("value: {{ a }}")
+ templates = Jinja2Templates(directory=str(tmpdir))
+
+ spy = mock.MagicMock()
+
+ def page(request):
+ return templates.TemplateResponse(
+ "index.html",
+ {"a": "b", "request": request},
+ status_code=201,
+ headers={"x-key": "value"},
+ media_type="text/plain",
+ background=BackgroundTask(func=spy),
+ )
+
+ app = Starlette(routes=[Route("/", page)])
+ client = test_client_factory(app)
+ with pytest.warns(DeprecationWarning):
+ response = client.get("/")
+
+ assert response.text == "value: b" # context was rendered
+ assert response.status_code == 201
+ assert response.headers["x-key"] == "value"
+ assert response.headers["content-type"] == "text/plain; charset=utf-8"
+ spy.assert_called()
+
+
+def test_templates_when_first_argument_is_request(tmpdir, test_client_factory):
+ # MAINTAINERS: remove after 1.0
+ path = os.path.join(tmpdir, "index.html")
+ with open(path, "w") as file:
+ file.write("value: {{ a }}")
+ templates = Jinja2Templates(directory=str(tmpdir))
+
+ spy = mock.MagicMock()
+
+ def page(request):
+ return templates.TemplateResponse(
+ request,
+ "index.html",
+ {"a": "b"},
+ status_code=201,
+ headers={"x-key": "value"},
+ media_type="text/plain",
+ background=BackgroundTask(func=spy),
+ )
+
+ app = Starlette(routes=[Route("/", page)])
+ client = test_client_factory(app)
+ response = client.get("/")
+
+ assert response.text == "value: b" # context was rendered
+ assert response.status_code == 201
+ assert response.headers["x-key"] == "value"
+ assert response.headers["content-type"] == "text/plain; charset=utf-8"
+ spy.assert_called()