diff --git a/azext_iot/iothub/common.py b/azext_iot/iothub/common.py index e7dc69ef4..9b3aa4dd5 100644 --- a/azext_iot/iothub/common.py +++ b/azext_iot/iothub/common.py @@ -193,3 +193,17 @@ class CertificateAuthorityVersions(Enum): """ v2 = "v2" v1 = "v1" + + +class IoTHubSDKVersion(Enum): + """ + Types to determine which object properties the hub supports for backwards compatibility with the + control plane sdk. Currently has these distinctions (from oldest to newest versions): + + No cosmos endpoints + Cosmos endpoints as collections + Cosmos endpoints as containers + """ + NoCosmos = 0 + CosmosCollections = 1 + CosmosContainers = 2 diff --git a/azext_iot/iothub/providers/message_endpoint.py b/azext_iot/iothub/providers/message_endpoint.py index 45960528d..b47c62f6a 100644 --- a/azext_iot/iothub/providers/message_endpoint.py +++ b/azext_iot/iothub/providers/message_endpoint.py @@ -23,7 +23,8 @@ SYSTEM_ASSIGNED_IDENTITY, AuthenticationType, EncodingFormat, - EndpointType + EndpointType, + IoTHubSDKVersion ) from azext_iot.iothub.providers.base import IoTHubProvider from azext_iot.common._azure import parse_cosmos_db_connection_string @@ -42,7 +43,12 @@ def __init__( rg: Optional[str] = None, ): super(MessageEndpoint, self).__init__(cmd, hub_name, rg, dataplane=False) - self.support_cosmos = hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections") + # Temporary flag to check for which cosmos property to look for. + self.support_cosmos = IoTHubSDKVersion.NoCosmos.value + if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_collections"): + self.support_cosmos = IoTHubSDKVersion.CosmosCollections.value + if hasattr(self.hub_resource.properties.routing.endpoints, "cosmos_db_sql_containers"): + self.support_cosmos = IoTHubSDKVersion.CosmosContainers.value self.cli = EmbeddedCLI(cli_ctx=self.cmd.cli_ctx) def create( @@ -179,16 +185,22 @@ def create( del new_endpoint["connectionString"] new_endpoint.update({ "databaseName": database_name, - "collectionName": container_name, "primaryKey": primary_key, "secondaryKey": secondary_key, "partitionKeyName": partition_key_name, "partitionKeyTemplate": partition_key_template, }) - # TODO @vilit - why is this None if empty - if endpoints.cosmos_db_sql_collections is None: - endpoints.cosmos_db_sql_collections = [] - endpoints.cosmos_db_sql_collections.append(new_endpoint) + # @vilit - None checks for when the service breaks things + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + new_endpoint["containerName"] = container_name + if endpoints.cosmos_db_sql_containers is None: + endpoints.cosmos_db_sql_containers = [] + endpoints.cosmos_db_sql_containers.append(new_endpoint) + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + new_endpoint["collectionName"] = container_name + if endpoints.cosmos_db_sql_collections is None: + endpoints.cosmos_db_sql_collections = [] + endpoints.cosmos_db_sql_collections.append(new_endpoint) elif endpoint_type.lower() == EndpointType.AzureStorageContainer.value: if fetch_connection_string: # try to get connection string @@ -369,8 +381,11 @@ def _show_by_type(self, endpoint_name: str, endpoint_type: Optional[str] = None) endpoint_list.extend(endpoints.service_bus_topics) if endpoint_type is None or endpoint_type.lower() == EndpointType.AzureStorageContainer.value: endpoint_list.extend(endpoints.storage_containers) - if self.support_cosmos and (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): - endpoint_list.extend(endpoints.cosmos_db_sql_collections) + if (endpoint_type is None or endpoint_type.lower() == EndpointType.CosmosDBContainer.value): + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoint_list.extend(endpoints.cosmos_db_sql_containers) + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoint_list.extend(endpoints.cosmos_db_sql_collections) for endpoint in endpoint_list: if endpoint.name.lower() == endpoint_name.lower(): @@ -397,8 +412,11 @@ def list(self, endpoint_type: Optional[str] = None): return endpoints.service_bus_queues elif EndpointType.ServiceBusTopic.value == endpoint_type: return endpoints.service_bus_topics - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos: - return endpoints.cosmos_db_sql_collections + elif EndpointType.CosmosDBContainer.value == endpoint_type: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + return endpoints.cosmos_db_sql_containers + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + return endpoints.cosmos_db_sql_collections elif EndpointType.CosmosDBContainer.value == endpoint_type: raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) elif EndpointType.AzureStorageContainer.value == endpoint_type: @@ -413,7 +431,9 @@ def delete( endpoints = self.hub_resource.properties.routing.endpoints if endpoint_type: endpoint_type = endpoint_type.lower() - if EndpointType.CosmosDBContainer.value == endpoint_type and not self.support_cosmos: + if ( + EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos == IoTHubSDKVersion.NoCosmos.value + ): raise InvalidArgumentValueError(INVALID_CLI_CORE_FOR_COSMOS) if self.hub_resource.properties.routing.enrichments or self.hub_resource.properties.routing.routes: @@ -433,8 +453,11 @@ def delete( endpoint_names.extend([e.name for e in endpoints.service_bus_queues]) if not endpoint_type or endpoint_type == EndpointType.ServiceBusTopic.value: endpoint_names.extend([e.name for e in endpoints.service_bus_topics]) - if self.support_cosmos and not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: - endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections]) + if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_containers]) + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoint_names.extend([e.name for e in endpoints.cosmos_db_sql_collections]) if not endpoint_type or endpoint_type == EndpointType.AzureStorageContainer.value: endpoint_names.extend([e.name for e in endpoints.storage_containers]) @@ -481,11 +504,17 @@ def delete( endpoints.service_bus_queues = [e for e in endpoints.service_bus_queues if e.name.lower() != endpoint_name] if not endpoint_type or EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [e for e in endpoints.service_bus_topics if e.name.lower() != endpoint_name] - if self.support_cosmos and not endpoint_type or EndpointType.CosmosDBContainer.value == endpoint_type: - cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else [] - endpoints.cosmos_db_sql_collections = [ - e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name - ] + if not endpoint_type or endpoint_type == EndpointType.CosmosDBContainer.value: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + cosmos_db_endpoints = endpoints.cosmos_db_sql_containers if endpoints.cosmos_db_sql_containers else [] + endpoints.cosmos_db_sql_containers = [ + e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name + ] + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + cosmos_db_endpoints = endpoints.cosmos_db_sql_collections if endpoints.cosmos_db_sql_collections else [] + endpoints.cosmos_db_sql_collections = [ + e for e in cosmos_db_endpoints if e.name.lower() != endpoint_name + ] if not endpoint_type or EndpointType.AzureStorageContainer.value == endpoint_type: endpoints.storage_containers = [e for e in endpoints.storage_containers if e.name.lower() != endpoint_name] elif endpoint_type: @@ -496,8 +525,11 @@ def delete( endpoints.service_bus_queues = [] elif EndpointType.ServiceBusTopic.value == endpoint_type: endpoints.service_bus_topics = [] - elif EndpointType.CosmosDBContainer.value == endpoint_type and self.support_cosmos: - endpoints.cosmos_db_sql_collections = [] + elif EndpointType.CosmosDBContainer.value == endpoint_type: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoints.cosmos_db_sql_containers = [] + elif self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: + endpoints.cosmos_db_sql_collections = [] elif EndpointType.AzureStorageContainer.value == endpoint_type: endpoints.storage_containers = [] else: @@ -505,7 +537,9 @@ def delete( endpoints.event_hubs = [] endpoints.service_bus_queues = [] endpoints.service_bus_topics = [] - if self.support_cosmos: + if self.support_cosmos == IoTHubSDKVersion.CosmosContainers.value: + endpoints.cosmos_db_sql_containers = [] + if self.support_cosmos == IoTHubSDKVersion.CosmosCollections.value: endpoints.cosmos_db_sql_collections = [] endpoints.storage_containers = [] diff --git a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py index 14b8788c3..c1e12b3e3 100644 --- a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py +++ b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_int.py @@ -1041,7 +1041,8 @@ def test_iot_cosmos_endpoint_lifecycle(provisioned_cosmosdb_with_identity_module ).as_json() assert len(cosmos_list) == 3 - assert endpoint_list["cosmosDbSqlCollections"] == cosmos_list + expected_list = endpoint_list.get("cosmosDbSqlCollections", []) + endpoint_list.get("cosmosDbSqlContainers", []) + assert cosmos_list == expected_list # Update # Keybased -> User, add pkn + pkt @@ -1457,8 +1458,7 @@ def build_expected_endpoint( expected["connectionString"] = connection_string if entity_path: expected["entityPath"] = entity_path - if container_name and not database_name: - # storage container + if container_name: expected["containerName"] = container_name if encoding: expected["encoding"] = encoding @@ -1471,9 +1471,6 @@ def build_expected_endpoint( expected["maxChunkSizeInBytes"] = max_chunk_size_in_bytes * max_chunk_size_constant if database_name: expected["databaseName"] = database_name - if container_name and database_name: - # cosmosdb container - expected["collectionName"] = container_name if partition_key_name: expected["partitionKeyName"] = partition_key_name if partition_key_template: @@ -1522,9 +1519,15 @@ def assert_endpoint_properties(result: dict, expected: dict): if "entityPath" in expected: assert result["entityPath"] == expected["entityPath"] - # Storage Account only + # Shared between Storage and Cosmos Db: if "containerName" in expected: - assert result["containerName"] == expected["containerName"] + resulting_container_name = result.get("containerName") + if resulting_container_name is None: + # older version of cosmos + resulting_container_name = result.get("collectionName") + assert resulting_container_name == expected["containerName"] + + # Storage Account only if "encoding" in expected: assert result["encoding"] == expected["encoding"] if "fileNameFormat" in expected: @@ -1537,8 +1540,6 @@ def assert_endpoint_properties(result: dict, expected: dict): # Cosmos DB only if "databaseName" in expected: assert result["databaseName"] == expected["databaseName"] - if "collectionName" in expected: - assert result["collectionName"] == expected["collectionName"] if "partitionKeyName" in expected: assert result["partitionKeyName"] == expected["partitionKeyName"] if "partitionKeyTemplate" in expected: diff --git a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py index 3b384aee3..28c0156a4 100644 --- a/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py +++ b/azext_iot/tests/iothub/message_endpoint/test_iothub_message_endpoint_unit.py @@ -52,6 +52,36 @@ def create_mock_endpoint(): hub_mock.properties.routing.endpoints.service_bus_queues = [create_mock_endpoint()] hub_mock.properties.routing.endpoints.service_bus_topics = [create_mock_endpoint()] hub_mock.properties.routing.endpoints.storage_containers = [create_mock_endpoint()] + hub_mock.properties.routing.endpoints.cosmos_db_sql_containers = [create_mock_endpoint()] + + def initialize_mock_client(self, *args): + self.client = mocker.MagicMock() + self.client.begin_create_or_update.return_value = generic_response + return hub_mock + + find_resource.side_effect = initialize_mock_client + + yield find_resource + + +@pytest.fixture() +def fixture_update_endpoint_backwards_comp_ops(mocker): + # Parse connection string + mocker.patch(parse_cosmos_db_cstring_path, return_value={ + "AccountKey": "get_cosmos_db_account_key", + "AccountEndpoint": "get_cosmos_db_account_endpoint" + }) + + # Hub Resource + find_resource = mocker.patch(path_find_resource, autospec=True) + + def create_mock_endpoint(): + endpoint = mocker.Mock() + endpoint.name = endpoint_name + return endpoint + + hub_mock = mocker.MagicMock() + del hub_mock.properties.routing.endpoints.cosmos_db_sql_containers hub_mock.properties.routing.endpoints.cosmos_db_sql_collections = [create_mock_endpoint()] def initialize_mock_client(self, *args): @@ -721,7 +751,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c assert req.get("resource_group_name") == resource_group hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2] # TODO: @vilit fix once service fixes their naming - endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections + endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_containers assert len(endpoints) == 1 endpoint = endpoints[0] @@ -800,7 +830,7 @@ def test_message_endpoint_update_cosmos_db_sql_container(self, mocker, fixture_c else: assert isinstance(endpoint.authentication_type, mock) - def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_cmd, fixture_update_endpoint_ops): + def test_message_endpoint_update_cosmos_db_sql_container_error(self, fixture_cmd, fixture_update_endpoint_ops): # Cannot do both types of Authentication with pytest.raises(MutuallyExclusiveArgumentError) as e: subject.message_endpoint_update_cosmos_db_container( @@ -848,3 +878,182 @@ def test_message_endpoint_update_cosmos_db_sql_collections_error(self, fixture_c hub_name=hub_name, endpoint_name=generate_names(), ) + + @pytest.mark.parametrize( + "req", + [ + {}, + { + "endpoint_resource_group": generate_names(), + "endpoint_subscription_id": generate_names(), + "database_name": generate_names(), + "connection_string": generate_names(), + "primary_key": None, + "secondary_key": None, + "endpoint_uri": generate_names(), + "partition_key_name": None, + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": generate_names(), + "partition_key_name": generate_names(), + "partition_key_template": generate_names(), + "identity": generate_names(), + "resource_group_name": generate_names(), + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": None, + "identity": "[system]", + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": generate_names(), + "primary_key": None, + "secondary_key": generate_names(), + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": generate_names(), + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": generate_names(), + "endpoint_subscription_id": None, + "database_name": None, + "connection_string": generate_names(), + "primary_key": generate_names(), + "secondary_key": generate_names(), + "endpoint_uri": None, + "partition_key_name": generate_names(), + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + { + "endpoint_resource_group": None, + "endpoint_subscription_id": None, + "database_name": generate_names(), + "connection_string": None, + "primary_key": None, + "secondary_key": None, + "endpoint_uri": None, + "partition_key_name": None, + "partition_key_template": None, + "identity": None, + "resource_group_name": None, + }, + ] + ) + def test_message_endpoint_update_cosmos_db_sql_collections( + self, mocker, fixture_cmd, fixture_update_endpoint_backwards_comp_ops, req + ): + result = subject.message_endpoint_update_cosmos_db_container( + cmd=fixture_cmd, + hub_name=hub_name, + endpoint_name=endpoint_name, + **req + ) + fixture_find_resource = fixture_update_endpoint_backwards_comp_ops + + assert result == generic_response + resource_group = fixture_find_resource.call_args[0][2] + assert req.get("resource_group_name") == resource_group + hub_resource = fixture_find_resource.call_args[0][0].client.begin_create_or_update.call_args[0][2] + # TODO: @vilit fix once service fixes their naming + endpoints = hub_resource.properties.routing.endpoints.cosmos_db_sql_collections + assert len(endpoints) == 1 + endpoint = endpoints[0] + + assert endpoint.name == endpoint_name + mock = mocker.Mock + + # if a prop is not set, it will be a Mock object + # Props that will always be set if present + if req.get("endpoint_resource_group"): + assert endpoint.resource_group == req.get("endpoint_resource_group") + else: + assert isinstance(endpoint.resource_group, mock) + + if req.get("endpoint_subscription_id"): + assert endpoint.subscription_id == req.get("endpoint_subscription_id") + else: + assert isinstance(endpoint.subscription_id, mock) + + if req.get("database_name"): + assert endpoint.database_name == req.get("database_name").lower() + else: + assert isinstance(endpoint.database_name, mock) + + if req.get("partition_key_name"): + partition_key_name = req.get("partition_key_name") + if partition_key_name == "": + assert endpoint.partition_key_name is None + else: + endpoint.partition_key_name == partition_key_name + else: + assert isinstance(endpoint.partition_key_name, mock) + + if req.get("partition_key_template"): + partition_key_template = req.get("partition_key_template") + if partition_key_template == "": + assert endpoint.partition_key_template is None + else: + endpoint.partition_key_template == partition_key_template + else: + assert isinstance(endpoint.partition_key_template, mock) + + # Connection strings dont exist + assert isinstance(endpoint.connection_string, mock) + + # Authentication props + if req.get("identity"): + assert endpoint.authentication_type == AuthenticationType.IdentityBased.value + assert endpoint.primary_key is None + assert endpoint.secondary_key is None + identity = req.get("identity") + if identity == "[system]": + assert endpoint.identity is None + else: + assert isinstance(endpoint.identity, ManagedIdentity) + assert endpoint.identity.user_assigned_identity == identity + elif any([req.get("connection_string"), req.get("primary_key"), req.get("secondary_key")]): + assert endpoint.authentication_type == AuthenticationType.KeyBased.value + assert endpoint.identity is None + assert endpoint.entity_path is None + connection_string = req.get("connection_string") + primary_key = req.get("primary_key") + secondary_key = req.get("secondary_key") + endpoint_uri = req.get("endpoint_uri") + + if primary_key: + assert endpoint.primary_key == primary_key + if secondary_key: + assert endpoint.secondary_key == secondary_key + if not primary_key and not secondary_key and connection_string: + assert endpoint.primary_key == endpoint.secondary_key == "get_cosmos_db_account_key" + + if endpoint_uri: + assert endpoint.endpoint_uri == endpoint_uri + elif connection_string: + assert endpoint.endpoint_uri == "get_cosmos_db_account_endpoint" + else: + assert isinstance(endpoint.authentication_type, mock)