Skip to content

Commit d2441fc

Browse files
author
Rakshith Bhyravabhotla
authored
Mypy Compatibilty for EventGrid (#14344)
* Mypy Compatibilyt for EventGrid * Update sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py * comments
1 parent 3b315d2 commit d2441fc

File tree

10 files changed

+95
-32
lines changed

10 files changed

+95
-32
lines changed

eng/tox/mypy_hard_failure_packages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
"azure-servicebus",
1212
"azure-ai-textanalytics",
1313
"azure-ai-formrecognizer",
14-
"azure-ai-metricsadvisor"
14+
"azure-ai-metricsadvisor",
15+
"azure-eventgrid",
1516
]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
1+
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
77
# --------------------------------------------------------------------------
88

9-
from typing import TYPE_CHECKING
9+
from typing import cast, TYPE_CHECKING
1010
import logging
1111
from ._models import CloudEvent, EventGridEvent
1212

1313
if TYPE_CHECKING:
1414
# pylint: disable=unused-import,ungrouped-imports
15-
from typing import Any
15+
from typing import Any, Union
1616

1717
_LOGGER = logging.getLogger(__name__)
1818

@@ -58,7 +58,7 @@ def decode_eventgrid_event(self, eventgrid_event, **kwargs): # pylint: disable=n
5858
eventgrid_event = EventGridEvent._from_json(eventgrid_event, encode) # pylint: disable=protected-access
5959
deserialized_event = EventGridEvent.deserialize(eventgrid_event)
6060
EventGridEvent._deserialize_data(deserialized_event, deserialized_event.event_type) # pylint: disable=protected-access
61-
return deserialized_event
61+
return cast(EventGridEvent, deserialized_event)
6262
except Exception as err:
6363
_LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \
6464
Event must be a string, dict, or bytes following the CloudEvent schema.')

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import hashlib
66
import hmac
77
import base64
8+
from typing import TYPE_CHECKING, Any
89
try:
910
from urllib.parse import quote
1011
except ImportError:
@@ -16,8 +17,11 @@
1617
from ._signature_credential_policy import EventGridSharedAccessSignatureCredentialPolicy
1718
from . import _constants as constants
1819

20+
if TYPE_CHECKING:
21+
from datetime import datetime
22+
1923
def generate_shared_access_signature(topic_hostname, shared_access_key, expiration_date_utc, **kwargs):
20-
# type: (str, str, datetime.Datetime, Any) -> str
24+
# type: (str, str, datetime, Any) -> str
2125
""" Helper method to generate shared access signature given hostname, key, and expiration date.
2226
:param str topic_hostname: The topic endpoint to send the events to.
2327
Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME>-1.eventgrid.azure.net
@@ -82,7 +86,7 @@ def _get_authentication_policy(credential):
8286
return authentication_policy
8387

8488
def _is_cloud_event(event):
85-
# type: dict -> bool
89+
# type: (Any) -> bool
8690
required = ('id', 'source', 'specversion', 'type')
8791
try:
8892
return all([_ in event for _ in required]) and event['specversion'] == "1.0"

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
55
# pylint:disable=protected-access
6+
from typing import Union, Any, Dict
67
import datetime as dt
78
import uuid
89
import json
@@ -87,6 +88,7 @@ def __init__(self, source, type, **kwargs): # pylint: disable=redefined-builtin
8788

8889
@classmethod
8990
def _from_generated(cls, cloud_event, **kwargs):
91+
# type: (Union[str, Dict, bytes], Any) -> CloudEvent
9092
generated = InternalCloudEvent.deserialize(cloud_event)
9193
if generated.additional_properties:
9294
extensions = dict(generated.additional_properties)

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
# license information.
55
# --------------------------------------------------------------------------
66
import json
7+
from typing import TYPE_CHECKING
78
import logging
89
from azure.core.pipeline.policies import SansIOHTTPPolicy
910

1011
_LOGGER = logging.getLogger(__name__)
1112

13+
if TYPE_CHECKING:
14+
from azure.core.pipeline import PipelineRequest
1215

1316
class CloudEventDistributedTracingPolicy(SansIOHTTPPolicy):
1417
"""CloudEventDistributedTracingPolicy is a policy which adds distributed tracing informatiom

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# license information.
66
# --------------------------------------------------------------------------
77

8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, cast, Dict, List, Any, Union
99

1010
from azure.core.tracing.decorator import distributed_trace
1111
from azure.core.pipeline.policies import (
@@ -27,10 +27,12 @@
2727
from ._generated._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl
2828
from ._policies import CloudEventDistributedTracingPolicy
2929
from ._version import VERSION
30+
from ._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent
3031

3132
if TYPE_CHECKING:
3233
# pylint: disable=unused-import,ungrouped-imports
33-
from typing import Any, Union, Dict, List
34+
from azure.core.credentials import AzureKeyCredential
35+
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
3436
SendType = Union[
3537
CloudEvent,
3638
EventGridEvent,
@@ -42,6 +44,13 @@
4244
List[Dict]
4345
]
4446

47+
ListEventType = Union[
48+
List[CloudEvent],
49+
List[EventGridEvent],
50+
List[CustomEvent],
51+
List[Dict]
52+
]
53+
4554

4655
class EventGridPublisherClient(object):
4756
"""EventGrid Python Publisher Client.
@@ -79,7 +88,7 @@ def _policies(credential, **kwargs):
7988
CustomHookPolicy(**kwargs),
8089
NetworkTraceLoggingPolicy(**kwargs),
8190
DistributedTracingPolicy(**kwargs),
82-
CloudEventDistributedTracingPolicy(**kwargs),
91+
CloudEventDistributedTracingPolicy(),
8392
HttpLoggingPolicy(**kwargs)
8493
]
8594
return policies
@@ -98,20 +107,24 @@ def send(self, events, **kwargs):
98107
:raises: :class:`ValueError`, when events do not follow specified SendType.
99108
"""
100109
if not isinstance(events, list):
101-
events = [events]
110+
events = cast(ListEventType, [events])
102111

103112
if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
104113
try:
105-
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
114+
events = [cast(CloudEvent, e)._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
106115
except AttributeError:
107116
pass # means it's a dictionary
108117
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
109-
self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
118+
self._client.publish_cloud_event_events(
119+
self._topic_hostname,
120+
cast(List[InternalCloudEvent], events),
121+
**kwargs
122+
)
110123
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
111124
kwargs.setdefault("content_type", "application/json; charset=utf-8")
112-
self._client.publish_events(self._topic_hostname, events, **kwargs)
125+
self._client.publish_events(self._topic_hostname, cast(List[InternalEventGridEvent], events), **kwargs)
113126
elif all(isinstance(e, CustomEvent) for e in events):
114-
serialized_events = [dict(e) for e in events]
115-
self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
127+
serialized_events = [dict(e) for e in events] # type: ignore
128+
self._client.publish_custom_event_events(self._topic_hostname, cast(List, serialized_events), **kwargs)
116129
else:
117130
raise ValueError("Event schema is not correct.")

sdk/eventgrid/azure-eventgrid/azure/eventgrid/_signature_credential_policy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,15 @@
44
# license information.
55
# -------------------------------------------------------------------------
66

7+
from typing import Any, TYPE_CHECKING
78
import six
89

910
from azure.core.pipeline.policies import SansIOHTTPPolicy
1011

12+
if TYPE_CHECKING:
13+
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
14+
15+
1116
class EventGridSharedAccessSignatureCredentialPolicy(SansIOHTTPPolicy):
1217
"""Adds a token header for the provided credential.
1318
:param credential: The credential used to authenticate requests.

sdk/eventgrid/azure-eventgrid/azure/eventgrid/aio/_publisher_client_async.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
77
# --------------------------------------------------------------------------
88

9-
from typing import Any, Union, List, Dict
9+
from typing import Any, Union, List, Dict, cast
1010
from azure.core.credentials import AzureKeyCredential
1111
from azure.core.tracing.decorator_async import distributed_trace_async
1212
from azure.core.pipeline.policies import (
@@ -26,19 +26,27 @@
2626
from .._models import CloudEvent, EventGridEvent, CustomEvent
2727
from .._helpers import _get_topic_hostname_only_fqdn, _get_authentication_policy, _is_cloud_event
2828
from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync
29+
from .._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent
2930
from .._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
3031
from .._version import VERSION
3132

3233
SendType = Union[
33-
CloudEvent,
34-
EventGridEvent,
35-
CustomEvent,
36-
Dict,
37-
List[CloudEvent],
38-
List[EventGridEvent],
39-
List[CustomEvent],
40-
List[Dict]
41-
]
34+
CloudEvent,
35+
EventGridEvent,
36+
CustomEvent,
37+
Dict,
38+
List[CloudEvent],
39+
List[EventGridEvent],
40+
List[CustomEvent],
41+
List[Dict]
42+
]
43+
44+
ListEventType = Union[
45+
List[CloudEvent],
46+
List[EventGridEvent],
47+
List[CustomEvent],
48+
List[Dict]
49+
]
4250

4351
class EventGridPublisherClient():
4452
"""Asynchronous EventGrid Python Publisher Client.
@@ -101,20 +109,34 @@ async def send(
101109
:raises: :class:`ValueError`, when events do not follow specified SendType.
102110
"""
103111
if not isinstance(events, list):
104-
events = [events]
112+
events = cast(ListEventType, [events])
105113

106114
if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
107115
try:
108-
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
116+
events = [
117+
cast(CloudEvent, e)._to_generated(**kwargs) for e in events # pylint: disable=protected-access
118+
]
109119
except AttributeError:
110120
pass # means it's a dictionary
111121
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
112-
await self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
122+
await self._client.publish_cloud_event_events(
123+
self._topic_hostname,
124+
cast(List[InternalCloudEvent], events),
125+
**kwargs
126+
)
113127
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
114128
kwargs.setdefault("content_type", "application/json; charset=utf-8")
115-
await self._client.publish_events(self._topic_hostname, events, **kwargs)
129+
await self._client.publish_events(
130+
self._topic_hostname,
131+
cast(List[InternalEventGridEvent], events),
132+
**kwargs
133+
)
116134
elif all(isinstance(e, CustomEvent) for e in events):
117-
serialized_events = [dict(e) for e in events]
118-
await self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
135+
serialized_events = [dict(e) for e in events] # type: ignore
136+
await self._client.publish_custom_event_events(
137+
self._topic_hostname,
138+
cast(List, serialized_events),
139+
**kwargs
140+
)
119141
else:
120142
raise ValueError("Event schema is not correct.")
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[mypy]
2+
python_version = 3.7
3+
warn_return_any = True
4+
warn_unused_configs = True
5+
ignore_missing_imports = True
6+
7+
# Per-module options:
8+
9+
[mypy-azure.eventgrid._generated.*]
10+
ignore_errors = True
11+
12+
[mypy-azure.core.*]
13+
ignore_errors = True

0 commit comments

Comments
 (0)