Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions docs/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,21 @@ Analogously, the single lifespan `asynccontextmanager` can be used.

```python
import contextlib
from typing import TypedDict

import httpx
from starlette.applications import Starlette
from starlette.routing import Route


class State(TypedDict):
http_client: httpx.AsyncClient


@contextlib.asynccontextmanager
async def lifespan(app, state):
async def lifespan(app: Starlette) -> State:
async with httpx.AsyncClient() as client:
state["http_client"] = client
yield
yield {"http_client": client}


app = Starlette(
Expand Down
57 changes: 16 additions & 41 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan
from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket, WebSocketClose


Expand Down Expand Up @@ -558,25 +558,17 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
return wrapper


_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan")


class _DefaultLifespan:
def __init__(self, router: "Router"):
self._router = router

async def __aenter__(self) -> None:
await self._router.startup(state=self._state)
await self._router.startup()

async def __aexit__(self, *exc_info: object) -> None:
await self._router.shutdown(state=self._state)

def __call__(
self: _TDefaultLifespan,
app: object,
state: typing.Optional[typing.Dict[str, typing.Any]],
) -> _TDefaultLifespan:
self._state = state
await self._router.shutdown()

def __call__(self: _T, app: object) -> _T:
return self


Expand All @@ -598,6 +590,7 @@ def __init__(

if lifespan is None:
self.lifespan_context: Lifespan = _DefaultLifespan(self)

elif inspect.isasyncgenfunction(lifespan):
warnings.warn(
"async generator function lifespans are deprecated, "
Expand Down Expand Up @@ -642,31 +635,21 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath:
pass
raise NoMatchFound(__name, path_params)

async def startup(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
async def startup(self) -> None:
"""
Run any `.on_startup` event handlers.
"""
for handler in self.on_startup:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
handler()

async def shutdown(
self, state: typing.Optional[typing.Dict[str, typing.Any]]
) -> None:
async def shutdown(self) -> None:
"""
Run any `.on_shutdown` event handlers.
"""
for handler in self.on_shutdown:
sig = inspect.signature(handler)
if len(sig.parameters) == 1 and state is not None:
handler = functools.partial(handler, state)
if is_async_callable(handler):
await handler()
else:
Expand All @@ -678,24 +661,16 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
startup and shutdown events.
"""
started = False
app = scope.get("app")
state = scope.get("state")
app: typing.Any = scope.get("app")
await receive()
lifespan_needs_state = (
len(inspect.signature(self.lifespan_context).parameters) == 2
)
server_supports_state = state is not None
if lifespan_needs_state and not server_supports_state:
raise RuntimeError(
'The server does not support "state" in the lifespan scope.'
)
try:
lifespan_context: Lifespan
if lifespan_needs_state:
lifespan_context = functools.partial(self.lifespan_context, state=state)
else:
lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context)
async with lifespan_context(app):
async with self.lifespan_context(app) as maybe_state:
if maybe_state is not None:
if "state" not in scope:
raise RuntimeError(
'The server does not support "state" in the lifespan scope.'
)
scope["state"].update(maybe_state)
await send({"type": "lifespan.startup.complete"})
started = True
await receive()
Expand Down
11 changes: 7 additions & 4 deletions starlette/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import typing

if typing.TYPE_CHECKING:
from starlette.applications import Starlette

Scope = typing.MutableMapping[str, typing.Any]
Message = typing.MutableMapping[str, typing.Any]

Expand All @@ -8,8 +11,8 @@

ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]]
StateLifespan = typing.Callable[
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any]
StatelessLifespan = typing.Callable[["Starlette"], typing.AsyncContextManager[None]]
StatefulLifespan = typing.Callable[
["Starlette"], typing.AsyncContextManager[typing.Mapping[str, typing.Any]]
]
Lifespan = typing.Union[StatelessLifespan, StateLifespan]
Lifespan = typing.Union[StatelessLifespan, StatefulLifespan]
78 changes: 20 additions & 58 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import contextlib
import functools
import sys
import typing
import uuid

if sys.version_info < (3, 8):
from typing_extensions import TypedDict # pragma: no cover
else:
from typing import TypedDict # pragma: no cover

import pytest

from starlette.applications import Starlette
Expand Down Expand Up @@ -670,57 +676,10 @@ def run_shutdown():
assert shutdown_complete


def test_lifespan_with_state(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.list is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.list.append(1)
return PlainTextResponse("hello, world")

async def run_startup(state):
nonlocal startup_complete
startup_complete = True
state["count"] = 0
state["list"] = []

async def run_shutdown(state):
nonlocal shutdown_complete
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["list"] == [1, 1]

app = Router(
on_startup=[run_startup],
on_shutdown=[run_shutdown],
routes=[Route("/", hello_world)],
)

assert not startup_complete
assert not shutdown_complete
with test_client_factory(app) as client:
assert startup_complete
assert not shutdown_complete
client.get("/")
# Calling it a second time to ensure that the state is preserved.
client.get("/")
assert startup_complete
assert shutdown_complete


def test_lifespan_state_unsupported(test_client_factory):
@contextlib.asynccontextmanager
async def lifespan(app, scope):
yield None # pragma: no cover
async def lifespan(app):
yield {"foo": "bar"}

app = Router(
lifespan=lifespan,
Expand All @@ -738,33 +697,36 @@ async def no_state_wrapper(scope, receive, send):
raise AssertionError("Should not be called") # pragma: no cover


def test_lifespan_async_cm(test_client_factory):
def test_lifespan_state_async_cm(test_client_factory):
startup_complete = False
shutdown_complete = False

async def hello_world(request):
class State(TypedDict):
count: int
items: typing.List[int]

async def hello_world(request: Request) -> Response:
# modifications to the state should not leak across requests
assert request.state.count == 0
# modify the state, this should not leak to the lifespan or other requests
request.state.count += 1
# since state.list is a mutable object this modification _will_ leak across
# since state.items is a mutable object this modification _will_ leak across
# requests and to the lifespan
request.state.list.append(1)
request.state.items.append(1)
return PlainTextResponse("hello, world")

@contextlib.asynccontextmanager
async def lifespan(app: Starlette, state: typing.Dict[str, typing.Any]):
async def lifespan(app: Starlette) -> typing.AsyncIterator[State]:
nonlocal startup_complete, shutdown_complete
startup_complete = True
state["count"] = 0
state["list"] = []
yield
state = State(count=0, items=[])
yield state
shutdown_complete = True
# modifications made to the state from a request do not leak to the lifespan
assert state["count"] == 0
# unless of course the request mutates a mutable object that is referenced
# via state
assert state["list"] == [1, 1]
assert state["items"] == [1, 1]

app = Router(
lifespan=lifespan,
Expand Down