Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions azext_iot/digitaltwins/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
# --------------------------------------------------------------------------------------------

from azext_iot.digitaltwins.providers.resource import ResourceProvider
from azext_iot.digitaltwins.providers.auth import DigitalTwinAuthentication
from azext_iot.sdk.digitaltwins.dataplane import AzureDigitalTwinsAPI
from azext_iot.sdk.digitaltwins.dataplane.models import ErrorResponseException
from azext_iot.constants import DIGITALTWINS_RESOURCE_ID
from azext_iot.constants import DIGITALTWINS_RESOURCE_ID, USER_AGENT
from azext_iot.common.utility import valid_hostname
from knack.cli import CLIError

Expand All @@ -32,12 +31,14 @@ def _get_endpoint(self):
http_prefix = "http://"

if self.name.lower().startswith(https_prefix):
self.name = self.name[len(https_prefix):]
self.name = self.name[len(https_prefix) :]
elif self.name.lower().startswith(http_prefix):
self.name = self.name[len(http_prefix):]
self.name = self.name[len(http_prefix) :]

if not all([valid_hostname(self.name), "." in self.name]):
instance = self.rp.find_instance(name=self.name, resource_group_name=self.rg)
instance = self.rp.find_instance(
name=self.name, resource_group_name=self.rg
)
host_name = instance.host_name
if not host_name:
raise CLIError("Instance has invalid hostName. Aborting operation...")
Expand All @@ -47,5 +48,16 @@ def _get_endpoint(self):
return "https://{}".format(host_name)

def get_sdk(self):
creds = DigitalTwinAuthentication(cmd=self.cmd, resource_id=self.resource_id)
return AzureDigitalTwinsAPI(base_url=self._get_endpoint(), credentials=creds)
from azure.cli.core.commands.client_factory import get_mgmt_service_client

client = get_mgmt_service_client(
cli_ctx=self.cmd.cli_ctx,
client_or_resource_type=AzureDigitalTwinsAPI,
base_url=self._get_endpoint(),
resource=self.resource_id,
subscription_bound=False,
base_url_bound=False,
)

client.config.add_user_agent(USER_AGENT)
return client
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .operations.digital_twins_operations import DigitalTwinsOperations
from .operations.event_routes_operations import EventRoutesOperations
from . import models
from azext_iot.constants import USER_AGENT


class AzureDigitalTwinsAPIConfiguration(AzureConfiguration):
Expand All @@ -43,8 +42,6 @@ def __init__(
super(AzureDigitalTwinsAPIConfiguration, self).__init__(base_url)

self.add_user_agent('azuredigitaltwinsapi/{}'.format(VERSION))
self.add_user_agent(USER_AGENT)

self.credentials = credentials


Expand Down
34 changes: 30 additions & 4 deletions azext_iot/tests/digitaltwins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MOCK_DEAD_LETTER_ENDPOINT = "https://accountname.blob.core.windows.net/containerName"
MOCK_DEAD_LETTER_SECRET = "{}?sasToken".format(MOCK_DEAD_LETTER_ENDPOINT)
REGION_RESOURCE_LIMIT = 10
REGION_LIST = ["westus2", "westus", "eastus", "eastus2euap"]
REGION_LIST = ["westus2", "westcentralus", "eastus2", "eastus", "eastus2euap"]


def generate_resource_id():
Expand Down Expand Up @@ -135,10 +135,13 @@ def get_available_region(self, capacity: int = 1, skip_regions: list = None) ->
skip_regions = []

region_capacity = self.calculate_region_capacity
for region in region_capacity:
if region_capacity[region] + capacity <= REGION_RESOURCE_LIMIT:
if region not in skip_regions:

while region_capacity:
region = min(region_capacity, key=region_capacity.get)
if region not in skip_regions:
if region_capacity[region] + capacity <= REGION_RESOURCE_LIMIT:
return region
region_capacity.pop(region, None)

raise RuntimeError(
"There are no available regions with capacity: {} for provision DT instances in subscription: {}".format(
Expand All @@ -154,3 +157,26 @@ def tearDown(self):
self.embedded_cli.invoke(
"dt delete -n {} -g {} -y --no-wait".format(instance[0], instance[1])
)

# Needed because the DT service will indicate provisioning is finished before it actually is.
def wait_for_hostname(
self, instance: dict, wait_in_sec: int = 5, interval: int = 3
):
from time import sleep

while interval >= 1:
logger.info(
"Waiting :{} (sec) for provisioning to complete.".format(wait_in_sec)
)
sleep(wait_in_sec)
interval = interval - 1
refereshed_instance = self.embedded_cli.invoke(
"dt show -n {} -g {}".format(
instance["name"], instance["resourceGroup"]
)
).as_json()

if refereshed_instance.get("hostName"):
return refereshed_instance

return instance
10 changes: 0 additions & 10 deletions azext_iot/tests/digitaltwins/test_dt_privatelinks_lifecycle_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,17 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------

import pytest
from knack.log import get_logger
from azext_iot.tests.settings import DynamoSettings
from . import DTLiveScenarioTest
from . import generate_resource_id, generate_generic_id

logger = get_logger(__name__)

resource_test_env_vars = ["azext_dt_vnet_subnet_id"]

settings = DynamoSettings(opt_env_set=resource_test_env_vars)


class TestDTPrivateLinksLifecycle(DTLiveScenarioTest):
def __init__(self, test_case):
super(TestDTPrivateLinksLifecycle, self).__init__(test_case)

@pytest.mark.skipif(
not settings.env.azext_dt_vnet_subnet_id,
reason="Set azext_dt_vnet_subnet_id (fully qualified resource Id) for private-link/private-endpoint tests.",
)
def test_dt_privatelinks(self):
self.wait_for_capacity()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_dt_resource(self):
self.track_instance(create_output)

assert_common_resource_attributes(
create_output,
self.wait_for_hostname(create_output),
instance_names[0],
self.rg,
self.region,
Expand All @@ -97,7 +97,7 @@ def test_dt_resource(self):
self.track_instance(create_msi_output)

assert_common_resource_attributes(
create_msi_output,
self.wait_for_hostname(create_msi_output),
instance_names[1],
self.rg,
self.rg_region,
Expand Down Expand Up @@ -567,6 +567,7 @@ def assert_common_resource_attributes(
):
assert instance_output["createdTime"]
hostname = instance_output.get("hostName")

assert hostname, "Provisioned instance is missing hostName."
assert hostname.startswith(resource_id)
assert instance_output["location"] == location
Expand Down