From 6470c9f29fbad95ec989aaf50e7b711e5df2bf2d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:09:42 -0600 Subject: [PATCH 1/7] Move exception handling logic to endpoints --- starlette/middleware/exceptions.py | 5 ++ starlette/routing.py | 113 +++++++++++++++++++++++++++-- tests/test_exceptions.py | 26 ++++++- 3 files changed, 135 insertions(+), 9 deletions(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index cd7294170..6698705ea 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -55,6 +55,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + response_started = False async def sender(message: Message) -> None: diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..ed85703f6 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -53,20 +53,78 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover return inspect.iscoroutinefunction(obj) +def _lookup_exception_handler( + exc: Exception, + handlers: typing.Mapping[typing.Type[Exception], typing.Callable[..., typing.Any]], +) -> typing.Optional[typing.Callable[..., typing.Any]]: + for cls in type(exc).__mro__: + if cls in handlers: + return handlers[cls] + return None + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ + is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive=receive, send=send) - if is_coroutine: - response = await func(request) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + request = Request(scope, receive=receive, send=sender) + + try: + if is_coroutine: + response = await func(request) + else: + response = await run_in_threadpool(func, request) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, send) else: - response = await run_in_threadpool(func, request) - await response(scope, receive, send) + try: + await response(scope, receive, send) + except Exception as exc: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc return app @@ -78,8 +136,49 @@ def websocket_session(func: typing.Callable) -> ASGIApp: # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: - session = WebSocket(scope, receive=receive, send=send) - await func(session) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + session = WebSocket(scope, receive=receive, send=sender) + + try: + await func(session) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + await handler(session, exc) + else: + await run_in_threadpool(handler, session, exc) return app diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 05583a430..fcbaaed47 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,7 +4,8 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware -from starlette.responses import PlainTextResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Route, Router, WebSocketRoute @@ -28,6 +29,20 @@ def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) +class BadBodyException(HTTPException): + pass + + +async def read_body_and_raise_exc(request: Request): + await request.body() + raise BadBodyException(422) + + +async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: + body = await request.body() + return JSONResponse(status_code=422, content={"body": body.decode()}) + + class HandledExcAfterResponse: async def __call__(self, scope, receive, send): response = PlainTextResponse("OK", status_code=200) @@ -44,11 +59,12 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), + Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), ] ) -app = ExceptionMiddleware(router) +app = ExceptionMiddleware(router, handlers={BadBodyException: handler_that_reads_body}) @pytest.fixture @@ -160,3 +176,9 @@ def test_exception_middleware_deprecation() -> None: with pytest.warns(DeprecationWarning): starlette.exceptions.ExceptionMiddleware + + +def test_request_in_app_and_handler_is_the_same_object(client) -> None: + response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!") + assert response.status_code == 422 + assert response.json() == {"body": "Hello!"} From 23c918ed52451b50903cb57aaa819413496e9cdc Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:19:43 -0600 Subject: [PATCH 2/7] lint --- tests/test_exceptions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fcbaaed47..3c36f67e1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -38,7 +38,9 @@ async def read_body_and_raise_exc(request: Request): raise BadBodyException(422) -async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: +async def handler_that_reads_body( + request: Request, exc: BadBodyException +) -> JSONResponse: body = await request.body() return JSONResponse(status_code=422, content={"body": body.decode()}) @@ -59,7 +61,11 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), - Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), + Route( + "/consume_body_in_endpoint_and_handler", + endpoint=read_body_and_raise_exc, + methods=["POST"], + ), ] ) From 0737cc95760e94a3b05fe5f2ef94c6c34825dfe6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:21:32 -0600 Subject: [PATCH 3/7] lint --- tests/test_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3c36f67e1..0311c48fe 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -5,7 +5,7 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request -from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute From bdb9d68916fee0e89e20fe2423c524b8cac534bb Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:24:40 -0600 Subject: [PATCH 4/7] more checks --- starlette/routing.py | 6 +++--- tests/test_exceptions.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index ed85703f6..13f194bb8 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -84,7 +84,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False - async def sender(message) -> None: + async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": @@ -148,7 +148,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False - async def sender(message) -> None: + async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0311c48fe..2f2b89167 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -70,7 +70,10 @@ async def __call__(self, scope, receive, send): ) -app = ExceptionMiddleware(router, handlers={BadBodyException: handler_that_reads_body}) +app = ExceptionMiddleware( + router, + handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item] +) @pytest.fixture From ce4a3d5b535def20edf4665ab97b47ea5fcc95ef Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:58:25 -0600 Subject: [PATCH 5/7] fix tests --- starlette/routing.py | 9 ++------- tests/test_routing.py | 6 +----- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 13f194bb8..111478c98 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -118,13 +118,8 @@ async def sender(message: Message) -> None: response = await handler(request, exc) else: response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, send) - else: - try: - await response(scope, receive, send) - except Exception as exc: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + + await response(scope, receive, send) return app diff --git a/tests/test_routing.py b/tests/test_routing.py index 09beb8bb9..cb8eeed01 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -945,13 +945,9 @@ async def modified_send(msg: Message) -> None: assert resp.status_code == 200, resp.content assert "X-Mounted" in resp.headers - # this is the "surprising" behavior bit - # the middleware on the mount never runs because there - # is nothing to catch the HTTPException - # since Mount middlweare is not wrapped by ExceptionMiddleware resp = client.get("/mount/err") assert resp.status_code == 403, resp.content - assert "X-Mounted" not in resp.headers + assert "X-Mounted" in resp.headers def test_route_repr() -> None: From 62807bd1659d53c2bca9ce79dbac7952132f761c Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 19:13:33 -0600 Subject: [PATCH 6/7] Update exceptions.py --- starlette/middleware/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 6698705ea..53c4e5506 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -110,5 +110,5 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException - ) -> None: - await websocket.close(code=exc.code, reason=exc.reason) + ) -> None: + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover From 5ca404d7ffb58d4200f6dd05f1c3e961a0a1022f Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 20:43:16 -0600 Subject: [PATCH 7/7] reformat --- starlette/middleware/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 53c4e5506..6033b9d38 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -110,5 +110,5 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException - ) -> None: + ) -> None: await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover