From a68f91a21b6bc7e27e8adb1d3b05c163aaf738ed Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 4 Mar 2023 15:35:40 -0600 Subject: [PATCH 1/2] Test for state required but not supported by server --- starlette/routing.py | 19 +++++++++++++------ starlette/types.py | 4 ++-- tests/test_routing.py | 21 +++++++++++++++++++++ 3 files changed, 36 insertions(+), 8 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 31b870129..e8fc47074 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -679,15 +679,22 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") - state = scope.get("state") await receive() + lifespan_needs_state = ( + len(inspect.signature(self.lifespan_context).parameters) == 2 + ) + server_supports_state = "state" in scope + if lifespan_needs_state and not server_supports_state: + raise RuntimeError( + 'This server does not support "state" in the lifespan scope.' + " Please try updating your ASGI server." + ) try: lifespan_context: Lifespan - if ( - len(inspect.signature(self.lifespan_context).parameters) == 2 - and state is not None - ): - lifespan_context = functools.partial(self.lifespan_context, state=state) + if lifespan_needs_state: + lifespan_context = functools.partial( + self.lifespan_context, state=scope["state"] + ) else: lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context) async with lifespan_context(app): diff --git a/starlette/types.py b/starlette/types.py index a11c68a7e..b83d9101a 100644 --- a/starlette/types.py +++ b/starlette/types.py @@ -8,8 +8,8 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]] -StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager] +StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]] StateLifespan = typing.Callable[ - [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager + [typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any] ] Lifespan = typing.Union[StatelessLifespan, StateLifespan] diff --git a/tests/test_routing.py b/tests/test_routing.py index c2cbf65a4..dab8cec59 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -717,6 +717,27 @@ async def run_shutdown(state): assert shutdown_complete +def test_lifespan_state_unsupported(test_client_factory): + @contextlib.asynccontextmanager + async def lifespan(app, scope): + yield None + + app = Router( + lifespan=lifespan, + routes=[Mount("/", PlainTextResponse("hello, world"))], + ) + + async def no_state_wrapper(scope, receive, send): + scope.pop("state", None) + await app(scope, receive, send) + + with pytest.raises( + RuntimeError, match='This server does not support "state" in the lifespan scope' + ): + with test_client_factory(no_state_wrapper): + raise AssertionError("Should not be called") + + def test_lifespan_async_cm(test_client_factory): startup_complete = False shutdown_complete = False From 5e441d8479cfa7f97515c4a4a0c5de10c3618dc0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 4 Mar 2023 15:37:45 -0600 Subject: [PATCH 2/2] pragma no cover --- tests/test_routing.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index dab8cec59..b6f518450 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -720,7 +720,7 @@ async def run_shutdown(state): def test_lifespan_state_unsupported(test_client_factory): @contextlib.asynccontextmanager async def lifespan(app, scope): - yield None + yield None # pragma: no cover app = Router( lifespan=lifespan, @@ -728,14 +728,14 @@ async def lifespan(app, scope): ) async def no_state_wrapper(scope, receive, send): - scope.pop("state", None) + del scope["state"] await app(scope, receive, send) with pytest.raises( RuntimeError, match='This server does not support "state" in the lifespan scope' ): with test_client_factory(no_state_wrapper): - raise AssertionError("Should not be called") + raise AssertionError("Should not be called") # pragma: no cover def test_lifespan_async_cm(test_client_factory):