Skip to content

Commit

Permalink
Add in test function for condition (#2801)
Browse files Browse the repository at this point in the history
* Add in test function for condition
  • Loading branch information
kddejong authored Jul 20, 2023
1 parent 0c62bee commit 86ffbad
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 4 deletions.
26 changes: 25 additions & 1 deletion src/cfnlint/conditions/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union

from sympy import And, Not, Or, Symbol
from sympy.logic.boolalg import BooleanFunction
Expand Down Expand Up @@ -70,6 +70,14 @@ def build_cnf(
return params.get(self._fn_equals.hash)
return None

def _test(self, scenarios: Mapping[str, str]) -> bool:
if self._fn_equals:
return self._fn_equals.test(scenarios)
if self._condition:
# pylint: disable=W0212
return self._condition._test(scenarios)
return False


class ConditionList(Condition):
"""The generic class to represent any type of List condition
Expand Down Expand Up @@ -119,6 +127,10 @@ def build_cnf(self, params: Dict[str, Symbol]) -> BooleanFunction:

return And(*conditions)

def _test(self, scenarios: Mapping[str, str]) -> bool:
# pylint: disable=W0212
return all(condition._test(scenarios) for condition in self._conditions)


class ConditionNot(ConditionList):
"""Represents the logic specific to an Not Condition"""
Expand All @@ -139,6 +151,10 @@ def build_cnf(self, params: Dict[str, Symbol]) -> BooleanFunction:
"""
return Not(self._conditions[0].build_cnf(params))

def _test(self, scenarios: Mapping[str, str]) -> bool:
# pylint: disable=W0212
return not any(condition._test(scenarios) for condition in self._conditions)


class ConditionOr(ConditionList):
"""Represents the logic specific to an Or Condition"""
Expand All @@ -160,6 +176,10 @@ def build_cnf(self, params: Dict[str, Symbol]) -> BooleanFunction:
conditions.append(child.build_cnf(params))
return Or(*conditions)

def _test(self, scenarios: Mapping[str, str]) -> bool:
# pylint: disable=W0212
return any(condition._test(scenarios) for condition in self._conditions)


class ConditionUnnammed(Condition):
"""Represents an unnamed condition which is basically a nested Equals"""
Expand Down Expand Up @@ -205,3 +225,7 @@ def build_false_cnf(self, params: Dict[str, Symbol]) -> Any:
Any: A Not SymPy CNF clause
"""
return Not(self.build_true_cnf(params))

def test(self, scenarios: Mapping[str, str]) -> bool:
"""Test a condition based on a scenario"""
return self._test(scenarios)
4 changes: 4 additions & 0 deletions src/cfnlint/conditions/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def _init_parameters(self, cfn: Any) -> None:
if isinstance(allowed_value, (str, int, float, bool)):
self._parameters[param_hash].append(get_hash(str(allowed_value)))

def get(self, name: str, default: Any = None) -> ConditionNamed:
"""Return the conditions"""
return self._conditions.get(name, default)

def _build_cnf(
self, condition_names: List[str]
) -> Tuple[EncodedCNF, Dict[str, Any]]:
Expand Down
21 changes: 19 additions & 2 deletions src/cfnlint/conditions/equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import json
import logging
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Mapping, Tuple, Union

from cfnlint.conditions._utils import get_hash

Expand All @@ -20,6 +20,8 @@ def __init__(self, value: dict):
self.hash: str = get_hash(value)

def __eq__(self, __o: Any):
if isinstance(__o, str):
return self.hash == __o
return self.hash == __o.hash


Expand All @@ -46,7 +48,8 @@ def __init__(self, equal: List[Union[str, dict]]) -> None:
elif isinstance(self._left, EqualParameter) and isinstance(
self._right, EqualParameter
):
self._is_static = self._left == self._right
if self._left == self._right:
self._is_static = True

self._is_region = (False, "")
if isinstance(self._left, EqualParameter):
Expand Down Expand Up @@ -116,3 +119,17 @@ def left(self):
@property
def right(self):
return self._right

def test(self, scenarios: Mapping[str, str]) -> bool:
"""Do an equals based on the provided scenario"""
if self._is_static in [True, False]:
return self._is_static
for scenario, value in scenarios.items():
if isinstance(self._left, EqualParameter):
if scenario == self._left:
return value == self._right
if isinstance(self._right, EqualParameter):
if scenario == self._right:
return value == self._left

raise ValueError("An appropriate scenario was not found")
54 changes: 53 additions & 1 deletion test/unit/module/conditions/test_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import string
from unittest import TestCase

from cfnlint.conditions.condition import ConditionNot, ConditionUnnammed
from cfnlint.conditions._utils import get_hash
from cfnlint.conditions.condition import (
ConditionAnd,
ConditionNot,
ConditionOr,
ConditionUnnammed,
)
from cfnlint.decode import decode_str
from cfnlint.template import Template

Expand Down Expand Up @@ -33,3 +39,49 @@ def test_unnamed_condition(self):
"equals", # not a string
{},
)

def test_condition_test(self):
equals = {"Ref": "AWS::Region"}
h = get_hash(equals)
self.assertTrue(
ConditionNot(
[
{"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]},
],
{},
)._test(
{h: "us-west-2"},
)
)
self.assertFalse(
ConditionNot(
[
{"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]},
],
{},
)._test(
{h: "us-east-1"},
)
)
self.assertFalse(
ConditionAnd(
[
{"Fn::Equals": [{"Ref": "AWS::Region"}, "us-east-1"]},
{"Fn::Equals": [{"Ref": "AWS::Region"}, "us-west-2"]},
],
{},
)._test(
{h: "us-east-1"},
)
)
self.assertTrue(
ConditionOr(
[
{"Fn::Equals": ["us-east-1", {"Ref": "AWS::Region"}]},
{"Fn::Equals": ["us-west-2", {"Ref": "AWS::Region"}]},
],
{},
)._test(
{h: "us-east-1"},
)
)
34 changes: 34 additions & 0 deletions test/unit/module/conditions/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import string
from unittest import TestCase

from cfnlint.conditions._utils import get_hash
from cfnlint.decode import decode_str
from cfnlint.template import Template

Expand Down Expand Up @@ -229,3 +230,36 @@ def test_check_condition_region(self):
False,
],
)

def test_test_condition(self):
"""Get condition and test"""
template = decode_str(
"""
Parameters:
Environment:
Type: String
AllowedValues: ["prod", "dev", "stage"]
Conditions:
IsUsEast1: !Equals [!Ref AWS::Region, "us-east-1"]
IsUsWest2: !Equals ["us-west-2", !Ref AWS::Region]
IsProd: !Equals [!Ref Environment, "prod"]
IsUsEast1AndProd: !And [!Condition IsUsEast1, !Condition IsProd]
"""
)[0]

h_region = get_hash({"Ref": "AWS::Region"})
h_environment = get_hash({"Ref": "Environment"})

cfn = Template("", template)
self.assertTrue(cfn.conditions.get("IsUsEast1").test({h_region: "us-east-1"}))
self.assertFalse(cfn.conditions.get("IsProd").test({h_environment: "dev"}))
self.assertTrue(
cfn.conditions.get("IsUsEast1AndProd").test(
{h_region: "us-east-1", h_environment: "prod"}
)
)
self.assertFalse(
cfn.conditions.get("IsUsEast1AndProd").test(
{h_region: "us-east-1", h_environment: "dev"}
)
)
75 changes: 75 additions & 0 deletions test/unit/module/conditions/test_equals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: MIT-0
"""
from unittest import TestCase

from cfnlint.conditions._utils import get_hash
from cfnlint.conditions.equals import Equal
from cfnlint.decode import decode_str
from cfnlint.template import Template


class TestEquals(TestCase):
"""Test Equals"""

def setUp(self) -> None:
self.equals = {"Ref": "AWS::Region"}
self.equals_h = get_hash(self.equals)

self.parameter_region = {"Ref": "Region"}
self.parameter_region_h = get_hash(self.parameter_region)

self.parameter_env = {"Ref": "Environment"}
self.parameter_env_h = get_hash(self.parameter_env)

return super().setUp()

def test_equals_error_test(self):
"""Test equals scenarios condition"""
equal = Equal([self.equals, "us-east-1"])

with self.assertRaises(ValueError):
equal.test({"foo": "bar"})

def test_equals_left_test(self):
"""Test equals scenarios condition"""
equal = Equal([self.equals, "us-east-1"])

self.assertTrue(equal.test({self.equals_h: "us-east-1"}))
self.assertFalse(equal.test({self.equals_h: "us-west-2"}))

equal = Equal([self.equals, self.parameter_region])
self.assertTrue(
equal.test(
{
self.equals_h: self.parameter_region_h,
self.parameter_region_h: self.equals_h,
}
)
)
self.assertFalse(equal.test({self.equals_h: self.parameter_env_h}))

def test_equals_left_right(self):
"""Test equals scenarios condition"""
equal = Equal(["us-east-1", self.equals])

self.assertTrue(equal.test({self.equals_h: "us-east-1"}))
self.assertFalse(equal.test({self.equals_h: "us-west-2"}))

equal = Equal([self.parameter_region, self.equals])
self.assertTrue(
equal.test(
{
self.equals_h: self.parameter_region_h,
self.parameter_region_h: self.equals_h,
}
)
)
self.assertFalse(equal.test({self.equals_h: self.parameter_env_h}))

def test_equal_string_test(self):
"""Test equals scenarios condition"""
equal = Equal(["us-west-2", "us-east-1"])

self.assertFalse(equal.test({"foo": "bar"}))

0 comments on commit 86ffbad

Please sign in to comment.