diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index b793d38c2d28..b6bf9b4d2d6a 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -2,6 +2,9 @@ ## 5.2.2 (Unreleased) +**New Features** + +* Added a `parse_connection_string` method which parses a connection string into a properties bag, `EventHubConnectionStringProperties`, containing its component parts. ## 5.2.1 (2021-01-11) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py index b97277f9335b..2a5d7a304dd0 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/__init__.py @@ -14,6 +14,10 @@ from ._eventprocessor.checkpoint_store import CheckpointStore from ._eventprocessor.common import CloseReason, LoadBalancingStrategy from ._eventprocessor.partition_context import PartitionContext +from ._connection_string_parser import ( + parse_connection_string, + EventHubConnectionStringProperties +) TransportType = constants.TransportType @@ -28,4 +32,6 @@ "CloseReason", "LoadBalancingStrategy", "PartitionContext", + "parse_connection_string", + "EventHubConnectionStringProperties" ] diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 2013ae8454b3..ff3d910d815a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -437,3 +437,69 @@ def add(self, event_data): self.message._body_gen.append(event_data) # pylint: disable=protected-access self._size = size_after_add self._count += 1 + +class DictMixin(object): + def __setitem__(self, key, item): + # type: (Any, Any) -> None + self.__dict__[key] = item + + def __getitem__(self, key): + # type: (Any) -> Any + return self.__dict__[key] + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + # type: () -> str + return str(self) + + def __len__(self): + # type: () -> int + return len(self.keys()) + + def __delitem__(self, key): + # type: (Any) -> None + self.__dict__[key] = None + + def __eq__(self, other): + # type: (Any) -> bool + """Compare objects by comparing all attributes.""" + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + # type: (Any) -> bool + """Compare objects by comparing all attributes.""" + return not self.__eq__(other) + + def __str__(self): + # type: () -> str + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + + def has_key(self, k): + # type: (Any) -> bool + return k in self.__dict__ + + def update(self, *args, **kwargs): + # type: (Any, Any) -> None + return self.__dict__.update(*args, **kwargs) + + def keys(self): + # type: () -> list + return [k for k in self.__dict__ if not k.startswith("_")] + + def values(self): + # type: () -> list + return [v for k, v in self.__dict__.items() if not k.startswith("_")] + + def items(self): + # type: () -> list + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] + + def get(self, key, default=None): + # type: (Any, Optional[Any]) -> Any + if key in self.__dict__: + return self.__dict__[key] + return default diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_string_parser.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_string_parser.py new file mode 100644 index 000000000000..e8918cec7629 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_string_parser.py @@ -0,0 +1,108 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +try: + from urllib.parse import urlparse +except ImportError: + from urlparse import urlparse # type: ignore + +from ._common import DictMixin + + +class EventHubConnectionStringProperties(DictMixin): + """ + Properties of a connection string. + """ + + def __init__(self, **kwargs): + self._fully_qualified_namespace = kwargs.pop("fully_qualified_namespace", None) + self._endpoint = kwargs.pop("endpoint", None) + self._eventhub_name = kwargs.pop("eventhub_name", None) + self._shared_access_signature = kwargs.pop("shared_access_signature", None) + self._shared_access_key_name = kwargs.pop("shared_access_key_name", None) + self._shared_access_key = kwargs.pop("shared_access_key", None) + + @property + def fully_qualified_namespace(self): + """The fully qualified host name for the Event Hubs namespace. + The namespace format is: `.servicebus.windows.net`. + """ + return self._fully_qualified_namespace + + @property + def endpoint(self): + """The endpoint for the Event Hubs resource. In the format sb:///""" + return self._endpoint + + @property + def eventhub_name(self): + """Optional. The name of the Event Hub, represented by `EntityPath` in the connection string.""" + return self._eventhub_name + + @property + def shared_access_signature(self): + """ + This can be provided instead of the shared_access_key_name and the shared_access_key. + """ + return self._shared_access_signature + + @property + def shared_access_key_name(self): + """ + The name of the shared_access_key. This must be used along with the shared_access_key. + """ + return self._shared_access_key_name + + @property + def shared_access_key(self): + """ + The shared_access_key can be used along with the shared_access_key_name as a credential. + """ + return self._shared_access_key + + +def parse_connection_string(conn_str): + # type(str) -> EventHubConnectionStringProperties + """Parse the connection string into a properties bag containing its component parts. + + :param conn_str: The connection string that has to be parsed. + :type conn_str: str + :rtype: ~azure.eventhub.EventHubConnectionStringProperties + """ + conn_settings = [s.split("=", 1) for s in conn_str.split(";")] + if any(len(tup) != 2 for tup in conn_settings): + raise ValueError("Connection string is either blank or malformed.") + conn_settings = dict(conn_settings) + shared_access_signature = None + for key, value in conn_settings.items(): + if key.lower() == "sharedaccesssignature": + shared_access_signature = value + shared_access_key = conn_settings.get("SharedAccessKey") + shared_access_key_name = conn_settings.get("SharedAccessKeyName") + if any([shared_access_key, shared_access_key_name]) and not all( + [shared_access_key, shared_access_key_name] + ): + raise ValueError( + "Connection string must have both SharedAccessKeyName and SharedAccessKey." + ) + if shared_access_signature is not None and shared_access_key is not None: + raise ValueError( + "Only one of the SharedAccessKey or SharedAccessSignature must be present." + ) + endpoint = conn_settings.get("Endpoint") + if not endpoint: + raise ValueError("Connection string is either blank or malformed.") + parsed = urlparse(endpoint.rstrip("/")) + if not parsed.netloc: + raise ValueError("Invalid Endpoint on the Connection String.") + namespace = parsed.netloc.strip() + props = { + "fully_qualified_namespace": namespace, + "endpoint": endpoint, + "eventhub_name": conn_settings.get("EntityPath"), + "shared_access_signature": shared_access_signature, + "shared_access_key_name": shared_access_key_name, + "shared_access_key": shared_access_key, + } + return EventHubConnectionStringProperties(**props) diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_connection_string_parser.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_connection_string_parser.py new file mode 100644 index 000000000000..671d06dfa6dd --- /dev/null +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_connection_string_parser.py @@ -0,0 +1,71 @@ +#------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +#-------------------------------------------------------------------------- + +import os +import pytest +from azure.eventhub import ( + EventHubConnectionStringProperties, + parse_connection_string, +) + +from devtools_testutils import AzureMgmtTestCase + +class EventHubConnectionStringParserTests(AzureMgmtTestCase): + + def test_eh_conn_str_parse_cs(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str) + assert parse_result.endpoint == 'sb://eh-namespace.servicebus.windows.net/' + assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net' + assert parse_result.shared_access_key_name == 'test-policy' + assert parse_result.shared_access_key == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + + def test_eh_conn_str_parse_with_entity_path(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;EntityPath=eventhub-name' + parse_result = parse_connection_string(conn_str) + assert parse_result.endpoint == 'sb://eh-namespace.servicebus.windows.net/' + assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net' + assert parse_result.shared_access_key_name == 'test-policy' + assert parse_result.shared_access_key == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.eventhub_name == 'eventhub-name' + + def test_eh_conn_str_parse_sas_and_shared_key(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=;SharedAccessSignature=THISISASASXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Only one of the SharedAccessKey or SharedAccessSignature must be present.' + + def test_eh_parse_malformed_conn_str_no_endpoint(self, **kwargs): + conn_str = 'SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string is either blank or malformed.' + + def test_eh_parse_malformed_conn_str_no_netloc(self, **kwargs): + conn_str = 'Endpoint=MALFORMED;SharedAccessKeyName=test-policy;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Invalid Endpoint on the Connection String.' + + def test_eh_parse_conn_str_sas(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessSignature=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + parse_result = parse_connection_string(conn_str) + assert parse_result.endpoint == 'sb://eh-namespace.servicebus.windows.net/' + assert parse_result.fully_qualified_namespace == 'eh-namespace.servicebus.windows.net' + assert parse_result.shared_access_signature == 'THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + assert parse_result.shared_access_key_name == None + + def test_eh_parse_conn_str_no_keyname(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKey=THISISATESTKEYXXXXXXXXXXXXXXXXXXXXXXXXXXXX=' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.' + + def test_eh_parse_conn_str_no_key(self, **kwargs): + conn_str = 'Endpoint=sb://eh-namespace.servicebus.windows.net/;SharedAccessKeyName=test-policy' + with pytest.raises(ValueError) as e: + parse_result = parse_connection_string(conn_str) + assert str(e.value) == 'Connection string must have both SharedAccessKeyName and SharedAccessKey.'