From 6470c9f29fbad95ec989aaf50e7b711e5df2bf2d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:09:42 -0600 Subject: [PATCH 01/13] Move exception handling logic to endpoints --- starlette/middleware/exceptions.py | 5 ++ starlette/routing.py | 113 +++++++++++++++++++++++++++-- tests/test_exceptions.py | 26 ++++++- 3 files changed, 135 insertions(+), 9 deletions(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index cd7294170..6698705ea 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -55,6 +55,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + response_started = False async def sender(message: Message) -> None: diff --git a/starlette/routing.py b/starlette/routing.py index 0aa90aa25..ed85703f6 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -53,20 +53,78 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover return inspect.iscoroutinefunction(obj) +def _lookup_exception_handler( + exc: Exception, + handlers: typing.Mapping[typing.Type[Exception], typing.Callable[..., typing.Any]], +) -> typing.Optional[typing.Callable[..., typing.Any]]: + for cls in type(exc).__mro__: + if cls in handlers: + return handlers[cls] + return None + + def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ + is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - request = Request(scope, receive=receive, send=send) - if is_coroutine: - response = await func(request) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + request = Request(scope, receive=receive, send=sender) + + try: + if is_coroutine: + response = await func(request) + else: + response = await run_in_threadpool(func, request) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + response = await handler(request, exc) + else: + response = await run_in_threadpool(handler, request, exc) + await response(scope, receive, send) else: - response = await run_in_threadpool(func, request) - await response(scope, receive, send) + try: + await response(scope, receive, send) + except Exception as exc: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc return app @@ -78,8 +136,49 @@ def websocket_session(func: typing.Callable) -> ASGIApp: # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: - session = WebSocket(scope, receive=receive, send=send) - await func(session) + exception_handlers: typing.Mapping[ + typing.Type[Exception], typing.Callable[..., typing.Any] + ] + status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + + response_started = False + + async def sender(message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + session = WebSocket(scope, receive=receive, send=sender) + + try: + await func(session) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc, exception_handlers) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if is_async_callable(handler): + await handler(session, exc) + else: + await run_in_threadpool(handler, session, exc) return app diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 05583a430..fcbaaed47 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -4,7 +4,8 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware -from starlette.responses import PlainTextResponse +from starlette.requests import Request +from starlette.responses import JSONResponse, PlainTextResponse, Response from starlette.routing import Route, Router, WebSocketRoute @@ -28,6 +29,20 @@ def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) +class BadBodyException(HTTPException): + pass + + +async def read_body_and_raise_exc(request: Request): + await request.body() + raise BadBodyException(422) + + +async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: + body = await request.body() + return JSONResponse(status_code=422, content={"body": body.decode()}) + + class HandledExcAfterResponse: async def __call__(self, scope, receive, send): response = PlainTextResponse("OK", status_code=200) @@ -44,11 +59,12 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), + Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), ] ) -app = ExceptionMiddleware(router) +app = ExceptionMiddleware(router, handlers={BadBodyException: handler_that_reads_body}) @pytest.fixture @@ -160,3 +176,9 @@ def test_exception_middleware_deprecation() -> None: with pytest.warns(DeprecationWarning): starlette.exceptions.ExceptionMiddleware + + +def test_request_in_app_and_handler_is_the_same_object(client) -> None: + response = client.post("/consume_body_in_endpoint_and_handler", content=b"Hello!") + assert response.status_code == 422 + assert response.json() == {"body": "Hello!"} From 23c918ed52451b50903cb57aaa819413496e9cdc Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:19:43 -0600 Subject: [PATCH 02/13] lint --- tests/test_exceptions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index fcbaaed47..3c36f67e1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -38,7 +38,9 @@ async def read_body_and_raise_exc(request: Request): raise BadBodyException(422) -async def handler_that_reads_body(request: Request, exc: BadBodyException) -> JSONResponse: +async def handler_that_reads_body( + request: Request, exc: BadBodyException +) -> JSONResponse: body = await request.body() return JSONResponse(status_code=422, content={"body": body.decode()}) @@ -59,7 +61,11 @@ async def __call__(self, scope, receive, send): Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), - Route("/consume_body_in_endpoint_and_handler", endpoint=read_body_and_raise_exc, methods=["POST"]), + Route( + "/consume_body_in_endpoint_and_handler", + endpoint=read_body_and_raise_exc, + methods=["POST"], + ), ] ) From 0737cc95760e94a3b05fe5f2ef94c6c34825dfe6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:21:32 -0600 Subject: [PATCH 03/13] lint --- tests/test_exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 3c36f67e1..0311c48fe 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -5,7 +5,7 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions import ExceptionMiddleware from starlette.requests import Request -from starlette.responses import JSONResponse, PlainTextResponse, Response +from starlette.responses import JSONResponse, PlainTextResponse from starlette.routing import Route, Router, WebSocketRoute From bdb9d68916fee0e89e20fe2423c524b8cac534bb Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:24:40 -0600 Subject: [PATCH 04/13] more checks --- starlette/routing.py | 6 +++--- tests/test_exceptions.py | 5 ++++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index ed85703f6..13f194bb8 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -17,7 +17,7 @@ from starlette.middleware import Middleware from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -84,7 +84,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False - async def sender(message) -> None: + async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": @@ -148,7 +148,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False - async def sender(message) -> None: + async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 0311c48fe..2f2b89167 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -70,7 +70,10 @@ async def __call__(self, scope, receive, send): ) -app = ExceptionMiddleware(router, handlers={BadBodyException: handler_that_reads_body}) +app = ExceptionMiddleware( + router, + handlers={BadBodyException: handler_that_reads_body}, # type: ignore[dict-item] +) @pytest.fixture From ce4a3d5b535def20edf4665ab97b47ea5fcc95ef Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 16:58:25 -0600 Subject: [PATCH 05/13] fix tests --- starlette/routing.py | 9 ++------- tests/test_routing.py | 6 +----- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 13f194bb8..111478c98 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -118,13 +118,8 @@ async def sender(message: Message) -> None: response = await handler(request, exc) else: response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, send) - else: - try: - await response(scope, receive, send) - except Exception as exc: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + + await response(scope, receive, send) return app diff --git a/tests/test_routing.py b/tests/test_routing.py index 09beb8bb9..cb8eeed01 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -945,13 +945,9 @@ async def modified_send(msg: Message) -> None: assert resp.status_code == 200, resp.content assert "X-Mounted" in resp.headers - # this is the "surprising" behavior bit - # the middleware on the mount never runs because there - # is nothing to catch the HTTPException - # since Mount middlweare is not wrapped by ExceptionMiddleware resp = client.get("/mount/err") assert resp.status_code == 403, resp.content - assert "X-Mounted" not in resp.headers + assert "X-Mounted" in resp.headers def test_route_repr() -> None: From 62807bd1659d53c2bca9ce79dbac7952132f761c Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 19:13:33 -0600 Subject: [PATCH 06/13] Update exceptions.py --- starlette/middleware/exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 6698705ea..53c4e5506 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -110,5 +110,5 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException - ) -> None: - await websocket.close(code=exc.code, reason=exc.reason) + ) -> None: + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover From 5ca404d7ffb58d4200f6dd05f1c3e961a0a1022f Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 27 Jan 2023 20:43:16 -0600 Subject: [PATCH 07/13] reformat --- starlette/middleware/exceptions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 53c4e5506..6033b9d38 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -110,5 +110,5 @@ def http_exception(self, request: Request, exc: HTTPException) -> Response: async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException - ) -> None: + ) -> None: await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover From 9468718e25f66b3e872789d16c3023ccdfa0b11e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 28 Jan 2023 01:25:21 -0600 Subject: [PATCH 08/13] Move wrapper into it's own file so it can be re-used --- starlette/middleware/exceptions.py | 114 -------------------- starlette/middleware/exceptions/__init__.py | 76 +++++++++++++ starlette/middleware/exceptions/_wrapper.py | 72 +++++++++++++ starlette/routing.py | 112 ++++--------------- 4 files changed, 170 insertions(+), 204 deletions(-) delete mode 100644 starlette/middleware/exceptions.py create mode 100644 starlette/middleware/exceptions/__init__.py create mode 100644 starlette/middleware/exceptions/_wrapper.py diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py deleted file mode 100644 index 6033b9d38..000000000 --- a/starlette/middleware/exceptions.py +++ /dev/null @@ -1,114 +0,0 @@ -import typing - -from starlette._utils import is_async_callable -from starlette.concurrency import run_in_threadpool -from starlette.exceptions import HTTPException, WebSocketException -from starlette.requests import Request -from starlette.responses import PlainTextResponse, Response -from starlette.types import ASGIApp, Message, Receive, Scope, Send -from starlette.websockets import WebSocket - - -class ExceptionMiddleware: - def __init__( - self, - app: ASGIApp, - handlers: typing.Optional[ - typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] - ] = None, - debug: bool = False, - ) -> None: - self.app = app - self.debug = debug # TODO: We ought to handle 404 cases if debug is set. - self._status_handlers: typing.Dict[int, typing.Callable] = {} - self._exception_handlers: typing.Dict[ - typing.Type[Exception], typing.Callable - ] = { - HTTPException: self.http_exception, - WebSocketException: self.websocket_exception, - } - if handlers is not None: - for key, value in handlers.items(): - self.add_exception_handler(key, value) - - def add_exception_handler( - self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], - handler: typing.Callable[[Request, Exception], Response], - ) -> None: - if isinstance(exc_class_or_status_code, int): - self._status_handlers[exc_class_or_status_code] = handler - else: - assert issubclass(exc_class_or_status_code, Exception) - self._exception_handlers[exc_class_or_status_code] = handler - - def _lookup_exception_handler( - self, exc: Exception - ) -> typing.Optional[typing.Callable]: - for cls in type(exc).__mro__: - if cls in self._exception_handlers: - return self._exception_handlers[cls] - return None - - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - if scope["type"] not in ("http", "websocket"): - await self.app(scope, receive, send) - return - - scope["starlette.exception_handlers"] = ( - self._exception_handlers, - self._status_handlers, - ) - - response_started = False - - async def sender(message: Message) -> None: - nonlocal response_started - - if message["type"] == "http.response.start": - response_started = True - await send(message) - - try: - await self.app(scope, receive, sender) - except Exception as exc: - handler = None - - if isinstance(exc, HTTPException): - handler = self._status_handlers.get(exc.status_code) - - if handler is None: - handler = self._lookup_exception_handler(exc) - - if handler is None: - raise exc - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc - - if scope["type"] == "http": - request = Request(scope, receive=receive) - if is_async_callable(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) - await response(scope, receive, sender) - elif scope["type"] == "websocket": - websocket = WebSocket(scope, receive=receive, send=send) - if is_async_callable(handler): - await handler(websocket, exc) - else: - await run_in_threadpool(handler, websocket, exc) - - def http_exception(self, request: Request, exc: HTTPException) -> Response: - if exc.status_code in {204, 304}: - return Response(status_code=exc.status_code, headers=exc.headers) - return PlainTextResponse( - exc.detail, status_code=exc.status_code, headers=exc.headers - ) - - async def websocket_exception( - self, websocket: WebSocket, exc: WebSocketException - ) -> None: - await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/starlette/middleware/exceptions/__init__.py b/starlette/middleware/exceptions/__init__.py new file mode 100644 index 000000000..b3f35c44d --- /dev/null +++ b/starlette/middleware/exceptions/__init__.py @@ -0,0 +1,76 @@ +import typing + +from starlette.exceptions import HTTPException, WebSocketException +from starlette.middleware.exceptions._wrapper import ( + ExcHandlers, + StatusHandlers, + wrap_app_handling_exceptions, +) +from starlette.requests import Request +from starlette.responses import PlainTextResponse, Response +from starlette.types import ASGIApp, Receive, Scope, Send +from starlette.websockets import WebSocket + + +class ExceptionMiddleware: + def __init__( + self, + app: ASGIApp, + handlers: typing.Optional[ + typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] + ] = None, + debug: bool = False, + ) -> None: + self.app = app + self.debug = debug # TODO: We ought to handle 404 cases if debug is set. + self._status_handlers: StatusHandlers = {} + self._exception_handlers: ExcHandlers = { + HTTPException: self.http_exception, + WebSocketException: self.websocket_exception, # type: ignore[dict-item] + } + if handlers is not None: + for key, value in handlers.items(): + self.add_exception_handler(key, value) + + def add_exception_handler( + self, + exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + handler: typing.Callable[[Request, Exception], Response], + ) -> None: + if isinstance(exc_class_or_status_code, int): + self._status_handlers[exc_class_or_status_code] = handler + else: + assert issubclass(exc_class_or_status_code, Exception) + self._exception_handlers[exc_class_or_status_code] = handler + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] not in ("http", "websocket"): + await self.app(scope, receive, send) + return + + scope["starlette.exception_handlers"] = ( + self._exception_handlers, + self._status_handlers, + ) + + conn: typing.Union[Request, WebSocket] + if scope["type"] == "http": + conn = Request(scope, receive, send) + else: + conn = WebSocket(scope, receive, send) + + await wrap_app_handling_exceptions( + self.app, self._exception_handlers, self._status_handlers, conn + )(scope, receive, send) + + def http_exception(self, request: Request, exc: Exception) -> Response: + assert isinstance(exc, HTTPException) + if exc.status_code in {204, 304}: + return Response(status_code=exc.status_code, headers=exc.headers) + return PlainTextResponse( + exc.detail, status_code=exc.status_code, headers=exc.headers + ) + + async def websocket_exception(self, websocket: WebSocket, exc: Exception) -> None: + assert isinstance(exc, WebSocketException) + await websocket.close(code=exc.code, reason=exc.reason) # pragma: no cover diff --git a/starlette/middleware/exceptions/_wrapper.py b/starlette/middleware/exceptions/_wrapper.py new file mode 100644 index 000000000..cc574fbda --- /dev/null +++ b/starlette/middleware/exceptions/_wrapper.py @@ -0,0 +1,72 @@ +import typing + +from starlette._utils import is_async_callable +from starlette.concurrency import run_in_threadpool +from starlette.exceptions import HTTPException +from starlette.requests import Request +from starlette.responses import Response +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.websockets import WebSocket + +Handler = typing.Callable[..., typing.Any] +ExcHandlers = typing.Dict[typing.Any, Handler] +StatusHandlers = typing.Dict[int, Handler] + + +def _lookup_exception_handler( + exc_handlers: ExcHandlers, exc: Exception +) -> typing.Optional[Handler]: + for cls in type(exc).__mro__: + if cls in exc_handlers: + return exc_handlers[cls] + return None + + +def wrap_app_handling_exceptions( + app: ASGIApp, + exc_handlers: ExcHandlers, + status_handlers: StatusHandlers, + conn: typing.Union[Request, WebSocket], +) -> ASGIApp: + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: + response_started = False + + async def sender(message: Message) -> None: + nonlocal response_started + + if message["type"] == "http.response.start": + response_started = True + await send(message) + + try: + await app(scope, receive, sender) + except Exception as exc: + handler = None + + if isinstance(exc, HTTPException): + handler = status_handlers.get(exc.status_code) + + if handler is None: + handler = _lookup_exception_handler(exc_handlers, exc) + + if handler is None: + raise exc + + if response_started: + msg = "Caught handled exception, but response already started." + raise RuntimeError(msg) from exc + + if scope["type"] == "http": + response: Response + if is_async_callable(handler): + response = await handler(conn, exc) + else: + response = await run_in_threadpool(handler, conn, exc) + await response(scope, receive, sender) + elif scope["type"] == "websocket": + if is_async_callable(handler): + await handler(conn, exc) + else: + await run_in_threadpool(handler, conn, exc) + + return wrapped_app diff --git a/starlette/routing.py b/starlette/routing.py index 111478c98..f4af8d063 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -15,9 +15,10 @@ from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException from starlette.middleware import Middleware +from starlette.middleware.exceptions._wrapper import wrap_app_handling_exceptions from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse -from starlette.types import ASGIApp, Message, Receive, Scope, Send +from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose @@ -53,73 +54,31 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover return inspect.iscoroutinefunction(obj) -def _lookup_exception_handler( - exc: Exception, - handlers: typing.Mapping[typing.Type[Exception], typing.Callable[..., typing.Any]], -) -> typing.Optional[typing.Callable[..., typing.Any]]: - for cls in type(exc).__mro__: - if cls in handlers: - return handlers[cls] - return None - - def request_response(func: typing.Callable) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, and returns an ASGI application. """ - is_coroutine = is_async_callable(func) async def app(scope: Scope, receive: Receive, send: Send) -> None: - exception_handlers: typing.Mapping[ - typing.Type[Exception], typing.Callable[..., typing.Any] - ] - status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] - - try: - exception_handlers, status_handlers = scope["starlette.exception_handlers"] - except KeyError: - exception_handlers, status_handlers = {}, {} - - response_started = False - - async def sender(message: Message) -> None: - nonlocal response_started + request = Request(scope, receive, send) - if message["type"] == "http.response.start": - response_started = True - await send(message) - - request = Request(scope, receive=receive, send=sender) - - try: + async def app(scope: Scope, receive: Receive, send: Send) -> None: if is_coroutine: response = await func(request) else: response = await run_in_threadpool(func, request) - except Exception as exc: - handler = None - - if isinstance(exc, HTTPException): - handler = status_handlers.get(exc.status_code) - - if handler is None: - handler = _lookup_exception_handler(exc, exception_handlers) - - if handler is None: - raise exc - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc + await response(scope, receive, send) - if is_async_callable(handler): - response = await handler(request, exc) - else: - response = await run_in_threadpool(handler, request, exc) + try: + exception_handlers, status_handlers = scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} - await response(scope, receive, send) + await wrap_app_handling_exceptions( + app, exception_handlers, status_handlers, request + )(scope, receive, send) return app @@ -131,49 +90,22 @@ def websocket_session(func: typing.Callable) -> ASGIApp: # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async" async def app(scope: Scope, receive: Receive, send: Send) -> None: - exception_handlers: typing.Mapping[ - typing.Type[Exception], typing.Callable[..., typing.Any] - ] - status_handlers: typing.Mapping[int, typing.Callable[..., typing.Any]] + session = WebSocket(scope, receive=receive, send=send) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + await func(session) try: exception_handlers, status_handlers = scope["starlette.exception_handlers"] except KeyError: exception_handlers, status_handlers = {}, {} - response_started = False - - async def sender(message: Message) -> None: - nonlocal response_started - - if message["type"] == "http.response.start": - response_started = True - await send(message) - - session = WebSocket(scope, receive=receive, send=sender) - - try: - await func(session) - except Exception as exc: - handler = None - - if isinstance(exc, HTTPException): - handler = status_handlers.get(exc.status_code) - - if handler is None: - handler = _lookup_exception_handler(exc, exception_handlers) - - if handler is None: - raise exc - - if response_started: - msg = "Caught handled exception, but response already started." - raise RuntimeError(msg) from exc - - if is_async_callable(handler): - await handler(session, exc) - else: - await run_in_threadpool(handler, session, exc) + await wrap_app_handling_exceptions( + app, + exception_handlers, + status_handlers, + session, + )(scope, receive, send) return app From 3cf38810e2c50bd143b08a3f2c7c4a3393ce357b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 1 Jun 2023 14:43:11 -0500 Subject: [PATCH 09/13] rename type alias from ExcHandlers to ExceptionHandlers --- starlette/middleware/exceptions/__init__.py | 4 ++-- starlette/middleware/exceptions/_wrapper.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/starlette/middleware/exceptions/__init__.py b/starlette/middleware/exceptions/__init__.py index b3f35c44d..88fa68ab3 100644 --- a/starlette/middleware/exceptions/__init__.py +++ b/starlette/middleware/exceptions/__init__.py @@ -2,7 +2,7 @@ from starlette.exceptions import HTTPException, WebSocketException from starlette.middleware.exceptions._wrapper import ( - ExcHandlers, + ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions, ) @@ -24,7 +24,7 @@ def __init__( self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. self._status_handlers: StatusHandlers = {} - self._exception_handlers: ExcHandlers = { + self._exception_handlers: ExceptionHandlers = { HTTPException: self.http_exception, WebSocketException: self.websocket_exception, # type: ignore[dict-item] } diff --git a/starlette/middleware/exceptions/_wrapper.py b/starlette/middleware/exceptions/_wrapper.py index cc574fbda..2a25fab85 100644 --- a/starlette/middleware/exceptions/_wrapper.py +++ b/starlette/middleware/exceptions/_wrapper.py @@ -9,12 +9,12 @@ from starlette.websockets import WebSocket Handler = typing.Callable[..., typing.Any] -ExcHandlers = typing.Dict[typing.Any, Handler] +ExceptionHandlers = typing.Dict[typing.Any, Handler] StatusHandlers = typing.Dict[int, Handler] def _lookup_exception_handler( - exc_handlers: ExcHandlers, exc: Exception + exc_handlers: ExceptionHandlers, exc: Exception ) -> typing.Optional[Handler]: for cls in type(exc).__mro__: if cls in exc_handlers: @@ -24,7 +24,7 @@ def _lookup_exception_handler( def wrap_app_handling_exceptions( app: ASGIApp, - exc_handlers: ExcHandlers, + exc_handlers: ExceptionHandlers, status_handlers: StatusHandlers, conn: typing.Union[Request, WebSocket], ) -> ASGIApp: From 6cc501e5c8b1c4e9c51e0a9d7840ea02ca2a8ea3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 7 Jun 2023 07:11:24 +0200 Subject: [PATCH 10/13] Refactor exception handling structure --- .../exceptions/_wrapper.py => _exception_handler.py} | 0 .../middleware/{exceptions/__init__.py => exceptions.py} | 4 ++-- starlette/routing.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename starlette/{middleware/exceptions/_wrapper.py => _exception_handler.py} (100%) rename starlette/middleware/{exceptions/__init__.py => exceptions.py} (98%) diff --git a/starlette/middleware/exceptions/_wrapper.py b/starlette/_exception_handler.py similarity index 100% rename from starlette/middleware/exceptions/_wrapper.py rename to starlette/_exception_handler.py diff --git a/starlette/middleware/exceptions/__init__.py b/starlette/middleware/exceptions.py similarity index 98% rename from starlette/middleware/exceptions/__init__.py rename to starlette/middleware/exceptions.py index 88fa68ab3..a0d77a0af 100644 --- a/starlette/middleware/exceptions/__init__.py +++ b/starlette/middleware/exceptions.py @@ -1,11 +1,11 @@ import typing -from starlette.exceptions import HTTPException, WebSocketException -from starlette.middleware.exceptions._wrapper import ( +from starlette._exception_handler import ( ExceptionHandlers, StatusHandlers, wrap_app_handling_exceptions, ) +from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Receive, Scope, Send diff --git a/starlette/routing.py b/starlette/routing.py index 63a4be848..b6ed09163 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -9,13 +9,13 @@ from contextlib import asynccontextmanager from enum import Enum +from starlette._exception_handler import wrap_app_handling_exceptions from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.convertors import CONVERTOR_TYPES, Convertor from starlette.datastructures import URL, Headers, URLPath from starlette.exceptions import HTTPException from starlette.middleware import Middleware -from starlette.middleware.exceptions._wrapper import wrap_app_handling_exceptions from starlette.requests import Request from starlette.responses import PlainTextResponse, RedirectResponse from starlette.types import ASGIApp, Lifespan, Receive, Scope, Send From f354ae614dd77e10363e4361d509052289356c72 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 7 Jun 2023 07:32:30 +0200 Subject: [PATCH 11/13] Remove implementation details from documentation --- docs/exceptions.md | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index f97f1af89..d376e1c13 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -104,20 +104,11 @@ it will be handled by the `handle_error` function, but at that point, the respon the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then it will use the response object - in the above example, the returned `JSONResponse`. -In order to deal with this behaviour correctly, the middleware stack of a -`Starlette` application is configured like this: - -* `ServerErrorMiddleware` - Returns 500 responses when server errors occur. -* Installed middleware -* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses. -* Router -* Endpoints - ## HTTPException The `HTTPException` class provides a base class that you can use for any -handled exceptions. The `ExceptionMiddleware` implementation defaults to -returning plain-text HTTP responses for any `HTTPException`. +handled exceptions. By default, a plain-text HTTP response is returned +for `HTTPException`s raised. * `HTTPException(status_code, detail=None, headers=None)` From 680143280dc8ccef04bf4a6d29110f639965b71b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 7 Jun 2023 07:34:54 +0200 Subject: [PATCH 12/13] Add documentation back --- docs/exceptions.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/exceptions.md b/docs/exceptions.md index d376e1c13..f97f1af89 100644 --- a/docs/exceptions.md +++ b/docs/exceptions.md @@ -104,11 +104,20 @@ it will be handled by the `handle_error` function, but at that point, the respon the response created by `handle_error` will be discarded. In case the error happens before the response was sent, then it will use the response object - in the above example, the returned `JSONResponse`. +In order to deal with this behaviour correctly, the middleware stack of a +`Starlette` application is configured like this: + +* `ServerErrorMiddleware` - Returns 500 responses when server errors occur. +* Installed middleware +* `ExceptionMiddleware` - Deals with handled exceptions, and returns responses. +* Router +* Endpoints + ## HTTPException The `HTTPException` class provides a base class that you can use for any -handled exceptions. By default, a plain-text HTTP response is returned -for `HTTPException`s raised. +handled exceptions. The `ExceptionMiddleware` implementation defaults to +returning plain-text HTTP responses for any `HTTPException`. * `HTTPException(status_code, detail=None, headers=None)` From 3fb111835755c65a4144603c552d9d06c54d80e2 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 7 Jun 2023 08:17:09 +0200 Subject: [PATCH 13/13] Move retrieval of handlers to `wrap_app_handling_exceptions` --- starlette/_exception_handler.py | 14 +++++++++----- starlette/middleware/exceptions.py | 4 +--- starlette/routing.py | 21 ++------------------- 3 files changed, 12 insertions(+), 27 deletions(-) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 2a25fab85..8a9beb3b2 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -23,11 +23,15 @@ def _lookup_exception_handler( def wrap_app_handling_exceptions( - app: ASGIApp, - exc_handlers: ExceptionHandlers, - status_handlers: StatusHandlers, - conn: typing.Union[Request, WebSocket], + app: ASGIApp, conn: typing.Union[Request, WebSocket] ) -> ASGIApp: + exception_handlers: ExceptionHandlers + status_handlers: StatusHandlers + try: + exception_handlers, status_handlers = conn.scope["starlette.exception_handlers"] + except KeyError: + exception_handlers, status_handlers = {}, {} + async def wrapped_app(scope: Scope, receive: Receive, send: Send) -> None: response_started = False @@ -47,7 +51,7 @@ async def sender(message: Message) -> None: handler = status_handlers.get(exc.status_code) if handler is None: - handler = _lookup_exception_handler(exc_handlers, exc) + handler = _lookup_exception_handler(exception_handlers, exc) if handler is None: raise exc diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index a0d77a0af..59010c7e6 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -59,9 +59,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: else: conn = WebSocket(scope, receive, send) - await wrap_app_handling_exceptions( - self.app, self._exception_handlers, self._status_handlers, conn - )(scope, receive, send) + await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send) def http_exception(self, request: Request, exc: Exception) -> Response: assert isinstance(exc, HTTPException) diff --git a/starlette/routing.py b/starlette/routing.py index b6ed09163..8e01c8562 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -71,14 +71,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: response = await run_in_threadpool(func, request) await response(scope, receive, send) - try: - exception_handlers, status_handlers = scope["starlette.exception_handlers"] - except KeyError: - exception_handlers, status_handlers = {}, {} - - await wrap_app_handling_exceptions( - app, exception_handlers, status_handlers, request - )(scope, receive, send) + await wrap_app_handling_exceptions(app, request)(scope, receive, send) return app @@ -95,17 +88,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: await func(session) - try: - exception_handlers, status_handlers = scope["starlette.exception_handlers"] - except KeyError: - exception_handlers, status_handlers = {}, {} - - await wrap_app_handling_exceptions( - app, - exception_handlers, - status_handlers, - session, - )(scope, receive, send) + await wrap_app_handling_exceptions(app, session)(scope, receive, send) return app