Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sdk/eventhub/azure-eventhub/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Release History

## 5.2.0b2 (Unreleased)
## 5.2.0 (2020-09-08)

**New Features**

- Connection strings used with `from_connection_string` methods now supports using the `SharedAccessSignature` key in leiu of `sharedaccesskey` and `sharedaccesskeyname`, taking the string of the properly constructed token as value.

## 5.2.0b1 (2020-07-06)

Expand Down
54 changes: 49 additions & 5 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from uamqp import AMQPClient, Message, authentication, constants, errors, compat, utils
import six
from azure.core.credentials import AccessToken

from .exceptions import _handle_exception, ClientClosedError, ConnectError
from ._configuration import Configuration
Expand All @@ -43,11 +44,13 @@


def _parse_conn_str(conn_str, kwargs):
# type: (str, Dict[str, Any]) -> Tuple[str, str, str, str]
# type: (str, Dict[str, Any]) -> Tuple[str, Optional[str], Optional[str], str, Optional[str], Optional[int]]
endpoint = None
shared_access_key_name = None
shared_access_key = None
entity_path = None # type: Optional[str]
shared_access_signature = None # type: Optional[str]
shared_access_signature_expiry = None # type: Optional[int]
eventhub_name = kwargs.pop("eventhub_name", None) # type: Optional[str]
for element in conn_str.split(";"):
key, _, value = element.partition("=")
Expand All @@ -61,7 +64,16 @@ def _parse_conn_str(conn_str, kwargs):
shared_access_key = value
elif key.lower() == "entitypath":
entity_path = value
if not all([endpoint, shared_access_key_name, shared_access_key]):
elif key.lower() == "sharedaccesssignature":
shared_access_signature = value
try:
# Expiry can be stored in the "se=<timestamp>" clause of the token. ('&'-separated key-value pairs)
# type: ignore
shared_access_signature_expiry = int(shared_access_signature.split('se=')[1].split('&')[0])
except (IndexError, TypeError, ValueError): # Fallback since technically expiry is optional.
# An arbitrary, absurdly large number, since you can't renew.
shared_access_signature_expiry = int(time.time() * 2)
if not (all((endpoint, shared_access_key_name, shared_access_key)) or all((endpoint, shared_access_signature))):
raise ValueError(
"Invalid connection string. Should be in the format: "
"Endpoint=sb://<FQDN>/;SharedAccessKeyName=<KeyName>;SharedAccessKey=<KeyValue>"
Expand All @@ -72,7 +84,12 @@ def _parse_conn_str(conn_str, kwargs):
host = cast(str, endpoint)[left_slash_pos + 2 :]
else:
host = str(endpoint)
return host, str(shared_access_key_name), str(shared_access_key), entity
return (host,
str(shared_access_key_name) if shared_access_key_name else None,
str(shared_access_key) if shared_access_key else None,
entity,
str(shared_access_signature) if shared_access_signature else None,
shared_access_signature_expiry)


def _generate_sas_token(uri, policy, key, expiry=None):
Expand Down Expand Up @@ -124,6 +141,30 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
return _generate_sas_token(scopes[0], self.policy, self.key)


class EventHubSASTokenCredential(object):
"""The shared access token credential used for authentication.

:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token, expiry):
# type: (str, int) -> None
"""
:param str token: The shared access token string
:param float expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
# type: (str, Any) -> AccessToken
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ClientBase(object): # pylint:disable=too-many-instance-attributes
def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwargs):
# type: (str, str, TokenCredential, Any) -> None
Expand All @@ -148,10 +189,13 @@ def __init__(self, fully_qualified_namespace, eventhub_name, credential, **kwarg
@staticmethod
def _from_connection_string(conn_str, **kwargs):
# type: (str, Any) -> Dict[str, Any]
host, policy, key, entity = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

def _create_auth(self):
Expand Down
2 changes: 1 addition & 1 deletion sdk/eventhub/azure-eventhub/azure/eventhub/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# Licensed under the MIT License.
# ------------------------------------

VERSION = "5.2.0b2"
VERSION = "5.2.0"
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Message,
AMQPClientAsync,
)
from azure.core.credentials import AccessToken

from .._client_base import ClientBase, _generate_sas_token, _parse_conn_str
from .._utils import utc_from_timestamp
Expand Down Expand Up @@ -62,6 +63,28 @@ async def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
return _generate_sas_token(scopes[0], self.policy, self.key)


class EventHubSASTokenCredential(object):
"""The shared access token credential used for authentication.

:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
def __init__(self, token: str, expiry: int) -> None:
"""
:param str token: The shared access token string
:param int expiry: The epoch timestamp
"""
self.token = token
self.expiry = expiry
self.token_type = b"servicebus.windows.net:sastoken"

async def get_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:disable=unused-argument
"""
This method is automatically called when token is about to expire.
"""
return AccessToken(self.token, self.expiry)


class ClientBaseAsync(ClientBase):
def __init__(
self,
Expand All @@ -86,10 +109,13 @@ def __enter__(self):

@staticmethod
def _from_connection_string(conn_str: str, **kwargs) -> Dict[str, Any]:
host, policy, key, entity = _parse_conn_str(conn_str, kwargs)
host, policy, key, entity, token, token_expiry = _parse_conn_str(conn_str, kwargs)
kwargs["fully_qualified_namespace"] = host
kwargs["eventhub_name"] = entity
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
if token and token_expiry:
kwargs["credential"] = EventHubSASTokenCredential(token, token_expiry)
elif policy and key:
kwargs["credential"] = EventHubSharedKeyCredential(policy, key)
return kwargs

async def _create_auth_async(self) -> authentication.JWTTokenAsync:
Expand Down
8 changes: 8 additions & 0 deletions sdk/eventhub/azure-eventhub/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,3 +224,11 @@ def connstr_senders(live_eventhub):
for s in senders:
s.close()
client.close()

# Note: This is duplicated between here and the basic conftest, so that it does not throw warnings if you're
# running locally to this SDK. (Everything works properly, pytest just makes a bit of noise.)
def pytest_configure(config):
# register an additional marker
config.addinivalue_line(
"markers", "liveTest: mark test to be a live test only"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,19 @@

import pytest
import asyncio
import datetime
import time

from azure.identity.aio import EnvironmentCredential
from azure.eventhub import EventData
from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient
from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient, EventHubSharedKeyCredential
from azure.eventhub.aio._client_base_async import EventHubSASTokenCredential

from devtools_testutils import AzureMgmtTestCase, CachedResourceGroupPreparer
from tests.eventhub_preparer import (
CachedEventHubNamespacePreparer,
CachedEventHubPreparer
)

@pytest.mark.liveTest
@pytest.mark.asyncio
Expand Down Expand Up @@ -43,3 +51,51 @@ def on_event(partition_context, event):
assert on_event.called is True
assert on_event.partition_id == "0"
assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8')


class AsyncEventHubAuthTests(AzureMgmtTestCase):

@pytest.mark.liveTest
@pytest.mark.live_test_only
@CachedResourceGroupPreparer(name_prefix='eventhubtest')
@CachedEventHubNamespacePreparer(name_prefix='eventhubtest')
@CachedEventHubPreparer(name_prefix='eventhubtest')
async def test_client_sas_credential_async(self,
eventhub,
eventhub_namespace,
eventhub_namespace_key_name,
eventhub_namespace_primary_key,
eventhub_namespace_connection_string,
**kwargs):
# This should "just work" to validate known-good.
hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name)
producer_client = EventHubProducerClient.from_connection_string(eventhub_namespace_connection_string, eventhub_name = eventhub.name)

async with producer_client:
batch = await producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await producer_client.send_batch(batch)

# This should also work, but now using SAS tokens.
credential = EventHubSharedKeyCredential(eventhub_namespace_key_name, eventhub_namespace_primary_key)
hostname = "{}.servicebus.windows.net".format(eventhub_namespace.name)
auth_uri = "sb://{}/{}".format(hostname, eventhub.name)
token = (await credential.get_token(auth_uri)).token
producer_client = EventHubProducerClient(fully_qualified_namespace=hostname,
eventhub_name=eventhub.name,
credential=EventHubSASTokenCredential(token, time.time() + 3000))

async with producer_client:
batch = await producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await producer_client.send_batch(batch)

# Finally let's do it with SAS token + conn str
token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode())
conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str,
eventhub_name=eventhub.name)

async with conn_str_producer_client:
batch = await conn_str_producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
await conn_str_producer_client.send_batch(batch)
39 changes: 37 additions & 2 deletions sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import pytest
import time
import threading
import datetime

from azure.identity import EnvironmentCredential
from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient

from azure.eventhub import EventData, EventHubProducerClient, EventHubConsumerClient, EventHubSharedKeyCredential
from azure.eventhub._client_base import EventHubSASTokenCredential

@pytest.mark.liveTest
def test_client_secret_credential(live_eventhub):
Expand Down Expand Up @@ -46,3 +47,37 @@ def on_event(partition_context, event):
assert on_event.called is True
assert on_event.partition_id == "0"
assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8')

@pytest.mark.liveTest
def test_client_sas_credential(live_eventhub):
# This should "just work" to validate known-good.
hostname = live_eventhub['hostname']
producer_client = EventHubProducerClient.from_connection_string(live_eventhub['connection_str'], eventhub_name = live_eventhub['event_hub'])

with producer_client:
batch = producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
producer_client.send_batch(batch)

# This should also work, but now using SAS tokens.
credential = EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])
auth_uri = "sb://{}/{}".format(hostname, live_eventhub['event_hub'])
token = credential.get_token(auth_uri).token
producer_client = EventHubProducerClient(fully_qualified_namespace=hostname,
eventhub_name=live_eventhub['event_hub'],
credential=EventHubSASTokenCredential(token, time.time() + 3000))

with producer_client:
batch = producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
producer_client.send_batch(batch)

# Finally let's do it with SAS token + conn str
token_conn_str = "Endpoint=sb://{}/;SharedAccessSignature={};".format(hostname, token.decode())
conn_str_producer_client = EventHubProducerClient.from_connection_string(token_conn_str,
eventhub_name=live_eventhub['event_hub'])

with conn_str_producer_client:
batch = conn_str_producer_client.create_batch(partition_id='0')
batch.add(EventData(body='A single message'))
conn_str_producer_client.send_batch(batch)