|  | 
| 1 | 1 | import logging | 
|  | 2 | +from abc import ABC | 
| 2 | 3 | from typing import Any, Callable, Optional, Type, TypeVar | 
| 3 | 4 | 
 | 
| 4 | 5 | from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent | 
|  | 
| 9 | 10 | AppSyncResolverEventT = TypeVar("AppSyncResolverEventT", bound=AppSyncResolverEvent) | 
| 10 | 11 | 
 | 
| 11 | 12 | 
 | 
| 12 |  | -class AppSyncResolver: | 
|  | 13 | +class BaseRouter(ABC): | 
|  | 14 | +    current_event: AppSyncResolverEventT  # type: ignore[valid-type] | 
|  | 15 | +    lambda_context: LambdaContext | 
|  | 16 | + | 
|  | 17 | +    def __init__(self): | 
|  | 18 | +        self._resolvers: dict = {} | 
|  | 19 | + | 
|  | 20 | +    def resolver(self, type_name: str = "*", field_name: Optional[str] = None): | 
|  | 21 | +        """Registers the resolver for field_name | 
|  | 22 | +
 | 
|  | 23 | +        Parameters | 
|  | 24 | +        ---------- | 
|  | 25 | +        type_name : str | 
|  | 26 | +            Type name | 
|  | 27 | +        field_name : str | 
|  | 28 | +            Field name | 
|  | 29 | +        """ | 
|  | 30 | + | 
|  | 31 | +        def register_resolver(func): | 
|  | 32 | +            logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") | 
|  | 33 | +            self._resolvers[f"{type_name}.{field_name}"] = {"func": func} | 
|  | 34 | +            return func | 
|  | 35 | + | 
|  | 36 | +        return register_resolver | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +class AppSyncResolver(BaseRouter): | 
| 13 | 40 |     """ | 
| 14 | 41 |     AppSync resolver decorator | 
| 15 | 42 | 
 | 
| @@ -40,29 +67,8 @@ def common_field() -> str: | 
| 40 | 67 |             return str(uuid.uuid4()) | 
| 41 | 68 |     """ | 
| 42 | 69 | 
 | 
| 43 |  | -    current_event: AppSyncResolverEventT  # type: ignore[valid-type] | 
| 44 |  | -    lambda_context: LambdaContext | 
| 45 |  | - | 
| 46 | 70 |     def __init__(self): | 
| 47 |  | -        self._resolvers: dict = {} | 
| 48 |  | - | 
| 49 |  | -    def resolver(self, type_name: str = "*", field_name: Optional[str] = None): | 
| 50 |  | -        """Registers the resolver for field_name | 
| 51 |  | -
 | 
| 52 |  | -        Parameters | 
| 53 |  | -        ---------- | 
| 54 |  | -        type_name : str | 
| 55 |  | -            Type name | 
| 56 |  | -        field_name : str | 
| 57 |  | -            Field name | 
| 58 |  | -        """ | 
| 59 |  | - | 
| 60 |  | -        def register_resolver(func): | 
| 61 |  | -            logger.debug(f"Adding resolver `{func.__name__}` for field `{type_name}.{field_name}`") | 
| 62 |  | -            self._resolvers[f"{type_name}.{field_name}"] = {"func": func} | 
| 63 |  | -            return func | 
| 64 |  | - | 
| 65 |  | -        return register_resolver | 
|  | 71 | +        super().__init__() | 
| 66 | 72 | 
 | 
| 67 | 73 |     def resolve( | 
| 68 | 74 |         self, event: dict, context: LambdaContext, data_model: Type[AppSyncResolverEvent] = AppSyncResolverEvent | 
| @@ -136,10 +142,10 @@ def lambda_handler(event, context): | 
| 136 | 142 |         ValueError | 
| 137 | 143 |             If we could not find a field resolver | 
| 138 | 144 |         """ | 
| 139 |  | -        self.current_event = data_model(event) | 
| 140 |  | -        self.lambda_context = context | 
| 141 |  | -        resolver = self._get_resolver(self.current_event.type_name, self.current_event.field_name) | 
| 142 |  | -        return resolver(**self.current_event.arguments) | 
|  | 145 | +        BaseRouter.current_event = data_model(event) | 
|  | 146 | +        BaseRouter.lambda_context = context | 
|  | 147 | +        resolver = self._get_resolver(BaseRouter.current_event.type_name, BaseRouter.current_event.field_name) | 
|  | 148 | +        return resolver(**BaseRouter.current_event.arguments) | 
| 143 | 149 | 
 | 
| 144 | 150 |     def _get_resolver(self, type_name: str, field_name: str) -> Callable: | 
| 145 | 151 |         """Get resolver for field_name | 
| @@ -167,3 +173,18 @@ def __call__( | 
| 167 | 173 |     ) -> Any: | 
| 168 | 174 |         """Implicit lambda handler which internally calls `resolve`""" | 
| 169 | 175 |         return self.resolve(event, context, data_model) | 
|  | 176 | + | 
|  | 177 | +    def include_router(self, router: "Router") -> None: | 
|  | 178 | +        """Adds all resolvers defined in a router | 
|  | 179 | +
 | 
|  | 180 | +        Parameters | 
|  | 181 | +        ---------- | 
|  | 182 | +        router : Router | 
|  | 183 | +            A router containing a dict of field resolvers | 
|  | 184 | +        """ | 
|  | 185 | +        self._resolvers.update(router._resolvers) | 
|  | 186 | + | 
|  | 187 | + | 
|  | 188 | +class Router(BaseRouter): | 
|  | 189 | +    def __init__(self): | 
|  | 190 | +        super().__init__() | 
0 commit comments