diff --git a/sdk/eventhub/azure-eventhub/CHANGELOG.md b/sdk/eventhub/azure-eventhub/CHANGELOG.md index 803e2df81e70..7ec20c6de589 100644 --- a/sdk/eventhub/azure-eventhub/CHANGELOG.md +++ b/sdk/eventhub/azure-eventhub/CHANGELOG.md @@ -1,7 +1,10 @@ # Release History -## 5.2.0b2 (Unreleased) +## 5.2.0 (2020-09-08) +**New Features** + +- Connection strings used with `from_connection_string` methods now supports using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value. ## 5.2.0b1 (2020-07-06) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py index ed9cf20e418c..79ef6b610659 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py @@ -20,6 +20,7 @@ from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils import six +from azure.core.credentials import AccessToken from .exceptions import _handle_exception, ClientClosedError, ConnectError from ._configuration import Configuration @@ -43,11 +44,13 @@ def _parse_conn_str(conn_str, kwargs): - # type: (str, Dict[str, Any]) -> Tuple[str, str, str, str] + # type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]] endpoint = None shared_access_key_name = None shared_access_key = None entity_path = None # type: Optional[str] + shared_access_signature = None # type: Optional[str] + shared_access_signature_expiry = None # type: Optional[int] eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str] for element in conn_str.split(";"): key, _, value = element.partition("=") @@ -61,7 +64,16 @@ def _parse_conn_str(conn_str, kwargs): shared_access_key = value elif key.lower() == "entitypath": entity_path = value - if not all([endpoint, shared_access_key_name, shared_access_key]): + elif key.lower() == "sharedaccesssignature": + shared_access_signature = value + try: + # Expiry can be stored in the "se=" clause of the token. ('&'-separated key-value pairs) + # type: ignore + shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0]) + except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional. + # An arbitrary, absurdly large number, since you can't renew. + shared_access_signature_expiry = int(time.time() * 2) + if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))): raise ValueError( "Invalid connection string. Should be in the format: " "Endpoint=sb:///;SharedAccessKeyName=;SharedAccessKey=" @@ -72,7 +84,12 @@ def _parse_conn_str(conn_str, kwargs): host = cast(str, endpoint)[left_slash_pos + 2 :] else: host = str(endpoint) - return host, str(shared_access_key_name), str(shared_access_key), entity + return (host, + str(shared_access_key_name) if shared_access_key_name else None, + str(shared_access_key) if shared_access_key else None, + entity, + str(shared_access_signature) if shared_access_signature else None, + shared_access_signature_expiry) def _generate_sas_token(uri, policy, key, expiry=None): @@ -124,6 +141,30 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument return _generate_sas_token(scopes[0], self.policy, self.key) +class EventHubSASTokenCredential(object): + """The shared access token credential used for authentication. + + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + def __init__(self, token, expiry): + # type: (str, int) -> None + """ + :param str token: The shared access token string + :param float expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" + + def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument + # type: (str, Any) -> AccessToken + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) + + class ClientBase(object): # pylint:disable=too-many-instance-attributes def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs): # type: (str, str, TokenCredential, Any) -> None @@ -148,10 +189,13 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg @staticmethod def _from_connection_string(conn_str, **kwargs): # type: (str, Any) -> Dict[str, Any] - host, policy, key, entity = _parse_conn_str(conn_str, kwargs) + host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs) kwargs["fully_qualified_namespace"] = host kwargs["eventhub_name"] = entity - kwargs["credential"] = EventHubSharedKeyCredential(policy, key) + if token and token_expiry: + kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry) + elif policy and key: + kwargs["credential"] = EventHubSharedKeyCredential(policy, key) return kwargs def _create_auth(self): diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py index aee8bd269e6b..0d016e67f81a 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_version.py @@ -3,4 +3,4 @@ # Licensed under the MIT License. # ------------------------------------ -VERSION = "5.2.0b2" +VERSION = "5.2.0" diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py index 680798c31ab4..49cb0f452b18 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_client_base_async.py @@ -19,6 +19,7 @@ Message, AMQPClientAsync, ) +from azure.core.credentials import AccessToken from .._client_base import ClientBase, _generate_sas_token, _parse_conn_str from .._utils import utc_from_timestamp @@ -62,6 +63,28 @@ async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument return _generate_sas_token(scopes[0], self.policy, self.key) +class EventHubSASTokenCredential(object): + """The shared access token credential used for authentication. + + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + def __init__(self, token: str, expiry: int) -> None: + """ + :param str token: The shared access token string + :param int expiry: The epoch timestamp + """ + self.token = token + self.expiry = expiry + self.token_type = b"servicebus.windows.net:sastoken" + + async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument + """ + This method is automatically called when token is about to expire. + """ + return AccessToken(self.token, self.expiry) + + class ClientBaseAsync(ClientBase): def __init__( self, @@ -86,10 +109,13 @@ def __enter__(self): @staticmethod def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]: - host, policy, key, entity = _parse_conn_str(conn_str, kwargs) + host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs) kwargs["fully_qualified_namespace"] = host kwargs["eventhub_name"] = entity - kwargs["credential"] = EventHubSharedKeyCredential(policy, key) + if token and token_expiry: + kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry) + elif policy and key: + kwargs["credential"] = EventHubSharedKeyCredential(policy, key) return kwargs async def _create_auth_async(self) -> authentication.JWTTokenAsync: diff --git a/sdk/eventhub/azure-eventhub/conftest.py b/sdk/eventhub/azure-eventhub/conftest.py index 41a52d7790ed..b89d82af4ad6 100644 --- a/sdk/eventhub/azure-eventhub/conftest.py +++ b/sdk/eventhub/azure-eventhub/conftest.py @@ -224,3 +224,11 @@ def connstr_senders(live_eventhub): for s in senders: s.close() client.close() + +# Note: This is duplicated between here and the basic conftest, so that it does not throw warnings if you're +# running locally to this SDK. (Everything works properly, pytest just makes a bit of noise.) +def pytest_configure(config): + # register an additional marker + config.addinivalue_line( + "markers", "liveTest: mark test to be a live test only" + ) \ No newline at end of file diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py index 1a055feacfde..e841a3b5a08c 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_auth_async.py @@ -6,11 +6,19 @@ import pytest import asyncio +import datetime +import time from azure.identity.aio import EnvironmentCredential from azure.eventhub import EventData -from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient +from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient, EventHubSharedKeyCredential +from azure.eventhub.aio._client_base_async import EventHubSASTokenCredential +from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer +from tests.eventhub_preparer import ( + CachedEventHubNamespacePreparer, + CachedEventHubPreparer +) @pytest.mark.liveTest @pytest.mark.asyncio @@ -43,3 +51,51 @@ def on_event(partition_context, event): assert on_event.called is True assert on_event.partition_id == "0" assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') + + +class AsyncEventHubAuthTests(AzureMgmtTestCase): + + @pytest.mark.liveTest + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='eventhubtest') + @CachedEventHubNamespacePreparer(name_prefix='eventhubtest') + @CachedEventHubPreparer(name_prefix='eventhubtest') + async def test_client_sas_credential_async(self, + eventhub, + eventhub_namespace, + eventhub_namespace_key_name, + eventhub_namespace_primary_key, + eventhub_namespace_connection_string, + **kwargs): + # This should "just work" to validate known-good. + hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name) + producer_client = EventHubProducerClient.from_connection_string(eventhub_namespace_connection_string, eventhub_name = eventhub.name) + + async with producer_client: + batch = await producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + await producer_client.send_batch(batch) + + # This should also work, but now using SAS tokens. + credential = EventHubSharedKeyCredential(eventhub_namespace_key_name, eventhub_namespace_primary_key) + hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name) + auth_uri = "sb://{}/{}".format(hostname, eventhub.name) + token = (await credential.get_token(auth_uri)).token + producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, + eventhub_name=eventhub.name, + credential=EventHubSASTokenCredential(token, time.time() + 3000)) + + async with producer_client: + batch = await producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + await producer_client.send_batch(batch) + + # Finally let's do it with SAS token + conn str + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, + eventhub_name=eventhub.name) + + async with conn_str_producer_client: + batch = await conn_str_producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + await conn_str_producer_client.send_batch(batch) diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py index 3959164d902d..8cbfdf19a04a 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py @@ -6,10 +6,11 @@ import pytest import time import threading +import datetime from azure.identity import EnvironmentCredential -from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient - +from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential +from azure.eventhub._client_base import EventHubSASTokenCredential @pytest.mark.liveTest def test_client_secret_credential(live_eventhub): @@ -46,3 +47,37 @@ def on_event(partition_context, event): assert on_event.called is True assert on_event.partition_id == "0" assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') + +@pytest.mark.liveTest +def test_client_sas_credential(live_eventhub): + # This should "just work" to validate known-good. + hostname = live_eventhub['hostname'] + producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub']) + + with producer_client: + batch = producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + producer_client.send_batch(batch) + + # This should also work, but now using SAS tokens. + credential = EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) + auth_uri = "sb://{}/{}".format(hostname, live_eventhub['event_hub']) + token = credential.get_token(auth_uri).token + producer_client = EventHubProducerClient(fully_qualified_namespace=hostname, + eventhub_name=live_eventhub['event_hub'], + credential=EventHubSASTokenCredential(token, time.time() + 3000)) + + with producer_client: + batch = producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + producer_client.send_batch(batch) + + # Finally let's do it with SAS token + conn str + token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode()) + conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str, + eventhub_name=live_eventhub['event_hub']) + + with conn_str_producer_client: + batch = conn_str_producer_client.create_batch(partition_id='0') + batch.add(EventData(body='A single message')) + conn_str_producer_client.send_batch(batch)