|
27 | 27 | _SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt |
28 | 28 | # API GW/ALB decode non-safe URI chars; we must support them too |
29 | 29 | _UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605 |
30 | | - |
31 | 30 | _NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)" |
32 | 31 |
|
33 | 32 |
|
@@ -435,7 +434,7 @@ def __init__( |
435 | 434 | self._proxy_type = proxy_type |
436 | 435 | self._routes: List[Route] = [] |
437 | 436 | self._route_keys: List[str] = [] |
438 | | - self._exception_handlers: Dict[Union[int, Type], Callable] = {} |
| 437 | + self._exception_handlers: Dict[Type, Callable] = {} |
439 | 438 | self._cors = cors |
440 | 439 | self._cors_enabled: bool = cors is not None |
441 | 440 | self._cors_methods: Set[str] = {"OPTIONS"} |
@@ -597,8 +596,7 @@ def _not_found(self, method: str) -> ResponseBuilder: |
597 | 596 | headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) |
598 | 597 | return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None)) |
599 | 598 |
|
600 | | - # Allow for custom exception handlers |
601 | | - handler = self._exception_handlers.get(404) |
| 599 | + handler = self._lookup_exception_handler(NotFoundError) |
602 | 600 | if handler: |
603 | 601 | return ResponseBuilder(handler(NotFoundError())) |
604 | 602 |
|
@@ -635,6 +633,40 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder: |
635 | 633 |
|
636 | 634 | raise |
637 | 635 |
|
| 636 | + def not_found(self, func: Callable): |
| 637 | + return self.exception_handler(NotFoundError)(func) |
| 638 | + |
| 639 | + def exception_handler(self, exc_class: Type[Exception]): |
| 640 | + def register_exception_handler(func: Callable): |
| 641 | + self._exception_handlers[exc_class] = func |
| 642 | + |
| 643 | + return register_exception_handler |
| 644 | + |
| 645 | + def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]: |
| 646 | + # Use "Method Resolution Order" to allow for matching against a base class |
| 647 | + # of an exception |
| 648 | + for cls in exp_type.__mro__: |
| 649 | + if cls in self._exception_handlers: |
| 650 | + return self._exception_handlers[cls] |
| 651 | + return None |
| 652 | + |
| 653 | + def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: |
| 654 | + handler = self._lookup_exception_handler(type(exp)) |
| 655 | + if handler: |
| 656 | + return ResponseBuilder(handler(exp), route) |
| 657 | + |
| 658 | + if isinstance(exp, ServiceError): |
| 659 | + return ResponseBuilder( |
| 660 | + Response( |
| 661 | + status_code=exp.status_code, |
| 662 | + content_type=content_types.APPLICATION_JSON, |
| 663 | + body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), |
| 664 | + ), |
| 665 | + route, |
| 666 | + ) |
| 667 | + |
| 668 | + return None |
| 669 | + |
638 | 670 | def _to_response(self, result: Union[Dict, Response]) -> Response: |
639 | 671 | """Convert the route's result to a Response |
640 | 672 |
|
@@ -679,38 +711,6 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None |
679 | 711 |
|
680 | 712 | self.route(*route)(func) |
681 | 713 |
|
682 | | - def not_found(self, func: Callable): |
683 | | - return self.exception_handler(404)(func) |
684 | | - |
685 | | - def exception_handler(self, exc_class_or_status_code: Union[int, Type[Exception]]): |
686 | | - def register_exception_handler(func: Callable): |
687 | | - self._exception_handlers[exc_class_or_status_code] = func |
688 | | - |
689 | | - return register_exception_handler |
690 | | - |
691 | | - def _lookup_exception_handler(self, exp: Exception) -> Optional[Callable]: |
692 | | - for cls in type(exp).__mro__: |
693 | | - if cls in self._exception_handlers: |
694 | | - return self._exception_handlers[cls] |
695 | | - return None |
696 | | - |
697 | | - def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]: |
698 | | - handler = self._lookup_exception_handler(exp) |
699 | | - if handler: |
700 | | - return ResponseBuilder(handler(exp), route) |
701 | | - |
702 | | - if isinstance(exp, ServiceError): |
703 | | - return ResponseBuilder( |
704 | | - Response( |
705 | | - status_code=exp.status_code, |
706 | | - content_type=content_types.APPLICATION_JSON, |
707 | | - body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}), |
708 | | - ), |
709 | | - route, |
710 | | - ) |
711 | | - |
712 | | - return None |
713 | | - |
714 | 714 |
|
715 | 715 | class Router(BaseRouter): |
716 | 716 | """Router helper class to allow splitting ApiGatewayResolver into multiple files""" |
|
0 commit comments