Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ venv*/
.python-version
build/
dist/
.idea
/.venv/
8 changes: 6 additions & 2 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@


class Starlette:
"""Creates an Starlette application."""
"""Creates a Starlette application."""

# This is the default router class used by Starlette. if you want to customized router have to
# override this variable in your subclass.
router_class = Router

def __init__(
self: AppType,
Expand Down Expand Up @@ -71,7 +75,7 @@ def __init__(

self.debug = debug
self.state = State()
self.router = Router(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
self.router = self.router_class(routes, on_startup=on_startup, on_shutdown=on_shutdown, lifespan=lifespan)
self.exception_handlers = {} if exception_handlers is None else dict(exception_handlers)
self.user_middleware = [] if middleware is None else list(middleware)
self.middleware_stack: ASGIApp | None = None
Expand Down
9 changes: 7 additions & 2 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,11 @@ def __call__(self: _T, app: object) -> _T:


class Router:
# The default route and websocket route classes. if you want to use customized route classes
# you have to override this class variables in your subclass.
route_class = Route
websocket_route_class = WebSocketRoute

def __init__(
self,
routes: typing.Sequence[BaseRoute] | None = None,
Expand Down Expand Up @@ -782,7 +787,7 @@ def add_route(
name: str | None = None,
include_in_schema: bool = True,
) -> None: # pragma: no cover
route = Route(
route = self.route_class(
path,
endpoint=endpoint,
methods=methods,
Expand All @@ -797,7 +802,7 @@ def add_websocket_route(
endpoint: typing.Callable[[WebSocket], typing.Awaitable[None]],
name: str | None = None,
) -> None: # pragma: no cover
route = WebSocketRoute(path, endpoint=endpoint, name=name)
route = self.websocket_route_class(path, endpoint=endpoint, name=name)
self.routes.append(route)

def route(
Expand Down