Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ combine-as-imports = true
[tool.mypy]
disallow_untyped_defs = true
ignore_missing_imports = true
no_implicit_optional = true
show_error_codes = true

[[tool.mypy.overrides]]
Expand Down
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Testing
black==23.3.0
coverage==7.2.5
coverage==7.2.7
importlib-metadata==6.6.0
mypy==1.3.0
ruff==0.0.263
typing_extensions==4.5.0
mypy==1.4.1
ruff==0.0.275
typing_extensions==4.7.0
types-contextvars==2.4.7.2
types-PyYAML==6.0.12.10
types-dataclasses==0.6.6
Expand All @@ -16,7 +16,7 @@ trio==0.21.0

# Documentation
mkdocs==1.4.3
mkdocs-material==9.1.15
mkdocs-material==9.1.17
mkautodoc==0.2.0

# Packaging
Expand Down
2 changes: 2 additions & 0 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
body = message.get("body", b"")
if body:
yield body
if not message.get("more_body", False):
break

if app_exc is not None:
raise app_exc
Expand Down
6 changes: 6 additions & 0 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,12 @@ def __init__(
"See more about it on https://www.starlette.io/lifespan/.",
DeprecationWarning,
)
if lifespan:
warnings.warn(
"The `lifespan` parameter cannot be used with `on_startup` or "
"`on_shutdown`. Both `on_startup` and `on_shutdown` will be "
"ignored."
)

if lifespan is None:
self.lifespan_context: Lifespan = _DefaultLifespan(self)
Expand Down
67 changes: 0 additions & 67 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL:
def get_template(self, name: str) -> "jinja2.Template":
return self.env.get_template(name)

@typing.overload
def TemplateResponse(
self,
request: Request,
Expand All @@ -151,72 +150,6 @@ def TemplateResponse(
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
...

@typing.overload
def TemplateResponse(
self,
name: str,
context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
# Deprecated usage
...

def TemplateResponse(
self, *args: typing.Any, **kwargs: typing.Any
) -> _TemplateResponse:
if args:
if isinstance(
args[0], str
): # the first argument is template name (old style)
warnings.warn(
"Argument 1 of TemplateResponse must be a Request instance.",
DeprecationWarning,
)

name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = (
args[2] if len(args) > 2 else kwargs.get("status_code", 200)
)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")

if "request" not in context:
raise ValueError('context must include a "request" key')
request = context["request"]
else: # the first argument is a request instance (new style)
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context")
status_code = (
args[3] if len(args) > 3 else kwargs.get("status_code", 200)
)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
else: # all arguments are kwargs
if "request" not in kwargs:
warnings.warn(
"TemplateResponse requires `request` keyword argument.",
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
raise ValueError('context must include a "request" key')

context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
name = typing.cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
background = kwargs.get("background")

context = context or {}
context.setdefault("request", request)
for context_processor in self.context_processors:
Expand Down
3 changes: 2 additions & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def __init__(
backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
cookies: httpx._client.CookieTypes = None,
headers: typing.Dict[str, str] = None,
follow_redirects: bool = True,
) -> None:
self.async_backend = _AsyncBackend(
backend=backend, backend_options=backend_options or {}
Expand Down Expand Up @@ -409,7 +410,7 @@ def __init__(
base_url=base_url,
headers=headers,
transport=transport,
follow_redirects=True,
follow_redirects=follow_redirects,
cookies=cookies,
)

Expand Down
65 changes: 65 additions & 0 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,71 @@ async def send(message):
assert background_task_run.is_set()


@pytest.mark.anyio
async def test_do_not_block_on_background_tasks():
request_body_sent = False
response_complete = anyio.Event()
events: List[Union[str, Message]] = []

async def sleep_and_set():
events.append("Background task started")
await anyio.sleep(0.1)
events.append("Background task finished")

async def endpoint_with_background_task(_):
return PlainTextResponse(
content="Hello", background=BackgroundTask(sleep_and_set)
)

async def passthrough(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
return await call_next(request)

app = Starlette(
middleware=[Middleware(BaseHTTPMiddleware, dispatch=passthrough)],
routes=[Route("/", endpoint_with_background_task)],
)

scope = {
"type": "http",
"version": "3",
"method": "GET",
"path": "/",
}

async def receive() -> Message:
nonlocal request_body_sent
if not request_body_sent:
request_body_sent = True
return {"type": "http.request", "body": b"", "more_body": False}
await response_complete.wait()
return {"type": "http.disconnect"}

async def send(message: Message):
if message["type"] == "http.response.body":
events.append(message)
if not message.get("more_body", False):
response_complete.set()

async with anyio.create_task_group() as tg:
tg.start_soon(app, scope, receive, send)
tg.start_soon(app, scope, receive, send)

# Without the fix, the background tasks would start and finish before the
# last http.response.body is sent.
assert events == [
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
{"body": b"Hello", "more_body": True, "type": "http.response.body"},
{"body": b"", "more_body": False, "type": "http.response.body"},
"Background task started",
"Background task started",
"Background task finished",
"Background task finished",
]


@pytest.mark.anyio
async def test_run_context_manager_exit_even_if_client_disconnects():
# test for https://github.com/encode/starlette/issues/1678#issuecomment-1172916042
Expand Down
44 changes: 44 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,50 @@ async def run_shutdown():
assert shutdown_complete


def test_lifespan_with_on_events(test_client_factory: typing.Callable[..., TestClient]):
lifespan_called = False
startup_called = False
shutdown_called = False

@contextlib.asynccontextmanager
async def lifespan(app: Starlette):
nonlocal lifespan_called
lifespan_called = True
yield

# We do not expected, neither of run_startup nor run_shutdown to be called
# we thus mark them as #pragma: no cover, to fulfill test coverage
def run_startup(): # pragma: no cover
nonlocal startup_called
startup_called = True

def run_shutdown(): # pragma: no cover
nonlocal shutdown_called
shutdown_called = True

with pytest.warns(
UserWarning,
match=(
"The `lifespan` parameter cannot be used with `on_startup` or `on_shutdown`." # noqa: E501
),
):
app = Router(
on_startup=[run_startup], on_shutdown=[run_shutdown], lifespan=lifespan
)

assert not lifespan_called
assert not startup_called
assert not shutdown_called

# Triggers the lifespan events
with test_client_factory(app):
...

assert lifespan_called
assert not startup_called
assert not shutdown_called


def test_lifespan_sync(test_client_factory):
startup_complete = False
shutdown_complete = False
Expand Down
Loading