diff --git a/docs/templates.md b/docs/templates.md index ba9c4255b..01f343238 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles templates = Jinja2Templates(directory='templates') async def homepage(request): - return templates.TemplateResponse('index.html', {'request': request}) + return templates.TemplateResponse(request, 'index.html') routes = [ Route('/', endpoint=homepage), diff --git a/starlette/templating.py b/starlette/templating.py index ec9ca193d..a96506608 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -142,17 +142,16 @@ def get_template(self, name: str) -> "jinja2.Template": def TemplateResponse( self, + request: Request, name: str, - context: dict, + context: typing.Optional[dict] = None, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, ) -> _TemplateResponse: - if "request" not in context: - raise ValueError('context must include a "request" key') - - request = typing.cast(Request, context["request"]) + context = context or {} + context.setdefault("request", request) for context_processor in self.context_processors: context.update(context_processor(request)) diff --git a/tests/test_templates.py b/tests/test_templates.py index 1f1909f4b..566718d18 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -17,7 +17,7 @@ def test_templates(tmpdir, test_client_factory): file.write("Hello, world") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") app = Starlette( debug=True, @@ -32,18 +32,12 @@ async def homepage(request): assert set(response.context.keys()) == {"request"} -def test_template_response_requires_request(tmpdir): - templates = Jinja2Templates(str(tmpdir)) - with pytest.raises(ValueError): - templates.TemplateResponse("", {}) - - def test_calls_context_processors(tmp_path, test_client_factory): path = tmp_path / "index.html" path.write_text("Hello {{ username }}") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") def hello_world_processor(request): return {"username": "World"} @@ -72,7 +66,7 @@ def test_template_with_middleware(tmpdir, test_client_factory): file.write("Hello, world") async def homepage(request): - return templates.TemplateResponse("index.html", {"request": request}) + return templates.TemplateResponse(request, "index.html") class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -99,7 +93,7 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory): template_a.write_text(" a") async def page_a(request): - return templates.TemplateResponse("template_a.html", {"request": request}) + return templates.TemplateResponse(request, "template_a.html") dir_b = tmp_path.resolve() / "b" dir_b.mkdir() @@ -107,7 +101,7 @@ async def page_a(request): template_b.write_text(" b") async def page_b(request): - return templates.TemplateResponse("template_b.html", {"request": request}) + return templates.TemplateResponse(request, "template_b.html") app = Starlette( debug=True,