diff --git a/starlette/requests.py b/starlette/requests.py index 5c7a4296c..83a52aca1 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -101,9 +101,7 @@ def base_url(self) -> URL: base_url_scope = dict(self.scope) base_url_scope["path"] = "/" base_url_scope["query_string"] = b"" - base_url_scope["root_path"] = base_url_scope.get( - "app_root_path", base_url_scope.get("root_path", "") - ) + base_url_scope["root_path"] = base_url_scope.get("root_path", "") self._base_url = URL(scope=base_url_scope) return self._base_url diff --git a/starlette/routing.py b/starlette/routing.py index cb1b1d919..b167f3513 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -251,8 +251,11 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "http": - match = self.path_regex.match(scope["path"]) + root_path = scope.get("route_root_path", scope.get("root_path", "")) + path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) + match = self.path_regex.match(path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): @@ -338,8 +341,11 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + path_params: "typing.Dict[str, typing.Any]" if scope["type"] == "websocket": - match = self.path_regex.match(scope["path"]) + root_path = scope.get("route_root_path", scope.get("root_path", "")) + path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) + match = self.path_regex.match(path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): @@ -410,23 +416,25 @@ def routes(self) -> typing.List[BaseRoute]: return getattr(self._base_app, "routes", []) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + path_params: "typing.Dict[str, typing.Any]" if scope["type"] in ("http", "websocket"): path = scope["path"] - match = self.path_regex.match(path) + root_path = scope.get("route_root_path", scope.get("root_path", "")) + route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path)) + match = self.path_regex.match(route_path) if match: matched_params = match.groupdict() for key, value in matched_params.items(): matched_params[key] = self.param_convertors[key].convert(value) remaining_path = "/" + matched_params.pop("path") - matched_path = path[: -len(remaining_path)] + matched_path = route_path[: -len(remaining_path)] path_params = dict(scope.get("path_params", {})) path_params.update(matched_params) root_path = scope.get("root_path", "") child_scope = { "path_params": path_params, - "app_root_path": scope.get("app_root_path", root_path), - "root_path": root_path + matched_path, - "path": remaining_path, + "route_root_path": root_path + matched_path, + "route_path": remaining_path, "endpoint": self.app, } return Match.FULL, child_scope @@ -767,11 +775,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await partial.handle(scope, receive, send) return - if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/": + root_path = scope.get("route_root_path", scope.get("root_path", "")) + path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) + if scope["type"] == "http" and self.redirect_slashes and path != "/": redirect_scope = dict(scope) - if scope["path"].endswith("/"): + if path.endswith("/"): + redirect_scope["route_path"] = path.rstrip("/") redirect_scope["path"] = redirect_scope["path"].rstrip("/") else: + redirect_scope["route_path"] = path + "/" redirect_scope["path"] = redirect_scope["path"] + "/" for route in self.routes: diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 2f1f1ddab..f36d58664 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -1,5 +1,6 @@ import importlib.util import os +import re import stat import typing from email.utils import parsedate @@ -108,7 +109,9 @@ def get_path(self, scope: Scope) -> str: Given the ASGI scope, return the `path` string to serve up, with OS specific path separators, and any '..', '.' components removed. """ - return os.path.normpath(os.path.join(*scope["path"].split("/"))) # type: ignore[no-any-return] # noqa: E501 + root_path = scope.get("route_root_path", scope.get("root_path", "")) + path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"])) + return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501 async def get_response(self, path: str, scope: Scope) -> Response: """ diff --git a/tests/test_routing.py b/tests/test_routing.py index 7644f5a24..1e99e335b 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -159,7 +159,7 @@ async def websocket_params(session: WebSocket): @pytest.fixture -def client(test_client_factory): +def client(test_client_factory: typing.Callable[..., TestClient]): with test_client_factory(app) as client: yield client @@ -170,7 +170,7 @@ def client(test_client_factory): r":UserWarning" r":charset_normalizer.api" ) -def test_router(client): +def test_router(client: TestClient): response = client.get("/") assert response.status_code == 200 assert response.text == "Hello, world" @@ -1210,3 +1210,57 @@ async def startup() -> None: ... # pragma: nocover router.on_event("startup")(startup) + + +async def echo_paths(request: Request, name: str): + return JSONResponse( + { + "name": name, + "path": request.scope["path"], + "root_path": request.scope["root_path"], + } + ) + + +echo_paths_routes = [ + Route( + "/path", + functools.partial(echo_paths, name="path"), + name="path", + methods=["GET"], + ), + Mount( + "/root", + name="mount", + routes=[ + Route( + "/path", + functools.partial(echo_paths, name="subpath"), + name="subpath", + methods=["GET"], + ) + ], + ), +] + + +def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]): + app = Starlette(routes=echo_paths_routes) + client = test_client_factory( + app, base_url="https://www.example.org/", root_path="/root" + ) + response = client.get("/root/path") + assert response.status_code == 200 + assert response.json() == { + "name": "path", + "path": "/root/path", + "root_path": "/root", + } + + response = client.get("/root/root/path") + assert response.status_code == 200 + assert response.json() == { + "name": "subpath", + "path": "/root/root/path", + "root_path": "/root", + }