diff --git a/src/aks-preview/azext_aks_preview/_validators.py b/src/aks-preview/azext_aks_preview/_validators.py index 965069f855c..6006a8f3053 100644 --- a/src/aks-preview/azext_aks_preview/_validators.py +++ b/src/aks-preview/azext_aks_preview/_validators.py @@ -22,7 +22,6 @@ from ._consts import ADDONS - logger = get_logger(__name__) @@ -215,7 +214,7 @@ def validate_spot_max_price(namespace): if not isnan(namespace.spot_max_price): if namespace.priority != "Spot": raise CLIError("--spot_max_price can only be set when --priority is Spot") - if namespace.spot_max_price * 100000 % 1 != 0: + if namespace.spot_max_price > 0 and not isclose(namespace.spot_max_price * 100000 % 1, 0, rel_tol=1e-06): raise CLIError("--spot_max_price can only include up to 5 decimal places") if namespace.spot_max_price <= 0 and not isclose(namespace.spot_max_price, -1.0, rel_tol=1e-06): raise CLIError( diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_validators.py b/src/aks-preview/azext_aks_preview/tests/latest/test_validators.py index 4514aeb7f5a..07532408b83 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_validators.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_validators.py @@ -154,6 +154,12 @@ def __init__(self, max_surge): self.max_surge = max_surge +class SpotMaxPriceNamespace: + def __init__(self, spot_max_price): + self.priority = "Spot" + self.spot_max_price = spot_max_price + + class TestMaxSurge(unittest.TestCase): def test_valid_cases(self): valid = ["5", "33%", "1", "100%"] @@ -171,6 +177,30 @@ def test_throws_on_negative(self): self.assertTrue('positive' in str(cm.exception), msg=str(cm.exception)) +class TestSpotMaxPrice(unittest.TestCase): + def test_valid_cases(self): + valid = [5, 5.12345, -1.0] + for v in valid: + validators.validate_spot_max_price(SpotMaxPriceNamespace(v)) + + def test_throws_if_more_than_5(self): + with self.assertRaises(CLIError) as cm: + validators.validate_spot_max_price(SpotMaxPriceNamespace(5.123456)) + self.assertTrue('--spot_max_price can only include up to 5 decimal places' in str(cm.exception), msg=str(cm.exception)) + + def test_throws_if_non_valid_negative(self): + with self.assertRaises(CLIError) as cm: + validators.validate_spot_max_price(SpotMaxPriceNamespace(-2)) + self.assertTrue('--spot_max_price can only be any decimal value greater than zero, or -1 which indicates' in str(cm.exception), msg=str(cm.exception)) + + def test_throws_if_input_max_price_for_regular(self): + ns = SpotMaxPriceNamespace(2) + ns.priority = "Regular" + with self.assertRaises(CLIError) as cm: + validators.validate_spot_max_price(ns) + self.assertTrue('--spot_max_price can only be set when --priority is Spot' in str(cm.exception), msg=str(cm.exception)) + + class ValidateAddonsNamespace: def __init__(self, addons): self.addons = addons