From ba1588a8e8ee0b5528eab3acbd5e39417cf572f6 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 4 Mar 2023 16:37:06 +0100 Subject: [PATCH 1/7] Support lifespan state --- starlette/applications.py | 6 +-- starlette/routing.py | 102 +++++++++++++++---------------------- starlette/testclient.py | 8 ++- starlette/types.py | 6 +++ tests/test_applications.py | 59 +++++---------------- tests/test_routing.py | 34 +++++++++++++ 6 files changed, 102 insertions(+), 113 deletions(-) diff --git a/starlette/applications.py b/starlette/applications.py index c68ad864a..ff88b6951 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.routing import BaseRoute, Router -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send class Starlette: @@ -55,9 +55,7 @@ def __init__( ] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[ - typing.Callable[["Starlette"], typing.AsyncContextManager] - ] = None, + lifespan: typing.Optional[Lifespan] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..9b3583223 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,12 +1,9 @@ -import contextlib import functools import inspect import re import traceback -import types import typing import warnings -from contextlib import asynccontextmanager from enum import Enum from starlette._utils import is_async_callable @@ -17,7 +14,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send, StatelessLifespan from starlette.websockets import WebSocket, WebSocketClose @@ -530,32 +527,7 @@ def __repr__(self) -> str: _T = typing.TypeVar("_T") -class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): - def __init__(self, cm: typing.ContextManager[_T]): - self._cm = cm - - async def __aenter__(self) -> _T: - return self._cm.__enter__() - - async def __aexit__( - self, - exc_type: typing.Optional[typing.Type[BaseException]], - exc_value: typing.Optional[BaseException], - traceback: typing.Optional[types.TracebackType], - ) -> typing.Optional[bool]: - return self._cm.__exit__(exc_type, exc_value, traceback) - - -def _wrap_gen_lifespan_context( - lifespan_context: typing.Callable[[typing.Any], typing.Generator] -) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: - cmgr = contextlib.contextmanager(lifespan_context) - - @functools.wraps(cmgr) - def wrapper(app: typing.Any) -> _AsyncLiftContextManager: - return _AsyncLiftContextManager(cmgr(app)) - - return wrapper +_TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan") class _DefaultLifespan: @@ -563,12 +535,17 @@ def __init__(self, router: "Router"): self._router = router async def __aenter__(self) -> None: - await self._router.startup() + await self._router.startup(state=self._state) async def __aexit__(self, *exc_info: object) -> None: - await self._router.shutdown() - - def __call__(self: _T, app: object) -> _T: + await self._router.shutdown(state=self._state) + + def __call__( + self: _TDefaultLifespan, + app: object, + state: typing.Optional[typing.Dict[str, typing.Any]], + ) -> _TDefaultLifespan: + self._state = state return self @@ -580,9 +557,7 @@ def __init__( default: typing.Optional[ASGIApp] = None, on_startup: typing.Optional[typing.Sequence[typing.Callable]] = None, on_shutdown: typing.Optional[typing.Sequence[typing.Callable]] = None, - lifespan: typing.Optional[ - typing.Callable[[typing.Any], typing.AsyncContextManager] - ] = None, + lifespan: typing.Optional[Lifespan] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -591,27 +566,13 @@ def __init__( self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) if lifespan is None: - self.lifespan_context: typing.Callable[ - [typing.Any], typing.AsyncContextManager - ] = _DefaultLifespan(self) - - elif inspect.isasyncgenfunction(lifespan): - warnings.warn( - "async generator function lifespans are deprecated, " - "use an @contextlib.asynccontextmanager function instead", - DeprecationWarning, - ) - self.lifespan_context = asynccontextmanager( - lifespan, # type: ignore[arg-type] - ) - elif inspect.isgeneratorfunction(lifespan): - warnings.warn( - "generator function lifespans are deprecated, " - "use an @contextlib.asynccontextmanager function instead", - DeprecationWarning, - ) - self.lifespan_context = _wrap_gen_lifespan_context( - lifespan, # type: ignore[arg-type] + self.lifespan_context: Lifespan = _DefaultLifespan(self) + elif inspect.isasyncgenfunction(lifespan) or inspect.isgeneratorfunction( + lifespan + ): + raise RuntimeError( + "Generator functions are not supported for lifespan, " + "use an @contextlib.asynccontextmanager function instead." ) else: self.lifespan_context = lifespan @@ -639,21 +600,31 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: pass raise NoMatchFound(name, path_params) - async def startup(self) -> None: + async def startup( + self, state: typing.Optional[typing.Dict[str, typing.Any]] + ) -> None: """ Run any `.on_startup` event handlers. """ for handler in self.on_startup: + sig = inspect.signature(handler) + if len(sig.parameters) == 1 and state is not None: + handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: handler() - async def shutdown(self) -> None: + async def shutdown( + self, state: typing.Optional[typing.Dict[str, typing.Any]] + ) -> None: """ Run any `.on_shutdown` event handlers. """ for handler in self.on_shutdown: + sig = inspect.signature(handler) + if len(sig.parameters) == 1 and state is not None: + handler = functools.partial(handler, state) if is_async_callable(handler): await handler() else: @@ -666,9 +637,18 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") + state = scope.get("state") await receive() try: - async with self.lifespan_context(app): + lifespan_context: Lifespan + if ( + len(inspect.signature(self.lifespan_context).parameters) == 2 + and state is not None + ): + lifespan_context = functools.partial(self.lifespan_context, state=state) + else: + lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) + async with lifespan_context(app): await send({"type": "lifespan.startup.complete"}) started = True await receive() diff --git a/starlette/testclient.py b/starlette/testclient.py index 549fa7621..276093b89 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -188,11 +188,13 @@ def __init__( portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", + state: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory + self.state = state def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme @@ -243,6 +245,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, + "state": self.state, } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) @@ -260,6 +263,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.debug": {}}, + "state": self.state, } request_complete = False @@ -380,11 +384,13 @@ def __init__( app = typing.cast(ASGI2App, app) # type: ignore[assignment] asgi_app = _WrapASGI2(app) # type: ignore[arg-type] self.app = asgi_app + self.app_state: typing.Dict[str, typing.Any] = {} transport = _TestClientTransport( self.app, portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, + state=self.app_state, ) if headers is None: headers = {} @@ -749,7 +755,7 @@ def __exit__(self, *args: typing.Any) -> None: self.exit_stack.close() async def lifespan(self) -> None: - scope = {"type": "lifespan"} + scope = {"type": "lifespan", "state": self.app_state} try: await self.app(scope, self.stream_receive.receive, self.stream_send.send) finally: diff --git a/starlette/types.py b/starlette/types.py index 888645cae..a11c68a7e 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -7,3 +7,9 @@ Send = typing.Callable[[Message], typing.Awaitable[None]] ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] + +StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager] +StateLifespan = typing.Callable[ + [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager +] +Lifespan = typing.Union[StatelessLifespan, StateLifespan] diff --git a/tests/test_applications.py b/tests/test_applications.py index ba10aff8e..19a6782f6 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -381,57 +381,22 @@ async def lifespan(app): assert cleanup_complete -deprecated_lifespan = pytest.mark.filterwarnings( - r"ignore" - r":(async )?generator function lifespans are deprecated, use an " - r"@contextlib\.asynccontextmanager function instead" - r":DeprecationWarning" - r":starlette.routing" -) - - -@deprecated_lifespan -def test_app_async_gen_lifespan(test_client_factory): - startup_complete = False - cleanup_complete = False - - async def lifespan(app): - nonlocal startup_complete, cleanup_complete - startup_complete = True - yield - cleanup_complete = True - - app = Starlette(lifespan=lifespan) - - assert not startup_complete - assert not cleanup_complete - with test_client_factory(app): - assert startup_complete - assert not cleanup_complete - assert startup_complete - assert cleanup_complete +async def async_gen_lifespan(): + yield # pragma: no cover -@deprecated_lifespan -def test_app_sync_gen_lifespan(test_client_factory): - startup_complete = False - cleanup_complete = False - - def lifespan(app): - nonlocal startup_complete, cleanup_complete - startup_complete = True - yield - cleanup_complete = True +def sync__gen_lifespan(): + yield # pragma: no cover - app = Starlette(lifespan=lifespan) - assert not startup_complete - assert not cleanup_complete - with test_client_factory(app): - assert startup_complete - assert not cleanup_complete - assert startup_complete - assert cleanup_complete +@pytest.mark.parametrize("lifespan", [async_gen_lifespan, sync__gen_lifespan]) +def test_app_gen_lifespan(lifespan): + with pytest.raises( + RuntimeError, + match="Generator functions are not supported for lifespan" + ", use an @contextlib.asynccontextmanager function instead.", + ): + Starlette(lifespan=lifespan) def test_decorator_deprecations() -> None: diff --git a/tests/test_routing.py b/tests/test_routing.py index 09beb8bb9..32ddcb87a 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -669,6 +669,40 @@ def run_shutdown(): assert shutdown_complete +def test_lifespan_with_state(test_client_factory): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + assert request.state.startup + return PlainTextResponse("hello, world") + + async def run_startup(state): + nonlocal startup_complete + startup_complete = True + state["startup"] = True + + async def run_shutdown(state): + nonlocal shutdown_complete + shutdown_complete = True + assert state["startup"] + + app = Router( + on_startup=[run_startup], + on_shutdown=[run_shutdown], + routes=[Route("/", hello_world)], + ) + + assert not startup_complete + assert not shutdown_complete + with test_client_factory(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() From 3e7b8048b87942b461c9663f700ba56be791ec58 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 4 Mar 2023 19:52:06 +0100 Subject: [PATCH 2/7] Copy the state instead of reusing on multiple requests --- starlette/testclient.py | 11 ++++++----- tests/test_routing.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 276093b89..bdae83bf0 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -188,13 +188,14 @@ def __init__( portal_factory: _PortalFactoryType, raise_server_exceptions: bool = True, root_path: str = "", - state: typing.Optional[typing.Dict[str, typing.Any]] = None, + *, + app_state: typing.Dict[str, typing.Any], ) -> None: self.app = app self.raise_server_exceptions = raise_server_exceptions self.root_path = root_path self.portal_factory = portal_factory - self.state = state + self.app_state = app_state def handle_request(self, request: httpx.Request) -> httpx.Response: scheme = request.url.scheme @@ -245,7 +246,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, - "state": self.state, + "state": self.app_state.copy(), } session = WebSocketTestSession(self.app, scope, self.portal_factory) raise _Upgrade(session) @@ -263,7 +264,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.debug": {}}, - "state": self.state, + "state": self.app_state.copy(), } request_complete = False @@ -390,7 +391,7 @@ def __init__( portal_factory=self._portal_factory, raise_server_exceptions=raise_server_exceptions, root_path=root_path, - state=self.app_state, + app_state=self.app_state, ) if headers is None: headers = {} diff --git a/tests/test_routing.py b/tests/test_routing.py index 32ddcb87a..18dfbdee8 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -674,18 +674,21 @@ def test_lifespan_with_state(test_client_factory): shutdown_complete = False async def hello_world(request): - assert request.state.startup + request.state.count += 1 + request.state.list.append(1) return PlainTextResponse("hello, world") async def run_startup(state): nonlocal startup_complete startup_complete = True - state["startup"] = True + state["count"] = 0 + state["list"] = [] async def run_shutdown(state): nonlocal shutdown_complete shutdown_complete = True - assert state["startup"] + assert state["count"] == 0 + assert state["list"] == [1, 1] app = Router( on_startup=[run_startup], @@ -699,6 +702,8 @@ async def run_shutdown(state): assert startup_complete assert not shutdown_complete client.get("/") + # Calling it a second time to ensure that the state is preserved. + client.get("/") assert startup_complete assert shutdown_complete From 35ecd94854b9bdc86c1a2d2d6a2fd627605edf7c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 4 Mar 2023 20:00:36 +0100 Subject: [PATCH 3/7] Readd deprecated conditionals --- starlette/routing.py | 54 ++++++++++++++++++++++++++++++---- tests/test_applications.py | 59 ++++++++++++++++++++++++++++++-------- 2 files changed, 95 insertions(+), 18 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 9b3583223..31b870129 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,9 +1,12 @@ +import contextlib import functools import inspect import re import traceback +import types import typing import warnings +from contextlib import asynccontextmanager from enum import Enum from starlette._utils import is_async_callable @@ -527,6 +530,34 @@ def __repr__(self) -> str: _T = typing.TypeVar("_T") +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> typing.Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + + _TDefaultLifespan = typing.TypeVar("_TDefaultLifespan", bound="_DefaultLifespan") @@ -567,12 +598,23 @@ def __init__( if lifespan is None: self.lifespan_context: Lifespan = _DefaultLifespan(self) - elif inspect.isasyncgenfunction(lifespan) or inspect.isgeneratorfunction( - lifespan - ): - raise RuntimeError( - "Generator functions are not supported for lifespan, " - "use an @contextlib.asynccontextmanager function instead." + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = asynccontextmanager( + lifespan, # type: ignore[arg-type] + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, # type: ignore[arg-type] ) else: self.lifespan_context = lifespan diff --git a/tests/test_applications.py b/tests/test_applications.py index 19a6782f6..ba10aff8e 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -381,22 +381,57 @@ async def lifespan(app): assert cleanup_complete -async def async_gen_lifespan(): - yield # pragma: no cover +deprecated_lifespan = pytest.mark.filterwarnings( + r"ignore" + r":(async )?generator function lifespans are deprecated, use an " + r"@contextlib\.asynccontextmanager function instead" + r":DeprecationWarning" + r":starlette.routing" +) + + +@deprecated_lifespan +def test_app_async_gen_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete -def sync__gen_lifespan(): - yield # pragma: no cover +@deprecated_lifespan +def test_app_sync_gen_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + app = Starlette(lifespan=lifespan) -@pytest.mark.parametrize("lifespan", [async_gen_lifespan, sync__gen_lifespan]) -def test_app_gen_lifespan(lifespan): - with pytest.raises( - RuntimeError, - match="Generator functions are not supported for lifespan" - ", use an @contextlib.asynccontextmanager function instead.", - ): - Starlette(lifespan=lifespan) + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete def test_decorator_deprecations() -> None: From 246a61f474f3fc5c6144ba709446e6b4bf7ae0a8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 4 Mar 2023 20:04:31 +0100 Subject: [PATCH 4/7] Apply suggestions from code review Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- tests/test_routing.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_routing.py b/tests/test_routing.py index 18dfbdee8..6a9a50ef3 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -674,7 +674,12 @@ def test_lifespan_with_state(test_client_factory): shutdown_complete = False async def hello_world(request): + # modifications to the state should not leak across requests + assert request.state.count == 0 + # modify the state, this should not leak to the lifespan or other requests request.state.count += 1 + # since state.list is a mutable object this modification _will_ leak across + # requests and to the lifespan request.state.list.append(1) return PlainTextResponse("hello, world") @@ -687,7 +692,10 @@ async def run_startup(state): async def run_shutdown(state): nonlocal shutdown_complete shutdown_complete = True + # modifications made to the state from a request do not leak to the lifespan assert state["count"] == 0 + # unless of course the request mutates a mutable object that is referenced + # via state assert state["list"] == [1, 1] app = Router( From fa07156b3fc6d6755e84f4132df81352b760abdb Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sat, 4 Mar 2023 20:21:50 +0100 Subject: [PATCH 5/7] Add test for asynccontextmanager --- tests/test_routing.py | 46 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_routing.py b/tests/test_routing.py index 6a9a50ef3..c2cbf65a4 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,3 +1,4 @@ +import contextlib import functools import typing import uuid @@ -716,6 +717,51 @@ async def run_shutdown(state): assert shutdown_complete +def test_lifespan_async_cm(test_client_factory): + startup_complete = False + shutdown_complete = False + + async def hello_world(request): + # modifications to the state should not leak across requests + assert request.state.count == 0 + # modify the state, this should not leak to the lifespan or other requests + request.state.count += 1 + # since state.list is a mutable object this modification _will_ leak across + # requests and to the lifespan + request.state.list.append(1) + return PlainTextResponse("hello, world") + + @contextlib.asynccontextmanager + async def lifespan(app: Starlette, state: typing.Dict[str, typing.Any]): + nonlocal startup_complete, shutdown_complete + startup_complete = True + state["count"] = 0 + state["list"] = [] + yield + shutdown_complete = True + # modifications made to the state from a request do not leak to the lifespan + assert state["count"] == 0 + # unless of course the request mutates a mutable object that is referenced + # via state + assert state["list"] == [1, 1] + + app = Router( + lifespan=lifespan, + routes=[Route("/", hello_world)], + ) + + assert not startup_complete + assert not shutdown_complete + with test_client_factory(app) as client: + assert startup_complete + assert not shutdown_complete + client.get("/") + # Calling it a second time to ensure that the state is preserved. + client.get("/") + assert startup_complete + assert shutdown_complete + + def test_raise_on_startup(test_client_factory): def run_startup(): raise RuntimeError() From c791251347ee511da5f89b8e8db791a88795b411 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 5 Mar 2023 07:34:44 -0600 Subject: [PATCH 6/7] Add test for server not supporting lifespan state but app requiring it (#2061) --- starlette/routing.py | 19 +++++++++++++------ starlette/types.py | 4 ++-- tests/test_routing.py | 21 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 31b870129..e8fc47074 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -679,15 +679,22 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") - state = scope.get("state") await receive() + lifespan_needs_state = ( + len(inspect.signature(self.lifespan_context).parameters) == 2 + ) + server_supports_state = "state" in scope + if lifespan_needs_state and not server_supports_state: + raise RuntimeError( + 'This server does not support "state" in the lifespan scope.' + " Please try updating your ASGI server." + ) try: lifespan_context: Lifespan - if ( - len(inspect.signature(self.lifespan_context).parameters) == 2 - and state is not None - ): - lifespan_context = functools.partial(self.lifespan_context, state=state) + if lifespan_needs_state: + lifespan_context = functools.partial( + self.lifespan_context, state=scope["state"] + ) else: lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) async with lifespan_context(app): diff --git a/starlette/types.py b/starlette/types.py index a11c68a7e..b83d9101a 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -8,8 +8,8 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] -StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager] +StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]] StateLifespan = typing.Callable[ - [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager + [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any] ] Lifespan = typing.Union[StatelessLifespan, StateLifespan] diff --git a/tests/test_routing.py b/tests/test_routing.py index c2cbf65a4..b6f518450 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -717,6 +717,27 @@ async def run_shutdown(state): assert shutdown_complete +def test_lifespan_state_unsupported(test_client_factory): + @contextlib.asynccontextmanager + async def lifespan(app, scope): + yield None # pragma: no cover + + app = Router( + lifespan=lifespan, + routes=[Mount("/", PlainTextResponse("hello, world"))], + ) + + async def no_state_wrapper(scope, receive, send): + del scope["state"] + await app(scope, receive, send) + + with pytest.raises( + RuntimeError, match='This server does not support "state" in the lifespan scope' + ): + with test_client_factory(no_state_wrapper): + raise AssertionError("Should not be called") # pragma: no cover + + def test_lifespan_async_cm(test_client_factory): startup_complete = False shutdown_complete = False From b5aac4de8bcd1b982874ea49ba74d5f6551bab13 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 5 Mar 2023 14:41:30 +0100 Subject: [PATCH 7/7] Cutefy Adrian's improvements --- starlette/routing.py | 10 ++++------ tests/test_routing.py | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index e8fc47074..92b81f0bc 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -679,22 +679,20 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") + state = scope.get("state") await receive() lifespan_needs_state = ( len(inspect.signature(self.lifespan_context).parameters) == 2 ) - server_supports_state = "state" in scope + server_supports_state = state is not None if lifespan_needs_state and not server_supports_state: raise RuntimeError( - 'This server does not support "state" in the lifespan scope.' - " Please try updating your ASGI server." + 'The server does not support "state" in the lifespan scope.' ) try: lifespan_context: Lifespan if lifespan_needs_state: - lifespan_context = functools.partial( - self.lifespan_context, state=scope["state"] - ) + lifespan_context = functools.partial(self.lifespan_context, state=state) else: lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) async with lifespan_context(app): diff --git a/tests/test_routing.py b/tests/test_routing.py index b6f518450..636679757 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -732,7 +732,7 @@ async def no_state_wrapper(scope, receive, send): await app(scope, receive, send) with pytest.raises( - RuntimeError, match='This server does not support "state" in the lifespan scope' + RuntimeError, match='The server does not support "state" in the lifespan scope' ): with test_client_factory(no_state_wrapper): raise AssertionError("Should not be called") # pragma: no cover