Skip to content
This repository was archived by the owner on May 13, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/k8s-extension/azext_k8s_extension/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def update_k8s_extension(cmd, client, resource_group_name, cluster_name, name, c
extension_class = ExtensionFactory(extension_type_lower)

upd_extension = extension_class.Update(cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version,
config_settings, config_protected_settings)
config_settings, config_protected_settings, yes)

return sdk_no_wait(no_wait, client.begin_update, resource_group_name, cluster_rp, cluster_type,
cluster_name, name, upd_extension)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(self):
# constants for existing AKS to AMLARC migration
self.IS_AKS_MIGRATION = 'isAKSMigration'

# constants for others in Spec
self.installNvidiaDevicePlugin = 'installNvidiaDevicePlugin'

# reference mapping
self.reference_mapping = {
self.RELAY_SERVER_CONNECTION_STRING: [self.RELAY_CONNECTION_STRING_KEY],
Expand Down Expand Up @@ -165,9 +168,83 @@ def Delete(self, cmd, client, resource_group_name, cluster_name, name, cluster_t
user_confirmation_factory(cmd, yes)

def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings,
configuration_protected_settings):
configuration_protected_settings, yes=False):
self.__normalize_config(configuration_settings, configuration_protected_settings)

# Prompt message to ask customer to confirm again
if len(configuration_settings) > 0:
impactScenario = ""
messageBody = ""
disableTraining = False
disableInference = False
disableNvidiaDevicePlugin = False
hasAllowInsecureConnections = False
hasPrivateEndpointNodeport = False
hasPrivateEndpointILB = False
hasNodeSelector = False
enableLogAnalyticsWS = False

enableTraining = _get_value_from_config_protected_config(self.ENABLE_TRAINING, configuration_settings, configuration_protected_settings)
if enableTraining is not None:
disableTraining = str(enableTraining).lower() == 'false'
if disableTraining:
messageBody = messageBody + "enableTraining from True to False,\n"

enableInference = _get_value_from_config_protected_config(self.ENABLE_INFERENCE, configuration_settings, configuration_protected_settings)
if enableInference is not None:
disableInference = str(enableInference).lower() == 'false'
if disableInference:
messageBody = messageBody + "enableInference from True to False,\n"

installNvidiaDevicePlugin = _get_value_from_config_protected_config(self.installNvidiaDevicePlugin, configuration_settings, configuration_protected_settings)
if installNvidiaDevicePlugin is not None:
disableNvidiaDevicePlugin = str(installNvidiaDevicePlugin).lower() == 'false'
if disableNvidiaDevicePlugin:
messageBody = messageBody + "installNvidiaDevicePlugin from True to False if Nvidia GPU is used,\n"

allowInsecureConnections = _get_value_from_config_protected_config(self.allowInsecureConnections, configuration_settings, configuration_protected_settings)
if allowInsecureConnections is not None:
hasAllowInsecureConnections = True
messageBody = messageBody + "allowInsecureConnections\n"

privateEndpointNodeport = _get_value_from_config_protected_config(self.privateEndpointNodeport, configuration_settings, configuration_protected_settings)
if privateEndpointNodeport is not None:
hasPrivateEndpointNodeport = True
messageBody = messageBody + "privateEndpointNodeport\n"

privateEndpointILB = _get_value_from_config_protected_config(self.privateEndpointILB, configuration_settings, configuration_protected_settings)
if privateEndpointILB is not None:
hasPrivateEndpointILB = True
messageBody = messageBody + "privateEndpointILB\n"

hasNodeSelector = _check_nodeselector_existed(configuration_settings, configuration_protected_settings)
if hasNodeSelector:
messageBody = messageBody + "nodeSelector. Update operation can't remove an existed node selector, but can update or add new ones.\n"

logAnalyticsWS = _get_value_from_config_protected_config(self.LOG_ANALYTICS_WS_ENABLED, configuration_settings, configuration_protected_settings)
if logAnalyticsWS is not None:
enableLogAnalyticsWS = str(logAnalyticsWS).lower() == 'true'
if enableLogAnalyticsWS:
messageBody = messageBody + "To update logAnalyticsWS from False to True, please provide all original configurationProtectedSettings. Otherwise, those settings would be considered obsolete and deleted.\n"

if disableTraining or disableNvidiaDevicePlugin or hasNodeSelector:
impactScenario = "jobs"

if disableInference or disableNvidiaDevicePlugin or hasAllowInsecureConnections or hasPrivateEndpointNodeport or hasPrivateEndpointILB or hasNodeSelector:
if impactScenario == "":
impactScenario = "online endpoints and deployments"
else:
impactScenario = impactScenario + ", online endpoints and deployments"

if impactScenario != "":
message = ("\nThe following configuration update will IMPACT your active Machine Learning " + impactScenario +
". It will be the safe update if the cluster doesn't have active Machine Learning " + impactScenario + ".\n\n" + messageBody + "\nProceed?")
user_confirmation_factory(cmd, yes, message=message)
else:
if enableLogAnalyticsWS:
message = "\n" + messageBody + "\nProceed?"
user_confirmation_factory(cmd, yes, message=message)

if len(configuration_protected_settings) > 0:
subscription_id = get_subscription_id(cmd.cli_ctx)

Expand Down Expand Up @@ -558,3 +635,14 @@ def _get_cluster_rp_api_version(cluster_type) -> Tuple[str, str]:
else:
raise InvalidArgumentValueError("Error! Cluster type '{}' is not supported".format(cluster_type))
return rp, parent_api_version


def _check_nodeselector_existed(configuration_settings, configuration_protected_settings):
config_keys = configuration_settings.keys()
config_protected_keys = configuration_protected_settings.keys()
all_keys = set(config_keys) | set(config_protected_keys)
if all_keys:
for key in all_keys:
if "nodeSelector" in key:
return True
return False
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t
return extension, name, create_identity

def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_version, release_train, version, configuration_settings,
configuration_protected_settings):
configuration_protected_settings, yes=False):
"""Default validations & defaults for Update
Must create and return a valid 'PatchExtension' object.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def Create(self, cmd, client, resource_group_name: str, cluster_name: str, name:

@abstractmethod
def Update(self, cmd, resource_group_name: str, cluster_name: str, auto_upgrade_minor_version: bool, release_train: str, version: str,
configuration_settings: dict, configuration_protected_settings: dict) -> PatchExtension:
configuration_settings: dict, configuration_protected_settings: dict, yes: bool) -> PatchExtension:
pass

@abstractmethod
Expand Down