diff --git a/src/aks-preview/azext_aks_preview/decorator.py b/src/aks-preview/azext_aks_preview/decorator.py index d06a7c36685..b69a7596eb0 100644 --- a/src/aks-preview/azext_aks_preview/decorator.py +++ b/src/aks-preview/azext_aks_preview/decorator.py @@ -25,6 +25,7 @@ InvalidArgumentValueError, MutuallyExclusiveArgumentError, RequiredArgumentMissingError, + UnknownError, ) from azure.cli.core.commands import AzCliCommand from azure.cli.core.profiles import ResourceType @@ -1014,15 +1015,16 @@ def _get_enable_windows_gmsa(self, enable_validation: bool = False, **kwargs) -> """ # read the original value passed by the command enable_windows_gmsa = self.raw_param.get("enable_windows_gmsa") - # try to read the property value corresponding to the parameter from the `mc` object - if ( - self.mc and - self.mc.windows_profile and - hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility - self.mc.windows_profile.gmsa_profile and - self.mc.windows_profile.gmsa_profile.enabled is not None - ): - enable_windows_gmsa = self.mc.windows_profile.gmsa_profile.enabled + # In create mode, try to read the property value corresponding to the parameter from the `mc` object. + if self.decorator_mode == DecoratorMode.CREATE: + if ( + self.mc and + self.mc.windows_profile and + hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility + self.mc.windows_profile.gmsa_profile and + self.mc.windows_profile.gmsa_profile.enabled is not None + ): + enable_windows_gmsa = self.mc.windows_profile.gmsa_profile.enabled # this parameter does not need dynamic completion # validation @@ -1064,32 +1066,34 @@ def _get_gmsa_dns_server_and_root_domain_name(self, enable_validation: bool = Fa # gmsa_dns_server # read the original value passed by the command gmsa_dns_server = self.raw_param.get("gmsa_dns_server") - # try to read the property value corresponding to the parameter from the `mc` object + # In create mode, try to read the property value corresponding to the parameter from the `mc` object. gmsa_dns_read_from_mc = False - if ( - self.mc and - self.mc.windows_profile and - hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility - self.mc.windows_profile.gmsa_profile and - self.mc.windows_profile.gmsa_profile.dns_server is not None - ): - gmsa_dns_server = self.mc.windows_profile.gmsa_profile.dns_server - gmsa_dns_read_from_mc = True + if self.decorator_mode == DecoratorMode.CREATE: + if ( + self.mc and + self.mc.windows_profile and + hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility + self.mc.windows_profile.gmsa_profile and + self.mc.windows_profile.gmsa_profile.dns_server is not None + ): + gmsa_dns_server = self.mc.windows_profile.gmsa_profile.dns_server + gmsa_dns_read_from_mc = True # gmsa_root_domain_name # read the original value passed by the command gmsa_root_domain_name = self.raw_param.get("gmsa_root_domain_name") - # try to read the property value corresponding to the parameter from the `mc` object + # In create mode, try to read the property value corresponding to the parameter from the `mc` object. gmsa_root_read_from_mc = False - if ( - self.mc and - self.mc.windows_profile and - hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility - self.mc.windows_profile.gmsa_profile and - self.mc.windows_profile.gmsa_profile.root_domain_name is not None - ): - gmsa_root_domain_name = self.mc.windows_profile.gmsa_profile.root_domain_name - gmsa_root_read_from_mc = True + if self.decorator_mode == DecoratorMode.CREATE: + if ( + self.mc and + self.mc.windows_profile and + hasattr(self.mc.windows_profile, "gmsa_profile") and # backward compatibility + self.mc.windows_profile.gmsa_profile and + self.mc.windows_profile.gmsa_profile.root_domain_name is not None + ): + gmsa_root_domain_name = self.mc.windows_profile.gmsa_profile.root_domain_name + gmsa_root_read_from_mc = True # consistent check if gmsa_dns_read_from_mc != gmsa_root_read_from_mc: @@ -1718,6 +1722,8 @@ def __init__( def update_load_balancer_profile(self, mc: ManagedCluster) -> ManagedCluster: """Update load balancer profile for the ManagedCluster object. + Note: Inherited and extended in aks-preview to set dual stack related properties. + :return: the ManagedCluster object """ mc = super().update_load_balancer_profile(mc) @@ -1750,6 +1756,30 @@ def update_pod_security_policy(self, mc: ManagedCluster) -> ManagedCluster: mc.enable_pod_security_policy = False return mc + def update_windows_profile(self, mc: ManagedCluster) -> ManagedCluster: + """Update windows profile for the ManagedCluster object. + + Note: Inherited and extended in aks-preview to set gmsa related properties. + + :return: the ManagedCluster object + """ + mc = super().update_windows_profile(mc) + windows_profile = mc.windows_profile + + if self.context.get_enable_windows_gmsa(): + if not windows_profile: + raise UnknownError( + "Encounter an unexpected error while getting windows profile " + "from the cluster in the process of update." + ) + gmsa_dns_server, gmsa_root_domain_name = self.context.get_gmsa_dns_server_and_root_domain_name() + windows_profile.gmsa_profile = self.models.WindowsGmsaProfile( + enabled=True, + dns_server=gmsa_dns_server, + root_domain_name=gmsa_root_domain_name, + ) + return mc + def update_mc_preview_profile(self) -> ManagedCluster: """The overall controller used to update the preview ManagedCluster profile. diff --git a/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py b/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py index 4adf8f433e9..103aa89986e 100644 --- a/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py +++ b/src/aks-preview/azext_aks_preview/tests/latest/test_decorator.py @@ -54,6 +54,7 @@ InvalidArgumentValueError, MutuallyExclusiveArgumentError, RequiredArgumentMissingError, + UnknownError, ) from msrestazure.azure_exceptions import CloudError @@ -217,15 +218,17 @@ def test_get_pod_cidrs(self): # default ctx_1 = AKSPreviewContext( self.cmd, - {'pod_cidrs': '10.244.0.0/16,2001:abcd::/64'}, + {"pod_cidrs": "10.244.0.0/16,2001:abcd::/64"}, self.models, decorator_mode=DecoratorMode.CREATE, ) - self.assertEqual(ctx_1.get_pod_cidrs(), ['10.244.0.0/16','2001:abcd::/64']) + self.assertEqual( + ctx_1.get_pod_cidrs(), ["10.244.0.0/16", "2001:abcd::/64"] + ) ctx_2 = AKSPreviewContext( self.cmd, - {'pod_cidrs': ''}, + {"pod_cidrs": ""}, self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -233,7 +236,7 @@ def test_get_pod_cidrs(self): ctx_3 = AKSPreviewContext( self.cmd, - {'pod_cidrs': None}, + {"pod_cidrs": None}, self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -243,15 +246,17 @@ def test_get_service_cidrs(self): # default ctx_1 = AKSPreviewContext( self.cmd, - {'service_cidrs': '10.244.0.0/16,2001:abcd::/64'}, + {"service_cidrs": "10.244.0.0/16,2001:abcd::/64"}, self.models, decorator_mode=DecoratorMode.CREATE, ) - self.assertEqual(ctx_1.get_service_cidrs(), ['10.244.0.0/16','2001:abcd::/64']) + self.assertEqual( + ctx_1.get_service_cidrs(), ["10.244.0.0/16", "2001:abcd::/64"] + ) ctx_2 = AKSPreviewContext( self.cmd, - {'service_cidrs': ''}, + {"service_cidrs": ""}, self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -259,7 +264,7 @@ def test_get_service_cidrs(self): ctx_3 = AKSPreviewContext( self.cmd, - {'service_cidrs': None}, + {"service_cidrs": None}, self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -269,15 +274,15 @@ def test_get_ip_families(self): # default ctx_1 = AKSPreviewContext( self.cmd, - {'ip_families': 'IPv4,IPv6'}, + {"ip_families": "IPv4,IPv6"}, self.models, decorator_mode=DecoratorMode.CREATE, ) - self.assertEqual(ctx_1.get_ip_families(), ['IPv4','IPv6']) + self.assertEqual(ctx_1.get_ip_families(), ["IPv4", "IPv6"]) ctx_2 = AKSPreviewContext( self.cmd, - {'ip_families': ''}, + {"ip_families": ""}, self.models, decorator_mode=DecoratorMode.CREATE, ) @@ -285,22 +290,64 @@ def test_get_ip_families(self): ctx_3 = AKSPreviewContext( self.cmd, - {'ip_families': None}, + {"ip_families": None}, self.models, decorator_mode=DecoratorMode.CREATE, ) self.assertEqual(ctx_3.get_ip_families(), None) - def test_get_ipv6_count(self): + def test_get_load_balancer_managed_outbound_ipv6_count(self): # default ctx_1 = AKSPreviewContext( self.cmd, - {'load_balancer_managed_outbound_ipv6_count': 4}, + { + "load_balancer_managed_outbound_ipv6_count": None, + }, self.models, decorator_mode=DecoratorMode.CREATE, ) self.assertEqual( - ctx_1.get_load_balancer_managed_outbound_ipv6_count(), 4) + ctx_1.get_load_balancer_managed_outbound_ipv6_count(), None + ) + load_balancer_profile = self.models.lb_models.get( + "ManagedClusterLoadBalancerProfile" + )( + managed_outbound_i_ps=self.models.lb_models.get( + "ManagedClusterLoadBalancerProfileManagedOutboundIPs" + )(count_ipv6=10) + ) + network_profile = self.models.ContainerServiceNetworkProfile( + load_balancer_profile=load_balancer_profile + ) + mc = self.models.ManagedCluster( + location="test_location", network_profile=network_profile + ) + ctx_1.attach_mc(mc) + self.assertEqual( + ctx_1.get_load_balancer_managed_outbound_ipv6_count(), 10 + ) + + # custom value + ctx_2 = AKSPreviewContext( + self.cmd, + {"load_balancer_managed_outbound_ipv6_count": 4}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + self.assertEqual( + ctx_2.get_load_balancer_managed_outbound_ipv6_count(), 4 + ) + + # custom value + ctx_3 = AKSPreviewContext( + self.cmd, + {"load_balancer_managed_outbound_ipv6_count": 0}, + self.models, + decorator_mode=DecoratorMode.CREATE, + ) + self.assertEqual( + ctx_3.get_load_balancer_managed_outbound_ipv6_count(), 0 + ) def test_get_enable_fips_image(self): # default @@ -1263,6 +1310,7 @@ def test_test_get_outbound_type(self): load_balancer_profile=load_balancer_profile, ) + class AKSPreviewCreateDecoratorTestCase(unittest.TestCase): def setUp(self): # manually register CUSTOM_MGMT_AKS_PREVIEW @@ -2278,12 +2326,57 @@ def setUp(self): self.models = AKSPreviewModels(self.cmd, CUSTOM_MGMT_AKS_PREVIEW) self.client = MockClient() - def test_update_ipv6_count(self): + def test_update_load_balancer_profile(self): + # default value in `aks_update` + dec_1 = AKSPreviewUpdateDecorator( + self.cmd, + self.client, + { + "load_balancer_sku": None, + "load_balancer_managed_outbound_ip_count": None, + "load_balancer_outbound_ips": None, + "load_balancer_outbound_ip_prefixes": None, + "load_balancer_outbound_ports": None, + "load_balancer_idle_timeout": None, + "load_balancer_managed_outbound_ipv6_count": None, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + # fail on passing the wrong mc object + with self.assertRaises(CLIInternalError): + dec_1.update_load_balancer_profile(None) + + network_profile_1 = self.models.ContainerServiceNetworkProfile() mc_1 = self.models.ManagedCluster( + location="test_location", + network_profile=network_profile_1, + ) + dec_1.context.attach_mc(mc_1) + dec_mc_1 = dec_1.update_load_balancer_profile(mc_1) + + ground_truth_network_profile_1 = ( + self.models.ContainerServiceNetworkProfile() + ) + ground_truth_mc_1 = self.models.ManagedCluster( + location="test_location", + network_profile=ground_truth_network_profile_1, + ) + self.assertEqual(dec_mc_1, ground_truth_mc_1) + + # custom value + dec_2 = AKSPreviewUpdateDecorator( + self.cmd, + self.client, + { + "load_balancer_managed_outbound_ipv6_count": 4, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + mc_2 = self.models.ManagedCluster( location="test_location", network_profile=self.models.ContainerServiceNetworkProfile( load_balancer_profile=self.models.lb_models.get( - 'ManagedClusterLoadBalancerProfile' + "ManagedClusterLoadBalancerProfile" )( managed_outbound_i_ps=self.models.lb_models.get( "ManagedClusterLoadBalancerProfileManagedOutboundIPs" @@ -2292,26 +2385,18 @@ def test_update_ipv6_count(self): count_ipv6=7, ) ) - ) + ), ) - self.client.get = Mock(return_value=mc_1) - dec_1 = AKSPreviewUpdateDecorator( - self.cmd, - self.client, - { - "name": "test_cluster", - "resource_group_name": "test_rg_name", - "load_balancer_managed_outbound_ipv6_count": 4, - }, - CUSTOM_MGMT_AKS_PREVIEW, + dec_2.context.attach_mc(mc_2) + dec_mc_2 = dec_2.update_load_balancer_profile(mc_2) + self.assertEqual( + dec_mc_2.network_profile.load_balancer_profile.managed_outbound_i_ps.count, + 3, + ) + self.assertEqual( + dec_mc_2.network_profile.load_balancer_profile.managed_outbound_i_ps.count_ipv6, + 4, ) - - mc = dec_1.fetch_mc() - mc = dec_1.update_load_balancer_profile(mc) - self.assertEquals( - mc.network_profile.load_balancer_profile.managed_outbound_i_ps.count, 3) - self.assertEquals( - mc.network_profile.load_balancer_profile.managed_outbound_i_ps.count_ipv6, 4) def test_update_pod_security_policy(self): # default value in `aks_update` @@ -2384,3 +2469,93 @@ def test_update_pod_security_policy(self): enable_pod_security_policy=False, ) self.assertEqual(dec_mc_3, ground_truth_mc_3) + + def test_update_windows_profile(self): + # default value in `aks_update` + dec_1 = AKSPreviewUpdateDecorator( + self.cmd, + self.client, + { + "enable_ahub": False, + "disable_ahub": False, + "windows_admin_password": None, + "enable_windows_gmsa": False, + "gmsa_dns_server": None, + "gmsa_root_domain_name": None, + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + # fail on passing the wrong mc object + with self.assertRaises(CLIInternalError): + dec_1.update_windows_profile(None) + + mc_1 = self.models.ManagedCluster( + location="test_location", + ) + dec_1.context.attach_mc(mc_1) + dec_mc_1 = dec_1.update_windows_profile(mc_1) + ground_truth_mc_1 = self.models.ManagedCluster( + location="test_location", + ) + self.assertEqual(dec_mc_1, ground_truth_mc_1) + + # custom value + dec_2 = AKSPreviewUpdateDecorator( + self.cmd, + self.client, + { + "enable_windows_gmsa": True, + "gmsa_dns_server": "test_gmsa_dns_server", + "gmsa_root_domain_name": "test_gmsa_root_domain_name", + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + windows_profile_2 = self.models.ManagedClusterWindowsProfile( + # [SuppressMessage("Microsoft.Security", "CS002:SecretInNextLine", Justification="fake secrets in unit test")] + admin_username="test_win_admin_name", + admin_password="test_win_admin_password", + license_type="Windows_Server", + ) + mc_2 = self.models.ManagedCluster( + location="test_location", + windows_profile=windows_profile_2, + ) + dec_2.context.attach_mc(mc_2) + dec_mc_2 = dec_2.update_windows_profile(mc_2) + + ground_truth_gmsa_profile_2 = self.models.WindowsGmsaProfile( + enabled=True, + dns_server="test_gmsa_dns_server", + root_domain_name="test_gmsa_root_domain_name", + ) + ground_truth_windows_profile_2 = self.models.ManagedClusterWindowsProfile( + # [SuppressMessage("Microsoft.Security", "CS002:SecretInNextLine", Justification="fake secrets in unit test")] + admin_username="test_win_admin_name", + admin_password="test_win_admin_password", + license_type="Windows_Server", + gmsa_profile=ground_truth_gmsa_profile_2, + ) + ground_truth_mc_2 = self.models.ManagedCluster( + location="test_location", + windows_profile=ground_truth_windows_profile_2, + ) + self.assertEqual(dec_mc_2, ground_truth_mc_2) + + # custom value + dec_3 = AKSPreviewUpdateDecorator( + self.cmd, + self.client, + { + "enable_windows_gmsa": True, + "gmsa_dns_server": "test_gmsa_dns_server", + "gmsa_root_domain_name": "test_gmsa_root_domain_name", + }, + CUSTOM_MGMT_AKS_PREVIEW, + ) + mc_3 = self.models.ManagedCluster( + location="test_location", + ) + dec_3.context.attach_mc(mc_3) + # fail on incomplete mc object (no windows profile) + with self.assertRaises(UnknownError): + dec_3.update_windows_profile(mc_3)