diff --git a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py index 6334877f59e..ac71251ecbd 100644 --- a/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py +++ b/src/k8s-extension/azext_k8s_extension/partner_extensions/AzureMLKubernetes.py @@ -78,10 +78,12 @@ def __init__(self): self.sslKeyPemFile = 'sslKeyPemFile' self.sslCertPemFile = 'sslCertPemFile' self.allowInsecureConnections = 'allowInsecureConnections' - self.privateEndpointILB = 'privateEndpointILB' - self.privateEndpointNodeport = 'privateEndpointNodeport' - self.inferenceLoadBalancerHA = 'inferenceLoadBalancerHA' self.SSL_SECRET = 'sslSecret' + self.SSL_Cname = 'sslCname' + + self.inferenceRouterServiceType = 'inferenceRouterServiceType' + self.internalLoadBalancerProvider = 'internalLoadBalancerProvider' + self.inferenceLoadBalancerHA = 'inferenceLoadBalancerHA' # constants for existing AKS to AMLARC migration self.IS_AKS_MIGRATION = 'isAKSMigration' @@ -96,12 +98,14 @@ def __init__(self): 'cluster_name': ['clusterId', 'prometheus.prometheusSpec.externalLabels.cluster_name'], } + self.OPEN_SHIFT = 'openshift' + def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_type, extension_type, scope, auto_upgrade_minor_version, release_train, version, target_namespace, release_namespace, configuration_settings, configuration_protected_settings, configuration_settings_file, configuration_protected_settings_file): if scope == 'namespace': - raise InvalidArgumentValueError("Invalid scope '{}'. This extension can be installed " + raise InvalidArgumentValueError("Invalid scope '{}'. This extension can't be installed " "only at 'cluster' scope.".format(scope)) if not release_namespace: release_namespace = self.DEFAULT_RELEASE_NAMESPACE @@ -122,6 +126,10 @@ def Create(self, cmd, client, resource_group_name, cluster_name, name, cluster_t resource = resources.get_by_id( cluster_resource_id, parent_api_version) cluster_location = resource.location.lower() + if resource.properties['totalNodeCount'] == 1 or resource.properties['totalNodeCount'] == 2: + configuration_settings['clusterPurpose'] = 'DevTest' + if resource.properties['distribution'].lower() == "openshift": + configuration_settings[self.OPEN_SHIFT] = "true" except CloudError as ex: raise ex @@ -181,8 +189,9 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers disableInference = False disableNvidiaDevicePlugin = False hasAllowInsecureConnections = False - hasPrivateEndpointNodeport = False - hasPrivateEndpointILB = False + hasInferenceRouterServiceType = False + hasInternalLoadBalancerProvider = False + hasSslCname = False hasNodeSelector = False enableLogAnalyticsWS = False @@ -209,15 +218,20 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers hasAllowInsecureConnections = True messageBody = messageBody + "allowInsecureConnections\n" - privateEndpointNodeport = _get_value_from_config_protected_config(self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) - if privateEndpointNodeport is not None: - hasPrivateEndpointNodeport = True - messageBody = messageBody + "privateEndpointNodeport\n" + inferenceRouterServiceType = _get_value_from_config_protected_config(self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) + if inferenceRouterServiceType is not None: + hasInferenceRouterServiceType = True + messageBody = messageBody + "inferenceRouterServiceType\n" - privateEndpointILB = _get_value_from_config_protected_config(self.privateEndpointILB, configuration_settings, configuration_protected_settings) - if privateEndpointILB is not None: - hasPrivateEndpointILB = True - messageBody = messageBody + "privateEndpointILB\n" + internalLoadBalancerProvider = _get_value_from_config_protected_config(self.internalLoadBalancerProvider, configuration_settings, configuration_protected_settings) + if internalLoadBalancerProvider is not None: + hasInternalLoadBalancerProvider = True + messageBody = messageBody + "internalLoadBalancerProvider\n" + + sslCname = _get_value_from_config_protected_config(self.SSL_Cname, configuration_settings, configuration_protected_settings) + if sslCname is not None: + hasSslCname = True + messageBody = messageBody + "sslCname\n" hasNodeSelector = _check_nodeselector_existed(configuration_settings, configuration_protected_settings) if hasNodeSelector: @@ -232,7 +246,7 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers if disableTraining or disableNvidiaDevicePlugin or hasNodeSelector: impactScenario = "jobs" - if disableInference or disableNvidiaDevicePlugin or hasAllowInsecureConnections or hasPrivateEndpointNodeport or hasPrivateEndpointILB or hasNodeSelector: + if disableInference or disableNvidiaDevicePlugin or hasAllowInsecureConnections or hasInferenceRouterServiceType or hasInternalLoadBalancerProvider or hasNodeSelector or hasSslCname: if impactScenario == "": impactScenario = "online endpoints and deployments" else: @@ -286,7 +300,11 @@ def Update(self, cmd, resource_group_name, cluster_name, auto_upgrade_minor_vers if self.sslKeyPemFile in configuration_protected_settings and \ self.sslCertPemFile in configuration_protected_settings: logger.info(f"Both {self.sslKeyPemFile} and {self.sslCertPemFile} are set, update ssl key.") - self.__set_inference_ssl_from_file(configuration_protected_settings, self.sslCertPemFile, self.sslKeyPemFile) + fe_ssl_cert_file = configuration_protected_settings.get(self.sslCertPemFile) + fe_ssl_key_file = configuration_protected_settings.get(self.sslKeyPemFile) + + if fe_ssl_cert_file and fe_ssl_key_file: + self.__set_inference_ssl_from_file(configuration_protected_settings, fe_ssl_cert_file, fe_ssl_key_file) return PatchExtension(auto_upgrade_minor_version=auto_upgrade_minor_version, release_train=release_train, @@ -305,16 +323,21 @@ def __normalize_config(self, configuration_settings, configuration_protected_set else: configuration_settings['clusterPurpose'] = 'FastProd' - feIsNodePort = _get_value_from_config_protected_config( - self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) - if feIsNodePort is not None: - feIsNodePort = str(feIsNodePort).lower() == 'true' + inferenceRouterServiceType = _get_value_from_config_protected_config( + self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) + if inferenceRouterServiceType: + if inferenceRouterServiceType.lower() != 'nodeport' and inferenceRouterServiceType.lower() != 'loadbalancer': + raise InvalidArgumentValueError( + "inferenceRouterServiceType only supports nodePort or loadBalancer." + "Check https://aka.ms/arcmltsg for more information.") + + feIsNodePort = str(inferenceRouterServiceType).lower() == 'nodeport' configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort - feIsInternalLoadBalancer = _get_value_from_config_protected_config( - self.privateEndpointILB, configuration_settings, configuration_protected_settings) - if feIsInternalLoadBalancer is not None: - feIsInternalLoadBalancer = str(feIsInternalLoadBalancer).lower() == 'true' + internalLoadBalancerProvider = _get_value_from_config_protected_config( + self.internalLoadBalancerProvider, configuration_settings, configuration_protected_settings) + if internalLoadBalancerProvider: + feIsInternalLoadBalancer = str(internalLoadBalancerProvider).lower() == 'azure' configuration_settings['scoringFe.serviceType.internalLoadBalancer'] = feIsInternalLoadBalancer logger.warning( 'Internal load balancer only supported on AKS and AKS Engine Clusters.') @@ -345,7 +368,8 @@ def __validate_config(self, configuration_settings, configuration_protected_sett raise InvalidArgumentValueError( "To create Microsoft.AzureML.Kubernetes extension, either " "enable Machine Learning training or inference by specifying " - f"'--configuration-settings {self.ENABLE_TRAINING}=true' or '--configuration-settings {self.ENABLE_INFERENCE}=true'") + f"'--configuration-settings {self.ENABLE_TRAINING}=true' or '--configuration-settings {self.ENABLE_INFERENCE}=true'." + "Please check https://aka.ms/arcmltsg for more information.") configuration_settings[self.ENABLE_TRAINING] = configuration_settings.get(self.ENABLE_TRAINING, enable_training) configuration_settings[self.ENABLE_INFERENCE] = configuration_settings.get( @@ -378,20 +402,34 @@ def __validate_scoring_fe_settings(self, configuration_settings, configuration_p if not sslEnabled and not allowInsecureConnections: raise InvalidArgumentValueError( "To enable HTTPs endpoint, " - "either provide sslCertPemFile and sslKeyPemFile to config protected settings, " - f"or provide sslSecret (kubernetes secret name) containing both ssl cert and ssl key under {release_namespace} namespace. " + "either provide sslCertPemFile and sslKeyPemFile to --configuration-protected-settings, " + f"or provide sslSecret(kubernetes secret name) in --configuration-settings containing both ssl cert and ssl key under {release_namespace} namespace. " "Otherwise, to enable HTTP endpoint, explicitly set allowInsecureConnections=true.") - feIsNodePort = _get_value_from_config_protected_config( - self.privateEndpointNodeport, configuration_settings, configuration_protected_settings) - feIsNodePort = str(feIsNodePort).lower() == 'true' - feIsInternalLoadBalancer = _get_value_from_config_protected_config( - self.privateEndpointILB, configuration_settings, configuration_protected_settings) - feIsInternalLoadBalancer = str(feIsInternalLoadBalancer).lower() == 'true' + if sslEnabled: + sslCname = _get_value_from_config_protected_config( + self.SSL_Cname, configuration_settings, configuration_protected_settings) + if not sslCname: + raise InvalidArgumentValueError( + "To enable HTTPs endpoint, " + "please specify sslCname parameter in --configuration-settings. Check https://aka.ms/arcmltsg for more information.") + + inferenceRouterServiceType = _get_value_from_config_protected_config( + self.inferenceRouterServiceType, configuration_settings, configuration_protected_settings) + if not inferenceRouterServiceType or (inferenceRouterServiceType.lower() != 'nodeport' and inferenceRouterServiceType.lower() != 'loadbalancer'): + raise InvalidArgumentValueError( + "To use inference, " + "please specify inferenceRouterServiceType=nodePort or inferenceRouterServiceType=loadBalancer in --configuration-settings and also set internalLoadBalancerProvider=azure if your aks only supports internal load balancer." + "Check https://aka.ms/arcmltsg for more information.") + + feIsNodePort = str(inferenceRouterServiceType).lower() == 'nodeport' + internalLoadBalancerProvider = _get_value_from_config_protected_config( + self.internalLoadBalancerProvider, configuration_settings, configuration_protected_settings) + feIsInternalLoadBalancer = str(internalLoadBalancerProvider).lower() == 'azure' if feIsNodePort and feIsInternalLoadBalancer: raise MutuallyExclusiveArgumentError( - "Specify either privateEndpointNodeport=true or privateEndpointILB=true, but not both.") + "When using nodePort as inferenceRouterServiceType, no need to specify internalLoadBalancerProvider.") if feIsNodePort: configuration_settings['scoringFe.serviceType.nodePort'] = feIsNodePort elif feIsInternalLoadBalancer: diff --git a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 index b1e4b3d1d39..1b3b6bb662d 100644 --- a/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 +++ b/testing/test/extensions/public/AzureMLKubernetes.Tests.ps1 @@ -13,7 +13,7 @@ Describe 'AzureML Kubernetes Testing' { It 'Creates the extension and checks that it onboards correctly with inference and SSL enabled' { $sslKeyPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_key.pem" $sslCertPemFile = Join-Path (Join-Path (Join-Path (Split-Path $PSScriptRoot -Parent) "data") "azure_ml") "test_cert.pem" - az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train staging --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceLoadBalancerHA=False --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait + az $Env:K8sExtensionName create -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters --extension-type $extensionType -n $extensionName --release-train staging --config enableInference=true identity.proxy.remoteEnabled=True identity.proxy.remoteHost=https://master.experiments.azureml-test.net inferenceRouterServiceType=nodePort sslCname=test.domain --config-protected sslKeyPemFile=$sslKeyPemFile sslCertPemFile=$sslCertPemFile --no-wait $? | Should -BeTrue $output = az $Env:K8sExtensionName show -c $($ENVCONFIG.arcClusterName) -g $($ENVCONFIG.resourceGroup) --cluster-type connectedClusters -n $extensionName