diff --git a/src/spring/azext_spring/_constant.py b/src/spring/azext_spring/_constant.py index a9635e05d97..ff717f4f322 100644 --- a/src/spring/azext_spring/_constant.py +++ b/src/spring/azext_spring/_constant.py @@ -9,3 +9,4 @@ MARKETPLACE_OFFER_ID = 'azure-spring-cloud-vmware-tanzu-2' MARKETPLACE_PUBLISHER_ID = 'vmware-inc' MARKETPLACE_PLAN_ID = 'asa-ent-hr-mtr' +AKS_RP = 'Microsoft.ContainerService' diff --git a/src/spring/azext_spring/_utils.py b/src/spring/azext_spring/_utils.py index 7d3a355712b..a4f746985fa 100644 --- a/src/spring/azext_spring/_utils.py +++ b/src/spring/azext_spring/_utils.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import json +from datetime import datetime from enum import Enum import os from time import sleep @@ -14,9 +15,12 @@ from io import open from re import (search, match, compile) from json import dumps + +from azure.cli.core.commands.client_factory import get_subscription_id, get_mgmt_service_client +from azure.cli.core.profiles import ResourceType from knack.util import CLIError, todict from knack.log import get_logger -from azure.cli.core.azclierror import ValidationError +from azure.cli.core.azclierror import ValidationError, CLIInternalError from .vendored_sdks.appplatform.v2023_09_01_preview.models._app_platform_management_client_enums import SupportedRuntimeValue from ._client_factory import cf_resource_groups @@ -316,6 +320,53 @@ def handle_asc_exception(ex): raise CLIError(ex) +def register_provider_if_needed(cmd, rp_name): + if not _is_resource_provider_registered(cmd, rp_name): + _register_resource_provider(cmd, rp_name) + + +def _is_resource_provider_registered(cmd, resource_provider, subscription_id=None): + registered = None + if not subscription_id: + subscription_id = get_subscription_id(cmd.cli_ctx) + try: + providers_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, subscription_id=subscription_id).providers + registration_state = getattr(providers_client.get(resource_provider), 'registration_state', "NotRegistered") + + registered = (registration_state and registration_state.lower() == 'registered') + except Exception: # pylint: disable=broad-except + pass + return registered + + +def _register_resource_provider(cmd, resource_provider): + from azure.mgmt.resource.resources.models import ProviderRegistrationRequest, ProviderConsentDefinition + + logger.warning(f"Registering resource provider {resource_provider} ...") + properties = ProviderRegistrationRequest(third_party_provider_consent=ProviderConsentDefinition(consent_to_authorization=True)) + + client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).providers + try: + client.register(resource_provider, properties=properties) + # wait for registration to finish + timeout_secs = 120 + registration = _is_resource_provider_registered(cmd, resource_provider) + start = datetime.utcnow() + while not registration: + registration = _is_resource_provider_registered(cmd, resource_provider) + sleep(3) + if (datetime.utcnow() - start).seconds >= timeout_secs: + raise CLIInternalError(f"Timed out while waiting for the {resource_provider} resource provider to be registered.") + + except Exception as e: + msg = ("This operation requires registering the resource provider {0}. " + "We were unable to perform that registration on your behalf: " + "Server responded with error message -- {1} . " + "Please check with your admin on permissions, " + "or try running registration manually with: az provider register --wait --namespace {0}") + raise ValidationError(resource_provider, msg.format(e.args)) from e + + class BearerAuth(requests.auth.AuthBase): def __init__(self, token): self.token = token diff --git a/src/spring/azext_spring/spring_instance.py b/src/spring/azext_spring/spring_instance.py index 3665c49c124..0009b1fb8fe 100644 --- a/src/spring/azext_spring/spring_instance.py +++ b/src/spring/azext_spring/spring_instance.py @@ -5,7 +5,7 @@ # pylint: disable=wrong-import-order # pylint: disable=unused-argument, logging-format-interpolation, protected-access, wrong-import-order, too-many-lines -from ._utils import (wait_till_end, _get_rg_location) +from ._utils import (wait_till_end, _get_rg_location, register_provider_if_needed) from .vendored_sdks.appplatform.v2023_09_01_preview import models from .custom import (_warn_enable_java_agent, _update_application_insights_asc_create) from ._build_service import _update_default_build_agent_pool, create_build_service @@ -23,7 +23,7 @@ from azure.cli.core.commands import LongRunningOperation from knack.log import get_logger from ._marketplace import _spring_list_marketplace_plan -from ._constant import (MARKETPLACE_OFFER_ID, MARKETPLACE_PUBLISHER_ID) +from ._constant import (MARKETPLACE_OFFER_ID, MARKETPLACE_PUBLISHER_ID, AKS_RP) logger = get_logger(__name__) @@ -239,6 +239,9 @@ def spring_create(cmd, client, resource_group, name, 'no_wait': no_wait } + if vnet: + register_provider_if_needed(cmd, AKS_RP) + spring_factory = _get_factory(cmd, client, resource_group, name, location=location, sku=sku) return spring_factory.create(**kwargs)