From 44159d0285352530edf8bee095af27bef55c4fca Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Mar 2023 14:44:13 -0600 Subject: [PATCH 1/3] Revert "Support lifespan state (#2060)" This reverts commit da6461b239cde16ee9709b7d266c2529c26239d7. --- starlette/applications.py | 6 +- starlette/routing.py | 57 +++++-------------- starlette/testclient.py | 9 +-- starlette/types.py | 6 -- tests/test_routing.py | 114 -------------------------------------- 5 files changed, 20 insertions(+), 172 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index ff88b6951..c68ad864a 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router -from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send class Starlette: @@ -55,7 +55,9 @@ def __init__( ] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[Lifespan] = None, + lifespan: typing.Optional[ + typing.Callable[["Starlette"], typing.AsyncContextManager] + ] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index 92b81f0bc..0aa90aa25 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -558,25 +558,17 @@ def wrapper(app: typing.Any) -> _AsyncLiftContextManager: return wrapper -_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan") - - class _DefaultLifespan: def __init__(self, router: "Router"): self._router = router async def __aenter__(self) -> None: - await self._router.startup(state=self._state) + await self._router.startup() async def __aexit__(self, *exc_info: object) -> None: - await self._router.shutdown(state=self._state) - - def __call__( - self: _TDefaultLifespan, - app: object, - state: typing.Optional[typing.Dict[str, typing.Any]], - ) -> _TDefaultLifespan: - self._state = state + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: return self @@ -588,7 +580,9 @@ def __init__( default: typing.Optional[ASGIApp] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[Lifespan] = None, + lifespan: typing.Optional[ + typing.Callable[[typing.Any], typing.AsyncContextManager] + ] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -597,7 +591,10 @@ def __init__( self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) if lifespan is None: - self.lifespan_context: Lifespan = _DefaultLifespan(self) + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = _DefaultLifespan(self) + elif inspect.isasyncgenfunction(lifespan): warnings.warn( "async generator function lifespans are deprecated, " @@ -642,31 +639,21 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: pass raise NoMatchFound(name, path_params) - async def startup( - self, state: typing.Optional[typing.Dict[str, typing.Any]] - ) -> None: + async def startup(self) -> None: """ Run any `.on_startup` event handlers. """ for handler in self.on_startup: - sig = inspect.signature(handler) - if len(sig.parameters) == 1 and state is not None: - handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: handler() - async def shutdown( - self, state: typing.Optional[typing.Dict[str, typing.Any]] - ) -> None: + async def shutdown(self) -> None: """ Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: - sig = inspect.signature(handler) - if len(sig.parameters) == 1 and state is not None: - handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: @@ -679,23 +666,9 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") - state = scope.get("state") await receive() - lifespan_needs_state = ( - len(inspect.signature(self.lifespan_context).parameters) == 2 - ) - server_supports_state = state is not None - if lifespan_needs_state and not server_supports_state: - raise RuntimeError( - 'The server does not support "state" in the lifespan scope.' - ) try: - lifespan_context: Lifespan - if lifespan_needs_state: - lifespan_context = functools.partial(self.lifespan_context, state=state) - else: - lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) - async with lifespan_context(app): + async with self.lifespan_context(app): await send({"type": "lifespan.startup.complete"}) started = True await receive() diff --git a/starlette/testclient.py b/starlette/testclient.py index bdae83bf0..549fa7621 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -188,14 +188,11 @@ def __init__( portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", - *, - app_state: typing.Dict[str, typing.Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory - self.app_state = app_state def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme @@ -246,7 +243,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, - "state": self.app_state.copy(), } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) @@ -264,7 +260,6 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.debug": {}}, - "state": self.app_state.copy(), } request_complete = False @@ -385,13 +380,11 @@ def __init__( app = typing.cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app - self.app_state: typing.Dict[str, typing.Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, - app_state=self.app_state, ) if headers is None: headers = {} @@ -756,7 +749,7 @@ def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: - scope = {"type": "lifespan", "state": self.app_state} + scope = {"type": "lifespan"} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: diff --git a/starlette/types.py b/starlette/types.py index b83d9101a..888645cae 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -7,9 +7,3 @@ Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] - -StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]] -StateLifespan = typing.Callable[ - [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any] -] -Lifespan = typing.Union[StatelessLifespan, StateLifespan] diff --git a/tests/test_routing.py b/tests/test_routing.py index b70641680..da4848b8d 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,4 +1,3 @@ -import contextlib import functools import typing import uuid @@ -670,119 +669,6 @@ def run_shutdown(): assert shutdown_complete -def test_lifespan_with_state(test_client_factory): - startup_complete = False - shutdown_complete = False - - async def hello_world(request): - # modifications to the state should not leak across requests - assert request.state.count == 0 - # modify the state, this should not leak to the lifespan or other requests - request.state.count += 1 - # since state.list is a mutable object this modification _will_ leak across - # requests and to the lifespan - request.state.list.append(1) - return PlainTextResponse("hello, world") - - async def run_startup(state): - nonlocal startup_complete - startup_complete = True - state["count"] = 0 - state["list"] = [] - - async def run_shutdown(state): - nonlocal shutdown_complete - shutdown_complete = True - # modifications made to the state from a request do not leak to the lifespan - assert state["count"] == 0 - # unless of course the request mutates a mutable object that is referenced - # via state - assert state["list"] == [1, 1] - - app = Router( - on_startup=[run_startup], - on_shutdown=[run_shutdown], - routes=[Route("/", hello_world)], - ) - - assert not startup_complete - assert not shutdown_complete - with test_client_factory(app) as client: - assert startup_complete - assert not shutdown_complete - client.get("/") - # Calling it a second time to ensure that the state is preserved. - client.get("/") - assert startup_complete - assert shutdown_complete - - -def test_lifespan_state_unsupported(test_client_factory): - @contextlib.asynccontextmanager - async def lifespan(app, scope): - yield None # pragma: no cover - - app = Router( - lifespan=lifespan, - routes=[Mount("/", PlainTextResponse("hello, world"))], - ) - - async def no_state_wrapper(scope, receive, send): - del scope["state"] - await app(scope, receive, send) - - with pytest.raises( - RuntimeError, match='The server does not support "state" in the lifespan scope' - ): - with test_client_factory(no_state_wrapper): - raise AssertionError("Should not be called") # pragma: no cover - - -def test_lifespan_async_cm(test_client_factory): - startup_complete = False - shutdown_complete = False - - async def hello_world(request): - # modifications to the state should not leak across requests - assert request.state.count == 0 - # modify the state, this should not leak to the lifespan or other requests - request.state.count += 1 - # since state.list is a mutable object this modification _will_ leak across - # requests and to the lifespan - request.state.list.append(1) - return PlainTextResponse("hello, world") - - @contextlib.asynccontextmanager - async def lifespan(app: Starlette, state: typing.Dict[str, typing.Any]): - nonlocal startup_complete, shutdown_complete - startup_complete = True - state["count"] = 0 - state["list"] = [] - yield - shutdown_complete = True - # modifications made to the state from a request do not leak to the lifespan - assert state["count"] == 0 - # unless of course the request mutates a mutable object that is referenced - # via state - assert state["list"] == [1, 1] - - app = Router( - lifespan=lifespan, - routes=[Route("/", hello_world)], - ) - - assert not startup_complete - assert not shutdown_complete - with test_client_factory(app) as client: - assert startup_complete - assert not shutdown_complete - client.get("/") - # Calling it a second time to ensure that the state is preserved. - client.get("/") - assert startup_complete - assert shutdown_complete - - def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() From 14fd988b9a7713d31359af296c1c98beb2a63ec4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Mar 2023 15:12:02 -0600 Subject: [PATCH 2/3] new implementation --- starlette/applications.py | 6 ++-- starlette/routing.py | 20 ++++++----- starlette/testclient.py | 9 ++++- starlette/types.py | 9 +++++ tests/test_routing.py | 76 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 14 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index c68ad864a..ff88b6951 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send class Starlette: @@ -55,9 +55,7 @@ def __init__( ] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[ - typing.Callable[["Starlette"], typing.AsyncContextManager] - ] = None, + lifespan: typing.Optional[Lifespan] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..925d9e96b 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -580,9 +580,7 @@ def __init__( default: typing.Optional[ASGIApp] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[ - typing.Callable[[typing.Any], typing.AsyncContextManager] - ] = None, + lifespan: typing.Optional[Lifespan] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -591,9 +589,7 @@ def __init__( self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) if lifespan is None: - self.lifespan_context: typing.Callable[ - [typing.Any], typing.AsyncContextManager - ] = _DefaultLifespan(self) + self.lifespan_context: Lifespan = _DefaultLifespan(self) elif inspect.isasyncgenfunction(lifespan): warnings.warn( @@ -665,10 +661,16 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: startup and shutdown events. """ started = False - app = scope.get("app") + app: typing.Any = scope.get("app") await receive() try: - async with self.lifespan_context(app): + async with self.lifespan_context(app) as maybe_state: + if maybe_state is not None: + if "state" not in scope: + raise RuntimeError( + 'The server does not support "state" in the lifespan scope.' + ) + scope["state"].update(maybe_state) await send({"type": "lifespan.startup.complete"}) started = True await receive() diff --git a/starlette/testclient.py b/starlette/testclient.py index 549fa7621..bdae83bf0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -188,11 +188,14 @@ def __init__( portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", + *, + app_state: typing.Dict[str, typing.Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory + self.app_state = app_state def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme @@ -243,6 +246,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, + "state": self.app_state.copy(), } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) @@ -260,6 +264,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.debug": {}}, + "state": self.app_state.copy(), } request_complete = False @@ -380,11 +385,13 @@ def __init__( app = typing.cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app + self.app_state: typing.Dict[str, typing.Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, + app_state=self.app_state, ) if headers is None: headers = {} @@ -749,7 +756,7 @@ def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: - scope = {"type": "lifespan"} + scope = {"type": "lifespan", "state": self.app_state} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: diff --git a/starlette/types.py b/starlette/types.py index 888645cae..05f5446e4 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -1,5 +1,8 @@ import typing +if typing.TYPE_CHECKING: + from starlette.applications import Starlette + Scope = typing.MutableMapping[str, typing.Any] Message = typing.MutableMapping[str, typing.Any] @@ -7,3 +10,9 @@ Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + +StatelessLifespan = typing.Callable[["Starlette"], typing.AsyncContextManager[None]] +StatefulLifespan = typing.Callable[ + ["Starlette"], typing.AsyncContextManager[typing.Mapping[str, typing.Any]] +] +Lifespan = typing.Union[StatelessLifespan, StatefulLifespan] diff --git a/tests/test_routing.py b/tests/test_routing.py index da4848b8d..dccf089c2 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,7 +1,14 @@ +import contextlib import functools +import sys import typing import uuid +if sys.version_info < (3, 8): + from typing_extensions import TypedDict # pragma: no cover +else: + from typing import TypedDict # pragma: no cover + import pytest from starlette.applications import Starlette @@ -669,6 +676,75 @@ def run_shutdown(): assert shutdown_complete +def test_lifespan_state_unsupported(test_client_factory): + @contextlib.asynccontextmanager + async def lifespan(app): + yield {"foo": "bar"} + + app = Router( + lifespan=lifespan, + routes=[Mount("/", PlainTextResponse("hello, world"))], + ) + + async def no_state_wrapper(scope, receive, send): + del scope["state"] + await app(scope, receive, send) + + with pytest.raises( + RuntimeError, match='The server does not support "state" in the lifespan scope' + ): + with test_client_factory(no_state_wrapper): + raise AssertionError("Should not be called") # pragma: no cover + + +def test_lifespan_state_async_cm(test_client_factory): + startup_complete = False + shutdown_complete = False + + class State(TypedDict): + count: int + items: typing.List[int] + + async def hello_world(request: Request) -> Response: + # modifications to the state should not leak across requests + assert request.state.count == 0 + # modify the state, this should not leak to the lifespan or other requests + request.state.count += 1 + # since state.items is a mutable object this modification _will_ leak across + # requests and to the lifespan + request.state.items.append(1) + return PlainTextResponse("hello, world") + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> typing.AsyncIterator[State]: + nonlocal startup_complete, shutdown_complete + startup_complete = True + state = State(count=0, items=[]) + yield state + shutdown_complete = True + # modifications made to the state from a request do not leak to the lifespan + assert state["count"] == 0 + # unless of course the request mutates a mutable object that is referenced + # via state + assert state["items"] == [1, 1] + + app = Router( + lifespan=lifespan, + routes=[Route("/", hello_world)], + ) + + assert not startup_complete + assert not shutdown_complete + with test_client_factory(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + # Calling it a second time to ensure that the state is preserved. + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() From fc1182c4b0668e84568714b1e2d6aa6646c8bfdf Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 9 Mar 2023 20:24:33 +0100 Subject: [PATCH 3/3] Add documentation about lifespan state --- docs/events.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/events.md b/docs/events.md index 56c8c9ae6..2a43342c1 100644 --- a/docs/events.md +++ b/docs/events.md @@ -98,16 +98,21 @@ Analogously, the single lifespan `asynccontextmanager` can be used. ```python import contextlib +from typing import TypedDict + import httpx from starlette.applications import Starlette from starlette.routing import Route +class State(TypedDict): + http_client: httpx.AsyncClient + + @contextlib.asynccontextmanager -async def lifespan(app, state): +async def lifespan(app: Starlette) -> State: async with httpx.AsyncClient() as client: - state["http_client"] = client - yield + yield {"http_client": client} app = Starlette(