diff --git a/src/aks-preview/azext_aks_preview/decorator.py b/src/aks-preview/azext_aks_preview/decorator.py index 86b414f2e42..21695b2a8d7 100644 --- a/src/aks-preview/azext_aks_preview/decorator.py +++ b/src/aks-preview/azext_aks_preview/decorator.py @@ -13,11 +13,13 @@ AKSModels, AKSUpdateDecorator, safe_list_get, + safe_lower, ) from azure.cli.core import AzCommandsLoader from azure.cli.core.azclierror import ( CLIInternalError, InvalidArgumentValueError, + RequiredArgumentMissingError, ) from azure.cli.core.commands import AzCliCommand from azure.cli.core.profiles import ResourceType @@ -37,6 +39,7 @@ KubeletConfig = TypeVar("KubeletConfig") LinuxOSConfig = TypeVar("LinuxOSConfig") ManagedClusterHTTPProxyConfig = TypeVar("ManagedClusterHTTPProxyConfig") +ContainerServiceNetworkProfile = TypeVar("ContainerServiceNetworkProfile") # pylint: disable=too-many-instance-attributes,too-few-public-methods @@ -59,6 +62,11 @@ def __init__(self, cmd: AzCommandsLoader, resource_type: ResourceType): resource_type=self.resource_type, operation_group="managed_clusters", ) + self.ManagedClusterPodIdentityProfile = self.__cmd.get_models( + "ManagedClusterPodIdentityProfile", + resource_type=self.resource_type, + operation_group="managed_clusters", + ) # init nat gateway models self.init_nat_gateway_models() @@ -358,6 +366,136 @@ def get_nat_gateway_idle_timeout(self) -> Union[int, None]: # this parameter does not need validation return nat_gateway_idle_timeout + def get_enable_pod_security_policy(self) -> bool: + """Obtain the value of enable_pod_security_policy. + + :return: bool + """ + # read the original value passed by the command + enable_pod_security_policy = self.raw_param.get("enable_pod_security_policy") + # try to read the property value corresponding to the parameter from the `mc` object + if ( + self.mc and + self.mc.enable_pod_security_policy is not None + ): + enable_pod_security_policy = self.mc.enable_pod_security_policy + + # this parameter does not need dynamic completion + # this parameter does not need validation + return enable_pod_security_policy + + # pylint: disable=unused-argument + def _get_enable_managed_identity( + self, enable_validation: bool = False, read_only: bool = False, **kwargs + ) -> bool: + """Internal function to obtain the value of enable_pod_identity. + + Inherited and extended to perform additional validation. + + This function supports the option of enable_validation. When enabled, if enable_managed_identity is not + specified but enable_pod_identity is, raise a RequiredArgumentMissingError. + + :return: bool + """ + enable_managed_identity = super()._get_enable_managed_identity(enable_validation, read_only, **kwargs) + # additional validation + if enable_validation: + if not enable_managed_identity and self._get_enable_pod_identity(enable_validation=False): + raise RequiredArgumentMissingError( + "--enable-pod-identity can only be specified when --enable-managed-identity is specified" + ) + return enable_managed_identity + + # pylint: disable=unused-argument + def _get_enable_pod_identity(self, enable_validation: bool = False, **kwargs) -> bool: + """Internal function to obtain the value of enable_pod_identity. + + This function supports the option of enable_validation. When enabled, if enable_managed_identity is not + specified but enable_pod_identity is, raise a RequiredArgumentMissingError. If network_profile has been set + up in `mc`, network_plugin equals to "kubenet" and enable_pod_identity is specified but + enable_pod_identity_with_kubenet is not, raise a RequiredArgumentMissingError. + + :return: bool + """ + # read the original value passed by the command + enable_pod_identity = self.raw_param.get("enable_pod_identity") + # try to read the property value corresponding to the parameter from the `mc` object + if ( + self.mc and + self.mc.pod_identity_profile and + self.mc.pod_identity_profile.enabled is not None + ): + enable_pod_identity = self.mc.pod_identity_profile.enabled + + # this parameter does not need dynamic completion + # validation + if enable_validation: + if enable_pod_identity and not self._get_enable_managed_identity(enable_validation=False): + raise RequiredArgumentMissingError( + "--enable-pod-identity can only be specified when --enable-managed-identity is specified" + ) + if self.mc and self.mc.network_profile and safe_lower(self.mc.network_profile.network_plugin) == "kubenet": + if enable_pod_identity and not self._get_enable_pod_identity_with_kubenet(enable_validation=False): + raise RequiredArgumentMissingError( + "--enable-pod-identity-with-kubenet is required for enabling pod identity addon " + "when using Kubenet network plugin" + ) + return enable_pod_identity + + def get_enable_pod_identity(self) -> bool: + """Obtain the value of enable_pod_identity. + + This function will verify the parameter by default. If enable_managed_identity is not specified but + enable_pod_identity is, raise a RequiredArgumentMissingError. If network_profile has been set up in `mc`, + network_plugin equals to "kubenet" and enable_pod_identity is specified but enable_pod_identity_with_kubenet + is not, raise a RequiredArgumentMissingError. + + :return: bool + """ + + return self._get_enable_pod_identity(enable_validation=True) + + def _get_enable_pod_identity_with_kubenet(self, enable_validation: bool = False, **kwargs) -> bool: + """Internal function to obtain the value of enable_pod_identity_with_kubenet. + + This function supports the option of enable_validation. When enabled, if network_profile has been set up in + `mc`, network_plugin equals to "kubenet" and enable_pod_identity is specified but + enable_pod_identity_with_kubenet is not, raise a RequiredArgumentMissingError. + + :return: bool + """ + # read the original value passed by the command + enable_pod_identity_with_kubenet = self.raw_param.get("enable_pod_identity_with_kubenet") + # try to read the property value corresponding to the parameter from the `mc` object + if ( + self.mc and + self.mc.pod_identity_profile and + self.mc.pod_identity_profile.allow_network_plugin_kubenet is not None + ): + enable_pod_identity_with_kubenet = self.mc.pod_identity_profile.allow_network_plugin_kubenet + + # this parameter does not need dynamic completion + # validation + if enable_validation: + if self.mc and self.mc.network_profile and safe_lower(self.mc.network_profile.network_plugin) == "kubenet": + if not enable_pod_identity_with_kubenet and self._get_enable_pod_identity(enable_validation=False): + raise RequiredArgumentMissingError( + "--enable-pod-identity-with-kubenet is required for enabling pod identity addon " + "when using Kubenet network plugin" + ) + return enable_pod_identity_with_kubenet + + def get_enable_pod_identity_with_kubenet(self) -> bool: + """Obtain the value of enable_pod_identity_with_kubenet. + + This function will verify the parameter by default. If network_profile has been set up in `mc`, network_plugin + equals to "kubenet" and enable_pod_identity is specified but enable_pod_identity_with_kubenet is not, raise a + RequiredArgumentMissingError. + + :return: bool + """ + return self._get_enable_pod_identity_with_kubenet(enable_validation=True) + class AKSPreviewCreateDecorator(AKSCreateDecorator): # pylint: disable=super-init-not-called @@ -461,6 +599,42 @@ def set_up_network_profile(self, mc: ManagedCluster) -> ManagedCluster: mc.network_profile = network_profile return mc + def set_up_pod_security_policy(self, mc: ManagedCluster) -> ManagedCluster: + """Set up pod security policy for the ManagedCluster object. + + :return: the ManagedCluster object + """ + if not isinstance(mc, self.models.ManagedCluster): + raise CLIInternalError( + "Unexpected mc object with type '{}'.".format(type(mc)) + ) + + mc.enable_pod_security_policy = self.context.get_enable_pod_security_policy() + return mc + + def set_up_pod_identity_profile(self, mc: ManagedCluster) -> ManagedCluster: + """Set up pod identity profile for the ManagedCluster object. + + This profile depends on network profile. + + :return: the ManagedCluster object + """ + if not isinstance(mc, self.models.ManagedCluster): + raise CLIInternalError( + "Unexpected mc object with type '{}'.".format(type(mc)) + ) + + pod_identity_profile = None + enable_pod_identity = self.context.get_enable_pod_identity() + enable_pod_identity_with_kubenet = self.context.get_enable_pod_identity_with_kubenet() + if enable_pod_identity: + pod_identity_profile = self.models.ManagedClusterPodIdentityProfile( + enabled=True, + allow_network_plugin_kubenet=enable_pod_identity_with_kubenet, + ) + mc.pod_identity_profile = pod_identity_profile + return mc + def construct_preview_mc_profile(self) -> ManagedCluster: """The overall controller used to construct the preview ManagedCluster profile. @@ -475,6 +649,10 @@ def construct_preview_mc_profile(self) -> ManagedCluster: mc = self.set_up_http_proxy_config(mc) # set up node resource group mc = self.set_up_node_resource_group(mc) + # set up pod security policy + mc = self.set_up_pod_security_policy(mc) + # set up pod identity profile + mc = self.set_up_pod_identity_profile(mc) return mc diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py b/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py index c64c46fe1ec..1f24ba88a2c 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py @@ -346,6 +346,124 @@ def test_get_nat_gateway_idle_timeout(self): ctx_1.attach_mc(mc) self.assertEqual(ctx_1.get_nat_gateway_idle_timeout(), 20) + def test_get_enable_pod_security_policy(self): + # default + ctx_1 = AKSPreviewContext( + self.cmd, + {"enable_pod_security_policy": False}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + self.assertEqual(ctx_1.get_enable_pod_security_policy(), False) + mc = self.models.ManagedCluster( + location="test_location", + enable_pod_security_policy=True, + ) + ctx_1.attach_mc(mc) + self.assertEqual(ctx_1.get_enable_pod_security_policy(), True) + + def test_get_enable_managed_identity(self): + # custom value + ctx_1 = AKSPreviewContext( + self.cmd, + {"enable_managed_identity": False, "enable_pod_identity": True}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + with self.assertRaises(RequiredArgumentMissingError): + self.assertEqual(ctx_1.get_enable_managed_identity(), False) + + def test_get_enable_pod_identity(self): + # default + ctx_1 = AKSPreviewContext( + self.cmd, + {"enable_pod_identity": False}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + self.assertEqual(ctx_1.get_enable_pod_identity(), False) + pod_identity_profile = self.models.ManagedClusterPodIdentityProfile( + enabled=True + ) + mc = self.models.ManagedCluster( + location="test_location", + pod_identity_profile=pod_identity_profile, + ) + ctx_1.attach_mc(mc) + # fail on enable_managed_identity not specified + with self.assertRaises(RequiredArgumentMissingError): + self.assertEqual(ctx_1.get_enable_pod_identity(), True) + + # custom value + ctx_2 = AKSPreviewContext( + self.cmd, + { + "enable_managed_identity": True, + "enable_pod_identity": True, + "enable_pod_identity_with_kubenet": False, + }, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + network_profile_2 = self.models.ContainerServiceNetworkProfile( + network_plugin="kubenet" + ) + mc_2 = self.models.ManagedCluster( + location="test_location", + network_profile=network_profile_2, + ) + ctx_2.attach_mc(mc_2) + # fail on enable_pod_identity_with_kubenet not specified + with self.assertRaises(RequiredArgumentMissingError): + self.assertEqual(ctx_2.get_enable_pod_identity(), True) + + def test_get_enable_pod_identity_with_kubenet(self): + # default + ctx_1 = AKSPreviewContext( + self.cmd, + {"enable_pod_identity_with_kubenet": False}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + self.assertEqual(ctx_1.get_enable_pod_identity_with_kubenet(), False) + pod_identity_profile = self.models.ManagedClusterPodIdentityProfile( + enabled=True, + allow_network_plugin_kubenet=True, + ) + mc = self.models.ManagedCluster( + location="test_location", + pod_identity_profile=pod_identity_profile, + ) + ctx_1.attach_mc(mc) + # fail on enable_managed_identity not specified + # with self.assertRaises(RequiredArgumentMissingError): + self.assertEqual(ctx_1.get_enable_pod_identity_with_kubenet(), True) + + # custom value + ctx_2 = AKSPreviewContext( + self.cmd, + { + "enable_managed_identity": True, + "enable_pod_identity": True, + "enable_pod_identity_with_kubenet": False, + }, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + network_profile_2 = self.models.ContainerServiceNetworkProfile( + network_plugin="kubenet" + ) + mc_2 = self.models.ManagedCluster( + location="test_location", + network_profile=network_profile_2, + ) + ctx_2.attach_mc(mc_2) + # fail on enable_pod_identity_with_kubenet not specified + with self.assertRaises(RequiredArgumentMissingError): + self.assertEqual( + ctx_2.get_enable_pod_identity_with_kubenet(), False + ) + class AKSPreviewCreateDecoratorTestCase(unittest.TestCase): def setUp(self): @@ -660,7 +778,9 @@ def test_set_up_network_profile(self): mc_2 = self.models.ManagedCluster(location="test_location") dec_mc_2 = dec_2.set_up_network_profile(mc_2) - nat_gateway_profile_2 = self.models.nat_gateway_models.get("ManagedClusterNATGatewayProfile")( + nat_gateway_profile_2 = self.models.nat_gateway_models.get( + "ManagedClusterNATGatewayProfile" + )( managed_outbound_ip_profile=self.models.nat_gateway_models.get( "ManagedClusterManagedOutboundIPProfile" )(count=10), @@ -681,6 +801,94 @@ def test_set_up_network_profile(self): ) self.assertEqual(dec_mc_2, ground_truth_mc_2) + def test_set_up_pod_security_policy(self): + # default value in `aks_create` + dec_1 = AKSPreviewCreateDecorator( + self.cmd, + self.client, + { + "enable_pod_security_policy": False, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + mc_1 = self.models.ManagedCluster(location="test_location") + # fail on passing the wrong mc object + with self.assertRaises(CLIInternalError): + dec_1.set_up_pod_security_policy(None) + dec_mc_1 = dec_1.set_up_pod_security_policy(mc_1) + ground_truth_mc_1 = self.models.ManagedCluster( + location="test_location", enable_pod_security_policy=False + ) + self.assertEqual(dec_mc_1, ground_truth_mc_1) + + # custom value + dec_2 = AKSPreviewCreateDecorator( + self.cmd, + self.client, + {"enable_pod_security_policy": True}, + CUSTOM_MGMT_AKS_PREVIEW, + ) + mc_2 = self.models.ManagedCluster(location="test_location") + dec_mc_2 = dec_2.set_up_pod_security_policy(mc_2) + ground_truth_mc_2 = self.models.ManagedCluster( + location="test_location", + enable_pod_security_policy=True, + ) + self.assertEqual(dec_mc_2, ground_truth_mc_2) + + def test_set_up_pod_identity_profile(self): + # default value in `aks_create` + dec_1 = AKSPreviewCreateDecorator( + self.cmd, + self.client, + { + "enable_pod_identity": False, + "enable_pod_identity_with_kubenet": False, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + mc_1 = self.models.ManagedCluster(location="test_location") + # fail on passing the wrong mc object + with self.assertRaises(CLIInternalError): + dec_1.set_up_pod_identity_profile(None) + dec_mc_1 = dec_1.set_up_pod_identity_profile(mc_1) + ground_truth_mc_1 = self.models.ManagedCluster(location="test_location") + self.assertEqual(dec_mc_1, ground_truth_mc_1) + + # custom value + dec_2 = AKSPreviewCreateDecorator( + self.cmd, + self.client, + { + "enable_managed_identity": True, + "enable_pod_identity": True, + "enable_pod_identity_with_kubenet": True, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + network_profile_2 = self.models.ContainerServiceNetworkProfile( + network_plugin="kubenet" + ) + mc_2 = self.models.ManagedCluster( + location="test_location", network_profile=network_profile_2 + ) + dec_mc_2 = dec_2.set_up_pod_identity_profile(mc_2) + ground_truth_network_profile_2 = ( + self.models.ContainerServiceNetworkProfile(network_plugin="kubenet") + ) + ground_truth_pod_identity_profile_2 = ( + self.models.ManagedClusterPodIdentityProfile( + enabled=True, + allow_network_plugin_kubenet=True, + ) + ) + ground_truth_mc_2 = self.models.ManagedCluster( + location="test_location", + network_profile=ground_truth_network_profile_2, + pod_identity_profile=ground_truth_pod_identity_profile_2, + ) + self.assertEqual(dec_mc_2, ground_truth_mc_2) + def test_construct_preview_mc_profile(self): pass