diff --git a/.gitignore b/.gitignore index bff8fa258..eab648d1b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,5 @@ venv*/ .python-version build/ dist/ +.idea +/.venv/ diff --git a/starlette/applications.py b/starlette/applications.py index 6df5a707c..620b67490 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -25,7 +25,11 @@ class Starlette: - """Creates an Starlette application.""" + """Creates a Starlette application.""" + + # This is the default router class used by Starlette. if you want to customized router have to + # override this variable in your subclass. + router_class = Router def __init__( self: AppType, @@ -71,7 +75,7 @@ def __init__( self.debug = debug self.state = State() - self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) + self.router = self.router_class(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan) self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers) self.user_middleware = [] if middleware is None else list(middleware) self.middleware_stack: ASGIApp | None = None diff --git a/starlette/routing.py b/starlette/routing.py index add7df0c2..965af0af0 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -576,6 +576,11 @@ def __call__(self: _T, app: object) -> _T: class Router: + # The default route and websocket route classes. if you want to use customized route classes + # you have to override this class variables in your subclass. + route_class = Route + websocket_route_class = WebSocketRoute + def __init__( self, routes: typing.Sequence[BaseRoute] | None = None, @@ -782,7 +787,7 @@ def add_route( name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: no cover - route = Route( + route = self.route_class( path, endpoint=endpoint, methods=methods, @@ -797,7 +802,7 @@ def add_websocket_route( endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]], name: str | None = None, ) -> None: # pragma: no cover - route = WebSocketRoute(path, endpoint=endpoint, name=name) + route = self.websocket_route_class(path, endpoint=endpoint, name=name) self.routes.append(route) def route(