77from  abc  import  ABC , abstractmethod 
88from  collections  import  namedtuple 
99from  datetime  import  datetime , timedelta 
10- from  typing  import  TYPE_CHECKING , Any , Dict , Optional , Tuple , Type , Union 
10+ from  typing  import  (
11+     TYPE_CHECKING ,
12+     Any ,
13+     Callable ,
14+     Dict ,
15+     Optional ,
16+     Tuple ,
17+     Type ,
18+     Union ,
19+     cast ,
20+     overload ,
21+ )
1122
1223import  boto3 
1324from  botocore .config  import  Config 
1425
26+ from  aws_lambda_powertools .utilities .parameters .types  import  TransformOptions 
27+ 
1528from  .exceptions  import  GetParameterError , TransformParameterError 
1629
1730if  TYPE_CHECKING :
3043SUPPORTED_TRANSFORM_METHODS  =  [TRANSFORM_METHOD_JSON , TRANSFORM_METHOD_BINARY ]
3144ParameterClients  =  Union ["AppConfigDataClient" , "SecretsManagerClient" , "SSMClient" ]
3245
46+ TRANSFORM_METHOD_MAPPING  =  {
47+     TRANSFORM_METHOD_JSON : json .loads ,
48+     TRANSFORM_METHOD_BINARY : base64 .b64decode ,
49+     ".json" : json .loads ,
50+     ".binary" : base64 .b64decode ,
51+     None : lambda  x : x ,
52+ }
53+ 
3354
3455class  BaseProvider (ABC ):
3556    """ 
@@ -52,7 +73,7 @@ def get(
5273        self ,
5374        name : str ,
5475        max_age : int  =  DEFAULT_MAX_AGE_SECS ,
55-         transform : Optional [ str ]  =  None ,
76+         transform : TransformOptions  =  None ,
5677        force_fetch : bool  =  False ,
5778        ** sdk_options ,
5879    ) ->  Optional [Union [str , dict , bytes ]]:
@@ -124,7 +145,7 @@ def get_multiple(
124145        self ,
125146        path : str ,
126147        max_age : int  =  DEFAULT_MAX_AGE_SECS ,
127-         transform : Optional [ str ]  =  None ,
148+         transform : TransformOptions  =  None ,
128149        raise_on_transform_error : bool  =  False ,
129150        force_fetch : bool  =  False ,
130151        ** sdk_options ,
@@ -170,13 +191,7 @@ def get_multiple(
170191            raise  GetParameterError (str (exc ))
171192
172193        if  transform :
173-             transformed_values : dict  =  {}
174-             for  (item , value ) in  values .items ():
175-                 _transform  =  get_transform_method (item , transform )
176-                 if  not  _transform :
177-                     continue 
178-                 transformed_values [item ] =  transform_value (value , _transform , raise_on_transform_error )
179-             values .update (transformed_values )
194+             values .update (transform_value (values , transform , raise_on_transform_error ))
180195        self .store [key ] =  ExpirableValue (values , datetime .now () +  timedelta (seconds = max_age ))
181196
182197        return  values 
@@ -258,7 +273,7 @@ def _build_boto3_resource_client(
258273        return  session .resource (service_name = service_name , config = config , endpoint_url = endpoint_url )
259274
260275
261- def  get_transform_method (key : str , transform : Optional [ str ]  =  None ) ->  Optional [ str ]:
276+ def  get_transform_method (key : str , transform : TransformOptions  =  None ) ->  Callable [...,  Any ]:
262277    """ 
263278    Determine the transform method 
264279
@@ -278,37 +293,50 @@ def get_transform_method(key: str, transform: Optional[str] = None) -> Optional[
278293    Parameters 
279294    --------- 
280295    key: str 
281-         Only used when the tranform  is "auto". 
296+         Only used when the transform  is "auto". 
282297    transform: str, optional 
283298        Original transform method, only "auto" will try to detect the transform method by the key 
284299
285300    Returns 
286301    ------ 
287-     Optional[str]: 
288-         The transform method either when transform is "auto" then None, "json" or "binary" is returned 
289-         or the original transform method 
302+     Callable: 
303+         Transform function could be json.loads, base64.b64decode, or a lambda that echo the str value 
290304    """ 
291-     if  transform  !=  "auto" :
292-         return  transform 
305+     transform_method  =  TRANSFORM_METHOD_MAPPING .get (transform )
306+ 
307+     if  transform  ==  "auto" :
308+         key_suffix  =  key .rsplit ("." )[- 1 ]
309+         transform_method  =  TRANSFORM_METHOD_MAPPING .get (key_suffix , TRANSFORM_METHOD_MAPPING [None ])
310+ 
311+     return  cast (Callable , transform_method )  # https://github.com/python/mypy/issues/10740 
312+ 
313+ 
314+ @overload  
315+ def  transform_value (
316+     value : Dict [str , Any ], transform : TransformOptions , raise_on_transform_error : bool  =  False 
317+ ) ->  Dict [str , Any ]:
318+     ...
293319
294-     for  transform_method  in  SUPPORTED_TRANSFORM_METHODS :
295-         if  key .endswith ("."  +  transform_method ):
296-             return  transform_method 
297-     return  None 
320+ 
321+ @overload  
322+ def  transform_value (
323+     value : Union [str , bytes , Dict [str , Any ]], transform : TransformOptions , raise_on_transform_error : bool  =  False 
324+ ) ->  Optional [Union [str , bytes , Dict [str , Any ]]]:
325+     ...
298326
299327
300328def  transform_value (
301-     value : str , transform : str , raise_on_transform_error : Optional [ bool ]  =  True 
302- ) ->  Optional [Union [dict , bytes ]]:
329+     value : Union [ str , bytes ,  Dict [ str ,  Any ]],  transform : TransformOptions , raise_on_transform_error : bool  =  False 
330+ ) ->  Optional [Union [str , bytes ,  Dict [ str ,  Any ] ]]:
303331    """ 
304-     Apply  a transform to a value  
332+     Transform  a value using one of the available options.  
305333
306334    Parameters 
307335    --------- 
308336    value: str 
309337        Parameter value to transform 
310338    transform: str 
311-         Type of transform, supported values are "json"  and "binary"  
339+         Type of transform, supported values are "json", "binary",  and "auto" based on suffix (.json, .binary)  
312340    raise_on_transform_error: bool, optional 
313341        Raises an exception if any transform fails, otherwise this will 
314342        return a None value for each transform that failed 
@@ -318,18 +346,35 @@ def transform_value(
318346    TransformParameterError: 
319347        When the parameter value could not be transformed 
320348    """ 
349+     # Maintenance: For v3, we should consider returning the original value for soft transform failures. 
321350
322-     try :
323-         if  transform  ==  TRANSFORM_METHOD_JSON :
324-             return  json .loads (value )
325-         elif  transform  ==  TRANSFORM_METHOD_BINARY :
326-             return  base64 .b64decode (value )
327-         else :
328-             raise  ValueError (f"Invalid transform type '{ transform }  )
351+     err_msg  =  "Unable to transform value using '{transform}' transform: {exc}" 
352+ 
353+     if  isinstance (value , bytes ):
354+         value  =  value .decode ("utf-8" )
329355
356+     if  isinstance (value , dict ):
357+         # NOTE: We must handle partial failures when receiving multiple values 
358+         # where one of the keys might fail during transform, e.g. `{"a": "valid", "b": "{"}` 
359+         # expected: `{"a": "valid", "b": None}` 
360+ 
361+         transformed_values : Dict [str , Any ] =  {}
362+         for  dict_key , dict_value  in  value .items ():
363+             transform_method  =  get_transform_method (key = dict_key , transform = transform )
364+             try :
365+                 transformed_values [dict_key ] =  transform_method (dict_value )
366+             except  Exception  as  exc :
367+                 if  raise_on_transform_error :
368+                     raise  TransformParameterError (err_msg .format (transform = transform , exc = exc )) from  exc 
369+                 transformed_values [dict_key ] =  None 
370+         return  transformed_values 
371+ 
372+     try :
373+         transform_method  =  get_transform_method (key = value , transform = transform )
374+         return  transform_method (value )
330375    except  Exception  as  exc :
331376        if  raise_on_transform_error :
332-             raise  TransformParameterError (str ( exc )) 
377+             raise  TransformParameterError (err_msg . format ( transform = transform ,  exc = exc ))  from   exc 
333378        return  None 
334379
335380
0 commit comments