Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def __init__(
methods: typing.Optional[typing.List[str]] = None,
name: typing.Optional[str] = None,
include_in_schema: bool = True,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -236,6 +237,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

if methods is None:
self.methods = None
else:
Expand Down Expand Up @@ -309,6 +314,7 @@ def __init__(
endpoint: typing.Callable[..., typing.Any],
*,
name: typing.Optional[str] = None,
middleware: typing.Optional[typing.Sequence[Middleware]] = None,
) -> None:
assert path.startswith("/"), "Routed paths must start with '/'"
self.path = path
Expand All @@ -325,6 +331,10 @@ def __init__(
# Endpoint is a class. Treat it as ASGI.
self.app = endpoint

if middleware is not None:
for cls, options in reversed(middleware):
self.app = cls(app=self.app, **options)

self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def delete( # type: ignore[override]

def websocket_connect(
self, url: str, subprotocols: typing.Sequence[str] = None, **kwargs: typing.Any
) -> typing.Any:
) -> "WebSocketTestSession":
url = urljoin("ws://testserver", url)
headers = kwargs.get("headers", {})
headers.setdefault("connection", "upgrade")
Expand Down
53 changes: 52 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,18 @@ def assert_middleware_header_route(request: Request) -> Response:
return Response()


route_with_middleware = Starlette(
routes=[
Route(
"/http",
endpoint=assert_middleware_header_route,
methods=["GET"],
middleware=[Middleware(AddHeadersMiddleware)],
),
Route("/home", homepage),
]
)

mounted_routes_with_middleware = Starlette(
routes=[
Mount(
Expand Down Expand Up @@ -960,9 +972,10 @@ def assert_middleware_header_route(request: Request) -> Response:
[
mounted_routes_with_middleware,
mounted_app_with_middleware,
route_with_middleware,
],
)
def test_mount_middleware(
def test_base_route_middleware(
test_client_factory: typing.Callable[..., TestClient],
app: Starlette,
) -> None:
Expand Down Expand Up @@ -1076,6 +1089,44 @@ async def modified_send(msg: Message) -> None:
assert "X-Mounted" in resp.headers


def test_websocket_route_middleware(
test_client_factory: typing.Callable[..., TestClient]
):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()

class WebsocketMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
async def modified_send(msg: Message) -> None:
if msg["type"] == "websocket.accept":
msg["headers"].append((b"X-Test", b"Set by middleware"))
await send(msg)

await self.app(scope, receive, modified_send)

app = Starlette(
routes=[
WebSocketRoute(
"/ws",
endpoint=websocket_endpoint,
middleware=[Middleware(WebsocketMiddleware)],
)
]
)

client = test_client_factory(app)

with client.websocket_connect("/ws") as websocket:
text = websocket.receive_text()
assert text == "Hello, world!"
assert websocket.extra_headers == [(b"X-Test", b"Set by middleware")]


def test_route_repr() -> None:
route = Route("/welcome", endpoint=homepage)
assert (
Expand Down