From 2502198ceaaa4383d1811b3482215dd772c46e68 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Tue, 25 Jul 2023 13:28:51 +0200 Subject: [PATCH 1/3] refactor(parameters): use newer add_to_cache, fetch_from_cache --- aws_lambda_powertools/utilities/parameters/base.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index 4357b5d520e..badc6fc1cea 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -129,7 +129,7 @@ def get( max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) if not force_fetch and self.has_not_expired_in_cache(key): - return self.store[key].value + return self.fetch_from_cache(key) try: value = self._get(name, **sdk_options) @@ -142,7 +142,7 @@ def get( # NOTE: don't cache None, as they might've been failed transforms and may be corrected if value is not None: - self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + self.add_to_cache(key=key, value=value, max_age=max_age) return value @@ -197,7 +197,7 @@ def get_multiple( max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) if not force_fetch and self.has_not_expired_in_cache(key): - return self.store[key].value # type: ignore # need to revisit entire typing here + return self.fetch_from_cache(key) try: values = self._get_multiple(path, **sdk_options) @@ -208,7 +208,7 @@ def get_multiple( if transform: values.update(transform_value(values, transform, raise_on_transform_error)) - self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age)) + self.add_to_cache(key=key, value=values, max_age=max_age) return values @@ -222,6 +222,9 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def clear_cache(self): self.store.clear() + def fetch_from_cache(self, key: Tuple[str, TransformOptions]): + return self.store[key].value if key in self.store else {} + def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): if max_age <= 0: return From aead4a231f87b26c596ddcfb30e2a523649286f1 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Tue, 25 Jul 2023 14:39:50 +0200 Subject: [PATCH 2/3] fix(parameters): make cache aware of single vs multiple calls Signed-off-by: heitorlessa --- .../utilities/parameters/base.py | 24 ++++++++++++------- .../utilities/parameters/types.py | 1 + tests/functional/test_utilities_parameters.py | 12 ++++++---- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index badc6fc1cea..b8be4d1acb6 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -27,7 +27,7 @@ from aws_lambda_powertools.shared import constants, user_agent from aws_lambda_powertools.shared.functions import resolve_max_age -from aws_lambda_powertools.utilities.parameters.types import TransformOptions +from aws_lambda_powertools.utilities.parameters.types import RecursiveOptions, TransformOptions from .exceptions import GetParameterError, TransformParameterError @@ -66,16 +66,16 @@ class BaseProvider(ABC): Abstract Base Class for Parameter providers """ - store: Dict[Tuple[str, TransformOptions], ExpirableValue] + store: Dict[Tuple, ExpirableValue] def __init__(self): """ Initialize the base provider """ - self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {} + self.store: Dict[Tuple, ExpirableValue] = {} - def has_not_expired_in_cache(self, key: Tuple[str, TransformOptions]) -> bool: + def has_not_expired_in_cache(self, key: Tuple) -> bool: return key in self.store and self.store[key].ttl >= datetime.now() def get( @@ -123,7 +123,7 @@ def get( # parameter will always be used in a specific transform, this should be # an acceptable tradeoff. value: Optional[Union[str, bytes, dict]] = None - key = (name, transform) + key = self._build_cache_key(name=name, transform_options=transform, is_recursive=False) # If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) @@ -191,7 +191,7 @@ def get_multiple( TransformParameterError When the parameter provider fails to transform a parameter value. """ - key = (path, transform) + key = self._build_cache_key(name=path, transform_options=transform, is_recursive=True) # If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) @@ -222,15 +222,23 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: def clear_cache(self): self.store.clear() - def fetch_from_cache(self, key: Tuple[str, TransformOptions]): + def fetch_from_cache(self, key: Tuple): return self.store[key].value if key in self.store else {} - def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int): + def add_to_cache(self, key: Tuple, value: Any, max_age: int): if max_age <= 0: return self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age)) + def _build_cache_key( + self, + name: str, + transform_options: TransformOptions = None, + is_recursive: RecursiveOptions = False, + ): + return (name, transform_options, is_recursive) + @staticmethod def _build_boto3_client( service_name: str, diff --git a/aws_lambda_powertools/utilities/parameters/types.py b/aws_lambda_powertools/utilities/parameters/types.py index 6a15873c496..2dbf1593d72 100644 --- a/aws_lambda_powertools/utilities/parameters/types.py +++ b/aws_lambda_powertools/utilities/parameters/types.py @@ -1,3 +1,4 @@ from typing_extensions import Literal TransformOptions = Literal["json", "binary", "auto", None] +RecursiveOptions = Literal[True, False] diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index f151c1cd781..8bcc30fc244 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -139,7 +139,8 @@ def test_dynamodb_provider_get_cached(mock_name, mock_value, config): provider = parameters.DynamoDBProvider(table_name, config=config) # Inject value in the internal store - provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + cache_key = provider._build_cache_key(name=mock_name) + provider.add_to_cache(key=cache_key, value=mock_value, max_age=60) # Stub the boto3 client stubber = stub.Stubber(provider.table.meta.client) @@ -631,7 +632,8 @@ def test_ssm_provider_get_cached(mock_name, mock_value, config): provider = parameters.SSMProvider(config=config) # Inject value in the internal store - provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + cache_key = provider._build_cache_key(name=mock_name) + provider.add_to_cache(key=cache_key, value=mock_value, max_age=60) # Stub the boto3 client stubber = stub.Stubber(provider.client) @@ -1332,7 +1334,8 @@ def test_secrets_provider_get_cached(mock_name, mock_value, config): provider = parameters.SecretsProvider(config=config) # Inject value in the internal store - provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + cache_key = provider._build_cache_key(name=mock_name) + provider.add_to_cache(key=cache_key, value=mock_value, max_age=60) # Stub the boto3 client stubber = stub.Stubber(provider.client) @@ -1734,7 +1737,8 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: provider = TestProvider() - provider.store[(mock_name, None)] = ExpirableValue({"A": mock_value}, datetime.now() + timedelta(seconds=60)) + cache_key = provider._build_cache_key(name=mock_name, is_recursive=True) + provider.add_to_cache(key=cache_key, value={"A": mock_value}, max_age=60) value = provider.get_multiple(mock_name) From 275b418f2a4cdab6e14ea8ea56df913547c57d14 Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Tue, 25 Jul 2023 15:16:51 +0200 Subject: [PATCH 3/3] chore: cleanup, add test for single and nested Signed-off-by: heitorlessa --- .../utilities/parameters/base.py | 28 +++++++++++++++---- .../utilities/parameters/types.py | 1 - tests/functional/test_utilities_parameters.py | 24 +++++++++++++++- 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py index b8be4d1acb6..e4be9d33cdc 100644 --- a/aws_lambda_powertools/utilities/parameters/base.py +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -27,7 +27,7 @@ from aws_lambda_powertools.shared import constants, user_agent from aws_lambda_powertools.shared.functions import resolve_max_age -from aws_lambda_powertools.utilities.parameters.types import RecursiveOptions, TransformOptions +from aws_lambda_powertools.utilities.parameters.types import TransformOptions from .exceptions import GetParameterError, TransformParameterError @@ -123,7 +123,7 @@ def get( # parameter will always be used in a specific transform, this should be # an acceptable tradeoff. value: Optional[Union[str, bytes, dict]] = None - key = self._build_cache_key(name=name, transform_options=transform, is_recursive=False) + key = self._build_cache_key(name=name, transform=transform) # If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) @@ -191,7 +191,7 @@ def get_multiple( TransformParameterError When the parameter provider fails to transform a parameter value. """ - key = self._build_cache_key(name=path, transform_options=transform, is_recursive=True) + key = self._build_cache_key(name=path, transform=transform, is_nested=True) # If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age) @@ -234,10 +234,26 @@ def add_to_cache(self, key: Tuple, value: Any, max_age: int): def _build_cache_key( self, name: str, - transform_options: TransformOptions = None, - is_recursive: RecursiveOptions = False, + transform: TransformOptions = None, + is_nested: bool = False, ): - return (name, transform_options, is_recursive) + """Creates cache key for parameters + + Parameters + ---------- + name : str + Name of parameter, secret or config + transform : TransformOptions, optional + Transform method used, by default None + is_nested : bool, optional + Whether it's a single parameter or multiple nested parameters, by default False + + Returns + ------- + Tuple[str, TransformOptions, bool] + Cache key + """ + return (name, transform, is_nested) @staticmethod def _build_boto3_client( diff --git a/aws_lambda_powertools/utilities/parameters/types.py b/aws_lambda_powertools/utilities/parameters/types.py index 2dbf1593d72..6a15873c496 100644 --- a/aws_lambda_powertools/utilities/parameters/types.py +++ b/aws_lambda_powertools/utilities/parameters/types.py @@ -1,4 +1,3 @@ from typing_extensions import Literal TransformOptions = Literal["json", "binary", "auto", None] -RecursiveOptions = Literal[True, False] diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py index 8bcc30fc244..7822ff80949 100644 --- a/tests/functional/test_utilities_parameters.py +++ b/tests/functional/test_utilities_parameters.py @@ -1737,7 +1737,7 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: provider = TestProvider() - cache_key = provider._build_cache_key(name=mock_name, is_recursive=True) + cache_key = provider._build_cache_key(name=mock_name, is_nested=True) provider.add_to_cache(key=cache_key, value={"A": mock_value}, max_age=60) value = provider.get_multiple(mock_name) @@ -2504,3 +2504,25 @@ def test_cache_ignores_max_age_zero_or_negative(mock_value, config): # THEN they should not be added to the cache assert len(provider.store) == 0 assert provider.has_not_expired_in_cache(cache_key) is False + + +def test_base_provider_single_and_nested_parameters_cached(mock_name, mock_value): + # GIVEN a custom provider + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise ValueError("This parameter doesn't exist") + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + return {"A": mock_value} + + provider = TestProvider() + + # WHEN get_multiple is followed by get with the same name + # (path vs single parameter name) + provider.get_multiple(mock_name) + + # THEN get should raise GetParameterError + # since a path will likely not be a valid parameter + # see #2438 + with pytest.raises(parameters.exceptions.GetParameterError): + provider.get(mock_name)