diff --git a/starlette/applications.py b/starlette/applications.py index 82cfbdca0..3b83a3f6e 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -67,6 +67,9 @@ def on_event(self, event_type: str) -> typing.Callable: def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.router.mount(path, app=app, name=name) + def host(self, host: str, app: ASGIApp, name: str = None) -> None: + self.router.host(host, app=app, name=name) + def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None: self.error_middleware.app = middleware_class( self.error_middleware.app, **kwargs diff --git a/starlette/datastructures.py b/starlette/datastructures.py index 26b0486ba..822148e86 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -157,26 +157,35 @@ def replace(self, **kwargs: typing.Any) -> "URL": class URLPath(str): """ - A URL path string that also holds an associated protocol. + A URL path string that may also hold an associated protocol and/or host. Used by the routing to return `url_path_for` matches. """ - def __new__(cls, path: str, protocol: str) -> str: - assert protocol in ("http", "websocket") + def __new__(cls, path: str, protocol: str = "", host: str = "") -> str: + assert protocol in ("http", "websocket", "") return str.__new__(cls, path) # type: ignore - def __init__(self, path: str, protocol: str) -> None: + def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol + self.host = host def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: if isinstance(base_url, str): base_url = URL(base_url) - scheme = { - "http": {True: "https", False: "http"}, - "websocket": {True: "wss", False: "ws"}, - }[self.protocol][base_url.is_secure] - netloc = base_url.netloc - return str(URL(scheme=scheme, netloc=base_url.netloc, path=str(self))) + if self.protocol: + scheme = { + "http": {True: "https", False: "http"}, + "websocket": {True: "wss", False: "ws"}, + }[self.protocol][base_url.is_secure] + else: + scheme = base_url.scheme + + if self.host: + netloc = self.host + else: + netloc = base_url.netloc + + return str(URL(scheme=scheme, netloc=netloc, path=str(self))) class Secret: diff --git a/starlette/routing.py b/starlette/routing.py index 893bbf192..0d499f1d3 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -6,7 +6,7 @@ from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor -from starlette.datastructures import URL, URLPath +from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse @@ -84,6 +84,10 @@ def replace_params( return path, path_params +# Match parameters in URL paths, eg. '{param}', and '{param:int}' +PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") + + def compile_path( path: str ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: @@ -124,9 +128,6 @@ def compile_path( return re.compile(path_regex), path_format, param_convertors -PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}") - - class BaseRoute: def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: raise NotImplementedError() # pragma: no cover @@ -316,7 +317,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath: self.path_format, self.param_convertors, path_params ) if not remaining_params: - return URLPath(path=path, protocol="http") + return URLPath(path=path) elif self.name is None or name.startswith(self.name + ":"): if self.name is None: # No mount name. @@ -349,6 +350,69 @@ def __eq__(self, other: typing.Any) -> bool: ) +class Host(BaseRoute): + def __init__(self, host: str, app: ASGIApp, name: str = None) -> None: + self.host = host + self.app = app + self.name = name + self.host_regex, self.host_format, self.param_convertors = compile_path(host) + + @property + def routes(self) -> typing.List[BaseRoute]: + return getattr(self.app, "routes", None) + + def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: + headers = Headers(scope=scope) + host = headers.get("host", "").split(":")[0] + match = self.host_regex.match(host) + if match: + matched_params = match.groupdict() + for key, value in matched_params.items(): + matched_params[key] = self.param_convertors[key].convert(value) + path_params = dict(scope.get("path_params", {})) + path_params.update(matched_params) + child_scope = {"path_params": path_params, "endpoint": self.app} + return Match.FULL, child_scope + return Match.NONE, {} + + def url_path_for(self, name: str, **path_params: str) -> URLPath: + if self.name is not None and name == self.name and "path" in path_params: + # 'name' matches "". + path = path_params.pop("path") + host, remaining_params = replace_params( + self.host_format, self.param_convertors, path_params + ) + if not remaining_params: + return URLPath(path=path, host=host) + elif self.name is None or name.startswith(self.name + ":"): + if self.name is None: + # No mount name. + remaining_name = name + else: + # 'name' matches ":". + remaining_name = name[len(self.name) + 1 :] + host, remaining_params = replace_params( + self.host_format, self.param_convertors, path_params + ) + for route in self.routes or []: + try: + url = route.url_path_for(remaining_name, **remaining_params) + return URLPath(path=str(url), protocol=url.protocol, host=host) + except NoMatchFound as exc: + pass + raise NoMatchFound() + + def __call__(self, scope: Scope) -> ASGIInstance: + return self.app(scope) + + def __eq__(self, other: typing.Any) -> bool: + return ( + isinstance(other, Host) + and self.host == other.host + and self.app == other.app + ) + + class Router: def __init__( self, @@ -364,6 +428,10 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None: route = Mount(path, app=app, name=name) self.routes.append(route) + def host(self, host: str, app: ASGIApp, name: str = None) -> None: + route = Host(host, app=app, name=name) + self.routes.append(route) + def add_route( self, path: str, diff --git a/tests/test_applications.py b/tests/test_applications.py index 0a0345d13..0c62af93d 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -4,28 +4,16 @@ from starlette.datastructures import Headers from starlette.endpoints import HTTPEndpoint from starlette.exceptions import HTTPException +from starlette.middleware.trustedhost import TrustedHostMiddleware from starlette.responses import JSONResponse, PlainTextResponse -from starlette.routing import Mount, Route, Router, WebSocketRoute +from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient - -class TrustedHostMiddleware: - def __init__(self, app, hostname): - self.app = app - self.hostname = hostname - - def __call__(self, scope): - headers = Headers(scope=scope) - if headers.get("host") != self.hostname: - return PlainTextResponse("Invalid host header", status_code=400) - return self.app(scope) - - app = Starlette() -app.add_middleware(TrustedHostMiddleware, hostname="testserver") +app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"]) @app.exception_handler(500) @@ -76,6 +64,17 @@ def user_page(request): app.mount("/users", users) +subdomain = Router() + + +@subdomain.route("/") +def custom_subdomain(request): + return PlainTextResponse("Subdomain: " + request.path_params["subdomain"]) + + +app.host("{subdomain}.example.org", subdomain) + + @app.route("/500") def runtime_error(request): raise RuntimeError() @@ -129,6 +128,14 @@ def test_mounted_route_path_params(): assert response.text == "Hello, tomchristie!" +def test_subdomain_route(): + client = TestClient(app, base_url="https://foo.example.org/") + + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Subdomain: foo" + + def test_websocket_route(): with client.websocket_connect("/ws") as session: text = session.receive_text() @@ -179,6 +186,10 @@ def test_routes(): ] ), ), + Host( + "{subdomain}.example.org", + app=Router(routes=[Route("/", endpoint=custom_subdomain)]), + ), Route("/500", endpoint=runtime_error, methods=["GET"]), WebSocketRoute("/ws", endpoint=websocket_endpoint), ] diff --git a/tests/test_routing.py b/tests/test_routing.py index aebb72972..900d561d7 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,7 +1,7 @@ import pytest from starlette.responses import JSONResponse, PlainTextResponse, Response -from starlette.routing import Mount, NoMatchFound, Route, Router, WebSocketRoute +from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute from starlette.testclient import TestClient from starlette.websockets import WebSocket, WebSocketDisconnect @@ -271,3 +271,88 @@ def test_reverse_mount_urls(): assert ( mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom" ) + + +def users_api(request): + return JSONResponse({"users": [{"username": "tom"}]}) + + +mixed_hosts_app = Router( + routes=[ + Host( + "www.example.org", + app=Router( + [ + Route("/", homepage, name="homepage"), + Route("/users", users, name="users"), + ] + ), + ), + Host( + "api.example.org", + name="api", + app=Router([Route("/users", users_api, name="users")]), + ), + ] +) + + +def test_host_routing(): + client = TestClient(mixed_hosts_app, base_url="https://api.example.org/") + + response = client.get("/users") + assert response.status_code == 200 + assert response.json() == {"users": [{"username": "tom"}]} + + response = client.get("/") + assert response.status_code == 404 + + client = TestClient(mixed_hosts_app, base_url="https://www.example.org/") + + response = client.get("/users") + assert response.status_code == 200 + assert response.text == "All users" + + response = client.get("/") + assert response.status_code == 200 + + +def test_host_reverse_urls(): + assert ( + mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever") + == "https://www.example.org/" + ) + assert ( + mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever") + == "https://www.example.org/users" + ) + assert ( + mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever") + == "https://api.example.org/users" + ) + + +def subdomain_app(scope): + return JSONResponse({"subdomain": scope["path_params"]["subdomain"]}) + + +subdomain_app = Router( + routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")] +) + + +def test_subdomain_routing(): + client = TestClient(subdomain_app, base_url="https://foo.example.org/") + + response = client.get("/") + assert response.status_code == 200 + assert response.json() == {"subdomain": "foo"} + + +def test_subdomain_reverse_urls(): + assert ( + subdomain_app.url_path_for( + "subdomains", subdomain="foo", path="/homepage" + ).make_absolute_url("https://whatever") + == "https://foo.example.org/homepage" + )