diff --git a/docs/events.md b/docs/events.md index 56c8c9ae6..2a43342c1 100644 --- a/docs/events.md +++ b/docs/events.md @@ -98,16 +98,21 @@ Analogously, the single lifespan `asynccontextmanager` can be used. ```python import contextlib +from typing import TypedDict + import httpx from starlette.applications import Starlette from starlette.routing import Route +class State(TypedDict): + http_client: httpx.AsyncClient + + @contextlib.asynccontextmanager -async def lifespan(app, state): +async def lifespan(app: Starlette) -> State: async with httpx.AsyncClient() as client: - state["http_client"] = client - yield + yield {"http_client": client} app = Starlette( diff --git a/starlette/routing.py b/starlette/routing.py index 696efa51a..46de35d93 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -558,25 +558,17 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager: return wrapper -_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan") - - class _DefaultLifespan: def __init__(self, router: "Router"): self._router = router async def __aenter__(self) -> None: - await self._router.startup(state=self._state) + await self._router.startup() async def __aexit__(self, *exc_info: object) -> None: - await self._router.shutdown(state=self._state) - - def __call__( - self: _TDefaultLifespan, - app: object, - state: typing.Optional[typing.Dict[str, typing.Any]], - ) -> _TDefaultLifespan: - self._state = state + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: return self @@ -598,6 +590,7 @@ def __init__( if lifespan is None: self.lifespan_context: Lifespan = _DefaultLifespan(self) + elif inspect.isasyncgenfunction(lifespan): warnings.warn( "async generator function lifespans are deprecated, " @@ -642,31 +635,21 @@ def url_path_for(self, __name: str, **path_params: typing.Any) -> URLPath: pass raise NoMatchFound(__name, path_params) - async def startup( - self, state: typing.Optional[typing.Dict[str, typing.Any]] - ) -> None: + async def startup(self) -> None: """ Run any `.on_startup` event handlers. """ for handler in self.on_startup: - sig = inspect.signature(handler) - if len(sig.parameters) == 1 and state is not None: - handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: handler() - async def shutdown( - self, state: typing.Optional[typing.Dict[str, typing.Any]] - ) -> None: + async def shutdown(self) -> None: """ Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: - sig = inspect.signature(handler) - if len(sig.parameters) == 1 and state is not None: - handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: @@ -678,24 +661,16 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: startup and shutdown events. """ started = False - app = scope.get("app") - state = scope.get("state") + app: typing.Any = scope.get("app") await receive() - lifespan_needs_state = ( - len(inspect.signature(self.lifespan_context).parameters) == 2 - ) - server_supports_state = state is not None - if lifespan_needs_state and not server_supports_state: - raise RuntimeError( - 'The server does not support "state" in the lifespan scope.' - ) try: - lifespan_context: Lifespan - if lifespan_needs_state: - lifespan_context = functools.partial(self.lifespan_context, state=state) - else: - lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) - async with lifespan_context(app): + async with self.lifespan_context(app) as maybe_state: + if maybe_state is not None: + if "state" not in scope: + raise RuntimeError( + 'The server does not support "state" in the lifespan scope.' + ) + scope["state"].update(maybe_state) await send({"type": "lifespan.startup.complete"}) started = True await receive() diff --git a/starlette/types.py b/starlette/types.py index b83d9101a..05f5446e4 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,5 +1,8 @@ import typing +if typing.TYPE_CHECKING: + from starlette.applications import Starlette + Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] @@ -8,8 +11,8 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] -StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]] -StateLifespan = typing.Callable[ - [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any] +StatelessLifespan = typing.Callable[["Starlette"], typing.AsyncContextManager[None]] +StatefulLifespan = typing.Callable[ + ["Starlette"], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] ] -Lifespan = typing.Union[StatelessLifespan, StateLifespan] +Lifespan = typing.Union[StatelessLifespan, StatefulLifespan] diff --git a/tests/test_routing.py b/tests/test_routing.py index b70641680..dccf089c2 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,8 +1,14 @@ import contextlib import functools +import sys import typing import uuid +if sys.version_info < (3, 8): + from typing_extensions import TypedDict # pragma: no cover +else: + from typing import TypedDict # pragma: no cover + import pytest from starlette.applications import Starlette @@ -670,57 +676,10 @@ def run_shutdown(): assert shutdown_complete -def test_lifespan_with_state(test_client_factory): - startup_complete = False - shutdown_complete = False - - async def hello_world(request): - # modifications to the state should not leak across requests - assert request.state.count == 0 - # modify the state, this should not leak to the lifespan or other requests - request.state.count += 1 - # since state.list is a mutable object this modification _will_ leak across - # requests and to the lifespan - request.state.list.append(1) - return PlainTextResponse("hello, world") - - async def run_startup(state): - nonlocal startup_complete - startup_complete = True - state["count"] = 0 - state["list"] = [] - - async def run_shutdown(state): - nonlocal shutdown_complete - shutdown_complete = True - # modifications made to the state from a request do not leak to the lifespan - assert state["count"] == 0 - # unless of course the request mutates a mutable object that is referenced - # via state - assert state["list"] == [1, 1] - - app = Router( - on_startup=[run_startup], - on_shutdown=[run_shutdown], - routes=[Route("/", hello_world)], - ) - - assert not startup_complete - assert not shutdown_complete - with test_client_factory(app) as client: - assert startup_complete - assert not shutdown_complete - client.get("/") - # Calling it a second time to ensure that the state is preserved. - client.get("/") - assert startup_complete - assert shutdown_complete - - def test_lifespan_state_unsupported(test_client_factory): @contextlib.asynccontextmanager - async def lifespan(app, scope): - yield None # pragma: no cover + async def lifespan(app): + yield {"foo": "bar"} app = Router( lifespan=lifespan, @@ -738,33 +697,36 @@ async def no_state_wrapper(scope, receive, send): raise AssertionError("Should not be called") # pragma: no cover -def test_lifespan_async_cm(test_client_factory): +def test_lifespan_state_async_cm(test_client_factory): startup_complete = False shutdown_complete = False - async def hello_world(request): + class State(TypedDict): + count: int + items: typing.List[int] + + async def hello_world(request: Request) -> Response: # modifications to the state should not leak across requests assert request.state.count == 0 # modify the state, this should not leak to the lifespan or other requests request.state.count += 1 - # since state.list is a mutable object this modification _will_ leak across + # since state.items is a mutable object this modification _will_ leak across # requests and to the lifespan - request.state.list.append(1) + request.state.items.append(1) return PlainTextResponse("hello, world") @contextlib.asynccontextmanager - async def lifespan(app: Starlette, state: typing.Dict[str, typing.Any]): + async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: nonlocal startup_complete, shutdown_complete startup_complete = True - state["count"] = 0 - state["list"] = [] - yield + state = State(count=0, items=[]) + yield state shutdown_complete = True # modifications made to the state from a request do not leak to the lifespan assert state["count"] == 0 # unless of course the request mutates a mutable object that is referenced # via state - assert state["list"] == [1, 1] + assert state["items"] == [1, 1] app = Router( lifespan=lifespan,