Skip to content

Commit

Permalink
Add request argument to TemplateResponse (#2191)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <[email protected]>
  • Loading branch information
alex-oleshkevich and Kludex authored Jul 13, 2023
1 parent 11a3f12 commit 0308681
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 16 deletions.
2 changes: 1 addition & 1 deletion docs/templates.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles
templates = Jinja2Templates(directory='templates')

async def homepage(request):
return templates.TemplateResponse('index.html', {'request': request})
return templates.TemplateResponse(request, 'index.html')

routes = [
Route('/', endpoint=homepage),
Expand Down
76 changes: 72 additions & 4 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,87 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL:
def get_template(self, name: str) -> "jinja2.Template":
return self.env.get_template(name)

@typing.overload
def TemplateResponse(
self,
request: Request,
name: str,
context: dict,
context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
...

@typing.overload
def TemplateResponse(
self,
name: str,
context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
if "request" not in context:
raise ValueError('context must include a "request" key')
# Deprecated usage
...

request = typing.cast(Request, context["request"])
def TemplateResponse(
self, *args: typing.Any, **kwargs: typing.Any
) -> _TemplateResponse:
if args:
if isinstance(
args[0], str
): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
DeprecationWarning,
)

name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = (
args[2] if len(args) > 2 else kwargs.get("status_code", 200)
)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")

if "request" not in context:
raise ValueError('context must include a "request" key')
request = context["request"]
else: # the first argument is a request instance (new style)
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
status_code = (
args[3] if len(args) > 3 else kwargs.get("status_code", 200)
)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
else: # all arguments are kwargs
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
raise ValueError('context must include a "request" key')

context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
name = typing.cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
background = kwargs.get("background")

context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))

Expand Down
157 changes: 146 additions & 11 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from pathlib import Path
from unittest import mock

import jinja2
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
Expand All @@ -17,7 +19,7 @@ def test_templates(tmpdir, test_client_factory):
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

app = Starlette(
debug=True,
Expand All @@ -32,18 +34,12 @@ async def homepage(request):
assert set(response.context.keys()) == {"request"}


def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse("", {})


def test_calls_context_processors(tmp_path, test_client_factory):
path = tmp_path / "index.html"
path.write_text("<html>Hello {{ username }}</html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

def hello_world_processor(request):
return {"username": "World"}
Expand Down Expand Up @@ -72,7 +68,7 @@ def test_template_with_middleware(tmpdir, test_client_factory):
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
Expand All @@ -99,15 +95,15 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory):
template_a.write_text("<html><a href='{{ url_for('page_a') }}'></a> a</html>")

async def page_a(request):
return templates.TemplateResponse("template_a.html", {"request": request})
return templates.TemplateResponse(request, "template_a.html")

dir_b = tmp_path.resolve() / "b"
dir_b.mkdir()
template_b = dir_b / "template_b.html"
template_b.write_text("<html><a href='{{ url_for('page_b') }}'></a> b</html>")

async def page_b(request):
return templates.TemplateResponse("template_b.html", {"request": request})
return templates.TemplateResponse(request, "template_b.html")

app = Starlette(
debug=True,
Expand Down Expand Up @@ -158,3 +154,142 @@ def test_templates_with_environment(tmpdir):
def test_templates_with_environment_options_emit_warning(tmpdir):
with pytest.warns(DeprecationWarning):
Jinja2Templates(str(tmpdir), autoescape=True)


def test_templates_with_kwargs_only(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
request=request,
name="index.html",
context={"a": "b"},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()


def test_templates_with_kwargs_only_requires_request_in_context(tmpdir):
# MAINTAINERS: remove after 1.0

templates = Jinja2Templates(directory=str(tmpdir))
with pytest.warns(
DeprecationWarning,
match="requires the `request` argument",
):
with pytest.raises(ValueError):
templates.TemplateResponse(name="index.html", context={"a": "b"})


def test_templates_with_kwargs_only_warns_when_no_request_keyword(
tmpdir, test_client_factory
):
# MAINTAINERS: remove after 1.0

path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("Hello")

templates = Jinja2Templates(directory=str(tmpdir))

def page(request):
return templates.TemplateResponse(
name="index.html", context={"request": request}
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)

with pytest.warns(
DeprecationWarning,
match="requires the `request` argument",
):
client.get("/")


def test_templates_with_requires_request_in_context(tmpdir):
# MAINTAINERS: remove after 1.0
templates = Jinja2Templates(directory=str(tmpdir))
with pytest.warns(DeprecationWarning):
with pytest.raises(ValueError):
templates.TemplateResponse("index.html", context={})


def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
"index.html",
{"a": "b", "request": request},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
with pytest.warns(DeprecationWarning):
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()


def test_templates_when_first_argument_is_request(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
request,
"index.html",
{"a": "b"},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()

0 comments on commit 0308681

Please sign in to comment.