diff --git a/shiny/_autoreload.py b/shiny/_autoreload.py index 40d836525..b4079d990 100644 --- a/shiny/_autoreload.py +++ b/shiny/_autoreload.py @@ -8,8 +8,9 @@ import secrets import threading import webbrowser -from typing import Callable, Optional +from typing import Callable, Optional, cast +import starlette.types from asgiref.typing import ( ASGI3Application, ASGIReceiveCallable, @@ -90,8 +91,19 @@ class InjectAutoreloadMiddleware: because we want autoreload to be effective even when displaying an error page. """ - def __init__(self, app: ASGI3Application): - self.app = app + def __init__( + self, + app: starlette.types.ASGIApp | ASGI3Application, + *args: object, + **kwargs: object, + ): + if len(args) > 0 or len(kwargs) > 0: + raise TypeError( + f"InjectAutoreloadMiddleware does not support positional or keyword arguments, received {args}, {kwargs}" + ) + # The starlette types and the asgiref types are compatible, but we'll use the + # latter internally. See the note in the __call__ method for more details. + self.app = cast(ASGI3Application, app) ws_url = autoreload_url() self.script = ( f""" @@ -103,10 +115,22 @@ def __init__(self, app: ASGI3Application): ) async def __call__( - self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable + self, + scope: starlette.types.Scope | Scope, + receive: starlette.types.Receive | ASGIReceiveCallable, + send: starlette.types.Send | ASGISendCallable, ) -> None: - if scope["type"] != "http" or scope["path"] != "/" or len(self.script) == 0: - return await self.app(scope, receive, send) + # The starlette types and the asgiref types are compatible, but the latter are + # more rigorous. In the call interface, we accept both types for compatibility + # with both. But internally we'll use the more rigorous types. + # See https://github.com/encode/starlette/blob/39dccd9/docs/middleware.md#type-annotations + scope = cast(Scope, scope) + receive_casted = cast(ASGIReceiveCallable, receive) + send_casted = cast(ASGISendCallable, send) + if scope["type"] != "http": + return await self.app(scope, receive_casted, send_casted) + if scope["path"] != "/" or len(self.script) == 0: + return await self.app(scope, receive_casted, send_casted) def mangle_callback(body: bytes) -> tuple[bytes, bool]: if b"" in body: @@ -114,8 +138,8 @@ def mangle_callback(body: bytes) -> tuple[bytes, bool]: else: return (body, False) - mangler = ResponseMangler(send, mangle_callback) - await self.app(scope, receive, mangler.send) + mangler = ResponseMangler(send_casted, mangle_callback) + await self.app(scope, receive_casted, mangler.send) # PARENT PROCESS ------------------------------------------------------------