Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.types import ASGIApp, Receive, Scope, Send
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
Expand All @@ -29,9 +29,18 @@ async def call_next(request: Request) -> Response:
async def coro() -> None:
nonlocal app_exc

app_receive = request.receive

if receive_called:
# middleware consumed the request body
async def error_receive() -> Message:
raise RuntimeError("Receive stream already consumed")

app_receive = error_receive

async with send_stream:
try:
await self.app(scope, request.receive, send_stream.send)
await self.app(scope, app_receive, send_stream.send)
except Exception as exc:
app_exc = exc

Expand Down Expand Up @@ -61,8 +70,15 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

receive_called = False

async def tracked_receive() -> Message:
nonlocal receive_called
receive_called = True
return await receive()

async with anyio.create_task_group() as task_group:
request = Request(scope, receive=receive)
request = Request(scope, receive=tracked_receive)
response = await self.dispatch_func(request, call_next)
await response(scope, receive, send)
task_group.cancel_scope.cancel()
Expand Down
23 changes: 22 additions & 1 deletion tests/middleware/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response, StreamingResponse
from starlette.routing import Mount, Route, WebSocketRoute


Expand Down Expand Up @@ -163,3 +164,23 @@ def test_exception_on_mounted_apps(test_client_factory):
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"


def test_stream_consumed_in_middleware(test_client_factory) -> None:
class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
await request.body()
response = await call_next(request)
return response # pragma: no cover

async def endpoint(request: Request) -> Response:
await request.body()
return Response() # pragma: no cover

app = Starlette(
middleware=[Middleware(CustomMiddleware)], routes=[Route("/", endpoint)]
)

client = test_client_factory(app)
with pytest.raises(RuntimeError, match="Receive stream already consumed"):
client.get("/")