Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion eng/tox/mypy_hard_failure_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"azure-servicebus",
"azure-ai-textanalytics",
"azure-ai-formrecognizer",
"azure-ai-metricsadvisor"
"azure-ai-metricsadvisor",
"azure-eventgrid",
]
2 changes: 1 addition & 1 deletion sdk/eventgrid/azure-eventgrid/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__path__ = __import__('pkgutil').extend_path(__path__, __name__)
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
8 changes: 4 additions & 4 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import cast, TYPE_CHECKING
import logging
from ._models import CloudEvent, EventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Union

_LOGGER = logging.getLogger(__name__)

Expand All @@ -35,7 +35,7 @@ def decode_cloud_event(self, cloud_event, **kwargs): # pylint: disable=no-self-u
cloud_event = CloudEvent._from_json(cloud_event, encode) # pylint: disable=protected-access
deserialized_event = CloudEvent._from_generated(cloud_event) # pylint: disable=protected-access
CloudEvent._deserialize_data(deserialized_event, deserialized_event.type) # pylint: disable=protected-access
return deserialized_event
return cast(CloudEvent, deserialized_event)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the error that triggered the need for this cast? I would expect the return type of CloudEvent._from_generated to be CloudEvent , and then this cast unecessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was unnecessary - removed it

except Exception as err:
_LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \
Event must be a string, dict, or bytes following the CloudEvent schema.')
Expand All @@ -58,7 +58,7 @@ def decode_eventgrid_event(self, eventgrid_event, **kwargs): # pylint: disable=n
eventgrid_event = EventGridEvent._from_json(eventgrid_event, encode) # pylint: disable=protected-access
deserialized_event = EventGridEvent.deserialize(eventgrid_event)
EventGridEvent._deserialize_data(deserialized_event, deserialized_event.event_type) # pylint: disable=protected-access
return deserialized_event
return cast(EventGridEvent, deserialized_event)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary unlike above - we get

Returning Any from function declared to return "EventGridEvent"

because the msrest's deserialize method on line 59 isn't typed

Copy link
Member

@lmazuel lmazuel Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then let's fix msrest too :)
Azure/msrest-for-python#226

But ok here, since we should not wait for msrest fix

except Exception as err:
_LOGGER.error('Error: cannot deserialize event. Event does not have a valid format. \
Event must be a string, dict, or bytes following the CloudEvent schema.')
Expand Down
8 changes: 6 additions & 2 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import hashlib
import hmac
import base64
from typing import TYPE_CHECKING, Any
try:
from urllib.parse import quote
except ImportError:
Expand All @@ -16,8 +17,11 @@
from ._signature_credential_policy import EventGridSharedAccessSignatureCredentialPolicy
from . import _constants as constants

if TYPE_CHECKING:
from datetime import datetime

def generate_shared_access_signature(topic_hostname, shared_access_key, expiration_date_utc, **kwargs):
# type: (str, str, datetime.Datetime, Any) -> str
# type: (str, str, datetime, Any) -> str
""" Helper method to generate shared access signature given hostname, key, and expiration date.
:param str topic_hostname: The topic endpoint to send the events to.
Similar to <YOUR-TOPIC-NAME>.<YOUR-REGION-NAME>-1.eventgrid.azure.net
Expand Down Expand Up @@ -82,7 +86,7 @@ def _get_authentication_policy(credential):
return authentication_policy

def _is_cloud_event(event):
# type: dict -> bool
# type: (dict) -> bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Input is actually Any, since since if the input is not a dict, you don't raise an exception, you return False.

This will remove the cast later down the road

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MY rationale was that we expect only a dictionary here - but yes, you are right - changing it

required = ('id', 'source', 'specversion', 'type')
try:
return all([_ in event for _ in required]) and event['specversion'] == "1.0"
Expand Down
1 change: 1 addition & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License. See License.txt in the project root for license information.
# --------------------------------------------------------------------------------------------
# pylint:disable=protected-access
from typing import Union, Any
import datetime as dt
import uuid
import json
Expand Down
3 changes: 3 additions & 0 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
# license information.
# --------------------------------------------------------------------------
import json
from typing import TYPE_CHECKING
import logging
from azure.core.pipeline.policies import SansIOHTTPPolicy

_LOGGER = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.core.pipeline import PipelineRequest

class CloudEventDistributedTracingPolicy(SansIOHTTPPolicy):
"""CloudEventDistributedTracingPolicy is a policy which adds distributed tracing informatiom
Expand Down
33 changes: 23 additions & 10 deletions sdk/eventgrid/azure-eventgrid/azure/eventgrid/_publisher_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# license information.
# --------------------------------------------------------------------------

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast, Dict, List, Any, Union

from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import (
Expand All @@ -27,10 +27,12 @@
from ._generated._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl
from ._policies import CloudEventDistributedTracingPolicy
from ._version import VERSION
from ._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent

if TYPE_CHECKING:
# pylint: disable=unused-import,ungrouped-imports
from typing import Any, Union, Dict, List
from azure.core.credentials import AzureKeyCredential
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
SendType = Union[
CloudEvent,
EventGridEvent,
Expand All @@ -42,6 +44,13 @@
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]


class EventGridPublisherClient(object):
"""EventGrid Python Publisher Client.
Expand Down Expand Up @@ -79,7 +88,7 @@ def _policies(credential, **kwargs):
CustomHookPolicy(**kwargs),
NetworkTraceLoggingPolicy(**kwargs),
DistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(**kwargs),
CloudEventDistributedTracingPolicy(),
HttpLoggingPolicy(**kwargs)
]
return policies
Expand All @@ -98,20 +107,24 @@ def send(self, events, **kwargs):
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(cast(Dict, e)) for e in events):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast(Dict, e) not necessary once you change _is_cloud_event input to Any

try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [cast(CloudEvent, e)._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
self._client.publish_events(self._topic_hostname, events, **kwargs)
self._client.publish_events(self._topic_hostname, cast(List[InternalEventGridEvent], events), **kwargs)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? What's the error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument 1 to "dict" has incompatible type "Union[CloudEvent, EventGridEvent, CustomEvent, Dict[Any, Any]]"; expected "Mapping[Any, Any]"
I can use a cast, but it won't entirely be true - ideally we validate that it's not a cloudevent, eventgrid event by the time we hit this line and they should not be included in the union - afaik, it's a problem with mypy.

self._client.publish_custom_event_events(self._topic_hostname, cast(List, serialized_events), **kwargs)
else:
raise ValueError("Event schema is not correct.")
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
# license information.
# -------------------------------------------------------------------------

from typing import Any, TYPE_CHECKING
import six

from azure.core.pipeline.policies import SansIOHTTPPolicy

if TYPE_CHECKING:
from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential


class EventGridSharedAccessSignatureCredentialPolicy(SansIOHTTPPolicy):
"""Adds a token header for the provided credential.
:param credential: The credential used to authenticate requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import Any, Union, List, Dict
from typing import Any, Union, List, Dict, cast
from azure.core.credentials import AzureKeyCredential
from azure.core.tracing.decorator_async import distributed_trace_async
from azure.core.pipeline.policies import (
Expand All @@ -26,19 +26,27 @@
from .._models import CloudEvent, EventGridEvent, CustomEvent
from .._helpers import _get_topic_hostname_only_fqdn, _get_authentication_policy, _is_cloud_event
from .._generated.aio import EventGridPublisherClient as EventGridPublisherClientAsync
from .._generated.models import CloudEvent as InternalCloudEvent, EventGridEvent as InternalEventGridEvent
from .._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
from .._version import VERSION

SendType = Union[
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]
CloudEvent,
EventGridEvent,
CustomEvent,
Dict,
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

ListEventType = Union[
List[CloudEvent],
List[EventGridEvent],
List[CustomEvent],
List[Dict]
]

class EventGridPublisherClient():
"""Asynchronous EventGrid Python Publisher Client.
Expand Down Expand Up @@ -101,20 +109,34 @@ async def send(
:raises: :class:`ValueError`, when events do not follow specified SendType.
"""
if not isinstance(events, list):
events = [events]
events = cast(ListEventType, [events])

if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(e) for e in events):
if all(isinstance(e, CloudEvent) for e in events) or all(_is_cloud_event(cast(Dict, e)) for e in events):
try:
events = [e._to_generated(**kwargs) for e in events] # pylint: disable=protected-access
events = [
cast(CloudEvent, e)._to_generated(**kwargs) for e in events # pylint: disable=protected-access
]
except AttributeError:
pass # means it's a dictionary
kwargs.setdefault("content_type", "application/cloudevents-batch+json; charset=utf-8")
await self._client.publish_cloud_event_events(self._topic_hostname, events, **kwargs)
await self._client.publish_cloud_event_events(
self._topic_hostname,
cast(List[InternalCloudEvent], events),
**kwargs
)
elif all(isinstance(e, EventGridEvent) for e in events) or all(isinstance(e, dict) for e in events):
kwargs.setdefault("content_type", "application/json; charset=utf-8")
await self._client.publish_events(self._topic_hostname, events, **kwargs)
await self._client.publish_events(
self._topic_hostname,
cast(List[InternalEventGridEvent], events),
**kwargs
)
elif all(isinstance(e, CustomEvent) for e in events):
serialized_events = [dict(e) for e in events]
await self._client.publish_custom_event_events(self._topic_hostname, serialized_events, **kwargs)
serialized_events = [dict(e) for e in events] # type: ignore
await self._client.publish_custom_event_events(
self._topic_hostname,
cast(List, serialized_events),
**kwargs
)
else:
raise ValueError("Event schema is not correct.")
13 changes: 13 additions & 0 deletions sdk/eventgrid/azure-eventgrid/mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[mypy]
python_version = 3.7
warn_return_any = True
warn_unused_configs = True
ignore_missing_imports = True

# Per-module options:

[mypy-azure.eventgrid._generated.*]
ignore_errors = True

[mypy-azure.core.*]
ignore_errors = True