Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions starlette/middleware/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.app = app
self._dispatch_func = dispatch

def dispatch(self, conn: HTTPConnection) -> _DispatchFlow:
def dispatch(self, __conn: HTTPConnection) -> _DispatchFlow:
raise NotImplementedError # pragma: no cover

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
Expand All @@ -52,6 +52,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
maybe_early_response = await flow.__anext__()

if maybe_early_response is not None:
try:
await flow.__anext__()
except StopAsyncIteration:
pass
else:
raise RuntimeError("dispatch() should yield exactly once")

await maybe_early_response(scope, receive, send)
return

Expand Down Expand Up @@ -98,6 +105,3 @@ async def wrapped_send(message: Message) -> None:

await response(scope, receive, send)
return

if not response_started:
raise RuntimeError("No response returned.")
Comment on lines -102 to -103
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this here, we can just let the server handle it which is what would happen if we didn't install this middleware. I think we only do it in BaseHTTPMiddleware out of necessity / implementation detail.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

173 changes: 83 additions & 90 deletions tests/middleware/test_http.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,27 @@
import contextvars
from typing import AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Callable, Optional

import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.http import HTTPMiddleware
from starlette.requests import HTTPConnection
from starlette.requests import HTTPConnection, Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Route, WebSocketRoute
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.websockets import WebSocket


class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
response = yield
response.headers["Custom-Header"] = "Example"


def homepage(request):
def homepage(request: Request) -> Response:
return PlainTextResponse("Homepage")


def exc(request):
def exc(request: Request) -> Response:
raise Exception("Exc")


def exc_stream(request):
def exc_stream(request: Request) -> Response:
return StreamingResponse(_generate_faulty_stream())


Expand All @@ -36,22 +31,28 @@ def _generate_faulty_stream():


class NoResponse:
def __init__(self, scope, receive, send):
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
pass

def __await__(self):
def __await__(self) -> Any:
return self.dispatch().__await__()

async def dispatch(self):
async def dispatch(self) -> None:
pass


async def websocket_endpoint(session):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()


class CustomMiddleware(HTTPMiddleware):
async def dispatch(self, request: HTTPConnection) -> AsyncGenerator[None, Response]:
response = yield
response.headers["Custom-Header"] = "Example"


app = Starlette(
routes=[
Route("/", endpoint=homepage),
Expand All @@ -64,7 +65,9 @@ async def websocket_endpoint(session):
)


def test_custom_middleware(test_client_factory):
def test_custom_middleware(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
client = test_client_factory(app)
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"
Expand All @@ -77,49 +80,42 @@ def test_custom_middleware(test_client_factory):
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"

with pytest.raises(RuntimeError):
with pytest.raises(AssertionError): # from TestClient
response = client.get("/no-response")

with client.websocket_connect("/ws") as session:
text = session.receive_text()
assert text == "Hello, world!"


def test_state_data_across_multiple_middlewares(test_client_factory):
def test_state_data_across_multiple_middlewares(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
expected_value1 = "foo"
expected_value2 = "bar"

class aMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
conn.state.foo = expected_value1
yield
async def middleware_a(request: HTTPConnection) -> AsyncGenerator[None, None]:
request.state.foo = expected_value1
yield

class bMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
conn.state.bar = expected_value2
response = yield
response.headers["X-State-Foo"] = conn.state.foo
async def middleware_b(request: HTTPConnection) -> AsyncGenerator[None, Response]:
request.state.bar = expected_value2
response = yield
response.headers["X-State-Foo"] = request.state.foo

class cMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
response = yield
response.headers["X-State-Bar"] = conn.state.bar
async def middleware_c(request: HTTPConnection) -> AsyncGenerator[None, Response]:
response = yield
response.headers["X-State-Bar"] = request.state.bar

def homepage(request):
def homepage(request: Request) -> Response:
return PlainTextResponse("OK")

app = Starlette(
routes=[Route("/", homepage)],
middleware=[
Middleware(aMiddleware),
Middleware(bMiddleware),
Middleware(cMiddleware),
Middleware(HTTPMiddleware, dispatch=middleware_a),
Middleware(HTTPMiddleware, dispatch=middleware_b),
Middleware(HTTPMiddleware, dispatch=middleware_c),
],
)

Expand All @@ -130,8 +126,10 @@ def homepage(request):
assert response.headers["X-State-Bar"] == expected_value2


def test_dispatch_argument(test_client_factory):
def homepage(request):
def test_dispatch_argument(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
def homepage(request: Request):
return PlainTextResponse("Homepage")

async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
Expand All @@ -148,13 +146,8 @@ async def dispatch(conn: HTTPConnection) -> AsyncGenerator[None, Response]:
assert response.headers["Custom-Header"] == "Example"


def test_middleware_repr():
middleware = Middleware(CustomMiddleware)
assert repr(middleware) == "Middleware(CustomMiddleware)"


def test_early_response(test_client_factory):
async def index(request):
def test_early_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
async def index(request: Request):
return PlainTextResponse("Hello, world!")

class CustomMiddleware(HTTPMiddleware):
Expand All @@ -179,12 +172,12 @@ async def dispatch(
assert response.status_code == 401


def test_too_many_yields(test_client_factory) -> None:
def test_too_many_yields(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[None, Response]:
_ = yield
yield
yield

app = Starlette(middleware=[Middleware(CustomMiddleware)])
Expand All @@ -194,11 +187,28 @@ async def dispatch(
client.get("/")


def test_error_response(test_client_factory) -> None:
def test_too_many_yields_early_response(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
class CustomMiddleware(HTTPMiddleware):
async def dispatch(
self, conn: HTTPConnection
) -> AsyncGenerator[Optional[Response], Response]:
yield Response()
yield None

app = Starlette(middleware=[Middleware(CustomMiddleware)])

client = test_client_factory(app)
with pytest.raises(RuntimeError, match="should yield exactly once"):
client.get("/")


def test_error_response(test_client_factory: Callable[[ASGIApp], TestClient]) -> None:
class Failed(Exception):
pass

async def failure(request):
async def failure(request: Request):
raise Failed()

class CustomMiddleware(HTTPMiddleware):
Expand All @@ -221,11 +231,13 @@ async def dispatch(
assert response.status_code == 500


def test_no_error_response(test_client_factory) -> None:
def test_no_error_response(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
class Failed(Exception):
pass

async def index(request):
async def index(request: Request):
raise Failed()

class CustomMiddleware(HTTPMiddleware):
Expand All @@ -247,43 +259,24 @@ async def dispatch(
client.get("/")


ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")


class PureASGICustomMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
def test_modify_content_type(
test_client_factory: Callable[[ASGIApp], TestClient]
) -> None:
async def dispatch(request: HTTPConnection) -> AsyncGenerator[None, Response]:
resp = yield
resp.media_type = "text/csv"
Copy link
Copy Markdown
Contributor

@florimondmanca florimondmanca Jun 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be strange usage without modifying the response body, wouldn't it? Is there a situation where we'd actually do this w/o having to modify the response body and Content-Type header, which would require pure ASGI middleware (or the body modification API you thought about as a possible future improvement)?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure what the use case would be, but you are able to set the attribute, so maybe we should error in this situation instead by overriding Response.__setattr__?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need to be that defensive, to be honest. If we're clear enough in the documentation (which ought to be written before we consider merging this anyway, right?) that this is a "simplified response interface, allowing to inspect the status code and headers, or tweak headers", then maybe that's enough.

Though, now that I think of it — what if users do want to inspect response.body in the "on response" part of the dispatch flow? Right now, it's empty, and actually makes no sense because this is called before we start processing any chunk of the body. Hmm.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I don't think there's any behavior that's right or wrong here, it's just about what will be least confusing.


async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
ctxvar.set("set by middleware")
await self.app(scope, receive, send)
assert ctxvar.get() == "set by endpoint"


class HTTPCustomMiddleware(HTTPMiddleware):
async def dispatch(self, conn: HTTPConnection) -> AsyncGenerator[None, Response]:
ctxvar.set("set by middleware")
yield
assert ctxvar.get() == "set by endpoint"


@pytest.mark.parametrize(
"middleware_cls",
[
PureASGICustomMiddleware,
HTTPCustomMiddleware,
],
)
def test_contextvars(test_client_factory, middleware_cls: type):
async def homepage(request):
assert ctxvar.get() == "set by middleware"
ctxvar.set("set by endpoint")
return PlainTextResponse("Homepage")
def homepage(request: Request) -> Response:
return PlainTextResponse("OK")

app = Starlette(
middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)]
routes=[Route("/", homepage)],
middleware=[
Middleware(HTTPMiddleware, dispatch=dispatch),
],
)

client = test_client_factory(app)
response = client.get("/")
assert response.status_code == 200, response.content
assert response.text == "OK"
assert response.headers["Content-Type"] == "text/csv; charset=utf-8"