diff --git a/sdk/eventgrid/azure-eventgrid/CHANGELOG.md b/sdk/eventgrid/azure-eventgrid/CHANGELOG.md index 2ad25d664031..2c1e07294ba5 100644 --- a/sdk/eventgrid/azure-eventgrid/CHANGELOG.md +++ b/sdk/eventgrid/azure-eventgrid/CHANGELOG.md @@ -1,11 +1,13 @@ # Release History -## 4.3.1 (Unreleased) +## 4.4.0 (Unreleased) - Bumped `msrest` dependency to `0.6.21` to align with mgmt package. ### Features Added +- `EventGridPublisherClient` now supports Azure Active Directory (AAD) for authentication. + ### Breaking Changes ### Key Bugs Fixed diff --git a/sdk/eventgrid/azure-eventgrid/README.md b/sdk/eventgrid/azure-eventgrid/README.md index 4328c3851d0a..c045e36e29a2 100644 --- a/sdk/eventgrid/azure-eventgrid/README.md +++ b/sdk/eventgrid/azure-eventgrid/README.md @@ -38,6 +38,34 @@ az eventgrid domain --create --location --resource-group ..eventgrid.azure.net/api/events"` diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_constants.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_constants.py index 0d26f09c4bdb..c246323f9476 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_constants.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_constants.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +DEFAULT_EVENTGRID_SCOPE = "https://eventgrid.azure.net/.default" EVENTGRID_KEY_HEADER = "aeg-sas-key" EVENTGRID_TOKEN_HEADER = "aeg-sas-token" DEFAULT_API_VERSION = "2018-01-01" diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py index bc7b15bf089f..8c51aafffafc 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py @@ -16,7 +16,7 @@ from msrest import Serializer from azure.core.pipeline.transport import HttpRequest -from azure.core.pipeline.policies import AzureKeyCredentialPolicy +from azure.core.pipeline.policies import AzureKeyCredentialPolicy, BearerTokenCredentialPolicy from azure.core.credentials import AzureKeyCredential, AzureSasCredential from ._signature_credential_policy import EventGridSasCredentialPolicy from . import _constants as constants @@ -28,7 +28,6 @@ if TYPE_CHECKING: from datetime import datetime - def generate_sas(endpoint, shared_access_key, expiration_date_utc, **kwargs): # type: (str, str, datetime, Any) -> str """Helper method to generate shared access signature given hostname, key, and expiration date. @@ -70,9 +69,14 @@ def _generate_hmac(key, message): return base64.b64encode(hmac_new) -def _get_authentication_policy(credential): +def _get_authentication_policy(credential, bearer_token_policy=BearerTokenCredentialPolicy): if credential is None: raise ValueError("Parameter 'self._credential' must not be None.") + if hasattr(credential, "get_token"): + return bearer_token_policy( + credential, + constants.DEFAULT_EVENTGRID_SCOPE + ) if isinstance(credential, AzureKeyCredential): return AzureKeyCredentialPolicy( credential=credential, name=constants.EVENTGRID_KEY_HEADER @@ -82,7 +86,7 @@ def _get_authentication_policy(credential): credential=credential, name=constants.EVENTGRID_TOKEN_HEADER ) raise ValueError( - "The provided credential should be an instance of AzureSasCredential or AzureKeyCredential" + "The provided credential should be an instance of a TokenCredential, AzureSasCredential or AzureKeyCredential" ) diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py index 9c9b6abcec4b..197fc41a3a28 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py @@ -40,7 +40,11 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials import AzureKeyCredential, AzureSasCredential + from azure.core.credentials import ( + AzureKeyCredential, + AzureSasCredential, + TokenCredential, + ) SendType = Union[ CloudEvent, @@ -60,8 +64,9 @@ class EventGridPublisherClient(object): :param str endpoint: The topic endpoint to send the events to. :param credential: The credential object used for authentication which - implements SAS key authentication or SAS token authentication. - :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.AzureSasCredential + implements SAS key authentication or SAS token authentication or a TokenCredential. + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.AzureSasCredential or + ~azure.core.credentials.TokenCredential :rtype: None .. admonition:: Example: @@ -82,7 +87,7 @@ class EventGridPublisherClient(object): """ def __init__(self, endpoint, credential, **kwargs): - # type: (str, Union[AzureKeyCredential, AzureSasCredential], Any) -> None + # type: (str, Union[AzureKeyCredential, AzureSasCredential, TokenCredential], Any) -> None self._endpoint = endpoint self._client = EventGridPublisherClientImpl( policies=EventGridPublisherClient._policies(credential, **kwargs), **kwargs @@ -90,7 +95,7 @@ def __init__(self, endpoint, credential, **kwargs): @staticmethod def _policies(credential, **kwargs): - # type: (Union[AzureKeyCredential, AzureSasCredential], Any) -> List[Any] + # type: (Union[AzureKeyCredential, AzureSasCredential, TokenCredential], Any) -> List[Any] auth_policy = _get_authentication_policy(credential) sdk_moniker = "eventgrid/{}".format(VERSION) policies = [ @@ -183,7 +188,8 @@ def send(self, events, **kwargs): if isinstance(events[0], CloudEvent) or _is_cloud_event(events[0]): try: events = [ - _cloud_event_to_generated(e, **kwargs) for e in events # pylint: disable=protected-access + _cloud_event_to_generated(e, **kwargs) + for e in events # pylint: disable=protected-access ] except AttributeError: pass # means it's a dictionary @@ -191,9 +197,8 @@ def send(self, events, **kwargs): elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event(events[0]): for event in events: _eventgrid_data_typecheck(event) - self._client._send_request( # pylint: disable=protected-access - _build_request(self._endpoint, content_type, events), - **kwargs + self._client._send_request( # pylint: disable=protected-access + _build_request(self._endpoint, content_type, events), **kwargs ) def close(self): diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_version.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_version.py index 6d23ac4acdeb..b5234b1c4677 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_version.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/_version.py @@ -9,4 +9,4 @@ # regenerated. # -------------------------------------------------------------------------- -VERSION = "4.3.1" +VERSION = "4.4.0" diff --git a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py index 7846bb27aae4..956b37fbe65f 100644 --- a/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py +++ b/sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py @@ -6,7 +6,7 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, Union, List, Dict, cast +from typing import Any, Union, List, Dict, TYPE_CHECKING, cast from azure.core.credentials import AzureKeyCredential, AzureSasCredential from azure.core.tracing.decorator_async import distributed_trace_async from azure.core.messaging import CloudEvent @@ -22,20 +22,24 @@ DistributedTracingPolicy, HttpLoggingPolicy, UserAgentPolicy, + AsyncBearerTokenCredentialPolicy, ) from .._policies import CloudEventDistributedTracingPolicy from .._models import EventGridEvent from .._helpers import ( - _get_authentication_policy, _is_cloud_event, _is_eventgrid_event, _eventgrid_data_typecheck, _build_request, _cloud_event_to_generated, + _get_authentication_policy, ) from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync from .._version import VERSION +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + SendType = Union[ CloudEvent, EventGridEvent, Dict, List[CloudEvent], List[EventGridEvent], List[Dict] ] @@ -49,8 +53,9 @@ class EventGridPublisherClient: :param str endpoint: The topic endpoint to send the events to. :param credential: The credential object used for authentication which implements - SAS key authentication or SAS token authentication. - :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.AzureSasCredential + SAS key authentication or SAS token authentication or an AsyncTokenCredential. + :type credential: ~azure.core.credentials.AzureKeyCredential or ~azure.core.credentials.AzureSasCredential or + ~azure.core.credentials_async.AsyncTokenCredential :rtype: None .. admonition:: Example: @@ -73,7 +78,9 @@ class EventGridPublisherClient: def __init__( self, endpoint: str, - credential: Union[AzureKeyCredential, AzureSasCredential], + credential: Union[ + "AsyncTokenCredential", AzureKeyCredential, AzureSasCredential + ], **kwargs: Any ) -> None: self._client = EventGridPublisherClientAsync( @@ -83,9 +90,14 @@ def __init__( @staticmethod def _policies( - credential: Union[AzureKeyCredential, AzureSasCredential], **kwargs: Any + credential: Union[ + AzureKeyCredential, AzureSasCredential, "AsyncTokenCredential" + ], + **kwargs: Any ) -> List[Any]: - auth_policy = _get_authentication_policy(credential) + auth_policy = _get_authentication_policy( + credential, AsyncBearerTokenCredentialPolicy + ) sdk_moniker = "eventgridpublisherclient/{}".format(VERSION) policies = [ RequestIdPolicy(**kwargs), @@ -176,7 +188,8 @@ async def send(self, events: SendType, **kwargs: Any) -> None: if isinstance(events[0], CloudEvent) or _is_cloud_event(events[0]): try: events = [ - _cloud_event_to_generated(e, **kwargs) for e in events # pylint: disable=protected-access + _cloud_event_to_generated(e, **kwargs) + for e in events # pylint: disable=protected-access ] except AttributeError: pass # means it's a dictionary @@ -184,9 +197,8 @@ async def send(self, events: SendType, **kwargs: Any) -> None: elif isinstance(events[0], EventGridEvent) or _is_eventgrid_event(events[0]): for event in events: _eventgrid_data_typecheck(event) - await self._client._send_request( # pylint: disable=protected-access - _build_request(self._endpoint, content_type, events), - **kwargs + await self._client._send_request( # pylint: disable=protected-access + _build_request(self._endpoint, content_type, events), **kwargs ) async def __aenter__(self) -> "EventGridPublisherClient": diff --git a/sdk/eventgrid/azure-eventgrid/samples/async_samples/sample_authentication_async.py b/sdk/eventgrid/azure-eventgrid/samples/async_samples/sample_authentication_async.py index 99480bb2e3bc..efe9df4efebe 100644 --- a/sdk/eventgrid/azure-eventgrid/samples/async_samples/sample_authentication_async.py +++ b/sdk/eventgrid/azure-eventgrid/samples/async_samples/sample_authentication_async.py @@ -38,3 +38,20 @@ credential = AzureSasCredential(signature) client = EventGridPublisherClient(endpoint, credential) # [END client_auth_with_sas_cred_async] + +# [START client_auth_with_token_cred_async] +from azure.identity.aio import DefaultAzureCredential +from azure.eventgrid.aio import EventGridPublisherClient +from azure.eventgrid import EventGridEvent + +event = EventGridEvent( + data={"team": "azure-sdk"}, + subject="Door1", + event_type="Azure.Sdk.Demo", + data_version="2.0" +) + +credential = DefaultAzureCredential() +endpoint = os.environ["EG_TOPIC_HOSTNAME"] +client = EventGridPublisherClient(endpoint, credential) +# [END client_auth_with_token_cred_async] \ No newline at end of file diff --git a/sdk/eventgrid/azure-eventgrid/samples/sync_samples/sample_authentication.py b/sdk/eventgrid/azure-eventgrid/samples/sync_samples/sample_authentication.py index 5751bd14ed70..108f4f3010ee 100644 --- a/sdk/eventgrid/azure-eventgrid/samples/sync_samples/sample_authentication.py +++ b/sdk/eventgrid/azure-eventgrid/samples/sync_samples/sample_authentication.py @@ -38,3 +38,12 @@ credential = AzureSasCredential(signature) client = EventGridPublisherClient(endpoint, credential) # [END client_auth_with_sas_cred] + +# [START client_auth_with_token_cred] +from azure.identity import DefaultAzureCredential +from azure.eventgrid import EventGridPublisherClient, EventGridEvent + +credential = DefaultAzureCredential() +endpoint = os.environ["EG_TOPIC_HOSTNAME"] +client = EventGridPublisherClient(endpoint, credential) +# [END client_auth_with_token_cred] \ No newline at end of file diff --git a/sdk/eventgrid/azure-eventgrid/tests/eventgrid_preparer.py b/sdk/eventgrid/azure-eventgrid/tests/eventgrid_preparer.py index 52bb4ac64dd2..320f778ce380 100644 --- a/sdk/eventgrid/azure-eventgrid/tests/eventgrid_preparer.py +++ b/sdk/eventgrid/azure-eventgrid/tests/eventgrid_preparer.py @@ -3,12 +3,14 @@ import os from collections import namedtuple +from azure_devtools.scenario_tests import ReplayableTest +from azure.core.credentials import AccessToken from azure.mgmt.eventgrid import EventGridManagementClient from azure.mgmt.eventgrid.models import Topic, InputSchema, JsonInputSchemaMapping, JsonField, JsonFieldWithDefault from azure_devtools.scenario_tests.exceptions import AzureTestError from devtools_testutils import ( - ResourceGroupPreparer, AzureMgmtPreparer, FakeResource + ResourceGroupPreparer, AzureMgmtPreparer, FakeResource, AzureMgmtTestCase ) from devtools_testutils.resource_testcase import RESOURCE_GROUP_PARAM @@ -25,6 +27,15 @@ DATA_VERSION_JSON_FIELD_WITH_DEFAULT = JsonFieldWithDefault(source_field='customDataVersion', default_value='') CUSTOM_JSON_INPUT_SCHEMA_MAPPING = JsonInputSchemaMapping(id=ID_JSON_FIELD, topic=TOPIC_JSON_FIELD, event_time=EVENT_TIME_JSON_FIELD, event_type=EVENT_TYPE_JSON_FIELD_WITH_DEFAULT, subject=SUBJECT_JSON_FIELD_WITH_DEFAULT, data_version=DATA_VERSION_JSON_FIELD_WITH_DEFAULT) +class FakeTokenCredential(object): + """Protocol for classes able to provide OAuth tokens. + :param str scopes: Lets you specify the type of access needed. + """ + def __init__(self): + self.token = AccessToken("YOU SHALL NOT PASS", 0) + + def get_token(self, *args): + return self.token class EventGridTopicPreparer(AzureMgmtPreparer): def __init__(self, @@ -94,4 +105,5 @@ def _get_resource_group(self, **kwargs): 'decorator @{} in front of this event grid topic preparer.' raise AzureTestError(template.format(ResourceGroupPreparer.__name__)) + CachedEventGridTopicPreparer = functools.partial(EventGridTopicPreparer, use_cache=True) diff --git a/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client.py b/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client.py index 27af5f816442..3651363c2435 100644 --- a/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client.py +++ b/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client.py @@ -29,7 +29,7 @@ from azure.eventgrid._helpers import _cloud_event_to_generated from eventgrid_preparer import ( - CachedEventGridTopicPreparer + CachedEventGridTopicPreparer, ) class EventGridPublisherClientTests(AzureMgmtTestCase): @@ -343,5 +343,19 @@ def test_send_custom_schema_event_as_list(self, resource_group, eventgrid_topic, def test_send_throws_with_bad_credential(self): bad_credential = "I am a bad credential" - with pytest.raises(ValueError, match="The provided credential should be an instance of AzureSasCredential or AzureKeyCredential"): + with pytest.raises(ValueError, match="The provided credential should be an instance of a TokenCredential, AzureSasCredential or AzureKeyCredential"): client = EventGridPublisherClient("eventgrid_endpoint", bad_credential) + + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='eventgridtest') + @CachedEventGridTopicPreparer(name_prefix='eventgridtest') + def test_send_token_credential(self, resource_group, eventgrid_topic, eventgrid_topic_primary_key, eventgrid_topic_endpoint): + credential = self.get_credential(EventGridPublisherClient) + client = EventGridPublisherClient(eventgrid_topic_endpoint, credential) + eg_event = EventGridEvent( + subject="sample", + data={"sample": "eventgridevent"}, + event_type="Sample.EventGrid.Event", + data_version="2.0" + ) + client.send(eg_event) diff --git a/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client_async.py b/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client_async.py index 656c0c7fa0db..b2ec715a6a65 100644 --- a/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client_async.py +++ b/sdk/eventgrid/azure-eventgrid/tests/test_eg_publisher_client_async.py @@ -29,6 +29,7 @@ CachedEventGridTopicPreparer ) + class EventGridPublisherClientTests(AzureMgmtTestCase): FILTER_HEADERS = ReplayableTest.FILTER_HEADERS + ['aeg-sas-key', 'aeg-sas-token'] @@ -328,4 +329,19 @@ async def test_send_and_close_async_session(self, resource_group, eventgrid_topi @pytest.mark.asyncio def test_send_NONE_credential_async(self, resource_group, eventgrid_topic, eventgrid_topic_primary_key, eventgrid_topic_endpoint): with pytest.raises(ValueError, match="Parameter 'self._credential' must not be None."): - client = EventGridPublisherClient(eventgrid_topic_endpoint, None) \ No newline at end of file + client = EventGridPublisherClient(eventgrid_topic_endpoint, None) + + @pytest.mark.live_test_only + @CachedResourceGroupPreparer(name_prefix='eventgridtest') + @CachedEventGridTopicPreparer(name_prefix='eventgridtest') + @pytest.mark.asyncio + async def test_send_token_credential(self, resource_group, eventgrid_topic, eventgrid_topic_primary_key, eventgrid_topic_endpoint): + credential = self.get_credential(EventGridPublisherClient) + client = EventGridPublisherClient(eventgrid_topic_endpoint, credential) + eg_event = EventGridEvent( + subject="sample", + data={"sample": "eventgridevent"}, + event_type="Sample.EventGrid.Event", + data_version="2.0" + ) + await client.send(eg_event) \ No newline at end of file