Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(parameters): distinct cache key for single vs path with same name #2839

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions aws_lambda_powertools/utilities/parameters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,16 @@ class BaseProvider(ABC):
Abstract Base Class for Parameter providers
"""

store: Dict[Tuple[str, TransformOptions], ExpirableValue]
store: Dict[Tuple, ExpirableValue]

def __init__(self):
"""
Initialize the base provider
"""

self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {}
self.store: Dict[Tuple, ExpirableValue] = {}

def has_not_expired_in_cache(self, key: Tuple[str, TransformOptions]) -> bool:
def has_not_expired_in_cache(self, key: Tuple) -> bool:
return key in self.store and self.store[key].ttl >= datetime.now()

def get(
Expand Down Expand Up @@ -123,13 +123,13 @@ def get(
# parameter will always be used in a specific transform, this should be
# an acceptable tradeoff.
value: Optional[Union[str, bytes, dict]] = None
key = (name, transform)
key = self._build_cache_key(name=name, transform=transform)

# If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS
max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age)

if not force_fetch and self.has_not_expired_in_cache(key):
return self.store[key].value
return self.fetch_from_cache(key)

try:
value = self._get(name, **sdk_options)
Expand All @@ -142,7 +142,7 @@ def get(

# NOTE: don't cache None, as they might've been failed transforms and may be corrected
if value is not None:
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))
self.add_to_cache(key=key, value=value, max_age=max_age)

return value

Expand Down Expand Up @@ -191,13 +191,13 @@ def get_multiple(
TransformParameterError
When the parameter provider fails to transform a parameter value.
"""
key = (path, transform)
key = self._build_cache_key(name=path, transform=transform, is_nested=True)

# If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS
max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age)

if not force_fetch and self.has_not_expired_in_cache(key):
return self.store[key].value # type: ignore # need to revisit entire typing here
return self.fetch_from_cache(key)

try:
values = self._get_multiple(path, **sdk_options)
Expand All @@ -208,7 +208,7 @@ def get_multiple(
if transform:
values.update(transform_value(values, transform, raise_on_transform_error))

self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))
self.add_to_cache(key=key, value=values, max_age=max_age)

return values

Expand All @@ -222,12 +222,39 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
def clear_cache(self):
self.store.clear()

def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int):
def fetch_from_cache(self, key: Tuple):
return self.store[key].value if key in self.store else {}

def add_to_cache(self, key: Tuple, value: Any, max_age: int):
if max_age <= 0:
return

self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))

def _build_cache_key(
self,
name: str,
transform: TransformOptions = None,
is_nested: bool = False,
):
"""Creates cache key for parameters

Parameters
----------
name : str
Name of parameter, secret or config
transform : TransformOptions, optional
Transform method used, by default None
is_nested : bool, optional
Whether it's a single parameter or multiple nested parameters, by default False

Returns
-------
Tuple[str, TransformOptions, bool]
Cache key
"""
return (name, transform, is_nested)

@staticmethod
def _build_boto3_client(
service_name: str,
Expand Down
34 changes: 30 additions & 4 deletions tests/functional/test_utilities_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def test_dynamodb_provider_get_cached(mock_name, mock_value, config):
provider = parameters.DynamoDBProvider(table_name, config=config)

# Inject value in the internal store
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
cache_key = provider._build_cache_key(name=mock_name)
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)

# Stub the boto3 client
stubber = stub.Stubber(provider.table.meta.client)
Expand Down Expand Up @@ -631,7 +632,8 @@ def test_ssm_provider_get_cached(mock_name, mock_value, config):
provider = parameters.SSMProvider(config=config)

# Inject value in the internal store
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
cache_key = provider._build_cache_key(name=mock_name)
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)

# Stub the boto3 client
stubber = stub.Stubber(provider.client)
Expand Down Expand Up @@ -1332,7 +1334,8 @@ def test_secrets_provider_get_cached(mock_name, mock_value, config):
provider = parameters.SecretsProvider(config=config)

# Inject value in the internal store
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
cache_key = provider._build_cache_key(name=mock_name)
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)

# Stub the boto3 client
stubber = stub.Stubber(provider.client)
Expand Down Expand Up @@ -1734,7 +1737,8 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:

provider = TestProvider()

provider.store[(mock_name, None)] = ExpirableValue({"A": mock_value}, datetime.now() + timedelta(seconds=60))
cache_key = provider._build_cache_key(name=mock_name, is_nested=True)
provider.add_to_cache(key=cache_key, value={"A": mock_value}, max_age=60)

value = provider.get_multiple(mock_name)

Expand Down Expand Up @@ -2500,3 +2504,25 @@ def test_cache_ignores_max_age_zero_or_negative(mock_value, config):
# THEN they should not be added to the cache
assert len(provider.store) == 0
assert provider.has_not_expired_in_cache(cache_key) is False


def test_base_provider_single_and_nested_parameters_cached(mock_name, mock_value):
# GIVEN a custom provider
class TestProvider(BaseProvider):
def _get(self, name: str, **kwargs) -> str:
raise ValueError("This parameter doesn't exist")

def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
return {"A": mock_value}

provider = TestProvider()

# WHEN get_multiple is followed by get with the same name
# (path vs single parameter name)
provider.get_multiple(mock_name)

# THEN get should raise GetParameterError
# since a path will likely not be a valid parameter
# see #2438
with pytest.raises(parameters.exceptions.GetParameterError):
provider.get(mock_name)