55# license information.
66# --------------------------------------------------------------------------
77
8- from typing import TYPE_CHECKING
8+ from typing import TYPE_CHECKING , cast , Dict , List , Any , Union
99
1010from azure .core .tracing .decorator import distributed_trace
1111from azure .core .pipeline .policies import (
2727from ._generated ._event_grid_publisher_client import EventGridPublisherClient as EventGridPublisherClientImpl
2828from ._policies import CloudEventDistributedTracingPolicy
2929from ._version import VERSION
30+ from ._generated .models import CloudEvent as InternalCloudEvent , EventGridEvent as InternalEventGridEvent
3031
3132if TYPE_CHECKING :
3233 # pylint: disable=unused-import,ungrouped-imports
33- from typing import Any , Union , Dict , List
34+ from azure .core .credentials import AzureKeyCredential
35+ from ._shared_access_signature_credential import EventGridSharedAccessSignatureCredential
3436 SendType = Union [
3537 CloudEvent ,
3638 EventGridEvent ,
4244 List [Dict ]
4345 ]
4446
47+ ListEventType = Union [
48+ List [CloudEvent ],
49+ List [EventGridEvent ],
50+ List [CustomEvent ],
51+ List [Dict ]
52+ ]
53+
4554
4655class EventGridPublisherClient (object ):
4756 """EventGrid Python Publisher Client.
@@ -79,7 +88,7 @@ def _policies(credential, **kwargs):
7988 CustomHookPolicy (** kwargs ),
8089 NetworkTraceLoggingPolicy (** kwargs ),
8190 DistributedTracingPolicy (** kwargs ),
82- CloudEventDistributedTracingPolicy (** kwargs ),
91+ CloudEventDistributedTracingPolicy (),
8392 HttpLoggingPolicy (** kwargs )
8493 ]
8594 return policies
@@ -98,20 +107,24 @@ def send(self, events, **kwargs):
98107 :raises: :class:`ValueError`, when events do not follow specified SendType.
99108 """
100109 if not isinstance (events , list ):
101- events = [events ]
110+ events = cast ( ListEventType , [events ])
102111
103112 if all (isinstance (e , CloudEvent ) for e in events ) or all (_is_cloud_event (e ) for e in events ):
104113 try :
105- events = [e ._to_generated (** kwargs ) for e in events ] # pylint: disable=protected-access
114+ events = [cast ( CloudEvent , e ) ._to_generated (** kwargs ) for e in events ] # pylint: disable=protected-access
106115 except AttributeError :
107116 pass # means it's a dictionary
108117 kwargs .setdefault ("content_type" , "application/cloudevents-batch+json; charset=utf-8" )
109- self ._client .publish_cloud_event_events (self ._topic_hostname , events , ** kwargs )
118+ self ._client .publish_cloud_event_events (
119+ self ._topic_hostname ,
120+ cast (List [InternalCloudEvent ], events ),
121+ ** kwargs
122+ )
110123 elif all (isinstance (e , EventGridEvent ) for e in events ) or all (isinstance (e , dict ) for e in events ):
111124 kwargs .setdefault ("content_type" , "application/json; charset=utf-8" )
112- self ._client .publish_events (self ._topic_hostname , events , ** kwargs )
125+ self ._client .publish_events (self ._topic_hostname , cast ( List [ InternalEventGridEvent ], events ) , ** kwargs )
113126 elif all (isinstance (e , CustomEvent ) for e in events ):
114- serialized_events = [dict (e ) for e in events ]
115- self ._client .publish_custom_event_events (self ._topic_hostname , serialized_events , ** kwargs )
127+ serialized_events = [dict (e ) for e in events ] # type: ignore
128+ self ._client .publish_custom_event_events (self ._topic_hostname , cast ( List , serialized_events ) , ** kwargs )
116129 else :
117130 raise ValueError ("Event schema is not correct." )
0 commit comments