diff --git a/src/k8s-extension/azext_k8s_extension/custom.py b/src/k8s-extension/azext_k8s_extension/custom.py index d3fd501a0f2..c3720078056 100644 --- a/src/k8s-extension/azext_k8s_extension/custom.py +++ b/src/k8s-extension/azext_k8s_extension/custom.py @@ -204,7 +204,7 @@ def update_k8s_extension(cmd, client, resource_group_name, cluster_name, name, c extension_class = ExtensionFactory(extension_type_lower) upd_extension = extension_class.Update(cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, - config_settings, config_protected_settings) + config_settings, config_protected_settings, yes) return sdk_no_wait(no_wait, client.begin_update, resource_group_name, cluster_rp, cluster_type, cluster_name, name, upd_extension) diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 706973d9a13..4f4fb43fe15 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -84,6 +84,9 @@ def __init__(self): # constants for existing AKS to AMLARC migration self.IS_AKS_MIGRATION = 'isAKSMigration' + # constants for others in Spec + self.installNvidiaDevicePlugin = 'installNvidiaDevicePlugin' + # reference mapping self.reference_mapping = { self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY], @@ -165,9 +168,83 @@ def Delete(self, cmd, client, resource_group_name, cluster_name, name, cluster_t user_confirmation_factory(cmd, yes) def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings, - configuration_protected_settings): + configuration_protected_settings, yes=False): self.__normalize_config(configuration_settings, configuration_protected_settings) + # Prompt message to ask customer to confirm again + if len(configuration_settings) > 0: + impactScenario = "" + messageBody = "" + disableTraining = False + disableInference = False + disableNvidiaDevicePlugin = False + hasAllowInsecureConnections = False + hasPrivateEndpointNodeport = False + hasPrivateEndpointILB = False + hasNodeSelector = False + enableLogAnalyticsWS = False + + enableTraining = _get_value_from_config_protected_config(self.ENABLE_TRAINING, configuration_settings, configuration_protected_settings) + if enableTraining is not None: + disableTraining = str(enableTraining).lower() == 'false' + if disableTraining: + messageBody = messageBody + "enableTraining from True to False,\n" + + enableInference = _get_value_from_config_protected_config(self.ENABLE_INFERENCE, configuration_settings, configuration_protected_settings) + if enableInference is not None: + disableInference = str(enableInference).lower() == 'false' + if disableInference: + messageBody = messageBody + "enableInference from True to False,\n" + + installNvidiaDevicePlugin = _get_value_from_config_protected_config(self.installNvidiaDevicePlugin, configuration_settings, configuration_protected_settings) + if installNvidiaDevicePlugin is not None: + disableNvidiaDevicePlugin = str(installNvidiaDevicePlugin).lower() == 'false' + if disableNvidiaDevicePlugin: + messageBody = messageBody + "installNvidiaDevicePlugin from True to False if Nvidia GPU is used,\n" + + allowInsecureConnections = _get_value_from_config_protected_config(self.allowInsecureConnections, configuration_settings, configuration_protected_settings) + if allowInsecureConnections is not None: + hasAllowInsecureConnections = True + messageBody = messageBody + "allowInsecureConnections\n" + + privateEndpointNodeport = _get_value_from_config_protected_config(self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) + if privateEndpointNodeport is not None: + hasPrivateEndpointNodeport = True + messageBody = messageBody + "privateEndpointNodeport\n" + + privateEndpointILB = _get_value_from_config_protected_config(self.privateEndpointILB, configuration_settings, configuration_protected_settings) + if privateEndpointILB is not None: + hasPrivateEndpointILB = True + messageBody = messageBody + "privateEndpointILB\n" + + hasNodeSelector = _check_nodeselector_existed(configuration_settings, configuration_protected_settings) + if hasNodeSelector: + messageBody = messageBody + "nodeSelector. Update operation can't remove an existed node selector, but can update or add new ones.\n" + + logAnalyticsWS = _get_value_from_config_protected_config(self.LOG_ANALYTICS_WS_ENABLED, configuration_settings, configuration_protected_settings) + if logAnalyticsWS is not None: + enableLogAnalyticsWS = str(logAnalyticsWS).lower() == 'true' + if enableLogAnalyticsWS: + messageBody = messageBody + "To update logAnalyticsWS from False to True, please provide all original configurationProtectedSettings. Otherwise, those settings would be considered obsolete and deleted.\n" + + if disableTraining or disableNvidiaDevicePlugin or hasNodeSelector: + impactScenario = "jobs" + + if disableInference or disableNvidiaDevicePlugin or hasAllowInsecureConnections or hasPrivateEndpointNodeport or hasPrivateEndpointILB or hasNodeSelector: + if impactScenario == "": + impactScenario = "online endpoints and deployments" + else: + impactScenario = impactScenario + ", online endpoints and deployments" + + if impactScenario != "": + message = ("\nThe following configuration update will IMPACT your active Machine Learning " + impactScenario + + ". It will be the safe update if the cluster doesn't have active Machine Learning " + impactScenario + ".\n\n" + messageBody + "\nProceed?") + user_confirmation_factory(cmd, yes, message=message) + else: + if enableLogAnalyticsWS: + message = "\n" + messageBody + "\nProceed?" + user_confirmation_factory(cmd, yes, message=message) + if len(configuration_protected_settings) > 0: subscription_id = get_subscription_id(cmd.cli_ctx) @@ -558,3 +635,14 @@ def _get_cluster_rp_api_version(cluster_type) -> Tuple[str, str]: else: raise InvalidArgumentValueError("Error! Cluster type '{}' is not supported".format(cluster_type)) return rp, parent_api_version + + +def _check_nodeselector_existed(configuration_settings, configuration_protected_settings): + config_keys = configuration_settings.keys() + config_protected_keys = configuration_protected_settings.keys() + all_keys = set(config_keys) | set(config_protected_keys) + if all_keys: + for key in all_keys: + if "nodeSelector" in key: + return True + return False diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py index 5b76e500635..8289e931336 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/DefaultExtension.py @@ -46,7 +46,7 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t return extension, name, create_identity def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings, - configuration_protected_settings): + configuration_protected_settings, yes=False): """Default validations & defaults for Update Must create and return a valid 'PatchExtension' object. """ diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py index 33c8f683591..e1f37ec3a95 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/PartnerExtensionModel.py @@ -19,7 +19,7 @@ def Create(self, cmd, client, resource_group_name: str, cluster_name: str, name: @abstractmethod def Update(self, cmd, resource_group_name: str, cluster_name: str, auto_upgrade_minor_version: bool, release_train: str, version: str, - configuration_settings: dict, configuration_protected_settings: dict) -> PatchExtension: + configuration_settings: dict, configuration_protected_settings: dict, yes: bool) -> PatchExtension: pass @abstractmethod