diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 2e76da4ecab..db20b3bec3a 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -89,6 +89,7 @@ def __init__(self): self.inferenceRouterServiceType = 'inferenceRouterServiceType' self.internalLoadBalancerProvider = 'internalLoadBalancerProvider' self.inferenceRouterHA = 'inferenceRouterHA' + self.clusterPurpose = 'clusterPurpose' # constants for existing AKS to AMLARC migration self.IS_AKS_MIGRATION = 'isAKSMigration' @@ -133,15 +134,28 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t cluster_resource_id, parent_api_version) cluster_location = resource.location.lower() try: + isSmallScale = False if cluster_type.lower() == 'connectedclusters': if resource.properties['totalNodeCount'] < 3: - configuration_settings['clusterPurpose'] = 'DevTest' + isSmallScale = True if cluster_type.lower() == 'managedclusters': nodeCount = 0 for agent in resource.properties['agentPoolProfiles']: nodeCount += agent['count'] if nodeCount < 3: - configuration_settings['clusterPurpose'] = 'DevTest' + isSmallScale = True + + if isSmallScale: + clusterPurpose = _get_value_from_config_protected_config( + self.clusterPurpose, configuration_settings, configuration_protected_settings) + if clusterPurpose is None: + configuration_settings[self.clusterPurpose] = 'DevTest' + + inferenceRouterHA = _get_value_from_config_protected_config( + self.inferenceRouterHA, configuration_settings, configuration_protected_settings) + if inferenceRouterHA is None: + configuration_settings[self.inferenceRouterHA] = 'false' + if resource.properties.get('distribution', '').lower() == self.OPEN_SHIFT: configuration_settings[self.OPEN_SHIFT] = 'true' except: @@ -362,16 +376,6 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers configuration_protected_settings=configuration_protected_settings) def __normalize_config(self, configuration_settings, configuration_protected_settings): - # inference - inferenceRouterHA = _get_value_from_config_protected_config( - self.inferenceRouterHA, configuration_settings, configuration_protected_settings) - if inferenceRouterHA is not None: - isTestCluster = str(inferenceRouterHA).lower() == 'false' - if isTestCluster: - configuration_settings['clusterPurpose'] = 'DevTest' - else: - configuration_settings['clusterPurpose'] = 'FastProd' - inferenceRouterServiceType = _get_value_from_config_protected_config( self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) if inferenceRouterServiceType: @@ -428,13 +432,6 @@ def __validate_config(self, configuration_settings, configuration_protected_sett configuration_protected_settings.pop(self.ENABLE_INFERENCE, None) def __validate_scoring_fe_settings(self, configuration_settings, configuration_protected_settings, release_namespace): - inferenceRouterHA = _get_value_from_config_protected_config( - self.inferenceRouterHA, configuration_settings, configuration_protected_settings) - isTestCluster = True if inferenceRouterHA is not None and str(inferenceRouterHA).lower() == 'false' else False - if isTestCluster: - configuration_settings['clusterPurpose'] = 'DevTest' - else: - configuration_settings['clusterPurpose'] = 'FastProd' isAKSMigration = _get_value_from_config_protected_config( self.IS_AKS_MIGRATION, configuration_settings, configuration_protected_settings) isAKSMigration = str(isAKSMigration).lower() == 'true'