Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions starlette/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def on_event(self, event_type: str) -> typing.Callable:
def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.router.mount(path, app=app, name=name)

def host(self, host: str, app: ASGIApp, name: str = None) -> None:
self.router.host(host, app=app, name=name)

def add_middleware(self, middleware_class: type, **kwargs: typing.Any) -> None:
self.error_middleware.app = middleware_class(
self.error_middleware.app, **kwargs
Expand Down
29 changes: 19 additions & 10 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,26 +157,35 @@ def replace(self, **kwargs: typing.Any) -> "URL":

class URLPath(str):
"""
A URL path string that also holds an associated protocol.
A URL path string that may also hold an associated protocol and/or host.
Used by the routing to return `url_path_for` matches.
"""

def __new__(cls, path: str, protocol: str) -> str:
assert protocol in ("http", "websocket")
def __new__(cls, path: str, protocol: str = "", host: str = "") -> str:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path) # type: ignore

def __init__(self, path: str, protocol: str) -> None:
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
self.host = host

def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str:
if isinstance(base_url, str):
base_url = URL(base_url)
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
netloc = base_url.netloc
return str(URL(scheme=scheme, netloc=base_url.netloc, path=str(self)))
if self.protocol:
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
else:
scheme = base_url.scheme

if self.host:
netloc = self.host
else:
netloc = base_url.netloc

return str(URL(scheme=scheme, netloc=netloc, path=str(self)))


class Secret:
Expand Down
78 changes: 73 additions & 5 deletions starlette/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from starlette.concurrency import run_in_threadpool
from starlette.convertors import CONVERTOR_TYPES, Convertor
from starlette.datastructures import URL, URLPath
from starlette.datastructures import URL, Headers, URLPath
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.responses import PlainTextResponse, RedirectResponse
Expand Down Expand Up @@ -84,6 +84,10 @@ def replace_params(
return path, path_params


# Match parameters in URL paths, eg. '{param}', and '{param:int}'
PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")


def compile_path(
path: str
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
Expand Down Expand Up @@ -124,9 +128,6 @@ def compile_path(
return re.compile(path_regex), path_format, param_convertors


PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")


class BaseRoute:
def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
raise NotImplementedError() # pragma: no cover
Expand Down Expand Up @@ -316,7 +317,7 @@ def url_path_for(self, name: str, **path_params: str) -> URLPath:
self.path_format, self.param_convertors, path_params
)
if not remaining_params:
return URLPath(path=path, protocol="http")
return URLPath(path=path)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
Expand Down Expand Up @@ -349,6 +350,69 @@ def __eq__(self, other: typing.Any) -> bool:
)


class Host(BaseRoute):
def __init__(self, host: str, app: ASGIApp, name: str = None) -> None:
self.host = host
self.app = app
self.name = name
self.host_regex, self.host_format, self.param_convertors = compile_path(host)

@property
def routes(self) -> typing.List[BaseRoute]:
return getattr(self.app, "routes", None)

def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
headers = Headers(scope=scope)
host = headers.get("host", "").split(":")[0]
match = self.host_regex.match(host)
if match:
matched_params = match.groupdict()
for key, value in matched_params.items():
matched_params[key] = self.param_convertors[key].convert(value)
path_params = dict(scope.get("path_params", {}))
path_params.update(matched_params)
child_scope = {"path_params": path_params, "endpoint": self.app}
return Match.FULL, child_scope
return Match.NONE, {}

def url_path_for(self, name: str, **path_params: str) -> URLPath:
if self.name is not None and name == self.name and "path" in path_params:
# 'name' matches "<mount_name>".
path = path_params.pop("path")
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
if not remaining_params:
return URLPath(path=path, host=host)
elif self.name is None or name.startswith(self.name + ":"):
if self.name is None:
# No mount name.
remaining_name = name
else:
# 'name' matches "<mount_name>:<child_name>".
remaining_name = name[len(self.name) + 1 :]
host, remaining_params = replace_params(
self.host_format, self.param_convertors, path_params
)
for route in self.routes or []:
try:
url = route.url_path_for(remaining_name, **remaining_params)
return URLPath(path=str(url), protocol=url.protocol, host=host)
except NoMatchFound as exc:
pass
raise NoMatchFound()

def __call__(self, scope: Scope) -> ASGIInstance:
return self.app(scope)

def __eq__(self, other: typing.Any) -> bool:
return (
isinstance(other, Host)
and self.host == other.host
and self.app == other.app
)


class Router:
def __init__(
self,
Expand All @@ -364,6 +428,10 @@ def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
route = Mount(path, app=app, name=name)
self.routes.append(route)

def host(self, host: str, app: ASGIApp, name: str = None) -> None:
route = Host(host, app=app, name=name)
self.routes.append(route)

def add_route(
self,
path: str,
Expand Down
41 changes: 26 additions & 15 deletions tests/test_applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,16 @@
from starlette.datastructures import Headers
from starlette.endpoints import HTTPEndpoint
from starlette.exceptions import HTTPException
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.responses import JSONResponse, PlainTextResponse
from starlette.routing import Mount, Route, Router, WebSocketRoute
from starlette.routing import Host, Mount, Route, Router, WebSocketRoute
from starlette.staticfiles import StaticFiles
from starlette.testclient import TestClient


class TrustedHostMiddleware:
def __init__(self, app, hostname):
self.app = app
self.hostname = hostname

def __call__(self, scope):
headers = Headers(scope=scope)
if headers.get("host") != self.hostname:
return PlainTextResponse("Invalid host header", status_code=400)
return self.app(scope)


app = Starlette()


app.add_middleware(TrustedHostMiddleware, hostname="testserver")
app.add_middleware(TrustedHostMiddleware, allowed_hosts=["testserver", "*.example.org"])


@app.exception_handler(500)
Expand Down Expand Up @@ -76,6 +64,17 @@ def user_page(request):
app.mount("/users", users)


subdomain = Router()


@subdomain.route("/")
def custom_subdomain(request):
return PlainTextResponse("Subdomain: " + request.path_params["subdomain"])


app.host("{subdomain}.example.org", subdomain)


@app.route("/500")
def runtime_error(request):
raise RuntimeError()
Expand Down Expand Up @@ -129,6 +128,14 @@ def test_mounted_route_path_params():
assert response.text == "Hello, tomchristie!"


def test_subdomain_route():
client = TestClient(app, base_url="https://foo.example.org/")

response = client.get("/")
assert response.status_code == 200
assert response.text == "Subdomain: foo"


def test_websocket_route():
with client.websocket_connect("/ws") as session:
text = session.receive_text()
Expand Down Expand Up @@ -179,6 +186,10 @@ def test_routes():
]
),
),
Host(
"{subdomain}.example.org",
app=Router(routes=[Route("/", endpoint=custom_subdomain)]),
),
Route("/500", endpoint=runtime_error, methods=["GET"]),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
]
Expand Down
87 changes: 86 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from starlette.responses import JSONResponse, PlainTextResponse, Response
from starlette.routing import Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute
from starlette.testclient import TestClient
from starlette.websockets import WebSocket, WebSocketDisconnect

Expand Down Expand Up @@ -271,3 +271,88 @@ def test_reverse_mount_urls():
assert (
mounted.url_path_for("users", subpath="test", path="/tom") == "/test/users/tom"
)


def users_api(request):
return JSONResponse({"users": [{"username": "tom"}]})


mixed_hosts_app = Router(
routes=[
Host(
"www.example.org",
app=Router(
[
Route("/", homepage, name="homepage"),
Route("/users", users, name="users"),
]
),
),
Host(
"api.example.org",
name="api",
app=Router([Route("/users", users_api, name="users")]),
),
]
)


def test_host_routing():
client = TestClient(mixed_hosts_app, base_url="https://api.example.org/")

response = client.get("/users")
assert response.status_code == 200
assert response.json() == {"users": [{"username": "tom"}]}

response = client.get("/")
assert response.status_code == 404

client = TestClient(mixed_hosts_app, base_url="https://www.example.org/")

response = client.get("/users")
assert response.status_code == 200
assert response.text == "All users"

response = client.get("/")
assert response.status_code == 200


def test_host_reverse_urls():
assert (
mixed_hosts_app.url_path_for("homepage").make_absolute_url("https://whatever")
== "https://www.example.org/"
)
assert (
mixed_hosts_app.url_path_for("users").make_absolute_url("https://whatever")
== "https://www.example.org/users"
)
assert (
mixed_hosts_app.url_path_for("api:users").make_absolute_url("https://whatever")
== "https://api.example.org/users"
)


def subdomain_app(scope):
return JSONResponse({"subdomain": scope["path_params"]["subdomain"]})


subdomain_app = Router(
routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")]
)


def test_subdomain_routing():
client = TestClient(subdomain_app, base_url="https://foo.example.org/")

response = client.get("/")
assert response.status_code == 200
assert response.json() == {"subdomain": "foo"}


def test_subdomain_reverse_urls():
assert (
subdomain_app.url_path_for(
"subdomains", subdomain="foo", path="/homepage"
).make_absolute_url("https://whatever")
== "https://foo.example.org/homepage"
)