-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Add Mount(..., middleware=[...]) #1649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 27 commits
42548cf
01bbcdc
d8b626d
73dc39e
14d3005
8eb2699
baab334
bbca389
a75a523
578f618
ba59a35
3cadfb2
1eace2b
ca50340
f65dfc8
5edb100
1ef66e6
fb93ef5
273cc73
4369ee7
10f47ae
feeba5e
0bb54e4
d7c3f2a
e6fad81
adb52ab
0f7a2c4
f6de20f
72fbaa6
eee6a6f
aec580f
93ec37f
5f936e9
0079756
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -683,6 +683,41 @@ to use the `middleware=<List of Middleware instances>` style, as it will: | |
| * Ensure that everything remains wrapped in a single outermost `ServerErrorMiddleware`. | ||
| * Preserves the top-level `app` instance. | ||
|
|
||
| ## Applying middleware to `Mount`s | ||
|
|
||
| Middleware can also be added to `Mount`, which allows you to apply middleware to a single route, a group of routes or any mounted ASGI application: | ||
|
|
||
| ```python | ||
| from starlette.applications import Starlette | ||
| from starlette.middleware import Middleware | ||
| from starlette.middleware.gzip import GZipMiddleware | ||
| from starlette.routing import Mount, Route | ||
|
|
||
|
|
||
| routes = [ | ||
| Mount( | ||
| "/", | ||
| routes=[ | ||
| Route( | ||
| "/example", | ||
| endpoint=..., | ||
| ) | ||
| ], | ||
| middleware=[Middleware(GZipMiddleware)] | ||
| ) | ||
| ] | ||
|
|
||
| app = Starlette(routes=routes) | ||
| ``` | ||
|
|
||
| Note that middleware used in this way is *not* wrapped in exception handling middleware like the middleware applied to the `Starlette` application is. | ||
| This is often not a problem because it only applies to middleware that inspect or modify the `Response`, and even then you probably don't want to apply this logic to error responses. | ||
| If you do want to apply the middelware logic to error responses only on some routes you have a couple of options: | ||
|
|
||
| * Add an `ExceptionMiddleware` onto the `Mount` | ||
| * Add a `try/except` block to your middleware and return an error response from there | ||
| * Split up marking and processing into two middlewares, one that gets put on `Mount` which simply marks the response as needing processing (for example by setting `scope["log-response"] = True`) and another applied to the `Starlette` application that does the heavy lifting. | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... This works... 🤔 from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route
async def home(request):
return Response("Hi there!")
async def exception_handler(request, exc):
return Response("I have a 400 for you!", status_code=400)
class PotatoException(Exception):
...
class CustomMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
raise PotatoException()
routes = [
Mount(
"/",
routes=[
Route(
"/example",
endpoint=home,
)
],
middleware=[Middleware(CustomMiddleware)],
),
]
app = Starlette(routes=routes, exception_handlers={PotatoException: exception_handler})
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Try this example: from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route
class PotatoException(Exception):
...
async def bad_endpoint(request):
raise PotatoException
async def exception_handler(request, exc):
return Response("I have a 400 for you!", status_code=400)
class CustomMiddleware:
def __init__(self, app, name: str):
self.app = app
self.name = name
async def __call__(self, scope, receive, send):
async def wrapped_send(msg):
if msg["type"] == "http.response":
print(f"{self.name} called for {scope['raw_path']}")
await send(msg)
await self.app(scope, receive, wrapped_send)
routes = [
Mount(
"/mount",
routes=[
Route(
"/bad",
endpoint=bad_endpoint,
)
],
middleware=[Middleware(CustomMiddleware, name="on_mount")],
),
Route(
"/good",
endpoint=bad_endpoint,
)
]
app = Starlette(
routes=routes,
exception_handlers={PotatoException: exception_handler},
middleware=[Middleware(CustomMiddleware, name="on_app")],
)The "on_mount" middleware will never be called because PotatoException tears through
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, got it. But that's not what I got from what's written here... Do you think an image with middleware/app/mount as "blocks" would be helpful understanding this? 🤔 (I can help if you think it makes sense, but I lack design skills to do it 😎 👍)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds like a good idea, please give it a shot! I won't have time until next week (and also don't have the design skills 😆). Maybe do it with ascii text instead of an image since it will be easier to embed?
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a blocker. 👍 (Please do not resolve this conversation so ppl see it)
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's right. Adrian's example was missing something (I had to add it myself before, and I forgot)... from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import Response
from starlette.routing import Mount, Route
class PotatoException(Exception):
...
async def bad_endpoint(request):
raise PotatoException
async def exception_handler(request, exc):
return Response("I have a 400 for you!", status_code=400)
class CustomMiddleware:
def __init__(self, app, name: str):
self.app = app
self.name = name
async def __call__(self, scope, receive, send):
async def wrapped_send(msg):
if msg["type"] == "http.response.start": # IT WAS MISSING `.start`
print(f"{self.name} called for {scope['raw_path']}")
await send(msg)
await self.app(scope, receive, wrapped_send)
routes = [
Mount(
"/mount",
routes=[
Route(
"/bad",
endpoint=bad_endpoint,
)
],
middleware=[Middleware(CustomMiddleware, name="on_mount")],
),
Route(
"/good",
endpoint=bad_endpoint,
)
]
app = Starlette(
routes=routes,
exception_handlers={PotatoException: exception_handler},
middleware=[Middleware(CustomMiddleware, name="on_app")],
)
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can use @florimondmanca 's drawing here for the explanation 👍
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the broken example 😫, but yes you got it right Florimond!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a test in aec580f
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @florimondmanca I'm realizing now that it was not clear if this was completely resolved, sorry if I missed it. Do you think we should tweak the docs more, maybe adding something along the lines of your explanation in #1649 (comment)?
Kludex marked this conversation as resolved.
Outdated
|
||
|
|
||
| ## Third party middleware | ||
|
|
||
| #### [asgi-auth-github](https://github.com/simonw/asgi-auth-github) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,8 +5,20 @@ | |
| import pytest | ||
|
|
||
| from starlette.applications import Starlette | ||
| from starlette.middleware import Middleware | ||
| from starlette.requests import Request | ||
| from starlette.responses import JSONResponse, PlainTextResponse, Response | ||
| from starlette.routing import Host, Mount, NoMatchFound, Route, Router, WebSocketRoute | ||
| from starlette.routing import ( | ||
| BaseRoute, | ||
| Host, | ||
| Mount, | ||
| NoMatchFound, | ||
| Route, | ||
| Router, | ||
| WebSocketRoute, | ||
| ) | ||
| from starlette.testclient import TestClient | ||
| from starlette.types import ASGIApp, Message, Receive, Scope, Send | ||
| from starlette.websockets import WebSocket, WebSocketDisconnect | ||
|
|
||
|
|
||
|
|
@@ -768,6 +780,101 @@ def test_route_name(endpoint: typing.Callable, expected_name: str): | |
| assert Route(path="/", endpoint=endpoint).name == expected_name | ||
|
|
||
|
|
||
| class AddHeadersMiddleware: | ||
| def __init__(self, app: ASGIApp) -> None: | ||
| self.app = app | ||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| scope["add_headers_middleware"] = True | ||
|
|
||
| async def modified_send(msg: Message) -> None: | ||
| if msg["type"] == "http.response.start": | ||
| msg["headers"].append((b"X-Test", b"Set by middleware")) | ||
| await send(msg) | ||
|
|
||
| await self.app(scope, receive, modified_send) | ||
|
|
||
|
|
||
| def assert_middleware_header_route(request: Request) -> Response: | ||
| assert request.scope["add_headers_middleware"] is True | ||
| return Response() | ||
|
|
||
|
|
||
| mounted_routes_with_middleware = Mount( | ||
| "/http", | ||
| routes=[ | ||
| Route( | ||
| "/", | ||
| endpoint=assert_middleware_header_route, | ||
| methods=["GET"], | ||
| name="route", | ||
| ), | ||
| ], | ||
| middleware=[Middleware(AddHeadersMiddleware)], | ||
| ) | ||
|
|
||
|
|
||
| mounted_app_with_middleware = Mount( | ||
| "/http", | ||
| app=Route( | ||
| "/", | ||
| endpoint=assert_middleware_header_route, | ||
| methods=["GET"], | ||
| name="route", | ||
| ), | ||
| middleware=[Middleware(AddHeadersMiddleware)], | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "route", | ||
| [ | ||
| mounted_routes_with_middleware, | ||
| mounted_routes_with_middleware, | ||
|
Kludex marked this conversation as resolved.
|
||
| mounted_app_with_middleware, | ||
| ], | ||
| ) | ||
| def test_mount_middleware( | ||
| test_client_factory: typing.Callable[..., TestClient], | ||
| route: BaseRoute, | ||
| ) -> None: | ||
| test_client = test_client_factory(Router([route])) | ||
| response = test_client.get("/http") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we also
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in eee6a6f |
||
| assert response.status_code == 200 | ||
| assert response.headers["X-Test"] == "Set by middleware" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "route", | ||
| [ | ||
| mounted_routes_with_middleware, | ||
| mounted_routes_with_middleware, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Duplicated as well, so I think we can get rid of the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed in eee6a6f |
||
| ], | ||
| ) | ||
| def test_mount_middleware_url_path_for(route: BaseRoute) -> None: | ||
| """Checks that url_path_for still works with middleware on Mounts""" | ||
| router = Router([route]) | ||
| assert router.url_path_for("route") == "/http/" | ||
|
|
||
|
|
||
| def test_add_route_to_app_after_mount( | ||
| test_client_factory: typing.Callable[..., TestClient], | ||
| ) -> None: | ||
| """Checks that Mount will pick up routes | ||
| added to the underlying app after it is mounted | ||
| """ | ||
| inner_app = Router() | ||
| app = Mount("/http", app=inner_app) | ||
| inner_app.add_route( | ||
| "/inner", | ||
| endpoint=lambda request: Response(), | ||
| methods=["GET"], | ||
| ) | ||
| client = test_client_factory(app) | ||
| response = client.get("/http/inner") | ||
| assert response.status_code == 200 | ||
|
|
||
|
|
||
| def test_exception_on_mounted_apps(test_client_factory): | ||
| def exc(request): | ||
| raise Exception("Exc") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.