-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
copy over some tests from alternative implementation #1694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,32 +1,27 @@ | ||
| import contextvars | ||
| from typing import AsyncGenerator, Optional | ||
| from typing import Any, AsyncGenerator, Callable, Optional | ||
|
|
||
| import pytest | ||
|
|
||
| from starlette.applications import Starlette | ||
| from starlette.middleware import Middleware | ||
| from starlette.middleware.http import HTTPMiddleware | ||
| from starlette.requests import HTTPConnection | ||
| from starlette.requests import HTTPConnection, Request | ||
| from starlette.responses import PlainTextResponse, Response, StreamingResponse | ||
| from starlette.routing import Route, WebSocketRoute | ||
| from starlette.testclient import TestClient | ||
| from starlette.types import ASGIApp, Receive, Scope, Send | ||
| from starlette.websockets import WebSocket | ||
|
|
||
|
|
||
| class CustomMiddleware(HTTPMiddleware): | ||
| async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| response = yield | ||
| response.headers["Custom-Header"] = "Example" | ||
|
|
||
|
|
||
| def homepage(request): | ||
| def homepage(request: Request) -> Response: | ||
| return PlainTextResponse("Homepage") | ||
|
|
||
|
|
||
| def exc(request): | ||
| def exc(request: Request) -> Response: | ||
| raise Exception("Exc") | ||
|
|
||
|
|
||
| def exc_stream(request): | ||
| def exc_stream(request: Request) -> Response: | ||
| return StreamingResponse(_generate_faulty_stream()) | ||
|
|
||
|
|
||
|
|
@@ -36,22 +31,28 @@ def _generate_faulty_stream(): | |
|
|
||
|
|
||
| class NoResponse: | ||
| def __init__(self, scope, receive, send): | ||
| def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| pass | ||
|
|
||
| def __await__(self): | ||
| def __await__(self) -> Any: | ||
| return self.dispatch().__await__() | ||
|
|
||
| async def dispatch(self): | ||
| async def dispatch(self) -> None: | ||
| pass | ||
|
|
||
|
|
||
| async def websocket_endpoint(session): | ||
| async def websocket_endpoint(session: WebSocket): | ||
| await session.accept() | ||
| await session.send_text("Hello, world!") | ||
| await session.close() | ||
|
|
||
|
|
||
| class CustomMiddleware(HTTPMiddleware): | ||
| async def dispatch(self, request: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| response = yield | ||
| response.headers["Custom-Header"] = "Example" | ||
|
|
||
|
|
||
| app = Starlette( | ||
| routes=[ | ||
| Route("/", endpoint=homepage), | ||
|
|
@@ -64,7 +65,9 @@ async def websocket_endpoint(session): | |
| ) | ||
|
|
||
|
|
||
| def test_custom_middleware(test_client_factory): | ||
| def test_custom_middleware( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| client = test_client_factory(app) | ||
| response = client.get("/") | ||
| assert response.headers["Custom-Header"] == "Example" | ||
|
|
@@ -77,49 +80,42 @@ def test_custom_middleware(test_client_factory): | |
| response = client.get("/exc-stream") | ||
| assert str(ctx.value) == "Faulty Stream" | ||
|
|
||
| with pytest.raises(RuntimeError): | ||
| with pytest.raises(AssertionError): # from TestClient | ||
| response = client.get("/no-response") | ||
|
|
||
| with client.websocket_connect("/ws") as session: | ||
| text = session.receive_text() | ||
| assert text == "Hello, world!" | ||
|
|
||
|
|
||
| def test_state_data_across_multiple_middlewares(test_client_factory): | ||
| def test_state_data_across_multiple_middlewares( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| expected_value1 = "foo" | ||
| expected_value2 = "bar" | ||
|
|
||
| class aMiddleware(HTTPMiddleware): | ||
| async def dispatch( | ||
| self, conn: HTTPConnection | ||
| ) -> AsyncGenerator[None, Response]: | ||
| conn.state.foo = expected_value1 | ||
| yield | ||
| async def middleware_a(request: HTTPConnection) -> AsyncGenerator[None, None]: | ||
| request.state.foo = expected_value1 | ||
| yield | ||
|
|
||
| class bMiddleware(HTTPMiddleware): | ||
| async def dispatch( | ||
| self, conn: HTTPConnection | ||
| ) -> AsyncGenerator[None, Response]: | ||
| conn.state.bar = expected_value2 | ||
| response = yield | ||
| response.headers["X-State-Foo"] = conn.state.foo | ||
| async def middleware_b(request: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| request.state.bar = expected_value2 | ||
| response = yield | ||
| response.headers["X-State-Foo"] = request.state.foo | ||
|
|
||
| class cMiddleware(HTTPMiddleware): | ||
| async def dispatch( | ||
| self, conn: HTTPConnection | ||
| ) -> AsyncGenerator[None, Response]: | ||
| response = yield | ||
| response.headers["X-State-Bar"] = conn.state.bar | ||
| async def middleware_c(request: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| response = yield | ||
| response.headers["X-State-Bar"] = request.state.bar | ||
|
|
||
| def homepage(request): | ||
| def homepage(request: Request) -> Response: | ||
| return PlainTextResponse("OK") | ||
|
|
||
| app = Starlette( | ||
| routes=[Route("/", homepage)], | ||
| middleware=[ | ||
| Middleware(aMiddleware), | ||
| Middleware(bMiddleware), | ||
| Middleware(cMiddleware), | ||
| Middleware(HTTPMiddleware, dispatch=middleware_a), | ||
| Middleware(HTTPMiddleware, dispatch=middleware_b), | ||
| Middleware(HTTPMiddleware, dispatch=middleware_c), | ||
| ], | ||
| ) | ||
|
|
||
|
|
@@ -130,8 +126,10 @@ def homepage(request): | |
| assert response.headers["X-State-Bar"] == expected_value2 | ||
|
|
||
|
|
||
| def test_dispatch_argument(test_client_factory): | ||
| def homepage(request): | ||
| def test_dispatch_argument( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| def homepage(request: Request): | ||
| return PlainTextResponse("Homepage") | ||
|
|
||
| async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
|
|
@@ -148,13 +146,8 @@ async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]: | |
| assert response.headers["Custom-Header"] == "Example" | ||
|
|
||
|
|
||
| def test_middleware_repr(): | ||
| middleware = Middleware(CustomMiddleware) | ||
| assert repr(middleware) == "Middleware(CustomMiddleware)" | ||
|
|
||
|
|
||
| def test_early_response(test_client_factory): | ||
| async def index(request): | ||
| def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: | ||
| async def index(request: Request): | ||
| return PlainTextResponse("Hello, world!") | ||
|
|
||
| class CustomMiddleware(HTTPMiddleware): | ||
|
|
@@ -179,12 +172,12 @@ async def dispatch( | |
| assert response.status_code == 401 | ||
|
|
||
|
|
||
| def test_too_many_yields(test_client_factory) -> None: | ||
| def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: | ||
| class CustomMiddleware(HTTPMiddleware): | ||
| async def dispatch( | ||
| self, conn: HTTPConnection | ||
| ) -> AsyncGenerator[None, Response]: | ||
| _ = yield | ||
| yield | ||
| yield | ||
|
|
||
| app = Starlette(middleware=[Middleware(CustomMiddleware)]) | ||
|
|
@@ -194,11 +187,28 @@ async def dispatch( | |
| client.get("/") | ||
|
|
||
|
|
||
| def test_error_response(test_client_factory) -> None: | ||
| def test_too_many_yields_early_response( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| class CustomMiddleware(HTTPMiddleware): | ||
| async def dispatch( | ||
| self, conn: HTTPConnection | ||
| ) -> AsyncGenerator[Optional[Response], Response]: | ||
| yield Response() | ||
| yield None | ||
|
|
||
| app = Starlette(middleware=[Middleware(CustomMiddleware)]) | ||
|
|
||
| client = test_client_factory(app) | ||
| with pytest.raises(RuntimeError, match="should yield exactly once"): | ||
| client.get("/") | ||
|
|
||
|
|
||
| def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None: | ||
| class Failed(Exception): | ||
| pass | ||
|
|
||
| async def failure(request): | ||
| async def failure(request: Request): | ||
| raise Failed() | ||
|
|
||
| class CustomMiddleware(HTTPMiddleware): | ||
|
|
@@ -221,11 +231,13 @@ async def dispatch( | |
| assert response.status_code == 500 | ||
|
|
||
|
|
||
| def test_no_error_response(test_client_factory) -> None: | ||
| def test_no_error_response( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| class Failed(Exception): | ||
| pass | ||
|
|
||
| async def index(request): | ||
| async def index(request: Request): | ||
| raise Failed() | ||
|
|
||
| class CustomMiddleware(HTTPMiddleware): | ||
|
|
@@ -247,43 +259,24 @@ async def dispatch( | |
| client.get("/") | ||
|
|
||
|
|
||
| ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") | ||
|
|
||
|
|
||
| class PureASGICustomMiddleware: | ||
| def __init__(self, app: ASGIApp) -> None: | ||
| self.app = app | ||
| def test_modify_content_type( | ||
| test_client_factory: Callable[[ASGIApp], TestClient] | ||
| ) -> None: | ||
| async def dispatch(request: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| resp = yield | ||
| resp.media_type = "text/csv" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be strange usage without modifying the response body, wouldn't it? Is there a situation where we'd actually do this w/o having to modify the response body and
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I'm not sure what the use case would be, but you are able to set the attribute, so maybe we should error in this situation instead by overriding
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we need to be that defensive, to be honest. If we're clear enough in the documentation (which ought to be written before we consider merging this anyway, right?) that this is a "simplified response interface, allowing to inspect the status code and headers, or tweak headers", then maybe that's enough. Though, now that I think of it — what if users do want to inspect
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I don't think there's any behavior that's right or wrong here, it's just about what will be least confusing. |
||
|
|
||
| async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: | ||
| ctxvar.set("set by middleware") | ||
| await self.app(scope, receive, send) | ||
| assert ctxvar.get() == "set by endpoint" | ||
|
|
||
|
|
||
| class HTTPCustomMiddleware(HTTPMiddleware): | ||
| async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]: | ||
| ctxvar.set("set by middleware") | ||
| yield | ||
| assert ctxvar.get() == "set by endpoint" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "middleware_cls", | ||
| [ | ||
| PureASGICustomMiddleware, | ||
| HTTPCustomMiddleware, | ||
| ], | ||
| ) | ||
| def test_contextvars(test_client_factory, middleware_cls: type): | ||
| async def homepage(request): | ||
| assert ctxvar.get() == "set by middleware" | ||
| ctxvar.set("set by endpoint") | ||
| return PlainTextResponse("Homepage") | ||
| def homepage(request: Request) -> Response: | ||
| return PlainTextResponse("OK") | ||
|
|
||
| app = Starlette( | ||
| middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] | ||
| routes=[Route("/", homepage)], | ||
| middleware=[ | ||
| Middleware(HTTPMiddleware, dispatch=dispatch), | ||
| ], | ||
| ) | ||
|
|
||
| client = test_client_factory(app) | ||
| response = client.get("/") | ||
| assert response.status_code == 200, response.content | ||
| assert response.text == "OK" | ||
| assert response.headers["Content-Type"] == "text/csv; charset=utf-8" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this here, we can just let the server handle it which is what would happen if we didn't install this middleware. I think we only do it in
BaseHTTPMiddlewareout of necessity / implementation detail.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, eg. Uvicorn handles that here: https://github.com/encode/uvicorn/blob/f3c33fe7bca90326e38f66181f4623d8558571fb/uvicorn/protocols/http/h11_impl.py#L386-L389