diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index cd7294170..6033b9d38 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -55,6 +55,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + response_started = False async def sender(message: Message) -> None: @@ -106,4 +111,4 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException ) -> None: - await websocket.close(code=exc.code, reason=exc.reason) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..111478c98 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -53,19 +53,72 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover return inspect.iscoroutinefunction(obj) +def _lookup_exception_handler( + exc: Exception, + handlers: typing.Mapping[typing.Type[Exception], typing.Callable[..., typing.Any]], +) -> typing.Optional[typing.Callable[..., typing.Any]]: + for cls in type(exc).__mro__: + if cls in handlers: + return handlers[cls] + return None + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ + is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive=receive, send=send) - if is_coroutine: - response = await func(request) - else: - response = await run_in_threadpool(func, request) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + request = Request(scope, receive=receive, send=sender) + + try: + if is_coroutine: + response = await func(request) + else: + response = await run_in_threadpool(func, request) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, send) return app @@ -78,8 +131,49 @@ def websocket_session(func: typing.Callable) -> ASGIApp: # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: - session = WebSocket(scope, receive=receive, send=send) - await func(session) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + session = WebSocket(scope, receive=receive, send=sender) + + try: + await func(session) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + await handler(session, exc) + else: + await run_in_threadpool(handler, session, exc) return app diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 05583a430..2f2b89167 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,7 +4,8 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware -from starlette.responses import PlainTextResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute @@ -28,6 +29,22 @@ def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) +class BadBodyException(HTTPException): + pass + + +async def read_body_and_raise_exc(request: Request): + await request.body() + raise BadBodyException(422) + + +async def handler_that_reads_body( + request: Request, exc: BadBodyException +) -> JSONResponse: + body = await request.body() + return JSONResponse(status_code=422, content={"body": body.decode()}) + + class HandledExcAfterResponse: async def __call__(self, scope, receive, send): response = PlainTextResponse("OK", status_code=200) @@ -44,11 +61,19 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), + Route( + "/consume_body_in_endpoint_and_handler", + endpoint=read_body_and_raise_exc, + methods=["POST"], + ), ] ) -app = ExceptionMiddleware(router) +app = ExceptionMiddleware( + router, + handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item] +) @pytest.fixture @@ -160,3 +185,9 @@ def test_exception_middleware_deprecation() -> None: with pytest.warns(DeprecationWarning): starlette.exceptions.ExceptionMiddleware + + +def test_request_in_app_and_handler_is_the_same_object(client) -> None: + response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!") + assert response.status_code == 422 + assert response.json() == {"body": "Hello!"} diff --git a/tests/test_routing.py b/tests/test_routing.py index 09beb8bb9..cb8eeed01 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -945,13 +945,9 @@ async def modified_send(msg: Message) -> None: assert resp.status_code == 200, resp.content assert "X-Mounted" in resp.headers - # this is the "surprising" behavior bit - # the middleware on the mount never runs because there - # is nothing to catch the HTTPException - # since Mount middlweare is not wrapped by ExceptionMiddleware resp = client.get("/mount/err") assert resp.status_code == 403, resp.content - assert "X-Mounted" not in resp.headers + assert "X-Mounted" in resp.headers def test_route_repr() -> None: