Skip to content

Commit 37d20d3

Browse files
committed
Prevents reraising of exception from BaseHttpMiddleware
Fix for encode#2893 In current state exception is reraised and this adds a flag, which prevent the background exception handling code to rethrow a raised exception
1 parent bcdf0ad commit 37d20d3

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

starlette/middleware/base.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
104104
wrapped_receive = request.wrapped_receive
105105
response_sent = anyio.Event()
106106
app_exc: Exception | None = None
107+
exception_already_raised = False
107108

108109
async def call_next(request: Request) -> Response:
109110
async def receive_or_disconnect() -> Message:
@@ -150,6 +151,8 @@ async def coro() -> None:
150151
message = await recv_stream.receive()
151152
except anyio.EndOfStream:
152153
if app_exc is not None:
154+
nonlocal exception_already_raised
155+
exception_already_raised = True
153156
raise app_exc
154157
raise RuntimeError("No response returned.")
155158

@@ -176,8 +179,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
176179
await response(scope, wrapped_receive, send)
177180
response_sent.set()
178181
recv_stream.close()
179-
180-
if app_exc is not None:
182+
if app_exc is not None and not exception_already_raised:
181183
raise app_exc
182184

183185
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:

tests/middleware/test_base.py

+21
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,27 @@ async def passthrough(request: Request, call_next: RequestResponseEndpoint) -> R
320320
client.get("/")
321321

322322

323+
def test_exception_can_be_caught(test_client_factory: TestClientFactory) -> None:
324+
async def error_endpoint(_: Request):
325+
raise ValueError("TEST")
326+
327+
async def catches_error(request: Request, call_next: RequestResponseEndpoint) -> Response:
328+
try:
329+
return await call_next(request)
330+
except ValueError as exc:
331+
return PlainTextResponse(content=str(exc), status_code=400)
332+
333+
app = Starlette(
334+
middleware=[Middleware(BaseHTTPMiddleware, dispatch=catches_error)],
335+
routes=[Route("/", error_endpoint)],
336+
)
337+
338+
client = test_client_factory(app)
339+
response = client.get("/")
340+
assert response.status_code == 400
341+
assert response.text == "TEST"
342+
343+
323344
@pytest.mark.anyio
324345
async def test_do_not_block_on_background_tasks() -> None:
325346
response_complete = anyio.Event()

0 commit comments

Comments
 (0)