diff --git a/starlette/applications.py b/starlette/applications.py index c95fa0a5c..b754bb967 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -8,6 +8,12 @@ from starlette.routing import BaseRoute, Router from starlette.types import ASGIApp, Receive, Scope, Send +_ExcKey = typing.TypeVar( + "_ExcKey", + bound=typing.Union[int, typing.Type[Exception]], + contravariant=True, +) + class Starlette: """ @@ -23,11 +29,10 @@ class Starlette: any uncaught errors occurring anywhere in the entire stack. `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled exception cases occurring in the routing or endpoints. - * **exception_handlers** - A dictionary mapping either integer status codes, - or exception class types onto callables which handle the exceptions. - Exception handler callables should be of the form - `handler(request, exc) -> response` and may be be either standard functions, or - async functions. + * **exception_handlers** - A mapping of either integer status codes, or exception + class types onto callables which handle the exceptions. Exception handler + callables should be of the form `handler(request, exc) -> response` and may be + either standard functions, or async functions. * **on_startup** - A list of callables to run on application startup. Startup handler callables do not take any arguments, and may be be either standard functions, or async functions. @@ -41,9 +46,7 @@ def __init__( debug: bool = False, routes: typing.Sequence[BaseRoute] = None, middleware: typing.Sequence[Middleware] = None, - exception_handlers: typing.Dict[ - typing.Union[int, typing.Type[Exception]], typing.Callable - ] = None, + exception_handlers: typing.Mapping[_ExcKey, typing.Callable] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, @@ -59,7 +62,7 @@ def __init__( self.router = Router( routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan ) - self.exception_handlers = ( + self.exception_handlers: typing.Dict[_ExcKey, typing.Callable] = ( {} if exception_handlers is None else dict(exception_handlers) ) self.user_middleware = [] if middleware is None else list(middleware) @@ -128,7 +131,7 @@ def add_middleware(self, middleware_class: type, **options: typing.Any) -> None: def add_exception_handler( self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + exc_class_or_status_code: _ExcKey, handler: typing.Callable, ) -> None: self.exception_handlers[exc_class_or_status_code] = handler @@ -154,9 +157,7 @@ def add_websocket_route( ) -> None: self.router.add_websocket_route(path, route, name=name) - def exception_handler( - self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]] - ) -> typing.Callable: + def exception_handler(self, exc_class_or_status_code: _ExcKey) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_exception_handler(exc_class_or_status_code, func) return func