Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spring/azext_spring/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
MARKETPLACE_OFFER_ID = 'azure-spring-cloud-vmware-tanzu-2'
MARKETPLACE_PUBLISHER_ID = 'vmware-inc'
MARKETPLACE_PLAN_ID = 'asa-ent-hr-mtr'
AKS_RP = 'Microsoft.ContainerService'
53 changes: 52 additions & 1 deletion src/spring/azext_spring/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
import json
from datetime import datetime
from enum import Enum
import os
from time import sleep
Expand All @@ -14,9 +15,12 @@
from io import open
from re import (search, match, compile)
from json import dumps

from azure.cli.core.commands.client_factory import get_subscription_id, get_mgmt_service_client
from azure.cli.core.profiles import ResourceType
from knack.util import CLIError, todict
from knack.log import get_logger
from azure.cli.core.azclierror import ValidationError
from azure.cli.core.azclierror import ValidationError, CLIInternalError
from .vendored_sdks.appplatform.v2023_09_01_preview.models._app_platform_management_client_enums import SupportedRuntimeValue
from ._client_factory import cf_resource_groups

Expand Down Expand Up @@ -316,6 +320,53 @@ def handle_asc_exception(ex):
raise CLIError(ex)


def register_provider_if_needed(cmd, rp_name):
if not _is_resource_provider_registered(cmd, rp_name):
_register_resource_provider(cmd, rp_name)


def _is_resource_provider_registered(cmd, resource_provider, subscription_id=None):
registered = None
if not subscription_id:
subscription_id = get_subscription_id(cmd.cli_ctx)
try:
providers_client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES, subscription_id=subscription_id).providers
registration_state = getattr(providers_client.get(resource_provider), 'registration_state', "NotRegistered")

registered = (registration_state and registration_state.lower() == 'registered')
except Exception: # pylint: disable=broad-except
pass
return registered


def _register_resource_provider(cmd, resource_provider):
from azure.mgmt.resource.resources.models import ProviderRegistrationRequest, ProviderConsentDefinition

logger.warning(f"Registering resource provider {resource_provider} ...")
properties = ProviderRegistrationRequest(third_party_provider_consent=ProviderConsentDefinition(consent_to_authorization=True))

client = get_mgmt_service_client(cmd.cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES).providers
try:
client.register(resource_provider, properties=properties)
# wait for registration to finish
timeout_secs = 120
registration = _is_resource_provider_registered(cmd, resource_provider)
start = datetime.utcnow()
while not registration:
registration = _is_resource_provider_registered(cmd, resource_provider)
sleep(3)
if (datetime.utcnow() - start).seconds >= timeout_secs:
raise CLIInternalError(f"Timed out while waiting for the {resource_provider} resource provider to be registered.")

except Exception as e:
msg = ("This operation requires registering the resource provider {0}. "
"We were unable to perform that registration on your behalf: "
"Server responded with error message -- {1} . "
"Please check with your admin on permissions, "
"or try running registration manually with: az provider register --wait --namespace {0}")
raise ValidationError(resource_provider, msg.format(e.args)) from e


class BearerAuth(requests.auth.AuthBase):
def __init__(self, token):
self.token = token
Expand Down
7 changes: 5 additions & 2 deletions src/spring/azext_spring/spring_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# pylint: disable=wrong-import-order
# pylint: disable=unused-argument, logging-format-interpolation, protected-access, wrong-import-order, too-many-lines
from ._utils import (wait_till_end, _get_rg_location)
from ._utils import (wait_till_end, _get_rg_location, register_provider_if_needed)
from .vendored_sdks.appplatform.v2023_09_01_preview import models
from .custom import (_warn_enable_java_agent, _update_application_insights_asc_create)
from ._build_service import _update_default_build_agent_pool, create_build_service
Expand All @@ -23,7 +23,7 @@
from azure.cli.core.commands import LongRunningOperation
from knack.log import get_logger
from ._marketplace import _spring_list_marketplace_plan
from ._constant import (MARKETPLACE_OFFER_ID, MARKETPLACE_PUBLISHER_ID)
from ._constant import (MARKETPLACE_OFFER_ID, MARKETPLACE_PUBLISHER_ID, AKS_RP)

logger = get_logger(__name__)

Expand Down Expand Up @@ -239,6 +239,9 @@ def spring_create(cmd, client, resource_group, name,
'no_wait': no_wait
}

if vnet:
register_provider_if_needed(cmd, AKS_RP)

spring_factory = _get_factory(cmd, client, resource_group, name, location=location, sku=sku)
return spring_factory.create(**kwargs)

Expand Down