Skip to content

Commit

Permalink
Do not overwrite "path" and "root_path" scope keys (#2352)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Dec 1, 2023
1 parent 164b350 commit e8f0dcd
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 15 deletions.
4 changes: 1 addition & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,7 @@ def base_url(self) -> URL:
base_url_scope = dict(self.scope)
base_url_scope["path"] = "/"
base_url_scope["query_string"] = b""
base_url_scope["root_path"] = base_url_scope.get(
"app_root_path", base_url_scope.get("root_path", "")
)
base_url_scope["root_path"] = base_url_scope.get("root_path", "")
self._base_url = URL(scope=base_url_scope)
return self._base_url

Expand Down
30 changes: 21 additions & 9 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,11 @@ def __init__(
self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "http":
match = self.path_regex.match(scope["path"])
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -338,8 +341,11 @@ def __init__(
self.path_regex, self.path_format, self.param_convertors = compile_path(path)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] == "websocket":
match = self.path_regex.match(scope["path"])
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
match = self.path_regex.match(path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
Expand Down Expand Up @@ -410,23 +416,25 @@ def routes(self) -> typing.List[BaseRoute]:
return getattr(self._base_app, "routes", [])

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
path_params: "typing.Dict[str, typing.Any]"
if scope["type"] in ("http", "websocket"):
path = scope["path"]
match = self.path_regex.match(path)
root_path = scope.get("route_root_path", scope.get("root_path", ""))
route_path = scope.get("route_path", re.sub(r"^" + root_path, "", path))
match = self.path_regex.match(route_path)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
remaining_path = "/" + matched_params.pop("path")
matched_path = path[: -len(remaining_path)]
matched_path = route_path[: -len(remaining_path)]
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
root_path = scope.get("root_path", "")
child_scope = {
"path_params": path_params,
"app_root_path": scope.get("app_root_path", root_path),
"root_path": root_path + matched_path,
"path": remaining_path,
"route_root_path": root_path + matched_path,
"route_path": remaining_path,
"endpoint": self.app,
}
return Match.FULL, child_scope
Expand Down Expand Up @@ -767,11 +775,15 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await partial.handle(scope, receive, send)
return

if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
if scope["type"] == "http" and self.redirect_slashes and path != "/":
redirect_scope = dict(scope)
if scope["path"].endswith("/"):
if path.endswith("/"):
redirect_scope["route_path"] = path.rstrip("/")
redirect_scope["path"] = redirect_scope["path"].rstrip("/")
else:
redirect_scope["route_path"] = path + "/"
redirect_scope["path"] = redirect_scope["path"] + "/"

for route in self.routes:
Expand Down
5 changes: 4 additions & 1 deletion starlette/staticfiles.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import os
import re
import stat
import typing
from email.utils import parsedate
Expand Down Expand Up @@ -108,7 +109,9 @@ def get_path(self, scope: Scope) -> str:
Given the ASGI scope, return the `path` string to serve up,
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/"))) # type: ignore[no-any-return] # noqa: E501
root_path = scope.get("route_root_path", scope.get("root_path", ""))
path = scope.get("route_path", re.sub(r"^" + root_path, "", scope["path"]))
return os.path.normpath(os.path.join(*path.split("/"))) # type: ignore[no-any-return] # noqa: E501

async def get_response(self, path: str, scope: Scope) -> Response:
"""
Expand Down
58 changes: 56 additions & 2 deletions tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def websocket_params(session: WebSocket):


@pytest.fixture
def client(test_client_factory):
def client(test_client_factory: typing.Callable[..., TestClient]):
with test_client_factory(app) as client:
yield client

Expand All @@ -170,7 +170,7 @@ def client(test_client_factory):
r":UserWarning"
r":charset_normalizer.api"
)
def test_router(client):
def test_router(client: TestClient):
response = client.get("/")
assert response.status_code == 200
assert response.text == "Hello, world"
Expand Down Expand Up @@ -1210,3 +1210,57 @@ async def startup() -> None:
... # pragma: nocover

router.on_event("startup")(startup)


async def echo_paths(request: Request, name: str):
return JSONResponse(
{
"name": name,
"path": request.scope["path"],
"root_path": request.scope["root_path"],
}
)


echo_paths_routes = [
Route(
"/path",
functools.partial(echo_paths, name="path"),
name="path",
methods=["GET"],
),
Mount(
"/root",
name="mount",
routes=[
Route(
"/path",
functools.partial(echo_paths, name="subpath"),
name="subpath",
methods=["GET"],
)
],
),
]


def test_paths_with_root_path(test_client_factory: typing.Callable[..., TestClient]):
app = Starlette(routes=echo_paths_routes)
client = test_client_factory(
app, base_url="https://www.example.org/", root_path="/root"
)
response = client.get("/root/path")
assert response.status_code == 200
assert response.json() == {
"name": "path",
"path": "/root/path",
"root_path": "/root",
}

response = client.get("/root/root/path")
assert response.status_code == 200
assert response.json() == {
"name": "subpath",
"path": "/root/root/path",
"root_path": "/root",
}

0 comments on commit e8f0dcd

Please sign in to comment.