1010from enum import Enum
1111from functools import partial
1212from http import HTTPStatus
13- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
13+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
1414
1515from aws_lambda_powertools .event_handler import content_types
16- from aws_lambda_powertools .event_handler .exceptions import ServiceError
16+ from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
1717from aws_lambda_powertools .shared import constants
1818from aws_lambda_powertools .shared .functions import resolve_truthy_env_var_choice
1919from aws_lambda_powertools .shared .json_encoder import Encoder
2727_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
2828# API GW/ALB decode non-safe URI chars; we must support them too
2929_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605
30-
3130_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{ _SAFE_URI } { _UNSAFE_URI } \\w]+)"
3231
3332
@@ -435,6 +434,7 @@ def __init__(
435434 self ._proxy_type = proxy_type
436435 self ._routes : List [Route ] = []
437436 self ._route_keys : List [str ] = []
437+ self ._exception_handlers : Dict [Type , Callable ] = {}
438438 self ._cors = cors
439439 self ._cors_enabled : bool = cors is not None
440440 self ._cors_methods : Set [str ] = {"OPTIONS" }
@@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
596596 headers ["Access-Control-Allow-Methods" ] = "," .join (sorted (self ._cors_methods ))
597597 return ResponseBuilder (Response (status_code = 204 , content_type = None , headers = headers , body = None ))
598598
599+ handler = self ._lookup_exception_handler (NotFoundError )
600+ if handler :
601+ return ResponseBuilder (handler (NotFoundError ()))
602+
599603 return ResponseBuilder (
600604 Response (
601605 status_code = HTTPStatus .NOT_FOUND .value ,
@@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
609613 """Actually call the matching route with any provided keyword arguments."""
610614 try :
611615 return ResponseBuilder (self ._to_response (route .func (** args )), route )
612- except ServiceError as e :
613- return ResponseBuilder (
614- Response (
615- status_code = e .status_code ,
616- content_type = content_types .APPLICATION_JSON ,
617- body = self ._json_dump ({"statusCode" : e .status_code , "message" : e .msg }),
618- ),
619- route ,
620- )
621- except Exception :
616+ except Exception as exc :
617+ response_builder = self ._call_exception_handler (exc , route )
618+ if response_builder :
619+ return response_builder
620+
622621 if self ._debug :
623622 # If the user has turned on debug mode,
624623 # we'll let the original exception propagate so
@@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
628627 status_code = 500 ,
629628 content_type = content_types .TEXT_PLAIN ,
630629 body = "" .join (traceback .format_exc ()),
631- )
630+ ),
631+ route ,
632632 )
633+
633634 raise
634635
636+ def not_found (self , func : Callable ):
637+ return self .exception_handler (NotFoundError )(func )
638+
639+ def exception_handler (self , exc_class : Type [Exception ]):
640+ def register_exception_handler (func : Callable ):
641+ self ._exception_handlers [exc_class ] = func
642+
643+ return register_exception_handler
644+
645+ def _lookup_exception_handler (self , exp_type : Type ) -> Optional [Callable ]:
646+ # Use "Method Resolution Order" to allow for matching against a base class
647+ # of an exception
648+ for cls in exp_type .__mro__ :
649+ if cls in self ._exception_handlers :
650+ return self ._exception_handlers [cls ]
651+ return None
652+
653+ def _call_exception_handler (self , exp : Exception , route : Route ) -> Optional [ResponseBuilder ]:
654+ handler = self ._lookup_exception_handler (type (exp ))
655+ if handler :
656+ return ResponseBuilder (handler (exp ), route )
657+
658+ if isinstance (exp , ServiceError ):
659+ return ResponseBuilder (
660+ Response (
661+ status_code = exp .status_code ,
662+ content_type = content_types .APPLICATION_JSON ,
663+ body = self ._json_dump ({"statusCode" : exp .status_code , "message" : exp .msg }),
664+ ),
665+ route ,
666+ )
667+
668+ return None
669+
635670 def _to_response (self , result : Union [Dict , Response ]) -> Response :
636671 """Convert the route's result to a Response
637672
0 commit comments