Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,15 +679,22 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
started = False
app = scope.get("app")
state = scope.get("state")
await receive()
lifespan_needs_state = (
len(inspect.signature(self.lifespan_context).parameters) == 2
)
server_supports_state = "state" in scope
if lifespan_needs_state and not server_supports_state:
raise RuntimeError(
'This server does not support "state" in the lifespan scope.'
" Please try updating your ASGI server."
)
try:
lifespan_context: Lifespan
if (
len(inspect.signature(self.lifespan_context).parameters) == 2
and state is not None
):
lifespan_context = functools.partial(self.lifespan_context, state=state)
if lifespan_needs_state:
lifespan_context = functools.partial(
self.lifespan_context, state=scope["state"]
)
else:
lifespan_context = typing.cast(StatelessLifespan, self.lifespan_context)
async with lifespan_context(app):
Expand Down
4 changes: 2 additions & 2 deletions starlette/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]

StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager]
StatelessLifespan = typing.Callable[[object], typing.AsyncContextManager[typing.Any]]
StateLifespan = typing.Callable[
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager
[typing.Any, typing.Dict[str, typing.Any]], typing.AsyncContextManager[typing.Any]
]
Lifespan = typing.Union[StatelessLifespan, StateLifespan]
21 changes: 21 additions & 0 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,27 @@ async def run_shutdown(state):
assert shutdown_complete


def test_lifespan_state_unsupported(test_client_factory):
@contextlib.asynccontextmanager
async def lifespan(app, scope):
yield None # pragma: no cover

app = Router(
lifespan=lifespan,
routes=[Mount("/", PlainTextResponse("hello, world"))],
)

async def no_state_wrapper(scope, receive, send):
del scope["state"]
await app(scope, receive, send)

with pytest.raises(
RuntimeError, match='This server does not support "state" in the lifespan scope'
):
with test_client_factory(no_state_wrapper):
raise AssertionError("Should not be called") # pragma: no cover


def test_lifespan_async_cm(test_client_factory):
startup_complete = False
shutdown_complete = False
Expand Down