diff --git a/sdk/eventhub/azure-eventhubs/HISTORY.md b/sdk/eventhub/azure-eventhubs/HISTORY.md index 4c56f0eac0fb..20dcea9c7301 100644 --- a/sdk/eventhub/azure-eventhubs/HISTORY.md +++ b/sdk/eventhub/azure-eventhubs/HISTORY.md @@ -1,5 +1,9 @@ # Release History +## 2019-11-04 5.0.0b6 + +**Breaking changes** + ## 2019-11-04 5.0.0b5 **Breaking changes** diff --git a/sdk/eventhub/azure-eventhubs/README.md b/sdk/eventhub/azure-eventhubs/README.md index 84cd7ef10e3c..68164d85a8d5 100644 --- a/sdk/eventhub/azure-eventhubs/README.md +++ b/sdk/eventhub/azure-eventhubs/README.md @@ -189,13 +189,13 @@ client = EventHubConsumerClient.from_connection_string(connection_str, event_hub logger = logging.getLogger("azure.eventhub") -def on_events(partition_context, events): - logger.info("Received {} events from partition {}".format(len(events), partition_context.partition_id)) +def on_event(partition_context, event): + logger.info("Received event from partition {}".format(partition_context.partition_id)) with client: - client.receive(on_events=on_events, consumer_group="$Default") + client.receive(on_event=on_event, consumer_group="$Default") # receive events from specified partition: - # client.receive(on_events=on_events, consumer_group="$Default", partition_id='0') + # client.receive(on_event=on_event, consumer_group="$Default", partition_id='0') ``` ### Async publish events to an Event Hub @@ -273,15 +273,15 @@ event_hub_path = '<< NAME OF THE EVENT HUB >>' logger = logging.getLogger("azure.eventhub") -async def on_events(partition_context, events): - logger.info("Received {} events from partition {}".format(len(events), partition_context.partition_id)) +async def on_event(partition_context, event): + logger.info("Received event from partition {}".format(partition_context.partition_id)) async def receive(): client = EventHubConsumerClient.from_connection_string(connection_str, event_hub_path=event_hub_path) async with client: - received = await client.receive(on_events=on_events, consumer_group='$Default') + received = await client.receive(on_event=on_event, consumer_group='$Default') # receive events from specified partition: - # received = await client.receive(on_events=on_events, consumer_group='$Default', partition_id='0') + # received = await client.receive(on_event=on_event, consumer_group='$Default', partition_id='0') if __name__ == '__main__': loop = asyncio.get_event_loop() @@ -330,13 +330,12 @@ async def do_operation(event): print(event) -async def process_events(partition_context, events): - await asyncio.gather(*[do_operation(event) for event in events]) - await partition_context.update_checkpoint(events[-1]) +async def process_event(partition_context, event): + await partition_context.update_checkpoint(event) async def receive(client): try: - await client.receive(on_events=process_events, consumer_group="$Default") + await client.receive(on_event=process_event, consumer_group="$Default") except KeyboardInterrupt: await client.close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/__init__.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/__init__.py index 20f63a354cd9..6ac63715f71d 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/__init__.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/__init__.py @@ -4,17 +4,23 @@ # -------------------------------------------------------------------------------------------- __path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore -__version__ = "5.0.0b5" +__version__ = "5.0.0b6" from uamqp import constants # type: ignore -from .common import EventData, EventDataBatch, EventPosition -from .error import EventHubError, EventDataError, ConnectError, \ - AuthenticationError, EventDataSendError, ConnectionLostError +from ._common import EventData, EventDataBatch, EventPosition from ._producer_client import EventHubProducerClient from ._consumer_client import EventHubConsumerClient -from .common import EventHubSharedKeyCredential, EventHubSASTokenCredential +from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential from ._eventprocessor.partition_manager import PartitionManager from ._eventprocessor.common import CloseReason, OwnershipLostError from ._eventprocessor.partition_context import PartitionContext +from .exceptions import ( + EventHubError, + EventDataError, + ConnectError, + AuthenticationError, + EventDataSendError, + ConnectionLostError +) TransportType = constants.TransportType diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_client_base.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_client_base.py new file mode 100644 index 000000000000..2642d1e001c3 --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_client_base.py @@ -0,0 +1,378 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import unicode_literals + +import logging +import uuid +import time +import functools +import collections +from base64 import b64encode, b64decode +from hashlib import sha256 +from hmac import HMAC +from typing import Any, TYPE_CHECKING +try: + from urlparse import urlparse # type: ignore + from urllib import urlencode, quote_plus # type: ignore +except ImportError: + from urllib.parse import urlparse, urlencode, quote_plus + +from uamqp import ( + AMQPClient, + Message, + authentication, + constants, + errors, + compat +) + +from .exceptions import _handle_exception, ClientClosedError +from ._configuration import Configuration +from ._utils import parse_sas_token, utc_from_timestamp +from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential +from ._connection_manager import get_connection_manager +from ._constants import ( + CONTAINER_PREFIX, + JWT_TOKEN_SCOPE, + MGMT_OPERATION, + MGMT_PARTITION_OPERATION +) + +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential # type: ignore + +_LOGGER = logging.getLogger(__name__) + + +def _parse_conn_str(conn_str): + endpoint = None + shared_access_key_name = None + shared_access_key = None + entity_path = None + for element in conn_str.split(';'): + key, _, value = element.partition('=') + if key.lower() == 'endpoint': + endpoint = value.rstrip('/') + elif key.lower() == 'hostname': + endpoint = value.rstrip('/') + elif key.lower() == 'sharedaccesskeyname': + shared_access_key_name = value + elif key.lower() == 'sharedaccesskey': + shared_access_key = value + elif key.lower() == 'entitypath': + entity_path = value + if not all([endpoint, shared_access_key_name, shared_access_key]): + raise ValueError( + "Invalid connection string. Should be in the format: " + "Endpoint=sb:///;SharedAccessKeyName=;SharedAccessKey=") + return endpoint, shared_access_key_name, shared_access_key, entity_path + + +def _generate_sas_token(uri, policy, key, expiry=None): + """Create a shared access signiture token as a string literal. + :returns: SAS token as string literal. + :rtype: str + """ + if not expiry: + expiry = time.time() + 3600 # Default to 1 hour. + encoded_uri = quote_plus(uri) + ttl = int(expiry) + sign_key = '{}\n{}'.format(encoded_uri, ttl) + signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest()) + result = { + 'sr': uri, + 'sig': signature, + 'se': str(ttl)} + if policy: + result['skn'] = policy + return 'SharedAccessSignature ' + urlencode(result) + + +def _build_uri(address, entity): + parsed = urlparse(address) + if parsed.path: + return address + if not entity: + raise ValueError("No EventHub specified") + address += "/" + str(entity) + return address + + +_Address = collections.namedtuple('Address', 'hostname path') + + +class ClientBase(object): # pylint:disable=too-many-instance-attributes + def __init__(self, host, event_hub_path, credential, **kwargs): + self.eh_name = event_hub_path + path = "/" + event_hub_path if event_hub_path else "" + self._address = _Address(hostname=host, path=path) + self._container_id = CONTAINER_PREFIX + str(uuid.uuid4())[:8] + self._credential = credential + self._keep_alive = kwargs.get("keep_alive", 30) + self._auto_reconnect = kwargs.get("auto_reconnect", True) + self._mgmt_target = "amqps://{}/{}".format(self._address.hostname, self.eh_name) + self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) + self._config = Configuration(**kwargs) + self._debug = self._config.network_tracing + self._conn_manager = get_connection_manager(**kwargs) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + @classmethod + def from_connection_string(cls, conn_str, **kwargs): + event_hub_path = kwargs.pop("event_hub_path", None) + address, policy, key, entity = _parse_conn_str(conn_str) + entity = event_hub_path or entity + left_slash_pos = address.find("//") + if left_slash_pos != -1: + host = address[left_slash_pos + 2:] + else: + host = address + return cls(host, entity, EventHubSharedKeyCredential(policy, key), **kwargs) + + def _create_auth(self): + """ + Create an ~uamqp.authentication.SASTokenAuth instance to authenticate + the session. + """ + http_proxy = self._config.http_proxy + transport_type = self._config.transport_type + auth_timeout = self._config.auth_timeout + + # TODO: the following code can be refactored to create auth from classes directly instead of using if-else + if isinstance(self._credential, EventHubSharedKeyCredential): # pylint:disable=no-else-return + username = self._credential.policy + password = self._credential.key + if "@sas.root" in username: + return authentication.SASLPlain( + self._address.hostname, username, password, http_proxy=http_proxy, transport_type=transport_type) + return authentication.SASTokenAuth.from_shared_access_key( + self._auth_uri, username, password, timeout=auth_timeout, http_proxy=http_proxy, + transport_type=transport_type) + + elif isinstance(self._credential, EventHubSASTokenCredential): + token = self._credential.get_sas_token() + try: + expiry = int(parse_sas_token(token)['se']) + except (KeyError, TypeError, IndexError): + raise ValueError("Supplied SAS token has no valid expiry value.") + return authentication.SASTokenAuth( + self._auth_uri, self._auth_uri, token, + expires_at=expiry, + timeout=auth_timeout, + http_proxy=http_proxy, + transport_type=transport_type) + + else: # Azure credential + get_jwt_token = functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE) + return authentication.JWTTokenAuth(self._auth_uri, self._auth_uri, + get_jwt_token, http_proxy=http_proxy, + transport_type=transport_type) + + def _close_connection(self): + self._conn_manager.reset_connection_if_broken() + + def _backoff(self, retried_times, last_exception, timeout_time=None, entity_name=None): + entity_name = entity_name or self._container_id + backoff = self._config.backoff_factor * 2 ** retried_times + if backoff <= self._config.backoff_max and ( + timeout_time is None or time.time() + backoff <= timeout_time): # pylint:disable=no-else-return + time.sleep(backoff) + _LOGGER.info("%r has an exception (%r). Retrying...", format(entity_name), last_exception) + else: + _LOGGER.info("%r operation has timed out. Last exception before timeout is (%r)", + entity_name, last_exception) + raise last_exception + + def _management_request(self, mgmt_msg, op_type): + retried_times = 0 + last_exception = None + while retried_times <= self._config.max_retries: + mgmt_auth = self._create_auth() + mgmt_client = AMQPClient(self._mgmt_target) + try: + conn = self._conn_manager.get_connection(self._address.hostname, mgmt_auth) #pylint:disable=assignment-from-none + mgmt_client.open(connection=conn) + response = mgmt_client.mgmt_request( + mgmt_msg, + constants.READ_OPERATION, + op_type=op_type, + status_code_field=b'status-code', + description_fields=b'status-description') + return response + except Exception as exception: # pylint: disable=broad-except + last_exception = _handle_exception(exception, self) + self._backoff(retried_times=retried_times, last_exception=last_exception) + retried_times += 1 + if retried_times > self._config.max_retries: + _LOGGER.info("%r returns an exception %r", self._container_id, last_exception) + raise last_exception + finally: + mgmt_client.close() + + def _add_span_request_attributes(self, span): + span.add_attribute("component", "eventhubs") + span.add_attribute("message_bus.destination", self._address.path) + span.add_attribute("peer.address", self._address.hostname) + + def get_properties(self): + # type:() -> Dict[str, Any] + """Get properties of the EventHub. + + Keys in the returned dictionary include: + + - path + - created_at + - partition_ids + + :rtype: dict + :raises: :class:`EventHubError` + """ + mgmt_msg = Message(application_properties={'name': self.eh_name}) + response = self._management_request(mgmt_msg, op_type=MGMT_OPERATION) + output = {} + eh_info = response.get_data() + if eh_info: + output['path'] = eh_info[b'name'].decode('utf-8') + output['created_at'] = utc_from_timestamp(float(eh_info[b'created_at']) / 1000) + output['partition_ids'] = [p.decode('utf-8') for p in eh_info[b'partition_ids']] + return output + + def get_partition_ids(self): + # type:() -> List[str] + """ + Get partition ids of the specified EventHub. + + :rtype: list[str] + :raises: :class:`EventHubError` + """ + return self.get_properties()['partition_ids'] + + def get_partition_properties(self, partition): + # type:(str) -> Dict[str, Any] + """Get properties of the specified partition. + + Keys in the details dictionary include: + + - event_hub_path + - id + - beginning_sequence_number + - last_enqueued_sequence_number + - last_enqueued_offset + - last_enqueued_time_utc + - is_empty + + :param partition: The target partition id. + :type partition: str + :rtype: dict + :raises: :class:`EventHubError` + """ + mgmt_msg = Message(application_properties={'name': self.eh_name, + 'partition': partition}) + response = self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION) + partition_info = response.get_data() + output = {} + if partition_info: + output['event_hub_path'] = partition_info[b'name'].decode('utf-8') + output['id'] = partition_info[b'partition'].decode('utf-8') + output['beginning_sequence_number'] = partition_info[b'begin_sequence_number'] + output['last_enqueued_sequence_number'] = partition_info[b'last_enqueued_sequence_number'] + output['last_enqueued_offset'] = partition_info[b'last_enqueued_offset'].decode('utf-8') + output['is_empty'] = partition_info[b'is_partition_empty'] + output['last_enqueued_time_utc'] = utc_from_timestamp( + float(partition_info[b'last_enqueued_time_utc'] / 1000) + ) + return output + + def close(self): + # type:() -> None + self._conn_manager.close_connection() + + +class ConsumerProducerMixin(object): + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def _check_closed(self): + if self.closed: + raise ClientClosedError( + "{} has been closed. Please create a new one to handle event data.".format(self._name) + ) + + def _open(self): + """Open the EventHubConsumer/EventHubProducer using the supplied connection. + + """ + # pylint: disable=protected-access + if not self.running: + if self._handler: + self._handler.close() + self._create_handler() + self._handler.open(connection=self._client._conn_manager.get_connection( # pylint: disable=protected-access + self._client._address.hostname, + self._client._create_auth() + )) + while not self._handler.client_ready(): + time.sleep(0.05) + self._max_message_size_on_link = self._handler.message_handler._link.peer_max_message_size \ + or constants.MAX_MESSAGE_LENGTH_BYTES # pylint: disable=protected-access + self.running = True + + def _close_handler(self): + if self._handler: + self._handler.close() # close the link (sharing connection) or connection (not sharing) + self.running = False + + def _close_connection(self): + self._close_handler() + self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access + + def _handle_exception(self, exception): + if not self.running and isinstance(exception, compat.TimeoutException): + exception = errors.AuthenticationException("Authorization timeout.") + return _handle_exception(exception, self) + + def _do_retryable_operation(self, operation, timeout=None, **kwargs): + # pylint:disable=protected-access + timeout_time = (time.time() + timeout) if timeout else None + retried_times = 0 + last_exception = kwargs.pop('last_exception', None) + operation_need_param = kwargs.pop('operation_need_param', True) + max_retries = self._client._config.max_retries # pylint:disable=protected-access + + while retried_times <= max_retries: + try: + if operation_need_param: + return operation(timeout_time=timeout_time, last_exception=last_exception, **kwargs) + return operation() + except Exception as exception: # pylint:disable=broad-except + last_exception = self._handle_exception(exception) + self._client._backoff( + retried_times=retried_times, + last_exception=last_exception, + timeout_time=timeout_time, + entity_name=self._name + ) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) + raise last_exception + + def close(self): + # type:() -> None + """ + Close down the handler. If the handler has already closed, + this will be a no op. + """ + self._close_handler() + self.closed = True diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/common.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_common.py similarity index 64% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/common.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/_common.py index d48f0cd27a41..e39dd0846960 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/common.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_common.py @@ -10,35 +10,27 @@ import logging import six -from uamqp import BatchMessage, Message, types, constants # type: ignore -from uamqp.message import MessageHeader # type: ignore - -from azure.core.settings import settings # type: ignore - -from .error import EventDataError - -log = logging.getLogger(__name__) +from uamqp import BatchMessage, Message, constants # type: ignore + +from ._utils import set_message_partition_key, trace_message, utc_from_timestamp +from ._constants import ( + PROP_SEQ_NUMBER, + PROP_OFFSET, + PROP_PARTITION_KEY, + PROP_PARTITION_KEY_AMQP_SYMBOL, + PROP_TIMESTAMP, + PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, + PROP_LAST_ENQUEUED_OFFSET, + PROP_LAST_ENQUEUED_TIME_UTC, + PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, +) + +_LOGGER = logging.getLogger(__name__) # event_data.encoded_size < 255, batch encode overhead is 5, >=256, overhead is 8 each _BATCH_MESSAGE_OVERHEAD_COST = [5, 8] -def parse_sas_token(sas_token): - """Parse a SAS token into its components. - - :param sas_token: The SAS token. - :type sas_token: str - :rtype: dict[str, str] - """ - sas_data = {} - token = sas_token.partition(' ')[2] - fields = token.split('&') - for field in fields: - key, value = field.split('=', 1) - sas_data[key.lower()] = value - return sas_data - - class EventData(object): """ The EventData class is a holder of event content. @@ -57,16 +49,6 @@ class EventData(object): """ - _PROP_SEQ_NUMBER = b"x-opt-sequence-number" - _PROP_OFFSET = b"x-opt-offset" - _PROP_PARTITION_KEY = b"x-opt-partition-key" - _PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(_PROP_PARTITION_KEY) - _PROP_TIMESTAMP = b"x-opt-enqueued-time" - _PROP_LAST_ENQUEUED_SEQUENCE_NUMBER = b"last_enqueued_sequence_number" - _PROP_LAST_ENQUEUED_OFFSET = b"last_enqueued_offset" - _PROP_LAST_ENQUEUED_TIME_UTC = b"last_enqueued_time_utc" - _PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC = b"runtime_info_retrieval_time_utc" - def __init__(self, body=None): self._last_enqueued_event_properties = {} if body and isinstance(body, list): @@ -81,97 +63,68 @@ def __init__(self, body=None): self.message.application_properties = {} def __str__(self): - dic = { - 'body': self.body_as_str(), + try: + body = self.body_as_str() + except: # pylint: disable=bare-except + body = "" + message_as_dict = { + 'body': body, 'application_properties': str(self.application_properties) } + try: + if self.sequence_number: + message_as_dict['sequence_number'] = str(self.sequence_number) + if self.offset: + message_as_dict['offset'] = str(self.offset) + if self.enqueued_time: + message_as_dict['enqueued_time'] = str(self.enqueued_time) + if self.partition_key: + message_as_dict['partition_key'] = str(self.partition_key) + except: # pylint: disable=bare-except + pass + return str(message_as_dict) - if self.sequence_number: - dic['sequence_number'] = str(self.sequence_number) - if self.offset: - dic['offset'] = str(self.offset) - if self.enqueued_time: - dic['enqueued_time'] = str(self.enqueued_time) - if self.partition_key: - dic['partition_key'] = str(self.partition_key) - return str(dic) - - def _set_partition_key(self, value): - """ - Set the partition key of the event data object. + @classmethod + def _from_message(cls, message): + """Internal use only. - :param value: The partition key to set. - :type value: str or bytes - """ - annotations = dict(self.message.annotations) - annotations[EventData._PROP_PARTITION_KEY_AMQP_SYMBOL] = value - header = MessageHeader() - header.durable = True - self.message.annotations = annotations - self.message.header = header - - def _trace_message(self, parent_span=None): - """Add tracing information to this message. - - Will open and close a "Azure.EventHubs.message" span, and - add the "DiagnosticId" as app properties of the message. - """ - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is not None: - current_span = parent_span or span_impl_type(span_impl_type.get_current_span()) - message_span = current_span.span(name="Azure.EventHubs.message") - message_span.start() - app_prop = dict(self.application_properties) if self.application_properties else dict() - app_prop.setdefault(b"Diagnostic-Id", message_span.get_trace_parent().encode('ascii')) - self.application_properties = app_prop - message_span.finish() - - def _trace_link_message(self, parent_span=None): - """Link the current message to current span. - - Will extract DiagnosticId if available. + Creates an EventData object from a raw uamqp message. + + :param ~uamqp.Message message: A received uamqp message. + :rtype: ~azure.eventhub.EventData """ - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is not None: - current_span = parent_span or span_impl_type(span_impl_type.get_current_span()) - if current_span and self.application_properties: - traceparent = self.application_properties.get(b"Diagnostic-Id", "").decode('ascii') - if traceparent: - current_span.link(traceparent) + event_data = cls(body='') + event_data.message = message + return event_data def _get_last_enqueued_event_properties(self): + """Extracts the last enqueued event in from the received event delivery annotations. + + :rtype: Dict[str, Any] + """ if self._last_enqueued_event_properties: return self._last_enqueued_event_properties if self.message.delivery_annotations: - enqueued_time_stamp = \ - self.message.delivery_annotations.get(EventData._PROP_LAST_ENQUEUED_TIME_UTC, None) - retrieval_time_stamp = \ - self.message.delivery_annotations.get(EventData._PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None) - + sequence_number = self.message.delivery_annotations.get(PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None) + enqueued_time_stamp = self.message.delivery_annotations.get(PROP_LAST_ENQUEUED_TIME_UTC, None) + if enqueued_time_stamp: + enqueued_time_stamp = utc_from_timestamp(float(enqueued_time_stamp)/1000) + retrieval_time_stamp = self.message.delivery_annotations.get(PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC, None) + if retrieval_time_stamp: + retrieval_time_stamp = utc_from_timestamp(float(retrieval_time_stamp)/1000) + offset_bytes = self.message.delivery_annotations.get(PROP_LAST_ENQUEUED_OFFSET, None) + offset = offset_bytes.decode('UTF-8') if offset_bytes else None self._last_enqueued_event_properties = { - "sequence_number": - self.message.delivery_annotations.get(EventData._PROP_LAST_ENQUEUED_SEQUENCE_NUMBER, None), - "offset": - self.message.delivery_annotations.get(EventData._PROP_LAST_ENQUEUED_OFFSET, None), - "enqueued_time": - datetime.datetime.utcfromtimestamp( - float(enqueued_time_stamp)/1000) if enqueued_time_stamp else None, - "retrieval_time": - datetime.datetime.utcfromtimestamp( - float(retrieval_time_stamp)/1000) if retrieval_time_stamp else None + "sequence_number": sequence_number, + "offset": offset, + "enqueued_time": enqueued_time_stamp, + "retrieval_time": retrieval_time_stamp } return self._last_enqueued_event_properties return None - @classmethod - def _from_message(cls, message): - # pylint:disable=protected-access - event_data = cls(body='') - event_data.message = message - return event_data - @property def sequence_number(self): """ @@ -179,7 +132,7 @@ def sequence_number(self): :rtype: int or long """ - return self.message.annotations.get(EventData._PROP_SEQ_NUMBER, None) + return self.message.annotations.get(PROP_SEQ_NUMBER, None) @property def offset(self): @@ -189,7 +142,7 @@ def offset(self): :rtype: str """ try: - return self.message.annotations[EventData._PROP_OFFSET].decode('UTF-8') + return self.message.annotations[PROP_OFFSET].decode('UTF-8') except (KeyError, AttributeError): return None @@ -200,9 +153,9 @@ def enqueued_time(self): :rtype: datetime.datetime """ - timestamp = self.message.annotations.get(EventData._PROP_TIMESTAMP, None) + timestamp = self.message.annotations.get(PROP_TIMESTAMP, None) if timestamp: - return datetime.datetime.utcfromtimestamp(float(timestamp)/1000) + return utc_from_timestamp(float(timestamp)/1000) return None @property @@ -213,9 +166,9 @@ def partition_key(self): :rtype: bytes """ try: - return self.message.annotations[EventData._PROP_PARTITION_KEY_AMQP_SYMBOL] + return self.message.annotations[PROP_PARTITION_KEY_AMQP_SYMBOL] except KeyError: - return self.message.annotations.get(EventData._PROP_PARTITION_KEY, None) + return self.message.annotations.get(PROP_PARTITION_KEY, None) @property def application_properties(self): @@ -336,7 +289,7 @@ def __init__(self, max_size=None, partition_key=None): self._partition_key = partition_key self.message = BatchMessage(data=[], multi_messages=False, properties=None) - self._set_partition_key(partition_key) + set_message_partition_key(self.message, self._partition_key) self._size = self.message.gather()[0].get_message_encoded_size() self._count = 0 @@ -349,17 +302,6 @@ def _from_batch(batch_data, partition_key=None): batch_data_instance.message._body_gen = batch_data # pylint:disable=protected-access return batch_data_instance - def _set_partition_key(self, value): - if value: - annotations = self.message.annotations - if annotations is None: - annotations = dict() - annotations[types.AMQPSymbol(EventData._PROP_PARTITION_KEY)] = value # pylint:disable=protected-access - header = MessageHeader() - header.durable = True - self.message.annotations = annotations - self.message.header = header - @property def size(self): """The size of EventDataBatch object in bytes @@ -377,29 +319,22 @@ def try_add(self, event_data): :rtype: None :raise: :class:`ValueError`, when exceeding the size limit. """ - if event_data is None: - log.warning("event_data is None when calling EventDataBatch.try_add. Ignored") - return - if not isinstance(event_data, EventData): - raise TypeError('event_data should be type of EventData') - if self._partition_key: if event_data.partition_key and event_data.partition_key != self._partition_key: - raise EventDataError('The partition_key of event_data does not match the one of the EventDataBatch') + raise ValueError('The partition key of event_data does not match the partition key of this batch.') if not event_data.partition_key: - event_data._set_partition_key(self._partition_key) # pylint:disable=protected-access - - event_data._trace_message() # pylint:disable=protected-access + set_message_partition_key(event_data.message, self._partition_key) + trace_message(event_data) event_data_size = event_data.message.get_message_encoded_size() # For a BatchMessage, if the encoded_message_size of event_data is < 256, then the overhead cost to encode that - # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. - size_after_add = self._size + event_data_size\ + # message into the BatchMessage would be 5 bytes, if >= 256, it would be 8 bytes. + size_after_add = self._size + event_data_size \ + _BATCH_MESSAGE_OVERHEAD_COST[0 if (event_data_size < 256) else 1] if size_after_add > self.max_size: - raise ValueError("EventDataBatch has reached its size limit {}".format(self.max_size)) + raise ValueError("EventDataBatch has reached its size limit: {}".format(self.max_size)) self.message._body_gen.append(event_data) # pylint: disable=protected-access self._size = size_after_add @@ -481,9 +416,3 @@ class EventHubSharedKeyCredential(object): def __init__(self, policy, key): self.policy = policy self.key = key - - -class _Address(object): - def __init__(self, hostname=None, path=None): - self.hostname = hostname - self.path = path diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/configuration.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_configuration.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/configuration.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/_configuration.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_constants.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_constants.py new file mode 100644 index 000000000000..ab80c5e91f2a --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_constants.py @@ -0,0 +1,38 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import unicode_literals + +from uamqp import types # type: ignore + + +PROP_SEQ_NUMBER = b"x-opt-sequence-number" +PROP_OFFSET = b"x-opt-offset" +PROP_PARTITION_KEY = b"x-opt-partition-key" +PROP_PARTITION_KEY_AMQP_SYMBOL = types.AMQPSymbol(PROP_PARTITION_KEY) +PROP_TIMESTAMP = b"x-opt-enqueued-time" +PROP_LAST_ENQUEUED_SEQUENCE_NUMBER = b"last_enqueued_sequence_number" +PROP_LAST_ENQUEUED_OFFSET = b"last_enqueued_offset" +PROP_LAST_ENQUEUED_TIME_UTC = b"last_enqueued_time_utc" +PROP_RUNTIME_INFO_RETRIEVAL_TIME_UTC = b"runtime_info_retrieval_time_utc" + +EPOCH_SYMBOL = b'com.microsoft:epoch' +TIMEOUT_SYMBOL = b'com.microsoft:timeout' +RECEIVER_RUNTIME_METRIC_SYMBOL = b'com.microsoft:enable-receiver-runtime-metric' + +MAX_USER_AGENT_LENGTH = 512 +ALL_PARTITIONS = "all-partitions" +CONTAINER_PREFIX = "eventhub.pysdk-" +JWT_TOKEN_SCOPE = "https://eventhubs.azure.net//.default" +MGMT_OPERATION = b'com.microsoft:eventhub' +MGMT_PARTITION_OPERATION = b'com.microsoft:partition' +USER_AGENT_PREFIX = "azsdk-python-eventhubs" + +NO_RETRY_ERRORS = ( + b"com.microsoft:argument-out-of-range", + b"com.microsoft:entity-disabled", + b"com.microsoft:auth-failed", + b"com.microsoft:precondition-failed", + b"com.microsoft:argument-error" +) diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer.py new file mode 100644 index 000000000000..b45dafd78578 --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer.py @@ -0,0 +1,164 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import unicode_literals + +import uuid +import logging + +import uamqp # type: ignore +from uamqp import types, errors, utils # type: ignore +from uamqp import ReceiveClient, Source # type: ignore + +from .exceptions import _error_handler +from ._common import EventData, EventPosition +from ._client_base import ConsumerProducerMixin +from ._utils import create_properties, trace_link_message +from ._constants import ( + EPOCH_SYMBOL, + TIMEOUT_SYMBOL, + RECEIVER_RUNTIME_METRIC_SYMBOL, +) + + +_LOGGER = logging.getLogger(__name__) + + +class EventHubConsumer(ConsumerProducerMixin): # pylint:disable=too-many-instance-attributes + """ + A consumer responsible for reading EventData from a specific Event Hub + partition and as a member of a specific consumer group. + + A consumer may be exclusive, which asserts ownership over the partition for the consumer + group to ensure that only one consumer from that group is reading the from the partition. + These exclusive consumers are sometimes referred to as "Epoch Consumers." + + A consumer may also be non-exclusive, allowing multiple consumers from the same consumer + group to be actively reading events from the partition. These non-exclusive consumers are + sometimes referred to as "Non-Epoch Consumers." + + Please use the method `create_consumer` on `EventHubClient` for creating `EventHubConsumer`. + """ + + def __init__(self, client, source, **kwargs): + """ + Instantiate a consumer. EventHubConsumer should be instantiated by calling the `create_consumer` method + in EventHubClient. + + :param client: The parent EventHubClient. + :type client: ~azure.eventhub.client.EventHubClient + :param source: The source EventHub from which to receive events. + :type source: str + :param prefetch: The number of events to prefetch from the service + for processing. Default is 300. + :type prefetch: int + :param owner_level: The priority of the exclusive consumer. An exclusive + consumer will be created if owner_level is set. + :type owner_level: int + :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information + on the last enqueued event on its associated partition, and track that information as events are received. + When information about the partition's last enqueued event is being tracked, each event received from the + Event Hubs service will carry metadata about the partition. This results in a small amount of additional + network bandwidth consumption that is generally a favorable trade-off when considered against periodically + making requests for partition properties using the Event Hub client. + It is set to `False` by default. + :type track_last_enqueued_event_properties: bool + """ + event_position = kwargs.get("event_position", None) + prefetch = kwargs.get("prefetch", 300) + owner_level = kwargs.get("owner_level", None) + keep_alive = kwargs.get("keep_alive", None) + auto_reconnect = kwargs.get("auto_reconnect", True) + track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) + + self.running = False + self.closed = False + self.stop = False # used by event processor + + self._on_event_received = kwargs.get("on_event_received") + self._client = client + self._source = source + self._offset = event_position + self._prefetch = prefetch + self._owner_level = owner_level + self._keep_alive = keep_alive + self._auto_reconnect = auto_reconnect + self._retry_policy = errors.ErrorPolicy(max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access + self._reconnect_backoff = 1 + self._link_properties = {} + self._error = None + self._timeout = 0 + partition = self._source.split('/')[-1] + self._partition = partition + self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) + if owner_level: + self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong(int(owner_level)) + link_property_timeout_ms = (self._client._config.receive_timeout or self._timeout) * 1000 # pylint:disable=protected-access + self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong(int(link_property_timeout_ms)) + self._handler = None + self._track_last_enqueued_event_properties = track_last_enqueued_event_properties + self._last_enqueued_event_properties = {} + self._last_received_event = None + + def _create_handler(self): + source = Source(self._source) + if self._offset is not None: + source.set_filter(self._offset._selector()) # pylint:disable=protected-access + desired_capabilities = None + if self._track_last_enqueued_event_properties: + symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + + properties = create_properties(self._client._config.user_agent) # pylint:disable=protected-access + self._handler = ReceiveClient( + source, + auth=self._client._create_auth(), # pylint:disable=protected-access + debug=self._client._config.network_tracing, # pylint:disable=protected-access + prefetch=self._prefetch, + link_properties=self._link_properties, + timeout=self._timeout, + error_policy=self._retry_policy, + keep_alive_interval=self._keep_alive, + client_name=self._name, + receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + properties=properties, + desired_capabilities=desired_capabilities) + + self._handler._streaming_receive = True # pylint:disable=protected-access + self._handler._message_received_callback = self._message_received # pylint:disable=protected-access + + def _open_with_retry(self): + return self._do_retryable_operation(self._open, operation_need_param=False) + + def _message_received(self, message): + # pylint:disable=protected-access + event_data = EventData._from_message(message) + trace_link_message(event_data) + self._last_received_event = event_data + self._on_event_received(event_data) + + def receive(self): + retried_times = 0 + last_exception = None + max_retries = self._client._config.max_retries # pylint:disable=protected-access + + while retried_times <= max_retries: + try: + self._open() + self._handler.do_work() + return + except uamqp.errors.LinkDetach as ld_error: + if ld_error.condition == uamqp.constants.ErrorCodes.LinkStolen: + raise self._handle_exception(ld_error) + except Exception as exception: # pylint: disable=broad-except + if not self.running: # exit by close + return + if self._last_received_event: + self._offset = EventPosition(self._last_received_event.offset) + last_exception = self._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) + raise last_exception diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_client.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_client.py index cbf7f299ac51..366181110e91 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_client.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_client.py @@ -3,20 +3,23 @@ # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- import logging +import threading from typing import Any, Union, Dict, Tuple, TYPE_CHECKING, Callable, List -from .common import EventHubSharedKeyCredential, EventHubSASTokenCredential, EventData -from .client import EventHubClient +from ._common import EventHubSharedKeyCredential, EventHubSASTokenCredential, EventData +from ._client_base import ClientBase +from ._consumer import EventHubConsumer +from ._constants import ALL_PARTITIONS from ._eventprocessor.event_processor import EventProcessor from ._eventprocessor.partition_context import PartitionContext if TYPE_CHECKING: from azure.core.credentials import TokenCredential # type: ignore -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventHubConsumerClient(EventHubClient): +class EventHubConsumerClient(ClientBase): """ The EventHubProducerClient class defines a high level interface for receiving events from the Azure Event Hubs service. @@ -70,38 +73,37 @@ class EventHubConsumerClient(EventHubClient): def __init__(self, host, event_hub_path, credential, **kwargs): # type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None - """""" - receive_timeout = kwargs.get("receive_timeout", 3) - if receive_timeout <= 0: - raise ValueError("receive_timeout must be greater than 0.") - - kwargs['receive_timeout'] = receive_timeout self._partition_manager = kwargs.pop("partition_manager", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", 10) + network_tracing = kwargs.pop("logging_enable", False) super(EventHubConsumerClient, self).__init__( host=host, event_hub_path=event_hub_path, credential=credential, - network_tracing=kwargs.get("logging_enable"), **kwargs) - self._event_processors = dict() # type: Dict[Tuple[str, str], EventProcessor] - self._closed = False - - @classmethod - def _stop_eventprocessor(cls, event_processor): - # pylint: disable=protected-access - eventhub_client = event_processor._eventhub_client - consumer_group = event_processor._consumer_group_name - partition_id = event_processor._partition_id - with eventhub_client._lock: - event_processor.stop() - if partition_id and (consumer_group, partition_id) in eventhub_client._event_processors: - del eventhub_client._event_processors[(consumer_group, partition_id)] - elif (consumer_group, '-1') in eventhub_client._event_processors: - del eventhub_client._event_processors[(consumer_group, "-1")] + network_tracing=network_tracing, **kwargs) + self._lock = threading.Lock() + self._event_processors = {} # type: Dict[Tuple[str, str], EventProcessor] + + def _create_consumer(self, consumer_group, partition_id, event_position, **kwargs): + owner_level = kwargs.get("owner_level") + prefetch = kwargs.get("prefetch") or self._config.prefetch + track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) + on_event_received = kwargs.get("on_event_received") + + source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}".format( + self._address.hostname, self._address.path, consumer_group, partition_id) + handler = EventHubConsumer( + self, + source_url, + event_position=event_position, + owner_level=owner_level, + on_event_received=on_event_received, + prefetch=prefetch, + track_last_enqueued_event_properties=track_last_enqueued_event_properties) + return handler @classmethod def from_connection_string(cls, conn_str, **kwargs): # type: (str, Any) -> EventHubConsumerClient - """ - Create an EventHubConsumerClient from a connection string. + """Create an EventHubConsumerClient from a connection string. :param str conn_str: The connection string of an eventhub. :keyword str event_hub_path: The path of the specific Event Hub to connect the client to. @@ -124,6 +126,7 @@ def from_connection_string(cls, conn_str, **kwargs): :paramtype partition_manager: ~azure.eventhub.PartitionManager :keyword float load_balancing_interval: When load balancing kicks in, this is the interval in seconds between two load balancing. Default is 10. + :rtype: ~azure.eventhub.EventHubConsumerClient .. admonition:: Example: @@ -137,16 +140,16 @@ def from_connection_string(cls, conn_str, **kwargs): """ return super(EventHubConsumerClient, cls).from_connection_string(conn_str, **kwargs) - def receive(self, on_events, consumer_group, **kwargs): + def receive(self, on_event, consumer_group, **kwargs): # type: (Callable[[PartitionContext, List[EventData]], None], str, Any) -> None """Receive events from partition(s) optionally with load balancing and checkpointing. - :param on_events: The callback function for handling received events. The callback takes two - parameters: `partition_context` which contains partition context and `events` which are the received events. - Please define the callback like `on_event(partition_context, events)`. + :param on_event: The callback function for handling received event. The callback takes two + parameters: `partition_context` which contains partition context and `event` which is the received event. + Please define the callback like `on_event(partition_context, event)`. For detailed partition context information, please refer to :class:`PartitionContext`. - :type on_events: Callable[~azure.eventhub.PartitionContext, List[EventData]] + :type on_event: Callable[~azure.eventhub.PartitionContext, EventData] :param consumer_group: Receive events from the event hub for this consumer group :type consumer_group: str :keyword str partition_id: Receive from this partition only if it's not None. @@ -193,31 +196,37 @@ def receive(self, on_events, consumer_group, **kwargs): :caption: Receive events from the EventHub. """ partition_id = kwargs.get("partition_id") - with self._lock: error = None - if (consumer_group, '-1') in self._event_processors: - error = ValueError("This consumer client is already receiving events from all partitions for" - " consumer group {}. ".format(consumer_group)) + if (consumer_group, ALL_PARTITIONS) in self._event_processors: + error = ("This consumer client is already receiving events " + "from all partitions for consumer group {}.".format(consumer_group)) elif partition_id is None and any(x[0] == consumer_group for x in self._event_processors): - error = ValueError("This consumer client is already receiving events for consumer group {}. " - .format(consumer_group)) + error = ("This consumer client is already receiving events " + "for consumer group {}.".format(consumer_group)) elif (consumer_group, partition_id) in self._event_processors: - error = ValueError("This consumer is already receiving events from partition {} for consumer group {}. " - .format(partition_id, consumer_group)) + error = ("This consumer client is already receiving events " + "from partition {} for consumer group {}. ".format(partition_id, consumer_group)) if error: - log.warning(error) - raise error + _LOGGER.warning(error) + raise ValueError(error) event_processor = EventProcessor( - self, consumer_group, on_events, + self, consumer_group, on_event, partition_manager=self._partition_manager, polling_interval=self._load_balancing_interval, **kwargs ) - self._event_processors[(consumer_group, partition_id or "-1")] = event_processor - - event_processor.start() + self._event_processors[(consumer_group, partition_id or ALL_PARTITIONS)] = event_processor + try: + event_processor.start() + finally: + event_processor.stop() + with self._lock: + try: + del self._event_processors[(consumer_group, partition_id or ALL_PARTITIONS)] + except KeyError: + pass def close(self): # type: () -> None @@ -236,7 +245,7 @@ def close(self): """ with self._lock: - for _ in range(len(self._event_processors)): - _, ep = self._event_processors.popitem() - ep.stop() - super(EventHubConsumerClient, self).close() + for processor in self._event_processors.values(): + processor.stop() + self._event_processors = {} + super(EventHubConsumerClient, self).close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py deleted file mode 100644 index 5da486ea1b11..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_consumer_producer_mixin.py +++ /dev/null @@ -1,101 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals - -import logging -import time - -from uamqp import errors, constants, compat # type: ignore -from .error import EventHubError, _handle_exception - -log = logging.getLogger(__name__) - - -class ConsumerProducerMixin(object): - def __init__(self): - self._client = None - self._handler = None - self._name = None - self._running = False - self._closed = False - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def _check_closed(self): - if self._closed: - raise EventHubError("{} has been closed. Please create a new one to handle event data.".format(self._name)) - - def _create_handler(self): - pass - - def _open(self): - """Open the EventHubConsumer/EventHubProducer using the supplied connection. - - """ - # pylint: disable=protected-access - if not self._running: - if self._handler: - self._handler.close() - self._create_handler() - self._handler.open(connection=self._client._conn_manager.get_connection( # pylint: disable=protected-access - self._client._address.hostname, - self._client._create_auth() - )) - while not self._handler.client_ready(): - time.sleep(0.05) - self._max_message_size_on_link = self._handler.message_handler._link.peer_max_message_size \ - or constants.MAX_MESSAGE_LENGTH_BYTES # pylint: disable=protected-access - self._running = True - - def _close_handler(self): - if self._handler: - self._handler.close() # close the link (sharing connection) or connection (not sharing) - self._running = False - - def _close_connection(self): - self._close_handler() - self._client._conn_manager.reset_connection_if_broken() # pylint: disable=protected-access - - def _handle_exception(self, exception): - if not self._running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return _handle_exception(exception, self) - - def _do_retryable_operation(self, operation, timeout=100000, **kwargs): - # pylint:disable=protected-access - timeout_time = time.time() + ( - timeout if timeout else 100000) # timeout equals to 0 means no timeout, set the value to be a large number. - retried_times = 0 - last_exception = kwargs.pop('last_exception', None) - operation_need_param = kwargs.pop('operation_need_param', True) - - while retried_times <= self._client._config.max_retries: # pylint: disable=protected-access - try: - if operation_need_param: - return operation(timeout_time=timeout_time, last_exception=last_exception, **kwargs) - return operation() - except Exception as exception: # pylint:disable=broad-except - last_exception = self._handle_exception(exception) - self._client._try_delay(retried_times=retried_times, last_exception=last_exception, - timeout_time=timeout_time, entity_name=self._name) - retried_times += 1 - - log.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) - raise last_exception - - def close(self): - # type:() -> None - """ - Close down the handler. If the handler has already closed, - this will be a no op. - """ - if self._handler: - self._handler.close() # this will close link if sharing connection. Otherwise close connection - self._running = False - self._closed = True diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/_eventprocessor_mixin.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/_eventprocessor_mixin.py new file mode 100644 index 000000000000..d18e6c11f8f5 --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/_eventprocessor_mixin.py @@ -0,0 +1,53 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from contextlib import contextmanager + +from azure.core.tracing import SpanKind # type: ignore +from azure.core.settings import settings # type: ignore + +from azure.eventhub import EventPosition + + +class EventProcessorMixin(object): + + def get_init_event_position(self, partition_id, checkpoint): + checkpoint_offset = checkpoint.get("offset") if checkpoint else None + if checkpoint_offset: + initial_event_position = EventPosition(checkpoint_offset) + elif isinstance(self._initial_event_position, EventPosition): + initial_event_position = self._initial_event_position + elif isinstance(self._initial_event_position, dict): + initial_event_position = self._initial_event_position.get(partition_id, EventPosition("-1")) + else: + initial_event_position = EventPosition(self._initial_event_position) + return initial_event_position + + def create_consumer(self, partition_id, initial_event_position, on_event_received): + consumer = self._eventhub_client._create_consumer( # pylint: disable=protected-access + self._consumer_group_name, + partition_id, + initial_event_position, + on_event_received=on_event_received, + owner_level=self._owner_level, + track_last_enqueued_event_properties=self._track_last_enqueued_event_properties, + prefetch=self._prefetch, + ) + return consumer + + @contextmanager + def _context(self, event): + # Tracing + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + if span_impl_type is None: + yield + else: + child = span_impl_type(name="Azure.EventHubs.process") + self._eventhub_client._add_span_request_attributes(child) # pylint: disable=protected-access + child.kind = SpanKind.SERVER + + event._trace_link_message(child) # pylint: disable=protected-access + with child: + yield diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/common.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/common.py index 90a53a20cd53..6bf25ea1a96b 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/common.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/common.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- from enum import Enum diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/event_processor.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/event_processor.py index 86dbf7efca78..47a6708c860f 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/event_processor.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/event_processor.py @@ -1,29 +1,23 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- - -from contextlib import contextmanager -from typing import Dict, Type +# -------------------------------------------------------------------------------------------- import uuid import logging import time import threading - -from uamqp.compat import queue # type: ignore - -from azure.core.tracing import SpanKind # type: ignore -from azure.core.settings import settings # type: ignore +from functools import partial from azure.eventhub import EventPosition from .partition_context import PartitionContext from .ownership_manager import OwnershipManager from .common import CloseReason +from. _eventprocessor_mixin import EventProcessorMixin -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventProcessor(object): # pylint:disable=too-many-instance-attributes +class EventProcessor(EventProcessorMixin): # pylint:disable=too-many-instance-attributes """ An EventProcessor constantly receives events from one or multiple partitions of the Event Hub in the context of a given consumer group. @@ -56,11 +50,7 @@ def __init__(self, eventhub_client, consumer_group_name, on_event, **kwargs): self._running = False self._lock = threading.RLock() - # Each partition consumer is working in its own thread - self._working_threads = {} # type: Dict[str, threading.Thread] - self._threads_stop_flags = {} # type: Dict[str, bool] - - self._callback_queue = queue.Queue(maxsize=100) # Right now the limitation of receiving speed is ~10k + self._consumers = {} def __repr__(self): return 'EventProcessor: id {}'.format(self._id) @@ -68,66 +58,41 @@ def __repr__(self): def _cancel_tasks_for_partitions(self, to_cancel_partitions): with self._lock: for partition_id in to_cancel_partitions: - if partition_id in self._working_threads: - self._threads_stop_flags[partition_id] = True # the cancellation token sent to thread to stop + if partition_id in self._consumers: + self._consumers[partition_id].stop = True if to_cancel_partitions: - log.info("EventProcesor %r has cancelled partitions %r", self._id, to_cancel_partitions) + _LOGGER.info("EventProcesor %r has cancelled partitions %r", self._id, to_cancel_partitions) def _create_tasks_for_claimed_ownership(self, claimed_partitions, checkpoints=None): with self._lock: for partition_id in claimed_partitions: - if partition_id not in self._working_threads or not self._working_threads[partition_id].is_alive(): + if partition_id not in self._consumers: + + if partition_id not in self._partition_contexts: + partition_context = PartitionContext( + self._namespace, + self._eventhub_name, + self._consumer_group_name, + partition_id, + self._id, + self._partition_manager + ) + self._partition_contexts[partition_id] = partition_context + checkpoint = checkpoints.get(partition_id) if checkpoints else None - self._working_threads[partition_id] = threading.Thread(target=self._receive, - args=(partition_id, checkpoint)) - self._working_threads[partition_id].daemon = True - self._threads_stop_flags[partition_id] = False - self._working_threads[partition_id].start() - log.info("Working thread started, ownership %r, checkpoint %r", partition_id, checkpoint) - - @contextmanager - def _context(self, events): - # Tracing - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is None: - yield - else: - child = span_impl_type(name="Azure.EventHubs.process") - self._eventhub_client._add_span_request_attributes(child) # pylint: disable=protected-access - child.kind = SpanKind.SERVER - - for event in events: - event._trace_link_message(child) # pylint: disable=protected-access - with child: - yield - - def _process_error(self, partition_context, err): - log.warning( - "PartitionProcessor of EventProcessor instance %r of eventhub %r partition %r consumer group %r" - " has met an error. The exception is %r.", - partition_context.owner_id, - partition_context.eventhub_name, - partition_context.partition_id, - partition_context.consumer_group_name, - err - ) - if self._error_handler: - self._callback_queue.put((self._error_handler, partition_context, err), block=True) + initial_event_position = self.get_init_event_position(partition_id, checkpoint) + event_received_callback = partial(self._on_event_received, partition_context) - def _process_close(self, partition_context, reason): - if self._partition_close_handler: - log.info( - "PartitionProcessor of EventProcessor instance %r of eventhub %r partition %r consumer group %r" - " is being closed. Reason is: %r", - partition_context.owner_id, - partition_context.eventhub_name, - partition_context.partition_id, - partition_context.consumer_group_name, - reason - ) - if self._partition_close_handler: - self._callback_queue.put((self._partition_close_handler, partition_context, reason), block=True) + self._consumers[partition_id] = self.create_consumer(partition_id, + initial_event_position, + event_received_callback) + + if self._partition_initialize_handler: + self._handle_callback( + [self._partition_initialize_handler, + self._partition_contexts[partition_id]] + ) def _handle_callback(self, callback_and_args): callback = callback_and_args[0] @@ -135,10 +100,10 @@ def _handle_callback(self, callback_and_args): callback(*callback_and_args[1:]) except Exception as exp: # pylint:disable=broad-except partition_context = callback_and_args[1] - if callback != self._error_handler: - self._process_error(partition_context, exp) + if self._error_handler and callback != self._error_handler: + self._handle_callback([self._error_handler, partition_context, exp]) else: - log.warning( + _LOGGER.warning( "EventProcessor instance %r of eventhub %r partition %r consumer group %r" " has another error during running process_error(). The exception is %r.", partition_context.owner_id, @@ -148,74 +113,11 @@ def _handle_callback(self, callback_and_args): exp ) - def _receive(self, partition_id, checkpoint=None): # pylint: disable=too-many-statements - try: # pylint:disable =too-many-nested-blocks - log.info("start ownership %r, checkpoint %r", partition_id, checkpoint) - namespace = self._namespace - eventhub_name = self._eventhub_name - consumer_group_name = self._consumer_group_name - owner_id = self._id - checkpoint_offset = checkpoint.get("offset") if checkpoint else None - if checkpoint_offset: - initial_event_position = EventPosition(checkpoint_offset) - elif isinstance(self._initial_event_position, EventPosition): - initial_event_position = self._initial_event_position - elif isinstance(self._initial_event_position, dict): - initial_event_position = self._initial_event_position.get(partition_id, EventPosition("-1")) - else: - initial_event_position = EventPosition(self._initial_event_position) - if partition_id in self._partition_contexts: - partition_context = self._partition_contexts[partition_id] - else: - partition_context = PartitionContext( - namespace, - eventhub_name, - consumer_group_name, - partition_id, - owner_id, - self._partition_manager - ) - self._partition_contexts[partition_id] = partition_context - - partition_consumer = self._eventhub_client._create_consumer( # pylint: disable=protected-access - consumer_group_name, - partition_id, - initial_event_position, - owner_level=self._owner_level, - track_last_enqueued_event_properties=self._track_last_enqueued_event_properties, - prefetch=self._prefetch, - ) + def _on_event_received(self, partition_context, event): + with self._context(event): + self._handle_callback([self._event_handler, partition_context, event]) - try: - if self._partition_initialize_handler: - self._callback_queue.put((self._partition_initialize_handler, partition_context), block=True) - while self._threads_stop_flags[partition_id] is False: - try: - events = partition_consumer.receive() - if events: - if self._track_last_enqueued_event_properties: - self._last_enqueued_event_properties[partition_id] = \ - partition_consumer.last_enqueued_event_properties - with self._context(events): - self._callback_queue.put((self._event_handler, partition_context, events), block=True) - except Exception as error: # pylint:disable=broad-except - self._process_error(partition_context, error) - break - # Go to finally to stop this partition processor. - # Later an EventProcessor(this one or another one) will pick up this partition again. - finally: - partition_consumer.close() - if self._running: - # Event processor is running but the partition consumer has been stopped. - self._process_close(partition_context, CloseReason.OWNERSHIP_LOST) - else: - self._process_close(partition_context, CloseReason.SHUTDOWN) - finally: - with self._lock: - del self._working_threads[partition_id] - self._threads_stop_flags[partition_id] = True - - def _start(self): + def _load_balancing(self): """Start the EventProcessor. The EventProcessor will try to claim and balance partition ownership with other `EventProcessor` @@ -231,17 +133,17 @@ def _start(self): checkpoints = ownership_manager.get_checkpoints() if self._partition_manager else None claimed_partition_ids = ownership_manager.claim_ownership() if claimed_partition_ids: - to_cancel_list = set(self._working_threads.keys()) - set(claimed_partition_ids) + to_cancel_list = set(self._consumers.keys()) - set(claimed_partition_ids) self._create_tasks_for_claimed_ownership(claimed_partition_ids, checkpoints) else: - log.info("EventProcessor %r hasn't claimed an ownership. It keeps claiming.", self._id) - to_cancel_list = set(self._working_threads.keys()) + _LOGGER.info("EventProcessor %r hasn't claimed an ownership. It keeps claiming.", self._id) + to_cancel_list = set(self._consumers.keys()) if to_cancel_list: self._cancel_tasks_for_partitions(to_cancel_list) except Exception as err: # pylint:disable=broad-except - log.warning("An exception (%r) occurred during balancing and claiming ownership for " - "eventhub %r consumer group %r. Retrying after %r seconds", - err, self._eventhub_name, self._consumer_group_name, self._polling_interval) + _LOGGER.warning("An exception (%r) occurred during balancing and claiming ownership for " + "eventhub %r consumer group %r. Retrying after %r seconds", + err, self._eventhub_name, self._consumer_group_name, self._polling_interval) # ownership_manager.get_checkpoints() and ownership_manager.claim_ownership() may raise exceptions # when there are load balancing and/or checkpointing (partition_manager isn't None). # They're swallowed here to retry every self._polling_interval seconds. @@ -252,29 +154,64 @@ def _start(self): time.sleep(self._polling_interval) def _get_last_enqueued_event_properties(self, partition_id): - if partition_id in self._working_threads and partition_id in self._last_enqueued_event_properties: + if partition_id in self._consumers and partition_id in self._last_enqueued_event_properties: return self._last_enqueued_event_properties[partition_id] raise ValueError("You're not receiving events from partition {}".format(partition_id)) + def _close_consumer(self, partition_id, consumer, reason): + consumer.close() + with self._lock: + del self._consumers[partition_id] + + _LOGGER.info( + "PartitionProcessor of EventProcessor instance %r of eventhub %r partition %r consumer group %r" + " is being closed. Reason is: %r", + self._partition_contexts[partition_id].owner_id, + self._partition_contexts[partition_id].eventhub_name, + self._partition_contexts[partition_id].partition_id, + self._partition_contexts[partition_id].consumer_group_name, + reason + ) + + if self._partition_close_handler: + self._handle_callback([self._partition_close_handler, self._partition_contexts[partition_id], reason]) + def start(self): - if not self._running: - log.info("EventProcessor %r is being started", self._id) - self._running = True - thread = threading.Thread(target=self._start) - thread.daemon = True - thread.start() + if self._running: + _LOGGER.info("EventProcessor %r has already started.", self._id) + return + + _LOGGER.info("EventProcessor %r is being started", self._id) + self._running = True + thread = threading.Thread(target=self._load_balancing) + thread.daemon = True + thread.start() + + while self._running: + for partition_id, consumer in list(self._consumers.items()): + if consumer.stop: + self._close_consumer(partition_id, consumer, CloseReason.OWNERSHIP_LOST) + continue - while self._running or self._callback_queue.qsize() or self._working_threads: try: - callback_and_args = self._callback_queue.get(block=False) - self._handle_callback(callback_and_args) - self._callback_queue.task_done() - except queue.Empty: - # ignore queue empty exception - time.sleep(0.01) # sleep a short while to avoid this thread dominating CPU. + consumer.receive() + except Exception as error: # pylint:disable=broad-except + _LOGGER.warning( + "PartitionProcessor of EventProcessor instance %r of eventhub %r partition %r consumer group %r" + " has met an error. The exception is %r.", + self._partition_contexts[partition_id].owner_id, + self._partition_contexts[partition_id].eventhub_name, + self._partition_contexts[partition_id].partition_id, + self._partition_contexts[partition_id].consumer_group_name, + error + ) + if self._error_handler: + self._handle_callback([self._error_handler, self._partition_contexts[partition_id], error]) + self._close_consumer(partition_id, consumer, CloseReason.OWNERSHIP_LOST) - else: - log.info("EventProcessor %r has already started.", self._id) + with self._lock: + for partition_id, consumer in list(self._consumers.items()): + self._close_consumer(partition_id, consumer, CloseReason.SHUTDOWN) def stop(self): """Stop the EventProcessor. @@ -289,18 +226,8 @@ def stop(self): """ if not self._running: - log.info("EventProcessor %r has already been stopped.", self._id) + _LOGGER.info("EventProcessor %r has already been stopped.", self._id) return self._running = False - - with self._lock: - to_join_threads = [x for x in self._working_threads.values()] - self._cancel_tasks_for_partitions(list(self._working_threads.keys())) - - for thread in to_join_threads: - thread.join() - - # self._threads_stop_flags.clear() - - log.info("EventProcessor %r has been stopped.", self._id) + _LOGGER.info("EventProcessor %r has been stopped.", self._id) diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/local_partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/local_partition_manager.py index 1933d58e9f45..702e011ae7aa 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/local_partition_manager.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/local_partition_manager.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- from .sqlite3_partition_manager import Sqlite3PartitionManager diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/ownership_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/ownership_manager.py index 7b1bc9c7a46e..487f0b6aae8f 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/ownership_manager.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/ownership_manager.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- import time import random diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_context.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_context.py index 7dac5a2e713f..fc2c3cfa50c0 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_context.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_context.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- import logging from .partition_manager import PartitionManager diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_manager.py index db8545bd84da..5e8ee4425972 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_manager.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/partition_manager.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- from typing import Iterable, Dict, Any from abc import abstractmethod diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/sqlite3_partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/sqlite3_partition_manager.py index 802f5954bedf..103af52cee27 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/sqlite3_partition_manager.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_eventprocessor/sqlite3_partition_manager.py @@ -1,7 +1,7 @@ # -------------------------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. -# ----------------------------------------------------------------------------------- +# -------------------------------------------------------------------------------------------- import time import threading @@ -10,7 +10,7 @@ import logging from .partition_manager import PartitionManager -logger = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) def _check_table_name(table_name): @@ -109,9 +109,9 @@ def claim_ownership(self, ownership_list): + ") values ("+",".join(["?"] * len(self.ownership_fields)) + ")" cursor.execute(sql, tuple(p.get(field) for field in self.ownership_fields)) except sqlite3.OperationalError as op_err: - logger.info("EventProcessor %r failed to claim partition %r " - "because it was claimed by another EventProcessor at the same time. " - "The Sqlite3 exception is %r", p["owner_id"], p["partition_id"], op_err) + _LOGGER.info("EventProcessor %r failed to claim partition %r " + "because it was claimed by another EventProcessor at the same time. " + "The Sqlite3 exception is %r", p["owner_id"], p["partition_id"], op_err) continue else: result.append(p) @@ -128,10 +128,10 @@ def claim_ownership(self, ownership_list): + tuple(p.get(field) for field in self.primary_keys)) result.append(p) else: - logger.info("EventProcessor %r failed to claim partition %r " - "because it was claimed by another EventProcessor at the same time", - p["owner_id"], - p["partition_id"]) + _LOGGER.info("EventProcessor %r failed to claim partition %r " + "because it was claimed by another EventProcessor at the same time", + p["owner_id"], + p["partition_id"]) self.conn.commit() return result finally: diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer.py similarity index 66% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/_producer.py index 189838013492..b4f9d1498567 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/producer.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer.py @@ -7,6 +7,7 @@ import uuid import logging import time +import threading from typing import Iterable, Union, Type from uamqp import types, constants, errors # type: ignore @@ -15,30 +16,25 @@ from azure.core.tracing import SpanKind, AbstractSpan # type: ignore from azure.core.settings import settings # type: ignore -from .common import EventData, EventDataBatch -from .error import _error_handler, OperationTimeoutError, EventDataError -from ._consumer_producer_mixin import ConsumerProducerMixin +from .exceptions import _error_handler, OperationTimeoutError +from ._common import EventData, EventDataBatch +from ._client_base import ConsumerProducerMixin +from ._utils import create_properties, set_message_partition_key, trace_message +from ._constants import TIMEOUT_SYMBOL -log = logging.getLogger(__name__) - - -def _error(outcome, condition): - if outcome != constants.MessageSendResult.Ok: - raise condition +_LOGGER = logging.getLogger(__name__) def _set_partition_key(event_datas, partition_key): - ed_iter = iter(event_datas) - for ed in ed_iter: - ed._set_partition_key(partition_key) # pylint:disable=protected-access + for ed in iter(event_datas): + set_message_partition_key(ed.message, partition_key) yield ed def _set_trace_message(event_datas, parent_span=None): - ed_iter = iter(event_datas) - for ed in ed_iter: - ed._trace_message(parent_span) # pylint:disable=protected-access + for ed in iter(event_datas): + trace_message(ed.message, parent_span) yield ed @@ -51,7 +47,6 @@ class EventHubProducer(ConsumerProducerMixin): # pylint:disable=too-many-instan Please use the method `create_producer` on `EventHubClient` for creating `EventHubProducer`. """ - _timeout_symbol = b'com.microsoft:timeout' def __init__(self, client, target, **kwargs): """ @@ -80,7 +75,9 @@ def __init__(self, client, target, **kwargs): keep_alive = kwargs.get("keep_alive", None) auto_reconnect = kwargs.get("auto_reconnect", True) - super(EventHubProducer, self).__init__() + self.running = False + self.closed = False + self._max_message_size_on_link = None self._client = client self._target = target @@ -99,42 +96,48 @@ def __init__(self, client, target, **kwargs): self._handler = None self._outcome = None self._condition = None - self._link_properties = {types.AMQPSymbol(self._timeout_symbol): types.AMQPLong(int(self._timeout * 1000))} + self._lock = threading.Lock() + self._link_properties = {types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000))} def _create_handler(self): self._handler = SendClient( self._target, auth=self._client._create_auth(), # pylint:disable=protected-access debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout, + msg_timeout=self._timeout * 1000, error_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, - properties=self._client._create_properties(self._client._config.user_agent)) # pylint: disable=protected-access + properties=create_properties(self._client._config.user_agent)) # pylint: disable=protected-access def _open_with_retry(self): return self._do_retryable_operation(self._open, operation_need_param=False) + def _set_msg_timeout(self, timeout_time, last_exception): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + _LOGGER.info("%r send operation timed out. (%r)", self._name, error) + raise error + self._handler._msg_timeout = remaining_time * 1000 # pylint: disable=protected-access + def _send_event_data(self, timeout_time=None, last_exception=None): if self._unsent_events: self._open() - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("send operation timed out") - log.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # pylint: disable=protected-access + self._set_msg_timeout(timeout_time, last_exception) self._handler.queue_message(*self._unsent_events) self._handler.wait() self._unsent_events = self._handler.pending_messages if self._outcome != constants.MessageSendResult.Ok: if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("send operation timed out") - _error(self._outcome, self._condition) + self._condition = OperationTimeoutError("Send operation timed out") + raise self._condition def _send_event_data_with_retry(self, timeout=None): return self._do_retryable_operation(self._send_event_data, timeout=timeout) @@ -151,29 +154,24 @@ def _on_outcome(self, outcome, condition): self._outcome = outcome self._condition = condition - def create_batch(self, max_size=None, partition_key=None): - # type:(int, str) -> EventDataBatch - """ - Create an EventDataBatch object with max size being max_size. - The max_size should be no greater than the max allowed message size defined by the service side. - - :param max_size: The maximum size of bytes data that an EventDataBatch object can hold. - :type max_size: int - :param partition_key: With the given partition_key, event data will land to - a particular partition of the Event Hub decided by the service. - :type partition_key: str - :return: an EventDataBatch instance - :rtype: ~azure.eventhub.EventDataBatch - """ - - if not self._max_message_size_on_link: - self._open_with_retry() - - if max_size and max_size > self._max_message_size_on_link: - raise ValueError('Max message size: {} is too large, acceptable max batch size is: {} bytes.' - .format(max_size, self._max_message_size_on_link)) - - return EventDataBatch(max_size=(max_size or self._max_message_size_on_link), partition_key=partition_key) + def _wrap_eventdata(self, event_data, span, partition_key): + if isinstance(event_data, EventData): + if partition_key: + set_message_partition_key(event_data.message, partition_key) + wrapper_event_data = event_data + trace_message(wrapper_event_data.message, span) + else: + if isinstance(event_data, EventDataBatch): # The partition_key in the param will be omitted. + if partition_key and partition_key != event_data._partition_key: # pylint: disable=protected-access + raise ValueError('The partition_key does not match the one of the EventDataBatch') + wrapper_event_data = event_data # type:ignore + else: + if partition_key: + event_data = _set_partition_key(event_data, partition_key) + event_data = _set_trace_message(event_data) + wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # pylint: disable=protected-access + wrapper_event_data.message.on_send_complete = self._on_outcome + return wrapper_event_data def send(self, event_data, partition_key=None, timeout=None): # type:(Union[EventData, EventDataBatch, Iterable[EventData]], Union[str, bytes], float) -> None @@ -197,42 +195,28 @@ def send(self, event_data, partition_key=None, timeout=None): :rtype: None """ # Tracing code - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - child = None - if span_impl_type is not None: - child = span_impl_type(name="Azure.EventHubs.send") - child.kind = SpanKind.CLIENT # Should be PRODUCER - - self._check_closed() - if isinstance(event_data, EventData): - if partition_key: - event_data._set_partition_key(partition_key) # pylint: disable=protected-access - wrapper_event_data = event_data - wrapper_event_data._trace_message(child) # pylint: disable=protected-access - else: - if isinstance(event_data, EventDataBatch): # The partition_key in the param will be omitted. - if partition_key and partition_key != event_data._partition_key: # pylint: disable=protected-access - raise EventDataError('The partition_key does not match the one of the EventDataBatch') - wrapper_event_data = event_data # type:ignore + with self._lock: + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + child = None + if span_impl_type is not None: + child = span_impl_type(name="Azure.EventHubs.send") + child.kind = SpanKind.CLIENT # Should be PRODUCER + self._check_closed() + wrapper_event_data = self._wrap_eventdata(event_data, child, partition_key) + self._unsent_events = [wrapper_event_data.message] + + if span_impl_type is not None and child is not None: + with child: + self._client._add_span_request_attributes(child) # pylint: disable=protected-access + self._send_event_data_with_retry(timeout=timeout) else: - if partition_key: - event_data = _set_partition_key(event_data, partition_key) - event_data = _set_trace_message(event_data) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # pylint: disable=protected-access - wrapper_event_data.message.on_send_complete = self._on_outcome - self._unsent_events = [wrapper_event_data.message] - - if span_impl_type is not None and child is not None: - with child: - self._client._add_span_request_attributes(child) # pylint: disable=protected-access self._send_event_data_with_retry(timeout=timeout) - else: - self._send_event_data_with_retry(timeout=timeout) - def close(self): # pylint:disable=useless-super-delegation + def close(self): # type:() -> None """ Close down the handler. If the handler has already closed, this will be a no op. """ - super(EventHubProducer, self).close() + with self._lock: + super(EventHubProducer, self).close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer_client.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer_client.py index 02337bc21087..0f2ce2182921 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer_client.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_producer_client.py @@ -7,18 +7,25 @@ from typing import Any, Union, TYPE_CHECKING, Iterable, List from uamqp import constants # type:ignore -from .client import EventHubClient -from .producer import EventHubProducer -from .common import EventData, \ - EventHubSharedKeyCredential, EventHubSASTokenCredential, EventDataBatch + +from .exceptions import ConnectError, EventHubError +from ._client_base import ClientBase +from ._producer import EventHubProducer +from ._constants import ALL_PARTITIONS +from ._common import ( + EventData, + EventHubSharedKeyCredential, + EventHubSASTokenCredential, + EventDataBatch +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential # type: ignore -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventHubProducerClient(EventHubClient): +class EventHubProducerClient(ClientBase): """ The EventHubProducerClient class defines a high level interface for sending events to the Azure Event Hubs service. @@ -58,26 +65,51 @@ def __init__(self, host, event_hub_path, credential, **kwargs): super(EventHubProducerClient, self).__init__( host=host, event_hub_path=event_hub_path, credential=credential, network_tracing=kwargs.get("logging_enable"), **kwargs) - self._producers = [] # type: List[EventHubProducer] - self._client_lock = threading.Lock() - self._producers_locks = [] # type: List[threading.Lock] + self._producers = {ALL_PARTITIONS: self._create_producer()} # type: Dict[str, EventHubProducer] self._max_message_size_on_link = 0 + self._partition_ids = None + self._lock = threading.Lock() + + def _get_partitions(self): + if not self._partition_ids: + self._partition_ids = self.get_partition_ids() + for p_id in self._partition_ids: + self._producers[p_id] = None + + def _get_max_mesage_size(self): + # pylint: disable=protected-access + with self._lock: + if not self._max_message_size_on_link: + self._producers[ALL_PARTITIONS]._open_with_retry() + self._max_message_size_on_link = \ + self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size \ + or constants.MAX_MESSAGE_LENGTH_BYTES + + def _start_producer(self, partition_id, send_timeout): + with self._lock: + self._get_partitions() + if partition_id not in self._partition_ids and partition_id != ALL_PARTITIONS: + raise ConnectError("Invalid partition {} for the event hub {}".format(partition_id, self.eh_name)) + + if not self._producers[partition_id] or self._producers[partition_id].closed: + self._producers[partition_id] = self._create_producer( + partition_id=partition_id, + send_timeout=send_timeout + ) - def _init_locks_for_producers(self): - if not self._producers: - with self._client_lock: - if not self._producers: - num_of_producers = len(self.get_partition_ids()) + 1 - self._producers = [None] * num_of_producers - for _ in range(num_of_producers): - self._producers_locks.append(threading.Lock()) + def _create_producer(self, partition_id=None, send_timeout=None): + target = "amqps://{}{}".format(self._address.hostname, self._address.path) + send_timeout = self._config.send_timeout if send_timeout is None else send_timeout + + handler = EventHubProducer( + self, target, partition=partition_id, send_timeout=send_timeout) + return handler @classmethod def from_connection_string(cls, conn_str, **kwargs): # type: (str, Any) -> EventHubProducerClient """ Create an EventHubProducerClient from a connection string. - :param str conn_str: The connection string of an eventhub. :keyword str event_hub_path: The path of the specific Event Hub to connect the client to. :keyword bool network_tracing: Whether to output network trace logs to the logger. Default is `False`. @@ -92,6 +124,7 @@ def from_connection_string(cls, conn_str, **kwargs): :keyword transport_type: The type of transport protocol that will be used for communicating with the Event Hubs service. Default is `TransportType.Amqp`. :paramtype transport_type: ~azure.eventhub.TransportType + :rtype: ~azure.eventhub.EventHubProducerClient .. admonition:: Example: @@ -135,19 +168,13 @@ def send(self, event_data, **kwargs): :caption: Sends event data """ - partition_id = kwargs.pop("partition_id", None) - - self._init_locks_for_producers() - - producer_index = int(partition_id) if partition_id is not None else -1 - if self._producers[producer_index] is None or\ - self._producers[producer_index]._closed: # pylint:disable=protected-access - with self._producers_locks[producer_index]: - if self._producers[producer_index] is None: - self._producers[producer_index] = self._create_producer(partition_id=partition_id) - - with self._producers_locks[producer_index]: - self._producers[producer_index].send(event_data, **kwargs) + partition_id = kwargs.pop("partition_id", None) or ALL_PARTITIONS + send_timeout = kwargs.pop("timeout", None) + try: + self._producers[partition_id].send(event_data, **kwargs) + except (KeyError, AttributeError, EventHubError): + self._start_producer(partition_id, send_timeout) + self._producers[partition_id].send(event_data, **kwargs) def create_batch(self, max_size=None): # type:(int) -> EventDataBatch @@ -170,15 +197,7 @@ def create_batch(self, max_size=None): """ # pylint: disable=protected-access if not self._max_message_size_on_link: - self._init_locks_for_producers() - with self._producers_locks[-1]: - if self._producers[-1] is None: - self._producers[-1] = self._create_producer(partition_id=None) - self._producers[-1]._open_with_retry() # pylint: disable=protected-access - with self._client_lock: - self._max_message_size_on_link =\ - self._producers[-1]._handler.message_handler._link.peer_max_message_size \ - or constants.MAX_MESSAGE_LENGTH_BYTES + self._get_max_mesage_size() if max_size and max_size > self._max_message_size_on_link: raise ValueError('Max message size: {} is too large, acceptable max batch size is: {} bytes.' @@ -203,7 +222,9 @@ def close(self): :caption: Close down the client. """ - for p in self._producers: - if p: - p.close() + with self._lock: + for producer in self._producers.values(): + if producer: + producer.close() + self._producers = {} self._conn_manager.close_connection() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/_utils.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/_utils.py new file mode 100644 index 000000000000..52b6f71d3169 --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/_utils.py @@ -0,0 +1,142 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +from __future__ import unicode_literals + +import sys +import platform +import datetime + +from uamqp import types # type: ignore +from uamqp.message import MessageHeader # type: ignore + +from azure.core.settings import settings # type: ignore + +from azure.eventhub import __version__ +from ._constants import ( + PROP_PARTITION_KEY_AMQP_SYMBOL, + MAX_USER_AGENT_LENGTH, + USER_AGENT_PREFIX +) + + +class UTC(datetime.tzinfo): + """Time Zone info for handling UTC""" + + def utcoffset(self, dt): + """UTF offset for UTC is 0.""" + return datetime.timedelta(0) + + def tzname(self, dt): + """Timestamp representation.""" + return "Z" + + def dst(self, dt): + """No daylight saving for UTC.""" + return datetime.timedelta(hours=1) + + +try: + from datetime import timezone + TZ_UTC = timezone.utc # type: ignore +except ImportError: + TZ_UTC = UTC() # type: ignore + + +def utc_from_timestamp(timestamp): + return datetime.datetime.fromtimestamp(timestamp, tz=TZ_UTC) + + +def create_properties(user_agent=None): + """ + Format the properties with which to instantiate the connection. + This acts like a user agent over HTTP. + + :rtype: dict + """ + properties = {} + properties[types.AMQPSymbol("product")] = USER_AGENT_PREFIX + properties[types.AMQPSymbol("version")] = __version__ + framework = "Python {}.{}.{}, {}".format( + sys.version_info[0], sys.version_info[1], sys.version_info[2], platform.python_implementation() + ) + properties[types.AMQPSymbol("framework")] = framework + platform_str = platform.platform() + properties[types.AMQPSymbol("platform")] = platform_str + + final_user_agent = '{}/{} ({}, {})'.format(USER_AGENT_PREFIX, __version__, framework, platform_str) + if user_agent: + final_user_agent = '{}, {}'.format(final_user_agent, user_agent) + + if len(final_user_agent) > MAX_USER_AGENT_LENGTH: + raise ValueError("The user-agent string cannot be more than {} in length." + "Current user_agent string is: {} with length: {}".format( + MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent))) + properties[types.AMQPSymbol("user-agent")] = final_user_agent + return properties + + +def parse_sas_token(sas_token): + """Parse a SAS token into its components. + + :param sas_token: The SAS token. + :type sas_token: str + :rtype: dict[str, str] + """ + sas_data = {} + token = sas_token.partition(' ')[2] + fields = token.split('&') + for field in fields: + key, value = field.split('=', 1) + sas_data[key.lower()] = value + return sas_data + + +def set_message_partition_key(message, partition_key): + """Set the partition key as an annotation on a uamqp message. + + :param ~uamqp.Message message: The message to update. + :param bytes partition_key: The partition key value. + :rtype: None + """ + if partition_key: + annotations = message.annotations + if annotations is None: + annotations = dict() + annotations[PROP_PARTITION_KEY_AMQP_SYMBOL] = partition_key # pylint:disable=protected-access + header = MessageHeader() + header.durable = True + message.annotations = annotations + message.header = header + + +def trace_message(message, parent_span=None): + """Add tracing information to this message. + + Will open and close a "Azure.EventHubs.message" span, and + add the "DiagnosticId" as app properties of the message. + """ + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + if span_impl_type is not None: + current_span = parent_span or span_impl_type(span_impl_type.get_current_span()) + message_span = current_span.span(name="Azure.EventHubs.message") + message_span.start() + app_prop = dict(message.application_properties) if message.application_properties else dict() + app_prop.setdefault(b"Diagnostic-Id", message_span.get_trace_parent().encode('ascii')) + message.application_properties = app_prop + message_span.finish() + + +def trace_link_message(message, parent_span=None): + """Link the current message to current span. + + Will extract DiagnosticId if available. + """ + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + if span_impl_type is not None: + current_span = parent_span or span_impl_type(span_impl_type.get_current_span()) + if current_span and message.application_properties: + traceparent = message.application_properties.get(b"Diagnostic-Id", "").decode('ascii') + if traceparent: + current_span.link(traceparent) diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/__init__.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/__init__.py index 2f4abf2e3be3..90eb4ed245bf 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/__init__.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/__init__.py @@ -4,8 +4,8 @@ # -------------------------------------------------------------------------------------------- from ._consumer_client_async import EventHubConsumerClient from ._producer_client_async import EventHubProducerClient -from .eventprocessor.partition_manager import PartitionManager -from .eventprocessor.partition_context import PartitionContext +from ._eventprocessor.partition_manager import PartitionManager +from ._eventprocessor.partition_context import PartitionContext __all__ = [ "EventHubConsumerClient", diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/client_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_client_base_async.py similarity index 53% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/client_async.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_client_base_async.py index ef60cf02f9be..219a40c85ee2 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/client_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_client_base_async.py @@ -2,43 +2,41 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- +from __future__ import unicode_literals + import logging -import datetime +import asyncio import time import functools -import asyncio - -from typing import Any, List, Dict, Union, TYPE_CHECKING - -from uamqp import authentication, constants # type: ignore -from uamqp import Message, AMQPClientAsync # type: ignore - -from ..common import parse_sas_token, EventPosition, \ - EventHubSharedKeyCredential, EventHubSASTokenCredential -from ..client_abstract import EventHubClientAbstract - -from .producer_async import EventHubProducer -from .consumer_async import EventHubConsumer +from typing import Any, TYPE_CHECKING + +from uamqp import ( + authentication, + constants, + errors, + compat, + Message, + AMQPClientAsync +) + +from .._common import EventHubSharedKeyCredential, EventHubSASTokenCredential +from .._client_base import ClientBase +from .._utils import parse_sas_token, utc_from_timestamp +from ..exceptions import ClientClosedError +from .._constants import JWT_TOKEN_SCOPE, MGMT_OPERATION, MGMT_PARTITION_OPERATION from ._connection_manager_async import get_connection_manager -from .error_async import _handle_exception +from ._error_async import _handle_exception if TYPE_CHECKING: from azure.core.credentials import TokenCredential # type: ignore -log = logging.getLogger(__name__) - - -class EventHubClient(EventHubClientAbstract): - """ - The EventHubClient class defines a high level interface for asynchronously - sending events to and receiving events from the Azure Event Hubs service. +_LOGGER = logging.getLogger(__name__) - """ +class ClientBaseAsync(ClientBase): def __init__(self, host, event_hub_path, credential, **kwargs): - # type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None - super(EventHubClient, self).__init__(host=host, event_hub_path=event_hub_path, credential=credential, **kwargs) - self._lock = asyncio.Lock() + super(ClientBaseAsync, self).__init__(host=host, event_hub_path=event_hub_path, + credential=credential, **kwargs) self._conn_manager = get_connection_manager(**kwargs) async def __aenter__(self): @@ -62,7 +60,7 @@ def _create_auth(self): password = self._credential.key if "@sas.root" in username: return authentication.SASLPlain( - self._host, username, password, http_proxy=http_proxy, transport_type=transport_type) + self._address.hostname, username, password, http_proxy=http_proxy, transport_type=transport_type) return authentication.SASTokenAsync.from_shared_access_key( self._auth_uri, username, password, timeout=auth_timeout, http_proxy=http_proxy, transport_type=transport_type) @@ -81,7 +79,7 @@ def _create_auth(self): transport_type=transport_type) else: - get_jwt_token = functools.partial(self._credential.get_token, 'https://eventhubs.azure.net//.default') + get_jwt_token = functools.partial(self._credential.get_token, JWT_TOKEN_SCOPE) return authentication.JWTTokenAsync(self._auth_uri, self._auth_uri, get_jwt_token, http_proxy=http_proxy, transport_type=transport_type) @@ -89,15 +87,15 @@ def _create_auth(self): async def _close_connection(self): await self._conn_manager.reset_connection_if_broken() - async def _try_delay(self, retried_times, last_exception, timeout_time=None, entity_name=None): + async def _backoff(self, retried_times, last_exception, timeout_time=None, entity_name=None): entity_name = entity_name or self._container_id backoff = self._config.backoff_factor * 2 ** retried_times if backoff <= self._config.backoff_max and ( timeout_time is None or time.time() + backoff <= timeout_time): # pylint:disable=no-else-return await asyncio.sleep(backoff) - log.info("%r has an exception (%r). Retrying...", format(entity_name), last_exception) + _LOGGER.info("%r has an exception (%r). Retrying...", format(entity_name), last_exception) else: - log.info("%r operation has timed out. Last exception before timeout is (%r)", + _LOGGER.info("%r operation has timed out. Last exception before timeout is (%r)", entity_name, last_exception) raise last_exception @@ -108,7 +106,7 @@ async def _management_request(self, mgmt_msg, op_type): mgmt_auth = self._create_auth() mgmt_client = AMQPClientAsync(self._mgmt_target, auth=mgmt_auth, debug=self._config.network_tracing) try: - conn = await self._conn_manager.get_connection(self._host, mgmt_auth) + conn = await self._conn_manager.get_connection(self._address.hostname, mgmt_auth) await mgmt_client.open_async(connection=conn) response = await mgmt_client.mgmt_request_async( mgmt_msg, @@ -119,12 +117,13 @@ async def _management_request(self, mgmt_msg, op_type): return response except Exception as exception: # pylint:disable=broad-except last_exception = await _handle_exception(exception, self) - await self._try_delay(retried_times=retried_times, last_exception=last_exception) + await self._backoff(retried_times=retried_times, last_exception=last_exception) retried_times += 1 + if retried_times > self._config.max_retries: + _LOGGER.info("%r returns an exception %r", self._container_id, last_exception) + raise last_exception finally: await mgmt_client.close_async() - log.info("%r returns an exception %r", self._container_id, last_exception) # pylint:disable=specify-parameter-names-in-call - raise last_exception async def get_properties(self): # type:() -> Dict[str, Any] @@ -140,12 +139,12 @@ async def get_properties(self): :raises: :class:`EventHubError` """ mgmt_msg = Message(application_properties={'name': self.eh_name}) - response = await self._management_request(mgmt_msg, op_type=b'com.microsoft:eventhub') + response = await self._management_request(mgmt_msg, op_type=MGMT_OPERATION) output = {} eh_info = response.get_data() if eh_info: output['path'] = eh_info[b'name'].decode('utf-8') - output['created_at'] = datetime.datetime.utcfromtimestamp(float(eh_info[b'created_at']) / 1000) + output['created_at'] = utc_from_timestamp(float(eh_info[b'created_at']) / 1000) output['partition_ids'] = [p.decode('utf-8') for p in eh_info[b'partition_ids']] return output @@ -180,7 +179,7 @@ async def get_partition_properties(self, partition): """ mgmt_msg = Message(application_properties={'name': self.eh_name, 'partition': partition}) - response = await self._management_request(mgmt_msg, op_type=b'com.microsoft:partition') + response = await self._management_request(mgmt_msg, op_type=MGMT_PARTITION_OPERATION) partition_info = response.get_data() output = {} if partition_info: @@ -189,83 +188,98 @@ async def get_partition_properties(self, partition): output['beginning_sequence_number'] = partition_info[b'begin_sequence_number'] output['last_enqueued_sequence_number'] = partition_info[b'last_enqueued_sequence_number'] output['last_enqueued_offset'] = partition_info[b'last_enqueued_offset'].decode('utf-8') - output['last_enqueued_time_utc'] = datetime.datetime.utcfromtimestamp( - float(partition_info[b'last_enqueued_time_utc'] / 1000)) output['is_empty'] = partition_info[b'is_partition_empty'] + output['last_enqueued_time_utc'] = utc_from_timestamp( + float(partition_info[b'last_enqueued_time_utc'] / 1000) + ) return output - def _create_consumer( - self, - consumer_group: str, - partition_id: str, - event_position: EventPosition, **kwargs - ) -> EventHubConsumer: - """ - Create an async consumer to the client for a particular consumer group and partition. - - :param consumer_group: The name of the consumer group this consumer is associated with. - Events are read in the context of this group. The default consumer_group for an event hub is "$Default". - :type consumer_group: str - :param partition_id: The identifier of the Event Hub partition from which events will be received. - :type partition_id: str - :param event_position: The position within the partition where the consumer should begin reading events. - :type event_position: ~azure.eventhub.common.EventPosition - :param owner_level: The priority of the exclusive consumer. The client will create an exclusive - consumer if owner_level is set. - :type owner_level: int - :param prefetch: The message prefetch count of the consumer. Default is 300. - :type prefetch: int - :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information - on the last enqueued event on its associated partition, and track that information as events are received. - When information about the partition's last enqueued event is being tracked, each event received from the - Event Hubs service will carry metadata about the partition. This results in a small amount of additional - network bandwidth consumption that is generally a favorable trade-off when considered against periodically - making requests for partition properties using the Event Hub client. - It is set to `False` by default. - :type track_last_enqueued_event_properties: bool - :param loop: An event loop. If not specified the default event loop will be used. - :rtype: ~azure.eventhub.aio.consumer_async.EventHubConsumer - """ - owner_level = kwargs.get("owner_level") - prefetch = kwargs.get("prefetch") or self._config.prefetch - track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) - loop = kwargs.get("loop") - - source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}".format( - self._address.hostname, self._address.path, consumer_group, partition_id) - handler = EventHubConsumer( - self, source_url, event_position=event_position, owner_level=owner_level, - prefetch=prefetch, - track_last_enqueued_event_properties=track_last_enqueued_event_properties, loop=loop) - return handler - - def _create_producer( - self, *, - partition_id: str = None, - send_timeout: float = None, - loop: asyncio.AbstractEventLoop = None - ) -> EventHubProducer: + async def close(self): + # type: () -> None + await self._conn_manager.close_connection() + + +class ConsumerProducerMixin(object): + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.close() + + def _check_closed(self): + if self.closed: + raise ClientClosedError( + "{} has been closed. Please create a new one to handle event data.".format(self._name) + ) + + async def _open(self): """ - Create an async producer to send EventData object to an EventHub. - - :param partition_id: Optionally specify a particular partition to send to. - If omitted, the events will be distributed to available partitions via - round-robin. - :type partition_id: str - :param send_timeout: The timeout in seconds for an individual event to be sent from the time that it is - queued. Default value is 60 seconds. If set to 0, there will be no timeout. - :type send_timeout: float - :param loop: An event loop. If not specified the default event loop will be used. - :rtype: ~azure.eventhub.aio.producer_async.EventHubProducer + Open the EventHubConsumer using the supplied connection. + """ + # pylint: disable=protected-access + if not self.running: + if self._handler: + await self._handler.close_async() + self._create_handler() + await self._handler.open_async(connection=await self._client._conn_manager.get_connection( + self._client._address.hostname, + self._client._create_auth() + )) + while not await self._handler.client_ready_async(): + await asyncio.sleep(0.05) + self._max_message_size_on_link = self._handler.message_handler._link.peer_max_message_size \ + or constants.MAX_MESSAGE_LENGTH_BYTES # pylint: disable=protected-access + self.running = True + + async def _close_handler(self): + if self._handler: + await self._handler.close_async() # close the link (sharing connection) or connection (not sharing) + self.running = False + + async def _close_connection(self): + await self._close_handler() + await self._client._conn_manager.reset_connection_if_broken() # pylint:disable=protected-access - target = "amqps://{}{}".format(self._address.hostname, self._address.path) - send_timeout = self._config.send_timeout if send_timeout is None else send_timeout + async def _handle_exception(self, exception): + if not self.running and isinstance(exception, compat.TimeoutException): + exception = errors.AuthenticationException("Authorization timeout.") + return await _handle_exception(exception, self) - handler = EventHubProducer( - self, target, partition=partition_id, send_timeout=send_timeout, loop=loop) - return handler + return await _handle_exception(exception, self) + + async def _do_retryable_operation(self, operation, timeout=None, **kwargs): + # pylint:disable=protected-access + timeout_time = (time.time() + timeout) if timeout else None + retried_times = 0 + last_exception = kwargs.pop('last_exception', None) + operation_need_param = kwargs.pop('operation_need_param', True) + max_retries = self._client._config.max_retries + + while retried_times <= max_retries: + try: + if operation_need_param: + return await operation(timeout_time=timeout_time, last_exception=last_exception, **kwargs) + return await operation() + except Exception as exception: # pylint:disable=broad-except + last_exception = await self._handle_exception(exception) + await self._client._backoff( + retried_times=retried_times, + last_exception=last_exception, + timeout_time=timeout_time, + entity_name=self._name + ) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) + raise last_exception async def close(self): # type: () -> None - await self._conn_manager.close_connection() + """ + Close down the handler. If the handler has already closed, + this will be a no op. + """ + await self._close_handler() + self.closed = True diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_async.py new file mode 100644 index 000000000000..4eff0ae26cac --- /dev/null +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_async.py @@ -0,0 +1,174 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- +import asyncio +import uuid +import logging + +import uamqp # type: ignore +from uamqp import errors, types, utils # type: ignore +from uamqp import ReceiveClientAsync, Source # type: ignore +from uamqp.compat import queue + +from ._client_base_async import ConsumerProducerMixin +from .._common import EventData, EventPosition +from ..exceptions import _error_handler +from .._utils import create_properties, trace_link_message +from .._constants import ( + EPOCH_SYMBOL, + TIMEOUT_SYMBOL, + RECEIVER_RUNTIME_METRIC_SYMBOL +) + +_LOGGER = logging.getLogger(__name__) + + +class EventHubConsumer(ConsumerProducerMixin): # pylint:disable=too-many-instance-attributes + """ + A consumer responsible for reading EventData from a specific Event Hub + partition and as a member of a specific consumer group. + + A consumer may be exclusive, which asserts ownership over the partition for the consumer + group to ensure that only one consumer from that group is reading the from the partition. + These exclusive consumers are sometimes referred to as "Epoch Consumers." + + A consumer may also be non-exclusive, allowing multiple consumers from the same consumer + group to be actively reading events from the partition. These non-exclusive consumers are + sometimes referred to as "Non-Epoch Consumers." + + Please use the method `create_consumer` on `EventHubClient` for creating `EventHubConsumer`. + """ + + def __init__( # pylint: disable=super-init-not-called + self, client, source, **kwargs): + """ + Instantiate an async consumer. EventHubConsumer should be instantiated by calling the `create_consumer` method + in EventHubClient. + + :param client: The parent EventHubClientAsync. + :type client: ~azure.eventhub.aio.EventHubClientAsync + :param source: The source EventHub from which to receive events. + :type source: ~uamqp.address.Source + :param event_position: The position from which to start receiving. + :type event_position: ~azure.eventhub.common.EventPosition + :param prefetch: The number of events to prefetch from the service + for processing. Default is 300. + :type prefetch: int + :param owner_level: The priority of the exclusive consumer. An exclusive + consumer will be created if owner_level is set. + :type owner_level: int + :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information + on the last enqueued event on its associated partition, and track that information as events are received. + When information about the partition's last enqueued event is being tracked, each event received from the + Event Hubs service will carry metadata about the partition. This results in a small amount of additional + network bandwidth consumption that is generally a favorable trade-off when considered against periodically + making requests for partition properties using the Event Hub client. + It is set to `False` by default. + :type track_last_enqueued_event_properties: bool + :param loop: An event loop. + """ + event_position = kwargs.get("event_position", None) + prefetch = kwargs.get("prefetch", 300) + owner_level = kwargs.get("owner_level", None) + keep_alive = kwargs.get("keep_alive", None) + auto_reconnect = kwargs.get("auto_reconnect", True) + track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) + loop = kwargs.get("loop", None) + + self.running = False + self.closed = False + + self._on_event_received = kwargs.get("on_event_received") + self._loop = loop or asyncio.get_event_loop() + self._client = client + self._source = source + self._offset = event_position + self._prefetch = prefetch + self._owner_level = owner_level + self._keep_alive = keep_alive + self._auto_reconnect = auto_reconnect + self._retry_policy = errors.ErrorPolicy(max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access + self._reconnect_backoff = 1 + self._timeout = 0 + self._link_properties = {} + partition = self._source.split('/')[-1] + self._partition = partition + self._name = "EHReceiver-{}-partition{}".format(uuid.uuid4(), partition) + if owner_level: + self._link_properties[types.AMQPSymbol(EPOCH_SYMBOL)] = types.AMQPLong(int(owner_level)) + link_property_timeout_ms = (self._client._config.receive_timeout or self._timeout) * 1000 # pylint:disable=protected-access + self._link_properties[types.AMQPSymbol(TIMEOUT_SYMBOL)] = types.AMQPLong(int(link_property_timeout_ms)) + self._handler = None + self._track_last_enqueued_event_properties = track_last_enqueued_event_properties + self._last_enqueued_event_properties = {} + self._event_queue = queue.Queue() + self._last_received_event = None + + def _create_handler(self): + source = Source(self._source) + if self._offset is not None: + source.set_filter(self._offset._selector()) # pylint:disable=protected-access + desired_capabilities = None + if self._track_last_enqueued_event_properties: + symbol_array = [types.AMQPSymbol(RECEIVER_RUNTIME_METRIC_SYMBOL)] + desired_capabilities = utils.data_factory(types.AMQPArray(symbol_array)) + + properties = create_properties(self._client._config.user_agent) # pylint:disable=protected-access + self._handler = ReceiveClientAsync( + source, + auth=self._client._create_auth(), # pylint:disable=protected-access + debug=self._client._config.network_tracing, # pylint:disable=protected-access + prefetch=self._prefetch, + link_properties=self._link_properties, + timeout=self._timeout, + error_policy=self._retry_policy, + keep_alive_interval=self._keep_alive, + client_name=self._name, + receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, + auto_complete=False, + properties=properties, + desired_capabilities=desired_capabilities, + loop=self._loop) + + self._handler._streaming_receive = True # pylint:disable=protected-access + self._handler._message_received_callback = self._message_received # pylint:disable=protected-access + + async def _open_with_retry(self): + return await self._do_retryable_operation(self._open, operation_need_param=False) + + def _message_received(self, message): + self._event_queue.put(message) + + async def receive(self): + retried_times = 0 + last_exception = None + max_retries = self._client._config.max_retries # pylint:disable=protected-access + + while retried_times <= max_retries: + try: + await self._open() + await self._handler.do_work_async() + while not self._event_queue.empty(): + message = self._event_queue.get() + event_data = EventData._from_message(message) # pylint:disable=protected-access + self._last_received_event = event_data + trace_link_message(event_data) + await self._on_event_received(event_data) + self._event_queue.task_done() + return + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except uamqp.errors.LinkDetach as ld_error: + if ld_error.condition == uamqp.constants.ErrorCodes.LinkStolen: + raise await self._handle_exception(ld_error) + except Exception as exception: # pylint: disable=broad-except + if not self.running: # exit by close + return + if self._last_received_event: + self._offset = EventPosition(self._last_received_event.offset) + last_exception = await self._handle_exception(exception) + retried_times += 1 + if retried_times > max_retries: + _LOGGER.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) + raise last_exception diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_client_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_client_async.py index fd09857fe353..90f37b692985 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_client_async.py @@ -2,18 +2,24 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # -------------------------------------------------------------------------------------------- + +import asyncio import logging from typing import Any, Union, TYPE_CHECKING, Dict, Tuple + from azure.eventhub import EventPosition, EventHubSharedKeyCredential, EventHubSASTokenCredential -from .eventprocessor.event_processor import EventProcessor -from .client_async import EventHubClient +from ._eventprocessor.event_processor import EventProcessor +from ._consumer_async import EventHubConsumer +from ._client_base_async import ClientBaseAsync +from .._constants import ALL_PARTITIONS + if TYPE_CHECKING: from azure.core.credentials import TokenCredential # type: ignore -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventHubConsumerClient(EventHubClient): +class EventHubConsumerClient(ClientBaseAsync): """ The EventHubProducerClient class defines a high level interface for receiving events from the Azure Event Hubs service. @@ -67,24 +73,59 @@ class EventHubConsumerClient(EventHubClient): def __init__(self, host, event_hub_path, credential, **kwargs) -> None: # type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None - """""" self._partition_manager = kwargs.pop("partition_manager", None) self._load_balancing_interval = kwargs.pop("load_balancing_interval", 10) + network_tracing = kwargs.pop("logging_enable", False) super(EventHubConsumerClient, self).__init__( host=host, event_hub_path=event_hub_path, credential=credential, - network_tracing=kwargs.get("logging_enable"), **kwargs) + network_tracing=network_tracing, **kwargs) + self._lock = asyncio.Lock() self._event_processors = dict() # type: Dict[Tuple[str, str], EventProcessor] - self._closed = False + + def _create_consumer( + self, + consumer_group: str, + partition_id: str, + event_position: EventPosition, **kwargs + ) -> EventHubConsumer: + owner_level = kwargs.get("owner_level") + prefetch = kwargs.get("prefetch") or self._config.prefetch + track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) + on_event_received = kwargs.get("on_event_received") + loop = kwargs.get("loop") + + source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}".format( + self._address.hostname, self._address.path, consumer_group, partition_id) + handler = EventHubConsumer( + self, source_url, + on_event_received=on_event_received, + event_position=event_position, + owner_level=owner_level, + prefetch=prefetch, + track_last_enqueued_event_properties=track_last_enqueued_event_properties, loop=loop) + return handler @classmethod - def from_connection_string(cls, conn_str, **kwargs): - # type: (str, Any) -> EventHubConsumerClient + def from_connection_string(cls, conn_str: str, + *, + event_hub_path: str = None, + logging_enable: bool = False, + http_proxy: dict = None, + auth_timeout: float = 60, + user_agent: str = None, + retry_total: int = 3, + transport_type=None, + partition_manager=None, + load_balancing_interval: float = 10, + **kwargs + ) -> 'EventHubConsumerClient': + # pylint: disable=arguments-differ """ Create an EventHubConsumerClient from a connection string. :param str conn_str: The connection string of an eventhub. :keyword str event_hub_path: The path of the specific Event Hub to connect the client to. - :keyword bool network_tracing: Whether to output network trace logs to the logger. Default is `False`. + :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword dict[str,Any] http_proxy: HTTP proxy settings. This must be a dictionary with the following keys - 'proxy_hostname' (str value) and 'proxy_port' (int value). Additionally the following keys may also be present - 'username', 'password'. @@ -103,6 +144,7 @@ def from_connection_string(cls, conn_str, **kwargs): :paramtype partition_manager: ~azure.eventhub.aio.PartitionManager :keyword float load_balancing_interval: When load balancing kicks in, this is the interval in seconds between two load balancing. Default is 10. + :rtype: ~azure.eventhub.aio.EventHubConsumerClient .. admonition:: Example: @@ -114,10 +156,22 @@ def from_connection_string(cls, conn_str, **kwargs): :caption: Create a new instance of the EventHubConsumerClient from connection string. """ - return super(EventHubConsumerClient, cls).from_connection_string(conn_str, **kwargs) + return super(EventHubConsumerClient, cls).from_connection_string( + conn_str, + event_hub_path=event_hub_path, + logging_enable=logging_enable, + http_proxy=http_proxy, + auth_timeout=auth_timeout, + user_agent=user_agent, + retry_total=retry_total, + transport_type=transport_type, + partition_manager=partition_manager, + load_balancing_interval=load_balancing_interval, + **kwargs + ) async def receive( - self, on_events, consumer_group: str, + self, on_event, consumer_group: str, *, partition_id: str = None, owner_level: int = None, @@ -130,12 +184,12 @@ async def receive( ) -> None: """Receive events from partition(s) optionally with load balancing and checkpointing. - :param on_events: The callback function for handling received events. The callback takes two - parameters: `partition_context` which contains partition context and `events` which are the received events. - Please define the callback like `on_event(partition_context, events)`. + :param on_event: The callback function for handling received event. The callback takes two + parameters: `partition_context` which contains partition context and `event` which is the received event. + Please define the callback like `on_event(partition_context, event)`. For detailed partition context information, please refer to :class:`PartitionContext`. - :type on_events: Callable[~azure.eventhub.aio.PartitionContext, List[EventData]] + :type on_event: Callable[~azure.eventhub.aio.PartitionContext, EventData] :param consumer_group: Receive events from the event hub for this consumer group :type consumer_group: str :keyword str partition_id: Receive from this partition only if it's not None. @@ -183,21 +237,21 @@ async def receive( """ async with self._lock: error = None - if (consumer_group, '-1') in self._event_processors: - error = ValueError("This consumer client is already receiving events from all partitions for" - " consumer group {}. ".format(consumer_group)) + if (consumer_group, ALL_PARTITIONS) in self._event_processors: + error = ("This consumer client is already receiving events " + "from all partitions for consumer group {}. ".format(consumer_group)) elif partition_id is None and any(x[0] == consumer_group for x in self._event_processors): - error = ValueError("This consumer client is already receiving events for consumer group {}. " - .format(consumer_group)) + error = ("This consumer client is already receiving events " + "for consumer group {}. ".format(consumer_group)) elif (consumer_group, partition_id) in self._event_processors: - error = ValueError("This consumer is already receiving events from partition {} for consumer group {}. " - .format(partition_id, consumer_group)) + error = ("This consumer client is already receiving events " + "from partition {} for consumer group {}. ".format(partition_id, consumer_group)) if error: - log.warning(error) - raise error + _LOGGER.warning(error) + raise ValueError(error) event_processor = EventProcessor( - self, consumer_group, on_events, + self, consumer_group, on_event, partition_id=partition_id, partition_manager=self._partition_manager, error_handler=on_error, @@ -209,19 +263,16 @@ async def receive( prefetch=prefetch, track_last_enqueued_event_properties=track_last_enqueued_event_properties, ) - if partition_id: - self._event_processors[(consumer_group, partition_id)] = event_processor - else: - self._event_processors[(consumer_group, "-1")] = event_processor + self._event_processors[(consumer_group, partition_id or ALL_PARTITIONS)] = event_processor try: await event_processor.start() finally: await event_processor.stop() async with self._lock: - if partition_id and (consumer_group, partition_id) in self._event_processors: - del self._event_processors[(consumer_group, partition_id)] - elif partition_id is None and (consumer_group, '-1') in self._event_processors: - del self._event_processors[(consumer_group, "-1")] + try: + del self._event_processors[(consumer_group, partition_id or ALL_PARTITIONS)] + except KeyError: + pass async def close(self): # type: () -> None @@ -240,7 +291,6 @@ async def close(self): """ async with self._lock: - for _ in range(len(self._event_processors)): - _, ep = self._event_processors.popitem() - await ep.stop() + await asyncio.gather(*[p.stop() for p in self._event_processors.values()], return_exceptions=True) + self._event_processors = {} await super().close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py deleted file mode 100644 index 0fac427f7eae..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_consumer_producer_mixin_async.py +++ /dev/null @@ -1,105 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -import asyncio -import logging -import time - -from uamqp import errors, constants, compat # type: ignore -from ..error import EventHubError -from ..aio.error_async import _handle_exception - -log = logging.getLogger(__name__) - - -class ConsumerProducerMixin(object): - - def __init__(self): - self._client = None - self._handler = None - self._name = None - self._running = False - self._closed = False - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.close() - - def _check_closed(self): - if self._closed: - raise EventHubError("{} has been closed. Please create a new one to handle event data.".format(self._name)) - - def _create_handler(self): - pass - - async def _open(self): - """ - Open the EventHubConsumer using the supplied connection. - - """ - # pylint: disable=protected-access - if not self._running: - if self._handler: - await self._handler.close_async() - self._create_handler() - await self._handler.open_async(connection=await self._client._conn_manager.get_connection( - self._client._address.hostname, - self._client._create_auth() - )) - while not await self._handler.client_ready_async(): - await asyncio.sleep(0.05) - self._max_message_size_on_link = self._handler.message_handler._link.peer_max_message_size \ - or constants.MAX_MESSAGE_LENGTH_BYTES # pylint: disable=protected-access - self._running = True - - async def _close_handler(self): - if self._handler: - await self._handler.close_async() # close the link (sharing connection) or connection (not sharing) - self._running = False - - async def _close_connection(self): - await self._close_handler() - await self._client._conn_manager.reset_connection_if_broken() # pylint:disable=protected-access - - async def _handle_exception(self, exception): - if not self._running and isinstance(exception, compat.TimeoutException): - exception = errors.AuthenticationException("Authorization timeout.") - return await _handle_exception(exception, self) - - return await _handle_exception(exception, self) - - async def _do_retryable_operation(self, operation, timeout=100000, **kwargs): - # pylint:disable=protected-access - timeout_time = time.time() + ( - timeout if timeout else 100000) # timeout equals to 0 means no timeout, set the value to be a large number. - retried_times = 0 - last_exception = kwargs.pop('last_exception', None) - operation_need_param = kwargs.pop('operation_need_param', True) - - while retried_times <= self._client._config.max_retries: - try: - if operation_need_param: - return await operation(timeout_time=timeout_time, last_exception=last_exception, **kwargs) - return await operation() - except Exception as exception: # pylint:disable=broad-except - last_exception = await self._handle_exception(exception) - await self._client._try_delay(retried_times=retried_times, last_exception=last_exception, - timeout_time=timeout_time, entity_name=self._name) - retried_times += 1 - - log.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) - raise last_exception - - async def close(self): - # type: () -> None - """ - Close down the handler. If the handler has already closed, - this will be a no op. - """ - if self._handler: - await self._handler.close_async() - self._running = False - self._closed = True diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/error_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_error_async.py similarity index 66% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/error_async.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_error_async.py index 7bbc3b6153c1..f8866d9be0a3 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/error_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_error_async.py @@ -6,33 +6,15 @@ import logging from uamqp import errors, compat # type: ignore -from ..error import EventHubError, EventDataSendError, \ - EventDataError, ConnectError, ConnectionLostError, AuthenticationError +from ..exceptions import ( + _create_eventhub_exception, + EventHubError, + EventDataSendError, + EventDataError +) -log = logging.getLogger(__name__) - - -def _create_eventhub_exception(exception): - if isinstance(exception, errors.AuthenticationException): - error = AuthenticationError(str(exception), exception) - elif isinstance(exception, errors.VendorLinkDetach): - error = ConnectError(str(exception), exception) - elif isinstance(exception, errors.LinkDetach): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.ConnectionClose): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.MessageHandlerError): - error = ConnectionLostError(str(exception), exception) - elif isinstance(exception, errors.AMQPConnectionError): - error_type = AuthenticationError if str(exception).startswith("Unable to open authentication session") \ - else ConnectError - error = error_type(str(exception), exception) - elif isinstance(exception, compat.TimeoutException): - error = ConnectionLostError(str(exception), exception) - else: - error = EventHubError(str(exception), exception) - return error +_LOGGER = logging.getLogger(__name__) async def _handle_exception(exception, closable): # pylint:disable=too-many-branches, too-many-statements @@ -43,7 +25,7 @@ async def _handle_exception(exception, closable): # pylint:disable=too-many-bra except AttributeError: name = closable._container_id # pylint: disable=protected-access if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - log.info("%r stops due to keyboard interrupt", name) + _LOGGER.info("%r stops due to keyboard interrupt", name) await closable.close() raise exception elif isinstance(exception, EventHubError): @@ -57,11 +39,11 @@ async def _handle_exception(exception, closable): # pylint:disable=too-many-bra errors.MessageReleased, errors.MessageContentTooLarge) ): - log.info("%r Event data error (%r)", name, exception) + _LOGGER.info("%r Event data error (%r)", name, exception) error = EventDataError(str(exception), exception) raise error elif isinstance(exception, errors.MessageException): - log.info("%r Event data send error (%r)", name, exception) + _LOGGER.info("%r Event data send error (%r)", name, exception) error = EventDataSendError(str(exception), exception) raise error else: diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/__init__.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/__init__.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/__init__.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/__init__.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/_ownership_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/_ownership_manager.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/_ownership_manager.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/_ownership_manager.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/event_processor.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/event_processor.py similarity index 61% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/event_processor.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/event_processor.py index 7dc16c4519bb..6398b04da29b 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/event_processor.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/event_processor.py @@ -2,27 +2,24 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for license information. # ----------------------------------------------------------------------------------- - -from contextlib import contextmanager -from typing import Dict, Type, Callable, List, Any +from typing import Dict, Callable, List, Any import uuid import asyncio import logging +from functools import partial -from azure.core.tracing import SpanKind # type: ignore -from azure.core.settings import settings # type: ignore - -from azure.eventhub import EventPosition, EventData +from azure.eventhub import EventPosition, EventData, EventHubError from ..._eventprocessor.common import CloseReason +from ..._eventprocessor._eventprocessor_mixin import EventProcessorMixin from .partition_context import PartitionContext from .partition_manager import PartitionManager from ._ownership_manager import OwnershipManager from .utils import get_running_loop -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventProcessor(object): # pylint:disable=too-many-instance-attributes +class EventProcessor(EventProcessorMixin): # pylint:disable=too-many-instance-attributes """ An EventProcessor constantly receives events from one or multiple partitions of the Event Hub in the context of a given consumer group. @@ -62,6 +59,8 @@ def __init__( self._id = str(uuid.uuid4()) self._running = False + self._consumers = {} + def __repr__(self): return 'EventProcessor: id {}'.format(self._id) @@ -70,13 +69,13 @@ def _get_last_enqueued_event_properties(self, partition_id): return self._last_enqueued_event_properties[partition_id] raise ValueError("You're not receiving events from partition {}".format(partition_id)) - def _cancel_tasks_for_partitions(self, to_cancel_partitions): + async def _cancel_tasks_for_partitions(self, to_cancel_partitions): for partition_id in to_cancel_partitions: task = self._tasks.get(partition_id) if task: task.cancel() if to_cancel_partitions: - log.info("EventProcesor %r has cancelled partitions %r", self._id, to_cancel_partitions) + _LOGGER.info("EventProcesor %r has cancelled partitions %r", self._id, to_cancel_partitions) def _create_tasks_for_claimed_ownership(self, claimed_partitions, checkpoints=None): for partition_id in claimed_partitions: @@ -84,24 +83,8 @@ def _create_tasks_for_claimed_ownership(self, claimed_partitions, checkpoints=No checkpoint = checkpoints.get(partition_id) if checkpoints else None self._tasks[partition_id] = get_running_loop().create_task(self._receive(partition_id, checkpoint)) - @contextmanager - def _context(self, events): - # Tracing - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - if span_impl_type is None: - yield - else: - child = span_impl_type(name="Azure.EventHubs.process") - self._eventhub_client._add_span_request_attributes(child) # pylint: disable=protected-access - child.kind = SpanKind.SERVER - - for event in events: - event._trace_link_message(child) # pylint: disable=protected-access - with child: - yield - async def _process_error(self, partition_context, err): - log.warning( + _LOGGER.warning( "EventProcessor instance %r of eventhub %r partition %r consumer group %r" " has met an error. The exception is %r.", partition_context.owner_id, @@ -114,7 +97,7 @@ async def _process_error(self, partition_context, err): try: await self._error_handler(partition_context, err) except Exception as err_again: # pylint:disable=broad-except - log.warning( + _LOGGER.warning( "EventProcessor instance %r of eventhub %r partition %r consumer group %r. " "An error occurred while running process_error(). The exception is %r.", partition_context.owner_id, @@ -126,7 +109,7 @@ async def _process_error(self, partition_context, err): async def _close_partition(self, partition_context, reason): if self._partition_close_handler: - log.info( + _LOGGER.info( "EventProcessor instance %r of eventhub %r partition %r consumer group %r" " is being closed. Reason is: %r", partition_context.owner_id, @@ -138,7 +121,7 @@ async def _close_partition(self, partition_context, reason): try: await self._partition_close_handler(partition_context, reason) except Exception as err: # pylint:disable=broad-except - log.warning( + _LOGGER.warning( "EventProcessor instance %r of eventhub %r partition %r consumer group %r. " "An error occurred while running close(). The exception is %r.", partition_context.owner_id, @@ -148,85 +131,71 @@ async def _close_partition(self, partition_context, reason): err ) + async def _on_event_received(self, partition_context, event): + with self._context(event): + try: + await self._event_handler(partition_context, event) + except asyncio.CancelledError: # pylint: disable=try-except-raise + raise + except Exception as error: # pylint:disable=broad-except + await self._process_error(partition_context, error) + async def _receive(self, partition_id, checkpoint=None): # pylint: disable=too-many-statements try: # pylint:disable=too-many-nested-blocks - log.info("start ownership %r, checkpoint %r", partition_id, checkpoint) - namespace = self._namespace - eventhub_name = self._eventhub_name - consumer_group_name = self._consumer_group_name - owner_id = self._id - checkpoint_offset = checkpoint.get("offset") if checkpoint else None - if checkpoint_offset: - initial_event_position = EventPosition(checkpoint_offset) - elif isinstance(self._initial_event_position, EventPosition): - initial_event_position = self._initial_event_position - elif isinstance(self._initial_event_position, dict): - initial_event_position = self._initial_event_position.get(partition_id, EventPosition("-1")) - else: - initial_event_position = EventPosition(self._initial_event_position) + _LOGGER.info("start ownership %r, checkpoint %r", partition_id, checkpoint) + initial_event_position = self.get_init_event_position(partition_id, checkpoint) if partition_id in self._partition_contexts: partition_context = self._partition_contexts[partition_id] else: partition_context = PartitionContext( - namespace, - eventhub_name, - consumer_group_name, + self._namespace, + self._eventhub_name, + self._consumer_group_name, partition_id, - owner_id, + self._id, self._partition_manager ) self._partition_contexts[partition_id] = partition_context - partition_consumer = self._eventhub_client._create_consumer( # pylint: disable=protected-access - consumer_group_name, - partition_id, - initial_event_position, - owner_level=self._owner_level, - track_last_enqueued_event_properties=self._track_last_enqueued_event_properties, - prefetch=self._prefetch, - ) + event_received_callback = partial(self._on_event_received, partition_context) + self._consumers[partition_id] = self.create_consumer(partition_id, + initial_event_position, + event_received_callback) - try: - if self._partition_initialize_handler: - try: - await self._partition_initialize_handler(partition_context) - except Exception as err: # pylint:disable=broad-except - log.warning( - "EventProcessor instance %r of eventhub %r partition %r consumer group %r. " - " An error occurred while running initialize(). The exception is %r.", - owner_id, eventhub_name, partition_id, consumer_group_name, err - ) - while True: - try: - events = await partition_consumer.receive() - if events: - if self._track_last_enqueued_event_properties: - self._last_enqueued_event_properties[partition_id] = \ - partition_consumer.last_enqueued_event_properties - with self._context(events): - await self._event_handler(partition_context, events) - except asyncio.CancelledError: - log.info( - "EventProcessor instance %r of eventhub %r partition %r consumer group %r" - " is cancelled", - owner_id, - eventhub_name, - partition_id, - consumer_group_name - ) - raise - except Exception as error: # pylint:disable=broad-except - await self._process_error(partition_context, error) - break - # Go to finally to stop this partition processor. - # Later an EventProcessor(this one or another one) will pick up this partition again - finally: - await partition_consumer.close() - if self._running is False: - await self._close_partition(partition_context, CloseReason.SHUTDOWN) - else: - await self._close_partition(partition_context, CloseReason.OWNERSHIP_LOST) + if self._partition_initialize_handler: + try: + await self._partition_initialize_handler(partition_context) + except Exception as err: # pylint:disable=broad-except + _LOGGER.warning( + "EventProcessor instance %r of eventhub %r partition %r consumer group %r. " + "An error occurred while running initialize(). The exception is %r.", + self._id, self._eventhub_name, partition_id, self._consumer_group_name, err + ) + + while self._running: + try: + await self._consumers[partition_id].receive() + except asyncio.CancelledError: + _LOGGER.info( + "EventProcessor instance %r of eventhub %r partition %r consumer group %r" + " is cancelled", + self._id, + self._eventhub_name, + partition_id, + self._consumer_group_name + ) + raise + except EventHubError as eh_error: + await self._process_error(partition_context, eh_error) + break + except Exception as error: # pylint:disable=broad-except + await self._process_error(partition_context, error) finally: + await self._consumers[partition_id].close() + await self._close_partition( + partition_context, + CloseReason.OWNERSHIP_LOST if self._running else CloseReason.SHUTDOWN + ) if partition_id in self._tasks: del self._tasks[partition_id] @@ -239,7 +208,7 @@ async def start(self): :return: None """ - log.info("EventProcessor %r is being started", self._id) + _LOGGER.info("EventProcessor %r is being started", self._id) ownership_manager = OwnershipManager(self._eventhub_client, self._consumer_group_name, self._id, self._partition_manager, self._ownership_timeout, self._partition_id) if not self._running: @@ -252,9 +221,9 @@ async def start(self): to_cancel_list = self._tasks.keys() - claimed_partition_ids self._create_tasks_for_claimed_ownership(claimed_partition_ids, checkpoints) else: - log.info("EventProcessor %r hasn't claimed an ownership. It keeps claiming.", self._id) + _LOGGER.info("EventProcessor %r hasn't claimed an ownership. It keeps claiming.", self._id) to_cancel_list = set(self._tasks.keys()) - self._cancel_tasks_for_partitions(to_cancel_list) + await self._cancel_tasks_for_partitions(to_cancel_list) except Exception as err: # pylint:disable=broad-except ''' ownership_manager.get_checkpoints() and ownership_manager.claim_ownership() may raise exceptions @@ -265,9 +234,9 @@ async def start(self): that this EventProcessor is working on. So two or multiple EventProcessors may be working on the same partition. ''' # pylint:disable=pointless-string-statement - log.warning("An exception (%r) occurred during balancing and claiming ownership for " - "eventhub %r consumer group %r. Retrying after %r seconds", - err, self._eventhub_name, self._consumer_group_name, self._polling_interval) + _LOGGER.warning("An exception (%r) occurred during balancing and claiming ownership for " + "eventhub %r consumer group %r. Retrying after %r seconds", + err, self._eventhub_name, self._consumer_group_name, self._polling_interval) await asyncio.sleep(self._polling_interval) async def stop(self): @@ -284,8 +253,8 @@ async def stop(self): """ self._running = False pids = list(self._tasks.keys()) - self._cancel_tasks_for_partitions(pids) - log.info("EventProcessor %r tasks have been cancelled.", self._id) + await self._cancel_tasks_for_partitions(pids) + _LOGGER.info("EventProcessor %r tasks have been cancelled.", self._id) while self._tasks: await asyncio.sleep(1) - log.info("EventProcessor %r has been stopped.", self._id) + _LOGGER.info("EventProcessor %r has been stopped.", self._id) diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/local_partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/local_partition_manager.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/local_partition_manager.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/local_partition_manager.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_context.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_context.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_context.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_context.py index fac34bba567e..596f846d6a2f 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_context.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_context.py @@ -6,9 +6,9 @@ import logging from .partition_manager import PartitionManager - _LOGGER = logging.getLogger(__name__) + class PartitionContext(object): """Contains partition related context information for a PartitionProcessor instance to use. diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_manager.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_manager.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_manager.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_processor.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_processor.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/partition_processor.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/partition_processor.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/sqlite3_partition_manager.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/sqlite3_partition_manager.py similarity index 92% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/sqlite3_partition_manager.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/sqlite3_partition_manager.py index 577ed917f56e..c37a3b832cbd 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/sqlite3_partition_manager.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/sqlite3_partition_manager.py @@ -9,7 +9,7 @@ import logging from .partition_manager import PartitionManager -logger = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) def _check_table_name(table_name: str): @@ -98,9 +98,9 @@ async def claim_ownership(self, ownership_list): + ") values ("+",".join(["?"] * len(self.ownership_fields)) + ")" cursor.execute(sql, tuple(p.get(field) for field in self.ownership_fields)) except sqlite3.OperationalError as op_err: - logger.info("EventProcessor %r failed to claim partition %r " - "because it was claimed by another EventProcessor at the same time. " - "The Sqlite3 exception is %r", p["owner_id"], p["partition_id"], op_err) + _LOGGER.info("EventProcessor %r failed to claim partition %r " + "because it was claimed by another EventProcessor at the same time. " + "The Sqlite3 exception is %r", p["owner_id"], p["partition_id"], op_err) continue else: result.append(p) @@ -117,9 +117,9 @@ async def claim_ownership(self, ownership_list): + tuple(p.get(field) for field in self.primary_keys)) result.append(p) else: - logger.info("EventProcessor %r failed to claim partition %r " - "because it was claimed by another EventProcessor at the same time", p["owner_id"], - p["partition_id"]) + _LOGGER.info("EventProcessor %r failed to claim partition %r " + "because it was claimed by another EventProcessor at the same time", p["owner_id"], + p["partition_id"]) self.conn.commit() return result finally: diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/utils.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/utils.py similarity index 100% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/eventprocessor/utils.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_eventprocessor/utils.py diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/producer_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_async.py similarity index 67% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/aio/producer_async.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_async.py index d3414c489c03..9773c428118c 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/producer_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_async.py @@ -14,12 +14,14 @@ from azure.core.tracing import SpanKind, AbstractSpan # type: ignore from azure.core.settings import settings # type: ignore -from ..common import EventData, EventDataBatch -from ..error import _error_handler, OperationTimeoutError, EventDataError -from ..producer import _error, _set_partition_key, _set_trace_message -from ._consumer_producer_mixin_async import ConsumerProducerMixin +from .._common import EventData, EventDataBatch +from ..exceptions import _error_handler, OperationTimeoutError +from .._producer import _set_partition_key, _set_trace_message +from .._utils import create_properties, set_message_partition_key, trace_message +from .._constants import TIMEOUT_SYMBOL +from ._client_base_async import ConsumerProducerMixin -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) class EventHubProducer(ConsumerProducerMixin): # pylint: disable=too-many-instance-attributes @@ -31,7 +33,6 @@ class EventHubProducer(ConsumerProducerMixin): # pylint: disable=too-many-insta Please use the method `create_producer` on `EventHubClient` for creating `EventHubProducer`. """ - _timeout_symbol = b'com.microsoft:timeout' def __init__( # pylint: disable=super-init-not-called self, client, target, **kwargs): @@ -63,7 +64,9 @@ def __init__( # pylint: disable=super-init-not-called auto_reconnect = kwargs.get("auto_reconnect", True) loop = kwargs.get("loop", None) - super(EventHubProducer, self).__init__() + self.running = False + self.closed = False + self._loop = loop or asyncio.get_event_loop() self._max_message_size_on_link = None self._client = client @@ -72,7 +75,8 @@ def __init__( # pylint: disable=super-init-not-called self._keep_alive = keep_alive self._auto_reconnect = auto_reconnect self._timeout = send_timeout - self._retry_policy = errors.ErrorPolicy(max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access + self._retry_policy = errors.ErrorPolicy( + max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access self._reconnect_backoff = 1 self._name = "EHProducer-{}".format(uuid.uuid4()) self._unsent_events = None @@ -83,45 +87,49 @@ def __init__( # pylint: disable=super-init-not-called self._handler = None self._outcome = None self._condition = None - self._link_properties = {types.AMQPSymbol(self._timeout_symbol): types.AMQPLong(int(self._timeout * 1000))} + self._lock = asyncio.Lock() + self._link_properties = {types.AMQPSymbol(TIMEOUT_SYMBOL): types.AMQPLong(int(self._timeout * 1000))} def _create_handler(self): self._handler = SendClientAsync( self._target, auth=self._client._create_auth(), # pylint:disable=protected-access debug=self._client._config.network_tracing, # pylint:disable=protected-access - msg_timeout=self._timeout, + msg_timeout=self._timeout * 1000, error_policy=self._retry_policy, keep_alive_interval=self._keep_alive, client_name=self._name, link_properties=self._link_properties, - properties=self._client._create_properties( # pylint: disable=protected-access - self._client._config.user_agent), # pylint:disable=protected-access + properties=create_properties(self._client._config.user_agent), # pylint:disable=protected-access loop=self._loop) async def _open_with_retry(self): return await self._do_retryable_operation(self._open, operation_need_param=False) + def _set_msg_timeout(self, timeout_time, last_exception): + if not timeout_time: + return + remaining_time = timeout_time - time.time() + if remaining_time <= 0.0: + if last_exception: + error = last_exception + else: + error = OperationTimeoutError("Send operation timed out") + _LOGGER.info("%r send operation timed out. (%r)", self._name, error) + raise error + self._handler._msg_timeout = remaining_time * 1000 # pylint: disable=protected-access + async def _send_event_data(self, timeout_time=None, last_exception=None): if self._unsent_events: await self._open() - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - error = last_exception - else: - error = OperationTimeoutError("send operation timed out") - log.info("%r send operation timed out. (%r)", self._name, error) - raise error - self._handler._msg_timeout = remaining_time * 1000 # pylint: disable=protected-access + self._set_msg_timeout(timeout_time, last_exception) self._handler.queue_message(*self._unsent_events) await self._handler.wait_async() self._unsent_events = self._handler.pending_messages if self._outcome != constants.MessageSendResult.Ok: if self._outcome == constants.MessageSendResult.Timeout: - self._condition = OperationTimeoutError("send operation timed out") - _error(self._outcome, self._condition) - return + self._condition = OperationTimeoutError("Send operation timed out") + raise self._condition async def _send_event_data_with_retry(self, timeout=None): return await self._do_retryable_operation(self._send_event_data, timeout=timeout) @@ -138,26 +146,24 @@ def _on_outcome(self, outcome, condition): self._outcome = outcome self._condition = condition - async def create_batch(self, max_size=None): - # type:(int) -> EventDataBatch - """ - Create an EventDataBatch object with max size being max_size. - The max_size should be no greater than the max allowed message size defined by the service side. - - :param max_size: The maximum size of bytes data that an EventDataBatch object can hold. - :type max_size: int - :return: an EventDataBatch instance - :rtype: ~azure.eventhub.EventDataBatch - """ - - if not self._max_message_size_on_link: - await self._open_with_retry() - - if max_size and max_size > self._max_message_size_on_link: - raise ValueError('Max message size: {} is too large, acceptable max batch size is: {} bytes.' - .format(max_size, self._max_message_size_on_link)) - - return EventDataBatch(max_size=(max_size or self._max_message_size_on_link)) + def _wrap_eventdata(self, event_data, span, partition_key): + if isinstance(event_data, EventData): + if partition_key: + set_message_partition_key(event_data.message, partition_key) + wrapper_event_data = event_data + trace_message(wrapper_event_data.message, span) + else: + if isinstance(event_data, EventDataBatch): # The partition_key in the param will be omitted. + if partition_key and partition_key != event_data._partition_key: # pylint: disable=protected-access + raise ValueError('The partition_key does not match the one of the EventDataBatch') + wrapper_event_data = event_data # type:ignore + else: + if partition_key: + event_data = _set_partition_key(event_data, partition_key) + event_data = _set_trace_message(event_data) + wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # pylint: disable=protected-access + wrapper_event_data.message.on_send_complete = self._on_outcome + return wrapper_event_data async def send( self, event_data: Union[EventData, EventDataBatch, Iterable[EventData]], @@ -182,38 +188,22 @@ async def send( :rtype: None """ # Tracing code - span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] - child = None - if span_impl_type is not None: - child = span_impl_type(name="Azure.EventHubs.send") - child.kind = SpanKind.CLIENT # Should be PRODUCER - - self._check_closed() - if isinstance(event_data, EventData): - if partition_key: - event_data._set_partition_key(partition_key) # pylint: disable=protected-access - wrapper_event_data = event_data - wrapper_event_data._trace_message(child) # pylint: disable=protected-access - else: - if isinstance(event_data, EventDataBatch): - if partition_key and partition_key != event_data._partition_key: # pylint: disable=protected-access - raise EventDataError('The partition_key does not match the one of the EventDataBatch') - wrapper_event_data = event_data #type: ignore + async with self._lock: + span_impl_type = settings.tracing_implementation() # type: Type[AbstractSpan] + child = None + if span_impl_type is not None: + child = span_impl_type(name="Azure.EventHubs.send") + child.kind = SpanKind.CLIENT # Should be PRODUCER + self._check_closed() + wrapper_event_data = self._wrap_eventdata(event_data, child, partition_key) + self._unsent_events = [wrapper_event_data.message] + + if span_impl_type is not None and child is not None: + with child: + self._client._add_span_request_attributes(child) # pylint: disable=protected-access + await self._send_event_data_with_retry(timeout=timeout) # pylint:disable=unexpected-keyword-arg # TODO: to refactor else: - if partition_key: - event_data = _set_partition_key(event_data, partition_key) - event_data = _set_trace_message(event_data) - wrapper_event_data = EventDataBatch._from_batch(event_data, partition_key) # pylint: disable=protected-access - - wrapper_event_data.message.on_send_complete = self._on_outcome - self._unsent_events = [wrapper_event_data.message] - - if span_impl_type is not None: - with child: - self._client._add_span_request_attributes(child) # pylint: disable=protected-access await self._send_event_data_with_retry(timeout=timeout) # pylint:disable=unexpected-keyword-arg # TODO: to refactor - else: - await self._send_event_data_with_retry(timeout=timeout) # pylint:disable=unexpected-keyword-arg # TODO: to refactor async def close(self): # type: () -> None @@ -221,4 +211,5 @@ async def close(self): Close down the handler. If the handler has already closed, this will be a no op. """ - await super(EventHubProducer, self).close() + async with self._lock: + await super(EventHubProducer, self).close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_client_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_client_async.py index 51f1eff40cfb..d957ce58b547 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_client_async.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/_producer_client_async.py @@ -7,17 +7,25 @@ from typing import Any, Union, TYPE_CHECKING, Iterable, List from uamqp import constants # type: ignore -from azure.eventhub import EventData, EventHubSharedKeyCredential, EventHubSASTokenCredential, EventDataBatch -from .client_async import EventHubClient -from .producer_async import EventHubProducer + +from ..exceptions import ConnectError, EventHubError +from ._client_base_async import ClientBaseAsync +from ._producer_async import EventHubProducer +from .._constants import ALL_PARTITIONS +from .._common import ( + EventData, + EventHubSharedKeyCredential, + EventHubSASTokenCredential, + EventDataBatch +) if TYPE_CHECKING: from azure.core.credentials import TokenCredential # type: ignore -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) -class EventHubProducerClient(EventHubClient): +class EventHubProducerClient(ClientBaseAsync): """ The EventHubProducerClient class defines a high level interface for sending events to the Azure Event Hubs service. @@ -56,31 +64,72 @@ def __init__(self, host, event_hub_path, credential, **kwargs) -> None: """""" super(EventHubProducerClient, self).__init__( host=host, event_hub_path=event_hub_path, credential=credential, - network_tracing=kwargs.get("logging_enable"), **kwargs) - self._producers = [] # type: List[EventHubProducer] - self._client_lock = asyncio.Lock() # sync the creation of self._producers - self._producers_locks = [] # type: List[asyncio.Lock] + network_tracing=kwargs.pop("logging_enable", False), **kwargs) + self._producers = {ALL_PARTITIONS: self._create_producer()} # type: Dict[str, EventHubProducer] + self._lock = asyncio.Lock() # sync the creation of self._producers self._max_message_size_on_link = 0 - - async def _init_locks_for_producers(self): - if not self._producers: - async with self._client_lock: - if not self._producers: - num_of_producers = len(await self.get_partition_ids()) + 1 - self._producers = [None] * num_of_producers - for _ in range(num_of_producers): - self._producers_locks.append(asyncio.Lock()) - # self._producers_locks = [asyncio.Lock()] * num_of_producers + self._partition_ids = None + + async def _get_partitions(self): + if not self._partition_ids: + self._partition_ids = await self.get_partition_ids() + for p_id in self._partition_ids: + self._producers[p_id] = None + + async def _get_max_mesage_size(self): + # pylint: disable=protected-access + async with self._lock: + if not self._max_message_size_on_link: + await self._producers[ALL_PARTITIONS]._open_with_retry() + self._max_message_size_on_link = \ + self._producers[ALL_PARTITIONS]._handler.message_handler._link.peer_max_message_size \ + or constants.MAX_MESSAGE_LENGTH_BYTES + + async def _start_producer(self, partition_id, send_timeout): + async with self._lock: + await self._get_partitions() + if partition_id not in self._partition_ids and partition_id != ALL_PARTITIONS: + raise ConnectError("Invalid partition {} for the event hub {}".format(partition_id, self.eh_name)) + + if not self._producers[partition_id] or self._producers[partition_id].closed: + self._producers[partition_id] = self._create_producer( + partition_id=partition_id, + send_timeout=send_timeout + ) + + def _create_producer( + self, *, + partition_id: str = None, + send_timeout: float = None, + loop: asyncio.AbstractEventLoop = None + ) -> EventHubProducer: + target = "amqps://{}{}".format(self._address.hostname, self._address.path) + send_timeout = self._config.send_timeout if send_timeout is None else send_timeout + + handler = EventHubProducer( + self, target, partition=partition_id, send_timeout=send_timeout, loop=loop) + return handler @classmethod - def from_connection_string(cls, conn_str, **kwargs): + def from_connection_string( + cls, conn_str: str, + *, + event_hub_path: str = None, + logging_enable: bool = False, + http_proxy: dict = None, + auth_timeout: float = 60, + user_agent: str = None, + retry_total: int = 3, + transport_type=None, + **kwargs): # type: (str, Any) -> EventHubProducerClient + # pylint: disable=arguments-differ """ Create an EventHubProducerClient from a connection string. :param str conn_str: The connection string of an eventhub. :keyword str event_hub_path: The path of the specific Event Hub to connect the client to. - :keyword bool network_tracing: Whether to output network trace logs to the logger. Default is `False`. + :keyword bool logging_enable: Whether to output network trace logs to the logger. Default is `False`. :keyword dict[str,Any] http_proxy: HTTP proxy settings. This must be a dictionary with the following keys - 'proxy_hostname' (str value) and 'proxy_port' (int value). Additionally the following keys may also be present - 'username', 'password'. @@ -92,6 +141,7 @@ def from_connection_string(cls, conn_str, **kwargs): :keyword transport_type: The type of transport protocol that will be used for communicating with the Event Hubs service. Default is `TransportType.Amqp`. :paramtype transport_type: ~azure.eventhub.TransportType + :rtype: ~azure.eventhub.aio.EventHubProducerClient .. admonition:: Example: @@ -102,7 +152,17 @@ def from_connection_string(cls, conn_str, **kwargs): :dedent: 4 :caption: Create a new instance of the EventHubProducerClient from connection string. """ - return super(EventHubProducerClient, cls).from_connection_string(conn_str, **kwargs) + return super(EventHubProducerClient, cls).from_connection_string( + conn_str, + event_hub_path=event_hub_path, + logging_enable=logging_enable, + http_proxy=http_proxy, + auth_timeout=auth_timeout, + user_agent=user_agent, + retry_total=retry_total, + transport_type=transport_type, + **kwargs + ) async def send(self, event_data, *, partition_key: Union[str, bytes] = None, partition_id: str = None, timeout: float = None) -> None: @@ -136,16 +196,12 @@ async def send(self, event_data, :caption: Asynchronously sends event data """ - - await self._init_locks_for_producers() - - producer_index = int(partition_id) if partition_id is not None else -1 - if self._producers[producer_index] is None or self._producers[producer_index]._closed: # pylint:disable=protected-access - async with self._producers_locks[producer_index]: - if self._producers[producer_index] is None: - self._producers[producer_index] = self._create_producer(partition_id=partition_id) - async with self._producers_locks[producer_index]: - await self._producers[producer_index].send(event_data, partition_key=partition_key, timeout=timeout) + partition_id = partition_id or ALL_PARTITIONS + try: + await self._producers[partition_id].send(event_data, partition_key=partition_key) + except (KeyError, AttributeError, EventHubError): + await self._start_producer(partition_id, timeout) + await self._producers[partition_id].send(event_data, partition_key=partition_key) async def create_batch(self, max_size=None): # type:(int) -> EventDataBatch @@ -167,14 +223,7 @@ async def create_batch(self, max_size=None): """ if not self._max_message_size_on_link: - await self._init_locks_for_producers() - async with self._producers_locks[-1]: - if self._producers[-1] is None: - self._producers[-1] = self._create_producer(partition_id=None) - await self._producers[-1]._open_with_retry() # pylint: disable=protected-access - async with self._client_lock: - self._max_message_size_on_link = \ - self._producers[-1]._handler.message_handler._link.peer_max_message_size or constants.MAX_MESSAGE_LENGTH_BYTES # pylint: disable=protected-access, line-too-long + await self._get_max_mesage_size() if max_size and max_size > self._max_message_size_on_link: raise ValueError('Max message size: {} is too large, acceptable max batch size is: {} bytes.' @@ -200,7 +249,8 @@ async def close(self): :caption: Close down the handler. """ - for p in self._producers: - if p: - await p.close() + async with self._lock: + for producer in self._producers.values(): + if producer: + await producer.close() await self._conn_manager.close_connection() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/consumer_async.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/consumer_async.py deleted file mode 100644 index 02573b00b262..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/aio/consumer_async.py +++ /dev/null @@ -1,257 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -import asyncio -import uuid -import logging -from typing import List, Any -import time -from distutils.version import StrictVersion - -import uamqp # type: ignore -from uamqp import errors, types, utils # type: ignore -from uamqp import ReceiveClientAsync, Source # type: ignore - -from ..common import EventData, EventPosition -from ..error import _error_handler -from ._consumer_producer_mixin_async import ConsumerProducerMixin - -log = logging.getLogger(__name__) - - -class EventHubConsumer(ConsumerProducerMixin): # pylint:disable=too-many-instance-attributes - """ - A consumer responsible for reading EventData from a specific Event Hub - partition and as a member of a specific consumer group. - - A consumer may be exclusive, which asserts ownership over the partition for the consumer - group to ensure that only one consumer from that group is reading the from the partition. - These exclusive consumers are sometimes referred to as "Epoch Consumers." - - A consumer may also be non-exclusive, allowing multiple consumers from the same consumer - group to be actively reading events from the partition. These non-exclusive consumers are - sometimes referred to as "Non-Epoch Consumers." - - Please use the method `create_consumer` on `EventHubClient` for creating `EventHubConsumer`. - """ - _timeout = 0 - _epoch_symbol = b'com.microsoft:epoch' - _timeout_symbol = b'com.microsoft:timeout' - _receiver_runtime_metric_symbol = b'com.microsoft:enable-receiver-runtime-metric' - - def __init__( # pylint: disable=super-init-not-called - self, client, source, **kwargs): - """ - Instantiate an async consumer. EventHubConsumer should be instantiated by calling the `create_consumer` method - in EventHubClient. - - :param client: The parent EventHubClientAsync. - :type client: ~azure.eventhub.aio.EventHubClientAsync - :param source: The source EventHub from which to receive events. - :type source: ~uamqp.address.Source - :param event_position: The position from which to start receiving. - :type event_position: ~azure.eventhub.common.EventPosition - :param prefetch: The number of events to prefetch from the service - for processing. Default is 300. - :type prefetch: int - :param owner_level: The priority of the exclusive consumer. An exclusive - consumer will be created if owner_level is set. - :type owner_level: int - :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information - on the last enqueued event on its associated partition, and track that information as events are received. - When information about the partition's last enqueued event is being tracked, each event received from the - Event Hubs service will carry metadata about the partition. This results in a small amount of additional - network bandwidth consumption that is generally a favorable trade-off when considered against periodically - making requests for partition properties using the Event Hub client. - It is set to `False` by default. - :type track_last_enqueued_event_properties: bool - :param loop: An event loop. - """ - event_position = kwargs.get("event_position", None) - prefetch = kwargs.get("prefetch", 300) - owner_level = kwargs.get("owner_level", None) - keep_alive = kwargs.get("keep_alive", None) - auto_reconnect = kwargs.get("auto_reconnect", True) - track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) - loop = kwargs.get("loop", None) - - super(EventHubConsumer, self).__init__() - self._loop = loop or asyncio.get_event_loop() - self._client = client - self._source = source - self._offset = event_position - self._messages_iter = None - self._prefetch = prefetch - self._owner_level = owner_level - self._keep_alive = keep_alive - self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy(max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access - self._reconnect_backoff = 1 - self._link_properties = {} - partition = self._source.split('/')[-1] - self._partition = partition - self._name = "EHReceiver-{}-partition{}".format(uuid.uuid4(), partition) - if owner_level: - self._link_properties[types.AMQPSymbol(self._epoch_symbol)] = types.AMQPLong(int(owner_level)) - link_property_timeout_ms = (self._client._config.receive_timeout or self._timeout) * 1000 # pylint:disable=protected-access - self._link_properties[types.AMQPSymbol(self._timeout_symbol)] = types.AMQPLong(int(link_property_timeout_ms)) - self._handler = None - self._track_last_enqueued_event_properties = track_last_enqueued_event_properties - self._last_enqueued_event_properties = {} - - def __aiter__(self): - return self - - async def __anext__(self): - retried_times = 0 - last_exception = None - while retried_times < self._client._config.max_retries: # pylint:disable=protected-access - try: - await self._open() - if not self._messages_iter: - self._messages_iter = self._handler.receive_messages_iter_async() - message = await self._messages_iter.__anext__() - event_data = EventData._from_message(message) # pylint:disable=protected-access - event_data._trace_link_message() # pylint:disable=protected-access - self._offset = EventPosition(event_data.offset, inclusive=False) - retried_times = 0 - if self._track_last_enqueued_event_properties: - self._last_enqueued_event_properties = event_data._get_last_enqueued_event_properties() # pylint:disable=protected-access - return event_data - except Exception as exception: # pylint:disable=broad-except - last_exception = await self._handle_exception(exception) - await self._client._try_delay(retried_times=retried_times, last_exception=last_exception, # pylint:disable=protected-access - entity_name=self._name) - retried_times += 1 - log.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) - raise last_exception - - def _create_handler(self): - source = Source(self._source) - if self._offset is not None: - source.set_filter(self._offset._selector()) # pylint:disable=protected-access - - if StrictVersion(uamqp.__version__) < StrictVersion("1.2.3"): # backward compatible until uamqp 1.2.3 released - desired_capabilities = {} - elif self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(self._receiver_runtime_metric_symbol)] - desired_capabilities = {"desired_capabilities": utils.data_factory(types.AMQPArray(symbol_array))} - else: - desired_capabilities = {"desired_capabilities": None} - - self._handler = ReceiveClientAsync( - source, - auth=self._client._create_auth(), # pylint:disable=protected-access - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, - link_properties=self._link_properties, - timeout=self._timeout, - error_policy=self._retry_policy, - keep_alive_interval=self._keep_alive, - client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=self._client._create_properties( # pylint:disable=protected-access - self._client._config.user_agent), # pylint:disable=protected-access - **desired_capabilities, # pylint:disable=protected-access - loop=self._loop) - self._messages_iter = None - - async def _open_with_retry(self): - return await self._do_retryable_operation(self._open, operation_need_param=False) - - async def _receive(self, timeout_time=None, max_batch_size=None, **kwargs): - last_exception = kwargs.get("last_exception") - data_batch = [] - - await self._open() - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - log.info("%r receive operation timed out. (%r)", self._name, last_exception) - raise last_exception - return data_batch - - remaining_time_ms = 1000 * remaining_time - message_batch = await self._handler.receive_message_batch_async( - max_batch_size=max_batch_size, - timeout=remaining_time_ms) - for message in message_batch: - event_data = EventData._from_message(message) # pylint:disable=protected-access - data_batch.append(event_data) - event_data._trace_link_message() # pylint:disable=protected-access - - if data_batch: - self._offset = EventPosition(data_batch[-1].offset) - - if self._track_last_enqueued_event_properties and data_batch: - self._last_enqueued_event_properties = data_batch[-1]._get_last_enqueued_event_properties() # pylint:disable=protected-access - - return data_batch - - async def _receive_with_retry(self, timeout=None, max_batch_size=None, **kwargs): - return await self._do_retryable_operation(self._receive, timeout=timeout, - max_batch_size=max_batch_size, **kwargs) - - @property - def last_enqueued_event_properties(self): - """ - The latest enqueued event information. This property will be updated each time an event is received when - the receiver is created with `track_last_enqueued_event_properties` being `True`. - The dict includes following information of the partition: - - - `sequence_number` - - `offset` - - `enqueued_time` - - `retrieval_time` - - :rtype: dict or None - """ - return self._last_enqueued_event_properties if self._track_last_enqueued_event_properties else None - - @property - def queue_size(self): - # type: () -> int - """ - The current size of the unprocessed Event queue. - - :rtype: int - """ - # pylint: disable=protected-access - if self._handler._received_messages: - return self._handler._received_messages.qsize() - return 0 - - async def receive(self, *, max_batch_size=None, timeout=None): - # type: (Any, int, float) -> List[EventData] - """ - Receive events asynchronously from the EventHub. - - :param max_batch_size: Receive a batch of events. Batch size will - be up to the maximum specified, but will return as soon as service - returns no new events. If combined with a timeout and no events are - retrieve before the time, the result will be empty. If no batch - size is supplied, the prefetch size will be the maximum. - :type max_batch_size: int - :param timeout: The maximum wait time to build up the requested message count for the batch. - If not specified, the default wait time specified when the consumer was created will be used. - :type timeout: float - :rtype: list[~azure.eventhub.common.EventData] - :raises: ~azure.eventhub.AuthenticationError, ~azure.eventhub.ConnectError, ~azure.eventhub.ConnectionLostError, - ~azure.eventhub.EventHubError - """ - self._check_closed() - - timeout = timeout or self._client._config.receive_timeout # pylint:disable=protected-access - max_batch_size = max_batch_size or min(self._client._config.max_batch_size, self._prefetch) # pylint:disable=protected-access - - return await self._receive_with_retry(timeout=timeout, max_batch_size=max_batch_size) - - async def close(self): - # type: () -> None - """ - Close down the handler. If the handler has already closed, - this will be a no op. - """ - await super(EventHubConsumer, self).close() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/client.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/client.py deleted file mode 100644 index 58b4e6822012..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/client.py +++ /dev/null @@ -1,266 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals - -import logging -import datetime -import time -import functools -import threading - -from typing import Any, List, Dict, Union, TYPE_CHECKING - -import uamqp # type: ignore -from uamqp import Message # type: ignore -from uamqp import authentication # type: ignore -from uamqp import constants # type: ignore - -from .producer import EventHubProducer -from .consumer import EventHubConsumer -from .common import parse_sas_token, EventPosition -from .client_abstract import EventHubClientAbstract -from .common import EventHubSASTokenCredential, EventHubSharedKeyCredential -from ._connection_manager import get_connection_manager -from .error import _handle_exception - -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential # type: ignore - -log = logging.getLogger(__name__) - - -class EventHubClient(EventHubClientAbstract): - """ - The EventHubClient class defines a high level interface for sending - events to and receiving events from the Azure Event Hubs service. - """ - - def __init__(self, host, event_hub_path, credential, **kwargs): - # type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None - super(EventHubClient, self).__init__(host=host, event_hub_path=event_hub_path, credential=credential, **kwargs) - self._lock = threading.RLock() - self._conn_manager = get_connection_manager(**kwargs) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def _create_auth(self): - """ - Create an ~uamqp.authentication.SASTokenAuth instance to authenticate - the session. - """ - http_proxy = self._config.http_proxy - transport_type = self._config.transport_type - auth_timeout = self._config.auth_timeout - - # TODO: the following code can be refactored to create auth from classes directly instead of using if-else - if isinstance(self._credential, EventHubSharedKeyCredential): # pylint:disable=no-else-return - username = self._credential.policy - password = self._credential.key - if "@sas.root" in username: - return authentication.SASLPlain( - self._host, username, password, http_proxy=http_proxy, transport_type=transport_type) - return authentication.SASTokenAuth.from_shared_access_key( - self._auth_uri, username, password, timeout=auth_timeout, http_proxy=http_proxy, - transport_type=transport_type) - - elif isinstance(self._credential, EventHubSASTokenCredential): - token = self._credential.get_sas_token() - try: - expiry = int(parse_sas_token(token)['se']) - except (KeyError, TypeError, IndexError): - raise ValueError("Supplied SAS token has no valid expiry value.") - return authentication.SASTokenAuth( - self._auth_uri, self._auth_uri, token, - expires_at=expiry, - timeout=auth_timeout, - http_proxy=http_proxy, - transport_type=transport_type) - - else: # Azure credential - get_jwt_token = functools.partial(self._credential.get_token, - 'https://eventhubs.azure.net//.default') - return authentication.JWTTokenAuth(self._auth_uri, self._auth_uri, - get_jwt_token, http_proxy=http_proxy, - transport_type=transport_type) - - def _close_connection(self): - self._conn_manager.reset_connection_if_broken() - - def _try_delay(self, retried_times, last_exception, timeout_time=None, entity_name=None): - entity_name = entity_name or self._container_id - backoff = self._config.backoff_factor * 2 ** retried_times - if backoff <= self._config.backoff_max and ( - timeout_time is None or time.time() + backoff <= timeout_time): # pylint:disable=no-else-return - time.sleep(backoff) - log.info("%r has an exception (%r). Retrying...", format(entity_name), last_exception) - else: - log.info("%r operation has timed out. Last exception before timeout is (%r)", - entity_name, last_exception) - raise last_exception - - def _management_request(self, mgmt_msg, op_type): - retried_times = 0 - last_exception = None - while retried_times <= self._config.max_retries: - mgmt_auth = self._create_auth() - mgmt_client = uamqp.AMQPClient(self._mgmt_target) - try: - conn = self._conn_manager.get_connection(self._host, mgmt_auth) #pylint:disable=assignment-from-none - mgmt_client.open(connection=conn) - response = mgmt_client.mgmt_request( - mgmt_msg, - constants.READ_OPERATION, - op_type=op_type, - status_code_field=b'status-code', - description_fields=b'status-description') - return response - except Exception as exception: # pylint: disable=broad-except - last_exception = _handle_exception(exception, self) - self._try_delay(retried_times=retried_times, last_exception=last_exception) - retried_times += 1 - finally: - mgmt_client.close() - log.info("%r returns an exception %r", self._container_id, last_exception) # pylint:disable=specify-parameter-names-in-call - raise last_exception - - def get_properties(self): - # type:() -> Dict[str, Any] - """ - Get properties of the specified EventHub async. - Keys in the details dictionary include: - - - path - - created_at - - partition_ids - - :rtype: dict - :raises: :class:`EventHubError` - """ - mgmt_msg = Message(application_properties={'name': self.eh_name}) - response = self._management_request(mgmt_msg, op_type=b'com.microsoft:eventhub') - output = {} - eh_info = response.get_data() - if eh_info: - output['path'] = eh_info[b'name'].decode('utf-8') - output['created_at'] = datetime.datetime.utcfromtimestamp(float(eh_info[b'created_at']) / 1000) - output['partition_ids'] = [p.decode('utf-8') for p in eh_info[b'partition_ids']] - return output - - def get_partition_ids(self): - # type:() -> List[str] - """ - Get partition ids of the specified EventHub. - - :rtype: list[str] - :raises: :class:`EventHubError` - """ - return self.get_properties()['partition_ids'] - - def get_partition_properties(self, partition): - # type:(str) -> Dict[str, Any] - """ - Get properties of the specified partition async. - Keys in the details dictionary include: - - - event_hub_path - - id - - beginning_sequence_number - - last_enqueued_sequence_number - - last_enqueued_offset - - last_enqueued_time_utc - - is_empty - - :param partition: The target partition id. - :type partition: str - :rtype: dict - :raises: :class:`EventHubError` - """ - mgmt_msg = Message(application_properties={'name': self.eh_name, - 'partition': partition}) - response = self._management_request(mgmt_msg, op_type=b'com.microsoft:partition') - partition_info = response.get_data() - output = {} - if partition_info: - output['event_hub_path'] = partition_info[b'name'].decode('utf-8') - output['id'] = partition_info[b'partition'].decode('utf-8') - output['beginning_sequence_number'] = partition_info[b'begin_sequence_number'] - output['last_enqueued_sequence_number'] = partition_info[b'last_enqueued_sequence_number'] - output['last_enqueued_offset'] = partition_info[b'last_enqueued_offset'].decode('utf-8') - output['last_enqueued_time_utc'] = datetime.datetime.utcfromtimestamp( - float(partition_info[b'last_enqueued_time_utc'] / 1000)) - output['is_empty'] = partition_info[b'is_partition_empty'] - return output - - def _create_consumer(self, consumer_group, partition_id, event_position, **kwargs): - # type: (str, str, EventPosition, Any) -> EventHubConsumer - """ - Create a consumer to the client for a particular consumer group and partition. - - :param consumer_group: The name of the consumer group this consumer is associated with. - Events are read in the context of this group. The default consumer_group for an event hub is "$Default". - :type consumer_group: str - :param partition_id: The identifier of the Event Hub partition from which events will be received. - :type partition_id: str - :param event_position: The position within the partition where the consumer should begin reading events. - :type event_position: ~azure.eventhub.common.EventPosition - :param owner_level: The priority of the exclusive consumer. The client will create an exclusive - consumer if owner_level is set. - :type owner_level: int - :param prefetch: The message prefetch count of the consumer. Default is 300. - :type prefetch: int - :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information - on the last enqueued event on its associated partition, and track that information as events are received. - When information about the partition's last enqueued event is being tracked, each event received from the - Event Hubs service will carry metadata about the partition. This results in a small amount of additional - network bandwidth consumption that is generally a favorable trade-off when considered against periodically - making requests for partition properties using the Event Hub client. - It is set to `False` by default. - :type track_last_enqueued_event_properties: bool - :rtype: ~azure.eventhub.consumer.EventHubConsumer - """ - owner_level = kwargs.get("owner_level") - prefetch = kwargs.get("prefetch") or self._config.prefetch - track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) - - source_url = "amqps://{}{}/ConsumerGroups/{}/Partitions/{}".format( - self._address.hostname, self._address.path, consumer_group, partition_id) - handler = EventHubConsumer( - self, source_url, event_position=event_position, owner_level=owner_level, - prefetch=prefetch, - track_last_enqueued_event_properties=track_last_enqueued_event_properties) - return handler - - def _create_producer(self, partition_id=None, send_timeout=None): - # type: (str, float) -> EventHubProducer - """ - Create an producer to send EventData object to an EventHub. - - :param partition_id: Optionally specify a particular partition to send to. - If omitted, the events will be distributed to available partitions via - round-robin. - :type partition_id: str - :param operation: An optional operation to be appended to the hostname in the target URL. - The value must start with `/` character. - :type operation: str - :param send_timeout: The timeout in seconds for an individual event to be sent from the time that it is - queued. Default value is 60 seconds. If set to 0, there will be no timeout. - :type send_timeout: float - :rtype: ~azure.eventhub.producer.EventHubProducer - """ - - target = "amqps://{}{}".format(self._address.hostname, self._address.path) - send_timeout = self._config.send_timeout if send_timeout is None else send_timeout - - handler = EventHubProducer( - self, target, partition=partition_id, send_timeout=send_timeout) - return handler - - def close(self): - # type:() -> None - self._conn_manager.close_connection() diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/client_abstract.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/client_abstract.py deleted file mode 100644 index 6a8979369c90..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/client_abstract.py +++ /dev/null @@ -1,217 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals - -import logging -import sys -import platform -import uuid -import time -from abc import abstractmethod -from typing import Union, Any, TYPE_CHECKING - -from uamqp import types # type: ignore -from azure.eventhub import __version__ -from .configuration import Configuration -from .common import EventHubSharedKeyCredential, EventHubSASTokenCredential, _Address - -try: - from urlparse import urlparse # type: ignore - from urllib import urlencode, quote_plus # type: ignore -except ImportError: - from urllib.parse import urlparse, urlencode, quote_plus - -if TYPE_CHECKING: - from azure.core.credentials import TokenCredential # type: ignore - -log = logging.getLogger(__name__) -MAX_USER_AGENT_LENGTH = 512 - - -def _parse_conn_str(conn_str): - endpoint = None - shared_access_key_name = None - shared_access_key = None - entity_path = None - for element in conn_str.split(';'): - key, _, value = element.partition('=') - if key.lower() == 'endpoint': - endpoint = value.rstrip('/') - elif key.lower() == 'hostname': - endpoint = value.rstrip('/') - elif key.lower() == 'sharedaccesskeyname': - shared_access_key_name = value - elif key.lower() == 'sharedaccesskey': - shared_access_key = value - elif key.lower() == 'entitypath': - entity_path = value - if not all([endpoint, shared_access_key_name, shared_access_key]): - raise ValueError("Invalid connection string") - return endpoint, shared_access_key_name, shared_access_key, entity_path - - -def _generate_sas_token(uri, policy, key, expiry=None): - """Create a shared access signiture token as a string literal. - :returns: SAS token as string literal. - :rtype: str - """ - from base64 import b64encode, b64decode - from hashlib import sha256 - from hmac import HMAC - if not expiry: - expiry = time.time() + 3600 # Default to 1 hour. - encoded_uri = quote_plus(uri) - ttl = int(expiry) - sign_key = '%s\n%d' % (encoded_uri, ttl) - signature = b64encode(HMAC(b64decode(key), sign_key.encode('utf-8'), sha256).digest()) - result = { - 'sr': uri, - 'sig': signature, - 'se': str(ttl)} - if policy: - result['skn'] = policy - return 'SharedAccessSignature ' + urlencode(result) - - -def _build_uri(address, entity): - parsed = urlparse(address) - if parsed.path: - return address - if not entity: - raise ValueError("No EventHub specified") - address += "/" + str(entity) - return address - - -class EventHubClientAbstract(object): # pylint:disable=too-many-instance-attributes - """ - The EventHubClientAbstract class defines a high level interface for sending - events to and receiving events from the Azure Event Hubs service. - """ - - def __init__(self, host, event_hub_path, credential, **kwargs): - # type:(str, str, Union[EventHubSharedKeyCredential, EventHubSASTokenCredential, TokenCredential], Any) -> None - """ - :param host: The hostname of the Event Hub. - :type host: str - :param event_hub_path: The path of the specific Event Hub to connect the client to. - :type event_hub_path: str - :param network_tracing: Whether to output network trace logs to the logger. Default - is `False`. - :type network_tracing: bool - :param credential: The credential object used for authentication which implements particular interface - of getting tokens. It accepts ~azure.eventhub.EventHubSharedKeyCredential, - ~azure.eventhub.EventHubSASTokenCredential, credential objects generated by the azure-identity library and - objects that implement `get_token(self, *scopes)` method. - :param http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys - 'proxy_hostname' (str value) and 'proxy_port' (int value). - Additionally the following keys may also be present - 'username', 'password'. - :type http_proxy: dict[str, Any] - :param auth_timeout: The time in seconds to wait for a token to be authorized by the service. - The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. - :type auth_timeout: float - :param user_agent: The user agent that needs to be appended to the built in user agent string. - :type user_agent: str - :param retry_total: The total number of attempts to redo the failed operation when an error happened. Default - value is 3. - :type retry_total: int - :param transport_type: The type of transport protocol that will be used for communicating with - the Event Hubs service. Default is ~azure.eventhub.TransportType.Amqp. - :type transport_type: ~azure.eventhub.TransportType - """ - self.eh_name = event_hub_path - self._host = host - self._container_id = "eventhub.pysdk-" + str(uuid.uuid4())[:8] - self._address = _Address() - self._address.hostname = host - self._address.path = "/" + event_hub_path if event_hub_path else "" - self._credential = credential - self._keep_alive = kwargs.get("keep_alive", 30) - self._auto_reconnect = kwargs.get("auto_reconnect", True) - self._mgmt_target = "amqps://{}/{}".format(self._host, self.eh_name) - self._auth_uri = "sb://{}{}".format(self._address.hostname, self._address.path) - self._config = Configuration(**kwargs) - self._debug = self._config.network_tracing - - log.info("%r: Created the Event Hub client", self._container_id) - - @abstractmethod - def _create_auth(self): - pass - - def _create_properties(self, user_agent=None): # pylint: disable=no-self-use - """ - Format the properties with which to instantiate the connection. - This acts like a user agent over HTTP. - - :rtype: dict - """ - properties = {} - product = "azsdk-python-eventhubs" - properties[types.AMQPSymbol("product")] = product - properties[types.AMQPSymbol("version")] = __version__ - framework = "Python {}.{}.{}, {}".format( - sys.version_info[0], sys.version_info[1], sys.version_info[2], platform.python_implementation() - ) - properties[types.AMQPSymbol("framework")] = framework - platform_str = platform.platform() - properties[types.AMQPSymbol("platform")] = platform_str - - final_user_agent = '{}/{} ({}, {})'.format(product, __version__, framework, platform_str) - if user_agent: - final_user_agent = '{}, {}'.format(final_user_agent, user_agent) - - if len(final_user_agent) > MAX_USER_AGENT_LENGTH: - raise ValueError("The user-agent string cannot be more than {} in length." - "Current user_agent string is: {} with length: {}".format( - MAX_USER_AGENT_LENGTH, final_user_agent, len(final_user_agent))) - properties[types.AMQPSymbol("user-agent")] = final_user_agent - return properties - - def _add_span_request_attributes(self, span): - span.add_attribute("component", "eventhubs") - span.add_attribute("message_bus.destination", self._address.path) - span.add_attribute("peer.address", self._address.hostname) - - @classmethod - def from_connection_string(cls, conn_str, **kwargs): - """ - Create an EventHubProducerClient/EventHubConsumerClient from a connection string. - - :param str conn_str: The connection string of an eventhub. - :keyword str event_hub_path: The path of the specific Event Hub to connect the client to. - :keyword credential: The credential object used for authentication which implements particular interface - of getting tokens. It accepts ~azure.eventhub.EventHubSharedKeyCredential, - ~azure.eventhub.EventHubSASTokenCredential, credential objects generated by the azure-identity library and - objects that implement `get_token(self, *scopes)` method. - :keyword bool network_tracing: Whether to output network trace logs to the logger. Default is `False`. - :keyword dict[str, Any] http_proxy: HTTP proxy settings. This must be a dictionary with the following - keys - 'proxy_hostname' (str value) and 'proxy_port' (int value). - Additionally the following keys may also be present - 'username', 'password'. - :keyword float auth_timeout: The time in seconds to wait for a token to be authorized by the service. - The default value is 60 seconds. If set to 0, no timeout will be enforced from the client. - :keyword str user_agent: The user agent that needs to be appended to the built in user agent string. - :keyword int retry_total: The total number of attempts to redo the failed operation when an error happened. - Default value is 3. - :param transport_type: The type of transport protocol that will be used for communicating with - the Event Hubs service. Default is ~azure.eventhub.TransportType.Amqp. - :type transport_type: ~azure.eventhub.TransportType - :keyword partition_manager: **Only for EventHubConsumerClient** - stores the load balancing data and checkpoint data when receiving events - if partition_manager is specified. If it's None, this EventHubConsumerClient instance will receive - events without load balancing and checkpoint. - :paramtype partition_manager: Implementation classes of ~azure.eventhub.aio.PartitionManager - :keyword float load_balancing_interval: **Only for EventHubConsumerClient** - When load balancing kicks in, this is the interval in seconds between two load balancing. Default is 10. - """ - event_hub_path = kwargs.pop("event_hub_path", None) - address, policy, key, entity = _parse_conn_str(conn_str) - entity = event_hub_path or entity - left_slash_pos = address.find("//") - if left_slash_pos != -1: - host = address[left_slash_pos + 2:] - else: - host = address - return cls(host, entity, EventHubSharedKeyCredential(policy, key), **kwargs) diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/consumer.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/consumer.py deleted file mode 100644 index 0550bb618a74..000000000000 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/consumer.py +++ /dev/null @@ -1,254 +0,0 @@ -# -------------------------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for license information. -# -------------------------------------------------------------------------------------------- -from __future__ import unicode_literals - -import uuid -import logging -import time -from typing import List -from distutils.version import StrictVersion - -import uamqp # type: ignore -from uamqp import types, errors, utils # type: ignore -from uamqp import ReceiveClient, Source # type: ignore - -from .common import EventData, EventPosition -from .error import _error_handler -from ._consumer_producer_mixin import ConsumerProducerMixin - - -log = logging.getLogger(__name__) - - -class EventHubConsumer(ConsumerProducerMixin): # pylint:disable=too-many-instance-attributes - """ - A consumer responsible for reading EventData from a specific Event Hub - partition and as a member of a specific consumer group. - - A consumer may be exclusive, which asserts ownership over the partition for the consumer - group to ensure that only one consumer from that group is reading the from the partition. - These exclusive consumers are sometimes referred to as "Epoch Consumers." - - A consumer may also be non-exclusive, allowing multiple consumers from the same consumer - group to be actively reading events from the partition. These non-exclusive consumers are - sometimes referred to as "Non-Epoch Consumers." - - Please use the method `create_consumer` on `EventHubClient` for creating `EventHubConsumer`. - """ - _timeout = 0 - _epoch_symbol = b'com.microsoft:epoch' - _timeout_symbol = b'com.microsoft:timeout' - _receiver_runtime_metric_symbol = b'com.microsoft:enable-receiver-runtime-metric' - - def __init__(self, client, source, **kwargs): - """ - Instantiate a consumer. EventHubConsumer should be instantiated by calling the `create_consumer` method - in EventHubClient. - - :param client: The parent EventHubClient. - :type client: ~azure.eventhub.client.EventHubClient - :param source: The source EventHub from which to receive events. - :type source: str - :param prefetch: The number of events to prefetch from the service - for processing. Default is 300. - :type prefetch: int - :param owner_level: The priority of the exclusive consumer. An exclusive - consumer will be created if owner_level is set. - :type owner_level: int - :param track_last_enqueued_event_properties: Indicates whether or not the consumer should request information - on the last enqueued event on its associated partition, and track that information as events are received. - When information about the partition's last enqueued event is being tracked, each event received from the - Event Hubs service will carry metadata about the partition. This results in a small amount of additional - network bandwidth consumption that is generally a favorable trade-off when considered against periodically - making requests for partition properties using the Event Hub client. - It is set to `False` by default. - :type track_last_enqueued_event_properties: bool - """ - event_position = kwargs.get("event_position", None) - prefetch = kwargs.get("prefetch", 300) - owner_level = kwargs.get("owner_level", None) - keep_alive = kwargs.get("keep_alive", None) - auto_reconnect = kwargs.get("auto_reconnect", True) - track_last_enqueued_event_properties = kwargs.get("track_last_enqueued_event_properties", False) - - super(EventHubConsumer, self).__init__() - self._client = client - self._source = source - self._offset = event_position - self._messages_iter = None - self._prefetch = prefetch - self._owner_level = owner_level - self._keep_alive = keep_alive - self._auto_reconnect = auto_reconnect - self._retry_policy = errors.ErrorPolicy(max_retries=self._client._config.max_retries, on_error=_error_handler) # pylint:disable=protected-access - self._reconnect_backoff = 1 - self._link_properties = {} - self._error = None - partition = self._source.split('/')[-1] - self._partition = partition - self._name = "EHConsumer-{}-partition{}".format(uuid.uuid4(), partition) - if owner_level: - self._link_properties[types.AMQPSymbol(self._epoch_symbol)] = types.AMQPLong(int(owner_level)) - link_property_timeout_ms = (self._client._config.receive_timeout or self._timeout) * 1000 # pylint:disable=protected-access - self._link_properties[types.AMQPSymbol(self._timeout_symbol)] = types.AMQPLong(int(link_property_timeout_ms)) - self._handler = None - self._track_last_enqueued_event_properties = track_last_enqueued_event_properties - self._last_enqueued_event_properties = {} - - def __iter__(self): - return self - - def __next__(self): - retried_times = 0 - last_exception = None - while retried_times < self._client._config.max_retries: # pylint:disable=protected-access - try: - self._open() - if not self._messages_iter: - self._messages_iter = self._handler.receive_messages_iter() - message = next(self._messages_iter) - event_data = EventData._from_message(message) # pylint:disable=protected-access - event_data._trace_link_message() # pylint:disable=protected-access - self._offset = EventPosition(event_data.offset, inclusive=False) - retried_times = 0 - if self._track_last_enqueued_event_properties: - self._last_enqueued_event_properties = event_data._get_last_enqueued_event_properties() # pylint:disable=protected-access - return event_data - except Exception as exception: # pylint:disable=broad-except - last_exception = self._handle_exception(exception) - self._client._try_delay(retried_times=retried_times, last_exception=last_exception, # pylint:disable=protected-access - entity_name=self._name) - retried_times += 1 - log.info("%r operation has exhausted retry. Last exception: %r.", self._name, last_exception) - raise last_exception - - def _create_handler(self): - source = Source(self._source) - if self._offset is not None: - source.set_filter(self._offset._selector()) # pylint:disable=protected-access - - if StrictVersion(uamqp.__version__) < StrictVersion("1.2.3"): # backward compatible until uamqp 1.2.3 released - desired_capabilities = {} - elif self._track_last_enqueued_event_properties: - symbol_array = [types.AMQPSymbol(self._receiver_runtime_metric_symbol)] - desired_capabilities = {"desired_capabilities": utils.data_factory(types.AMQPArray(symbol_array))} - else: - desired_capabilities = {"desired_capabilities": None} - - self._handler = ReceiveClient( - source, - auth=self._client._create_auth(), # pylint:disable=protected-access - debug=self._client._config.network_tracing, # pylint:disable=protected-access - prefetch=self._prefetch, - link_properties=self._link_properties, - timeout=self._timeout, - error_policy=self._retry_policy, - keep_alive_interval=self._keep_alive, - client_name=self._name, - receive_settle_mode=uamqp.constants.ReceiverSettleMode.ReceiveAndDelete, - auto_complete=False, - properties=self._client._create_properties( # pylint:disable=protected-access - self._client._config.user_agent), # pylint:disable=protected-access - **desired_capabilities) # pylint:disable=protected-access - self._messages_iter = None - - def _open_with_retry(self): - return self._do_retryable_operation(self._open, operation_need_param=False) - - def _receive(self, timeout_time=None, max_batch_size=None, **kwargs): - last_exception = kwargs.get("last_exception") - data_batch = [] - - self._open() - remaining_time = timeout_time - time.time() - if remaining_time <= 0.0: - if last_exception: - log.info("%r receive operation timed out. (%r)", self._name, last_exception) - raise last_exception - return data_batch - remaining_time_ms = 1000 * remaining_time - message_batch = self._handler.receive_message_batch( - max_batch_size=max_batch_size, - timeout=remaining_time_ms) - for message in message_batch: - event_data = EventData._from_message(message) # pylint:disable=protected-access - data_batch.append(event_data) - event_data._trace_link_message() # pylint:disable=protected-access - - if data_batch: - self._offset = EventPosition(data_batch[-1].offset) - - if self._track_last_enqueued_event_properties and data_batch: - self._last_enqueued_event_properties = data_batch[-1]._get_last_enqueued_event_properties() # pylint:disable=protected-access - - return data_batch - - def _receive_with_retry(self, timeout=None, max_batch_size=None, **kwargs): - return self._do_retryable_operation(self._receive, timeout=timeout, - max_batch_size=max_batch_size, **kwargs) - - @property - def last_enqueued_event_properties(self): - """ - The latest enqueued event information. This property will be updated each time an event is received when - the receiver is created with `track_last_enqueued_event_properties` being `True`. - The dict includes following information of the partition: - - - `sequence_number` - - `offset` - - `enqueued_time` - - `retrieval_time` - - :rtype: dict or None - """ - return self._last_enqueued_event_properties if self._track_last_enqueued_event_properties else None - - @property - def queue_size(self): - # type:() -> int - """ - The current size of the unprocessed Event queue. - - :rtype: int - """ - # pylint: disable=protected-access - if self._handler._received_messages: - return self._handler._received_messages.qsize() - return 0 - - def receive(self, max_batch_size=None, timeout=None): - # type: (int, float) -> List[EventData] - """ - Receive events from the EventHub. - - :param max_batch_size: Receive a batch of events. Batch size will - be up to the maximum specified, but will return as soon as service - returns no new events. If combined with a timeout and no events are - retrieve before the time, the result will be empty. If no batch - size is supplied, the prefetch size will be the maximum. - :type max_batch_size: int - :param timeout: The maximum wait time to build up the requested message count for the batch. - If not specified, the default wait time specified when the consumer was created will be used. - :type timeout: float - :rtype: list[~azure.eventhub.common.EventData] - :raises: ~azure.eventhub.AuthenticationError, ~azure.eventhub.ConnectError, ~azure.eventhub.ConnectionLostError, - ~azure.eventhub.EventHubError - """ - self._check_closed() - - timeout = timeout or self._client._config.receive_timeout # pylint:disable=protected-access - max_batch_size = max_batch_size or min(self._client._config.max_batch_size, self._prefetch) # pylint:disable=protected-access - - return self._receive_with_retry(timeout=timeout, max_batch_size=max_batch_size) - - def close(self): # pylint:disable=useless-super-delegation - # type:() -> None - """ - Close down the handler. If the handler has already closed, - this will be a no op. - """ - super(EventHubConsumer, self).close() - - next = __next__ # for python2.7 diff --git a/sdk/eventhub/azure-eventhubs/azure/eventhub/error.py b/sdk/eventhub/azure-eventhubs/azure/eventhub/exceptions.py similarity index 91% rename from sdk/eventhub/azure-eventhubs/azure/eventhub/error.py rename to sdk/eventhub/azure-eventhubs/azure/eventhub/exceptions.py index 755925bca743..1c01721ca2f4 100644 --- a/sdk/eventhub/azure-eventhubs/azure/eventhub/error.py +++ b/sdk/eventhub/azure-eventhubs/azure/eventhub/exceptions.py @@ -7,16 +7,9 @@ from uamqp import errors, compat # type: ignore +from ._constants import NO_RETRY_ERRORS -_NO_RETRY_ERRORS = ( - b"com.microsoft:argument-out-of-range", - b"com.microsoft:entity-disabled", - b"com.microsoft:auth-failed", - b"com.microsoft:precondition-failed", - b"com.microsoft:argument-error" -) - -log = logging.getLogger(__name__) +_LOGGER = logging.getLogger(__name__) def _error_handler(error): @@ -38,7 +31,7 @@ def _error_handler(error): return errors.ErrorAction(retry=True) if error.condition == b"com.microsoft:container-close": return errors.ErrorAction(retry=True, backoff=4) - if error.condition in _NO_RETRY_ERRORS: + if error.condition in NO_RETRY_ERRORS: return errors.ErrorAction(retry=False) return errors.ErrorAction(retry=True) @@ -97,41 +90,32 @@ def _parse_error(self, error_list): self.details = details -class ConnectionLostError(EventHubError): - """Connection to event hub is lost. SDK will retry. So this shouldn't happen. +class ClientClosedError(EventHubError): + """The Client has been closed and is unable to process further events.""" - """ + +class ConnectionLostError(EventHubError): + """Connection to event hub is lost. SDK will retry. So this shouldn't happen.""" class ConnectError(EventHubError): - """Fail to connect to event hubs - - """ + """Fail to connect to event hubs.""" class AuthenticationError(ConnectError): - """Fail to connect to event hubs because of authentication problem - - - """ + """Fail to connect to event hubs because of authentication problem.""" class EventDataError(EventHubError): - """Problematic event data so the send will fail at client side - - """ + """Problematic event data so the send will fail at client side.""" class EventDataSendError(EventHubError): - """Service returns error while an event data is being sent - - """ + """Service returns error while an event data is being sent.""" class OperationTimeoutError(EventHubError): - """Operation times out - - """ + """Operation times out.""" def _create_eventhub_exception(exception): @@ -162,7 +146,7 @@ def _handle_exception(exception, closable): # pylint:disable=too-many-branches, except AttributeError: # closable is an client object name = closable._container_id # pylint: disable=protected-access if isinstance(exception, KeyboardInterrupt): # pylint:disable=no-else-raise - log.info("%r stops due to keyboard interrupt", name) + _LOGGER.info("%r stops due to keyboard interrupt", name) closable.close() raise exception elif isinstance(exception, EventHubError): @@ -176,11 +160,11 @@ def _handle_exception(exception, closable): # pylint:disable=too-many-branches, errors.MessageReleased, errors.MessageContentTooLarge) ): - log.info("%r Event data error (%r)", name, exception) + _LOGGER.info("%r Event data error (%r)", name, exception) error = EventDataError(str(exception), exception) raise error elif isinstance(exception, errors.MessageException): - log.info("%r Event data send error (%r)", name, exception) + _LOGGER.info("%r Event data send error (%r)", name, exception) error = EventDataSendError(str(exception), exception) raise error else: diff --git a/sdk/eventhub/azure-eventhubs/conftest.py b/sdk/eventhub/azure-eventhubs/conftest.py index b5e07e8915ca..39ff979bd0b7 100644 --- a/sdk/eventhub/azure-eventhubs/conftest.py +++ b/sdk/eventhub/azure-eventhubs/conftest.py @@ -21,9 +21,15 @@ collect_ignore.append("samples/async_samples") collect_ignore.append("examples/async_examples") -from azure.eventhub.client import EventHubClient +# from azure.eventhub.client import EventHubClient +from azure.eventhub import EventHubConsumerClient +from azure.eventhub import EventHubProducerClient from azure.eventhub import EventPosition +import uamqp +from uamqp import authentication +PARTITION_COUNT = 2 +CONN_STR = "Endpoint=sb://{}/;SharedAccessKeyName={};SharedAccessKey={};EntityPath={}" def pytest_addoption(parser): parser.addoption( @@ -65,7 +71,7 @@ def get_logger(filename, level=logging.INFO): def create_eventhub(eventhub_config, client=None): from azure.servicebus.control_client import ServiceBusService, EventHub hub_name = str(uuid.uuid4()) - hub_value = EventHub(partition_count=2) + hub_value = EventHub(partition_count=PARTITION_COUNT) client = client or ServiceBusService( service_namespace=eventhub_config['namespace'], shared_access_key_name=eventhub_config['key_name'], @@ -95,6 +101,7 @@ def live_eventhub_config(): config['namespace'] = os.environ['EVENT_HUB_NAMESPACE'] config['consumer_group'] = "$Default" config['partition'] = "0" + config['connection_str'] = CONN_STR except KeyError: pytest.skip("Live EventHub configuration not found.") else: @@ -123,7 +130,7 @@ def live_eventhub(live_eventhub_config): # pylint: disable=redefined-outer-name @pytest.fixture() def connection_str(live_eventhub): - return "Endpoint=sb://{}/;SharedAccessKeyName={};SharedAccessKey={};EntityPath={}".format( + return CONN_STR.format( live_eventhub['hostname'], live_eventhub['key_name'], live_eventhub['access_key'], @@ -132,7 +139,8 @@ def connection_str(live_eventhub): @pytest.fixture() def invalid_hostname(live_eventhub_config): - return "Endpoint=sb://invalid123.servicebus.windows.net/;SharedAccessKeyName={};SharedAccessKey={};EntityPath={}".format( + return CONN_STR.format( + "invalid123.servicebus.windows.net", live_eventhub_config['key_name'], live_eventhub_config['access_key'], live_eventhub_config['event_hub']) @@ -140,16 +148,18 @@ def invalid_hostname(live_eventhub_config): @pytest.fixture() def invalid_key(live_eventhub_config): - return "Endpoint=sb://{}/;SharedAccessKeyName={};SharedAccessKey=invalid;EntityPath={}".format( + return CONN_STR.format( live_eventhub_config['hostname'], live_eventhub_config['key_name'], + "invalid", live_eventhub_config['event_hub']) @pytest.fixture() def invalid_policy(live_eventhub_config): - return "Endpoint=sb://{}/;SharedAccessKeyName=invalid;SharedAccessKey={};EntityPath={}".format( + return CONN_STR.format( live_eventhub_config['hostname'], + "invalid", live_eventhub_config['access_key'], live_eventhub_config['event_hub']) @@ -163,24 +173,30 @@ def aad_credential(): @pytest.fixture() -def connstr_receivers(connection_str): - client = EventHubClient.from_connection_string(connection_str, network_tracing=False) - partitions = client.get_partition_ids() +def connstr_receivers(connection_str, live_eventhub_config): + partitions = [str(i) for i in range(PARTITION_COUNT)] receivers = [] for p in partitions: - receiver = client._create_consumer(consumer_group="$default", partition_id=p, event_position=EventPosition("-1"), prefetch=500) - receiver._open() + uri = "sb://{}/{}".format(live_eventhub_config['hostname'], live_eventhub_config['event_hub']) + sas_auth = authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub_config['key_name'], live_eventhub_config['access_key']) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub_config['hostname'], + live_eventhub_config['event_hub'], + live_eventhub_config['consumer_group'], + p) + receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) + receiver.open() receivers.append(receiver) yield connection_str, receivers - for r in receivers: r.close() - client.close() @pytest.fixture() def connstr_senders(connection_str): - client = EventHubClient.from_connection_string(connection_str, network_tracing=False) + client = EventHubProducerClient.from_connection_string(connection_str) partitions = client.get_partition_ids() senders = [] diff --git a/sdk/eventhub/azure-eventhubs/dev_requirements.txt b/sdk/eventhub/azure-eventhubs/dev_requirements.txt index c808b7948163..2a1e4e3d9ceb 100644 --- a/sdk/eventhub/azure-eventhubs/dev_requirements.txt +++ b/sdk/eventhub/azure-eventhubs/dev_requirements.txt @@ -2,6 +2,7 @@ ../../core/azure-core -e ../../identity/azure-identity -e ../../servicebus/azure-servicebus +aiohttp>=3.0; python_version >= '3.5' docutils>=0.14 pygments>=2.2.0 behave==1.2.6 diff --git a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_async.py b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_async.py index 591234613edb..9dbe2eee7861 100644 --- a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_async.py +++ b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_async.py @@ -16,22 +16,15 @@ CONNECTION_STR = os.environ["EVENT_HUB_CONN_STR"] -async def do_operation(event): - pass - # do some sync or async operations. If the operation is i/o intensive, async will have better performance +async def on_event(partition_context, event): + print("Received event from partition: {}".format(partition_context.partition_id)) + # Do some sync or async operations. If the operation is i/o intensive, async will have better performance # print(event) -async def on_events(partition_context, events): - # put your code here - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - await asyncio.gather(*[do_operation(event) for event in events]) - - async def receive(client): try: - await client.receive(on_events=on_events, - consumer_group="$default") + await client.receive(on_event=on_event, consumer_group="$default") except KeyboardInterrupt: await client.close() diff --git a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_track_last_enqueued_event_info_async.py b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_track_last_enqueued_event_info_async.py index f8b1266437ed..9e5f21f2340c 100644 --- a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_track_last_enqueued_event_info_async.py +++ b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_track_last_enqueued_event_info_async.py @@ -17,28 +17,25 @@ CONNECTION_STR = os.environ["EVENT_HUB_CONN_STR"] -async def do_operation(event): - pass - # do some sync or async operations. If the operation is i/o intensive, async will have better performance - # print(event) - - -async def on_events(partition_context, events): - # put your code here - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - await asyncio.gather(*[do_operation(event) for event in events]) - - print("Last enqueued event properties from partition: {} is: {}". - format(partition_context.partition_id, - events[-1].last_enqueued_event_properties)) +async def on_event(partition_context, event): + print("Received events from partition: {}".format(partition_context.partition_id)) + # Do some sync or async operations. If the operation is i/o intensive, async will have better performance + print(event) + + print("Last enqueued event properties from partition: {} is: {}".format( + partition_context.partition_id, + event.last_enqueued_event_properties) + ) async def receive(client): try: - await client.receive(on_events=on_events, - consumer_group="$default", - partition_id='0', - track_last_enqueued_event_properties=True) + await client.receive( + on_event=on_event, + consumer_group="$default", + partition_id='0', + track_last_enqueued_event_properties=True + ) except KeyboardInterrupt: await client.close() diff --git a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_with_partition_manager_async.py b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_with_partition_manager_async.py index 46f368ae4876..7128121511e0 100644 --- a/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_with_partition_manager_async.py +++ b/sdk/eventhub/azure-eventhubs/samples/async_samples/recv_with_partition_manager_async.py @@ -28,11 +28,10 @@ async def do_operation(event): # print(event) -async def on_events(partition_context, events): +async def on_event(partition_context, event): # put your code here - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - await asyncio.gather(*[do_operation(event) for event in events]) - await partition_context.update_checkpoint(events[-1]) + print("Received event from partition: {}".format(partition_context.partition_id)) + await partition_context.update_checkpoint(event) async def receive(client): @@ -42,9 +41,9 @@ async def receive(client): partition manager, the client will load-balance partition assignment with other EventHubConsumerClient instances which also try to receive events from all partitions and use the same storage resource. """ - await client.receive(on_events=on_events, consumer_group="$Default") + await client.receive(on_event=on_event, consumer_group="$Default") # With specified partition_id, load-balance will be disabled - # await client.receive(event_handler=event_handler, consumer_group="$default", partition_id = '0')) + # await client.receive(on_event=on_event, consumer_group="$default", partition_id = '0')) except KeyboardInterrupt: await client.close() diff --git a/sdk/eventhub/azure-eventhubs/samples/async_samples/sample_code_eventhub_async.py b/sdk/eventhub/azure-eventhubs/samples/async_samples/sample_code_eventhub_async.py index 6fb29ab47015..790905149222 100644 --- a/sdk/eventhub/azure-eventhubs/samples/async_samples/sample_code_eventhub_async.py +++ b/sdk/eventhub/azure-eventhubs/samples/async_samples/sample_code_eventhub_async.py @@ -88,12 +88,12 @@ async def example_eventhub_async_send_and_receive(): # [START eventhub_consumer_client_receive_async] logger = logging.getLogger("azure.eventhub") - async def on_events(partition_context, events): - logger.info("Received {} messages from partition: {}".format( - len(events), partition_context.partition_id)) - # Do ops on received events + async def on_event(partition_context, event): + logger.info("Received event from partition: {}".format(partition_context.partition_id)) + # Do asnchronous ops on received events + async with consumer: - await consumer.receive(on_events=on_events, consumer_group="$default") + await consumer.receive(on_event=on_event, consumer_group="$default") # [END eventhub_consumer_client_receive_async] finally: pass @@ -108,8 +108,10 @@ async def example_eventhub_async_producer_ops(): event_hub_connection_str = os.environ['EVENT_HUB_CONN_STR'] event_hub = os.environ['EVENT_HUB_NAME'] - producer = EventHubProducerClient.from_connection_string(conn_str=event_hub_connection_str, - event_hub_path=event_hub) + producer = EventHubProducerClient.from_connection_string( + conn_str=event_hub_connection_str, + event_hub_path=event_hub + ) try: await producer.send(EventData(b"A single event")) finally: @@ -133,15 +135,14 @@ async def example_eventhub_async_consumer_ops(): logger = logging.getLogger("azure.eventhub") - async def on_events(partition_context, events): - logger.info("Received {} messages from partition: {}".format( - len(events), partition_context.partition_id)) - # Do ops on received events + async def on_event(partition_context, event): + logger.info("Received event from partition: {}".format(partition_context.partition_id)) + # Do asynchronous ops on received events # The receive method is a coroutine method which can be called by `await consumer.receive(...)` and it will block. # so execute it in an async task to better demonstrate how to stop the receiving by calling he close method. - recv_task = asyncio.ensure_future(consumer.receive(on_events=on_events, consumer_group='$Default')) + recv_task = asyncio.ensure_future(consumer.receive(on_event=on_event, consumer_group='$Default')) await asyncio.sleep(3) # keep receiving for 3 seconds recv_task.cancel() # stop receiving diff --git a/sdk/eventhub/azure-eventhubs/samples/sync_samples/proxy.py b/sdk/eventhub/azure-eventhubs/samples/sync_samples/proxy.py index f4c027132f36..1db71885fea5 100644 --- a/sdk/eventhub/azure-eventhubs/samples/sync_samples/proxy.py +++ b/sdk/eventhub/azure-eventhubs/samples/sync_samples/proxy.py @@ -25,17 +25,12 @@ } -def do_operation(event): +def on_event(partition_context, events): + print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) # do some operations on the event print(event) -def on_events(partition_context, events): - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - for event in events: - do_operation(event) - - consumer_client = EventHubConsumerClient.from_connection_string( conn_str=CONNECTION_STR, event_hub_path=EVENT_HUB, http_proxy=HTTP_PROXY) producer_client = EventHubProducerClient.from_connection_string( @@ -47,7 +42,6 @@ def on_events(partition_context, events): with consumer_client: receiving_time = 5 - consumer_client.receive(on_events=on_events, consumer_group='$Default') - time.sleep(receiving_time) + consumer_client.receive(on_event=on_event, consumer_group='$Default') print('Finish receiving.') diff --git a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv.py b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv.py index c6ea05711e4d..eecf8d1d9fc0 100644 --- a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv.py +++ b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv.py @@ -17,13 +17,6 @@ EVENT_POSITION = EventPosition("-1") PARTITION = "0" -total = 0 - - -def do_operation(event): - # do some operations on the event, avoid time-consuming ops - pass - def on_partition_initialize(partition_context): # put your code here @@ -35,21 +28,14 @@ def on_partition_close(partition_context, reason): print("Partition: {} has been closed, reason for closing: {}".format(partition_context.partition_id, reason)) - def on_error(partition_context, error): # put your code here print("Partition: {} met an exception during receiving: {}".format(partition_context.partition_id, error)) - -def on_events(partition_context, events): +def on_event(partition_context, event): # put your code here - global total - - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - total += len(events) - for event in events: - do_operation(event) + print("Received event from partition: {}".format(partition_context.partition_id)) if __name__ == '__main__': @@ -60,11 +46,12 @@ def on_events(partition_context, events): try: with consumer_client: - consumer_client.receive(on_events=on_events, consumer_group='$Default', - on_partition_initialize=on_partition_initialize, - on_partition_close=on_partition_close, - on_error=on_error) - # Receive with owner level: - # consumer_client.receive(on_events=on_events, consumer_group='$Default', owner_level=1) + consumer_client.receive( + on_event=on_event, + consumer_group='$Default', + on_partition_initialize=on_partition_initialize, + on_partition_close=on_partition_close, + on_error=on_error + ) except KeyboardInterrupt: print('Stop receiving.') diff --git a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_track_last_enqueued_event_info.py b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_track_last_enqueued_event_info.py index 5665befa3dbd..5c06fb1ca714 100644 --- a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_track_last_enqueued_event_info.py +++ b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_track_last_enqueued_event_info.py @@ -19,25 +19,18 @@ EVENT_POSITION = EventPosition("-1") PARTITION = "0" -total = 0 +def on_event(partition_context, event): + print("Received event from partition {}".format(partition_context.partition_id)) + + # Put your code here to do some operations on the event. + # Avoid time-consuming operations. + print(event) -def do_operation(event): - # do some operations on the event, avoid time-consuming ops - pass - - -def on_events(partition_context, events): - # put your code here - global total - print("received events: {} from partition {}".format(len(events), partition_context.partition_id)) - total += len(events) - for event in events: - do_operation(event) - - print("Last enqueued event properties from partition: {} is: {}". - format(partition_context.partition_id, - events[-1].last_enqueued_event_properties)) + print("Last enqueued event properties from partition: {} is: {}".format( + partition_context.partition_id, + event.last_enqueued_event_properties) + ) if __name__ == '__main__': @@ -48,8 +41,11 @@ def on_events(partition_context, events): try: with consumer_client: - consumer_client.receive(on_events=on_events, consumer_group='$Default', - partition_id='0', track_last_enqueued_event_properties=True) - + consumer_client.receive( + on_event=on_event, + consumer_group='$Default', + partition_id='0', + track_last_enqueued_event_properties=True + ) except KeyboardInterrupt: print('Stop receiving.') diff --git a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_with_partition_manager.py b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_with_partition_manager.py index d10fcd544eff..5c6930a500d0 100644 --- a/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_with_partition_manager.py +++ b/sdk/eventhub/azure-eventhubs/samples/sync_samples/recv_with_partition_manager.py @@ -21,18 +21,14 @@ STORAGE_CONNECTION_STR = os.environ["AZURE_STORAGE_CONN_STR"] -def do_operation(event): - # do some operations on the event, avoid time-consuming ops - pass +def on_event(partition_context, event): + print("Received event from partition: {}".format(partition_context.partition_id)) + # Put your code here to do some operations on the event. + # Avoid time-consuming operations. + print(event) -def on_events(partition_context, events): - # put your code here - print("received events: {} from partition: {}".format(len(events), partition_context.partition_id)) - for event in events: - do_operation(event) - - partition_context.update_checkpoint(events[-1]) + partition_context.update_checkpoint(event) if __name__ == '__main__': @@ -50,8 +46,8 @@ def on_events(partition_context, events): partition manager, the client will load-balance partition assignment with other EventHubConsumerClient instances which also try to receive events from all partitions and use the same storage resource. """ - consumer_client.receive(on_events=on_events, consumer_group='$Default') + consumer_client.receive(on_event=on_event, consumer_group='$Default') # With specified partition_id, load-balance will be disabled - # client.receive(on_events=on_events, consumer_group='$Default', partition_id='0') + # client.receive(on_event=on_event, consumer_group='$Default', partition_id='0') except KeyboardInterrupt: print('Stop receiving.') diff --git a/sdk/eventhub/azure-eventhubs/samples/sync_samples/sample_code_eventhub.py b/sdk/eventhub/azure-eventhubs/samples/sync_samples/sample_code_eventhub.py index aa45e23fdf11..089d68f3d39c 100644 --- a/sdk/eventhub/azure-eventhubs/samples/sync_samples/sample_code_eventhub.py +++ b/sdk/eventhub/azure-eventhubs/samples/sync_samples/sample_code_eventhub.py @@ -14,8 +14,10 @@ def create_eventhub_producer_client(): from azure.eventhub import EventHubProducerClient event_hub_connection_str = os.environ['EVENT_HUB_CONN_STR'] event_hub = os.environ['EVENT_HUB_NAME'] - producer = EventHubProducerClient.from_connection_string(conn_str=event_hub_connection_str, - event_hub_path=event_hub) + producer = EventHubProducerClient.from_connection_string( + conn_str=event_hub_connection_str, + event_hub_path=event_hub + ) # [END create_eventhub_producer_client_from_conn_str_sync] # [START create_eventhub_producer_client_sync] @@ -27,9 +29,12 @@ def create_eventhub_producer_client(): shared_access_policy = os.environ['EVENT_HUB_SAS_POLICY'] shared_access_key = os.environ['EVENT_HUB_SAS_KEY'] - producer = EventHubProducerClient(host=hostname, - event_hub_path=event_hub, - credential=EventHubSharedKeyCredential(shared_access_policy, shared_access_key)) + credential = EventHubSharedKeyCredential(shared_access_policy, shared_access_key) + producer = EventHubProducerClient( + host=hostname, + event_hub_path=event_hub, + credential=credential + ) # [END create_eventhub_producer_client_sync] return producer @@ -40,8 +45,10 @@ def create_eventhub_consumer_client(): from azure.eventhub import EventHubConsumerClient event_hub_connection_str = os.environ['EVENT_HUB_CONN_STR'] event_hub = os.environ['EVENT_HUB_NAME'] - consumer = EventHubConsumerClient.from_connection_string(conn_str=event_hub_connection_str, - event_hub_path=event_hub) + consumer = EventHubConsumerClient.from_connection_string( + conn_str=event_hub_connection_str, + event_hub_path=event_hub + ) # [END create_eventhub_consumer_client_from_conn_str_sync] # [START create_eventhub_consumer_client_sync] @@ -53,9 +60,11 @@ def create_eventhub_consumer_client(): shared_access_policy = os.environ['EVENT_HUB_SAS_POLICY'] shared_access_key = os.environ['EVENT_HUB_SAS_KEY'] - consumer = EventHubConsumerClient(host=hostname, - event_hub_path=event_hub, - credential=EventHubSharedKeyCredential(shared_access_policy, shared_access_key)) + credential = EventHubSharedKeyCredential(shared_access_policy, shared_access_key) + consumer = EventHubConsumerClient( + host=hostname, + event_hub_path=event_hub, + credential=credential) # [END create_eventhub_consumer_client_sync] return consumer @@ -95,13 +104,12 @@ def example_eventhub_sync_send_and_receive(): # [START eventhub_consumer_client_receive_sync] logger = logging.getLogger("azure.eventhub") - def on_events(partition_context, events): - logger.info("Received {} messages from partition: {}".format( - len(events), partition_context.partition_id)) + def on_event(partition_context, event): + logger.info("Received event from partition: {}".format(partition_context.partition_id)) # Do ops on received events with consumer: - consumer.receive(on_events=on_events, consumer_group='$Default') + consumer.receive(on_event=on_event, consumer_group='$Default') # [END eventhub_consumer_client_receive_sync] finally: pass @@ -115,8 +123,10 @@ def example_eventhub_producer_ops(): event_hub_connection_str = os.environ['EVENT_HUB_CONN_STR'] event_hub = os.environ['EVENT_HUB_NAME'] - producer = EventHubProducerClient.from_connection_string(conn_str=event_hub_connection_str, - event_hub_path=event_hub) + producer = EventHubProducerClient.from_connection_string( + conn_str=event_hub_connection_str, + event_hub_path=event_hub + ) try: producer.send(EventData(b"A single event")) finally: @@ -141,17 +151,17 @@ def example_eventhub_consumer_ops(): logger = logging.getLogger("azure.eventhub") - def on_events(partition_context, events): - logger.info("Received {} messages from partition: {}".format( - len(events), partition_context.partition_id)) + def on_event(partition_context, event): + logger.info("Received event from partition: {}".format(partition_context.partition_id)) # Do ops on received events # The receive method is blocking call, so execute it in a thread to # better demonstrate how to stop the receiving by calling he close method. - worker = threading.Thread(target=consumer.receive, - kwargs={"on_events": on_events, - "consumer_group": "$Default"}) + worker = threading.Thread( + target=consumer.receive, + kwargs={"on_event": on_event, "consumer_group": "$Default"} + ) worker.start() time.sleep(10) # Keep receiving for 10s then close. # Close down the consumer handler explicitly. diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_auth_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_auth_async.py index 1dc9a517f24e..41ec3763d8c4 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_auth_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_auth_async.py @@ -30,15 +30,17 @@ async def test_client_secret_credential_async(aad_credential, live_eventhub): user_agent='customized information') async with producer_client: - await producer_client.send(EventData(body='A single message')) - - async def event_handler(partition_context, events): - assert partition_context.partition_id == '0' - assert len(events) == 1 - assert list(events[0].body)[0] == 'A single message'.encode('utf-8') + await producer_client.send(EventData(body='A single message'), partition_id='0') + def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + on_event.called = False async with consumer_client: - task = asyncio.ensure_future( - consumer_client.receive(event_handler=event_handler, consumer_group='$default', partition_id='0')) - await asyncio.sleep(2) - task.cancel() + task = asyncio.ensure_future(consumer_client.receive(on_event, '$default', partition_id='0')) + await asyncio.sleep(6) + await task + assert on_event.called is True + assert on_event.partition_id == "0" + assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_consumer_client_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_consumer_client_async.py index 85811d35fc73..b8d458b765f0 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_consumer_client_async.py @@ -2,7 +2,8 @@ import asyncio from azure.eventhub import EventData, EventPosition from azure.eventhub.aio import EventHubConsumerClient -from azure.eventhub.aio.eventprocessor.local_partition_manager import InMemoryPartitionManager +from azure.eventhub.aio._eventprocessor.local_partition_manager import InMemoryPartitionManager +from azure.eventhub._constants import ALL_PARTITIONS @pytest.mark.liveTest @@ -12,18 +13,17 @@ async def test_receive_no_partition_async(connstr_senders): senders[0].send(EventData("Test EventData")) senders[1].send(EventData("Test EventData")) client = EventHubConsumerClient.from_connection_string(connection_str) - received = 0 - async def on_events(partition_context, events): - nonlocal received - received += len(events) + async def on_event(partition_context, event): + on_event.received += 1 + on_event.received = 0 async with client: task = asyncio.ensure_future( - client.receive(on_events, consumer_group="$default", initial_event_position="-1")) + client.receive(on_event, consumer_group="$default", initial_event_position="-1")) await asyncio.sleep(10) - assert received == 2 - # task.cancel() + assert on_event.received == 2 + task.cancel() @pytest.mark.liveTest @@ -32,22 +32,21 @@ async def test_receive_partition_async(connstr_senders): connection_str, senders = connstr_senders senders[0].send(EventData("Test EventData")) client = EventHubConsumerClient.from_connection_string(connection_str) - received = 0 - async def on_events(partition_context, events): - nonlocal received - received += len(events) + async def on_event(partition_context, event): assert partition_context.partition_id == "0" assert partition_context.consumer_group_name == "$default" assert partition_context.fully_qualified_namespace in connection_str assert partition_context.eventhub_name == senders[0]._client.eh_name + on_event.received += 1 + on_event.received = 0 async with client: task = asyncio.ensure_future( - client.receive(on_events, consumer_group="$default", partition_id="0", initial_event_position="-1")) + client.receive(on_event, consumer_group="$default", partition_id="0", initial_event_position="-1")) await asyncio.sleep(10) - assert received == 1 - # task.cancel() + assert on_event.received == 1 + task.cancel() @pytest.mark.liveTest @@ -60,16 +59,16 @@ async def test_receive_load_balancing_async(connstr_senders): client2 = EventHubConsumerClient.from_connection_string( connection_str, partition_manager=pm, load_balancing_interval=1) - async def on_events(partition_context, events): + async def on_event(partition_context, event): pass async with client1, client2: task1 = asyncio.ensure_future( - client1.receive(on_events, consumer_group="$default", initial_event_position="-1")) + client1.receive(on_event, consumer_group="$default", initial_event_position="-1")) task2 = asyncio.ensure_future( - client2.receive(on_events, consumer_group="$default", initial_event_position="-1")) + client2.receive(on_event, consumer_group="$default", initial_event_position="-1")) await asyncio.sleep(10) - assert len(client1._event_processors[("$default", "-1")]._tasks) == 1 - assert len(client2._event_processors[("$default", "-1")]._tasks) == 1 - # task1.cancel() - # task2.cancel() + assert len(client1._event_processors[("$default", ALL_PARTITIONS)]._tasks) == 1 + assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._tasks) == 1 + task1.cancel() + task2.cancel() diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_eventprocessor_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_eventprocessor_async.py index 03a13d403590..edfdc9e76a1f 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_eventprocessor_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_eventprocessor_async.py @@ -9,14 +9,14 @@ from azure.eventhub import EventData, EventHubError -from azure.eventhub.aio.client_async import EventHubClient -from azure.eventhub.aio.eventprocessor.event_processor import EventProcessor, CloseReason -from azure.eventhub.aio.eventprocessor.local_partition_manager import InMemoryPartitionManager +from azure.eventhub.aio import EventHubConsumerClient +from azure.eventhub.aio._eventprocessor.event_processor import EventProcessor, CloseReason +from azure.eventhub.aio._eventprocessor.local_partition_manager import InMemoryPartitionManager from azure.eventhub import OwnershipLostError -from azure.eventhub.common import _Address +from azure.eventhub._client_base import _Address -async def event_handler(partition_context, events): +async def event_handler(partition_context, event): pass @@ -27,7 +27,7 @@ async def test_loadbalancer_balance(connstr_senders): connection_str, senders = connstr_senders for sender in senders: sender.send(EventData("EventProcessor Test")) - eventhub_client = EventHubClient.from_connection_string(connection_str, receive_timeout=3) + eventhub_client = EventHubConsumerClient.from_connection_string(connection_str) partition_manager = InMemoryPartitionManager() tasks = [] @@ -76,10 +76,6 @@ async def test_loadbalancer_balance(connstr_senders): assert len(event_processor2._tasks) == 2 # event_procesor2 takes another one after event_processor1 stops await event_processor2.stop() - ''' - for task in tasks: - task.cancel() - ''' await eventhub_client.close() @@ -93,7 +89,7 @@ async def list_ownership(self, fully_qualified_namespace, eventhub_name, consume connection_str, senders = connstr_senders for sender in senders: sender.send(EventData("EventProcessor Test")) - eventhub_client = EventHubClient.from_connection_string(connection_str, receive_timeout=3) + eventhub_client = EventHubConsumerClient.from_connection_string(connection_str) partition_manager = ErrorPartitionManager() event_processor = EventProcessor(eventhub_client=eventhub_client, @@ -123,31 +119,31 @@ async def test_partition_processor(connstr_senders): error = None async def partition_initialize_handler(partition_context): - assert partition_context + partition_initialize_handler.partition_context = partition_context - async def event_handler(partition_context, events): + async def event_handler(partition_context, event): async with lock: - if events: + if event: nonlocal checkpoint, event_map - event_map[partition_context.partition_id] = event_map.get(partition_context.partition_id, 0) + len(events) - offset, sn = events[-1].offset, events[-1].sequence_number + event_map[partition_context.partition_id] = event_map.get(partition_context.partition_id, 0) + 1 + offset, sn = event.offset, event.sequence_number checkpoint = (offset, sn) - await partition_context.update_checkpoint(events[-1]) + await partition_context.update_checkpoint(event) async def partition_close_handler(partition_context, reason): + assert partition_context and reason nonlocal close_reason close_reason = reason - assert partition_context and reason async def error_handler(partition_context, err): + assert partition_context and err nonlocal error error = err - assert partition_context and err connection_str, senders = connstr_senders for sender in senders: sender.send(EventData("EventProcessor Test")) - eventhub_client = EventHubClient.from_connection_string(connection_str, receive_timeout=3) + eventhub_client = EventHubConsumerClient.from_connection_string(connection_str) partition_manager = InMemoryPartitionManager() event_processor = EventProcessor(eventhub_client=eventhub_client, @@ -170,12 +166,13 @@ async def error_handler(partition_context, err): assert checkpoint is not None assert close_reason == CloseReason.SHUTDOWN assert error is None + assert partition_initialize_handler.partition_context @pytest.mark.liveTest @pytest.mark.asyncio async def test_partition_processor_process_events_error(connstr_senders): - async def event_handler(partition_context, events): + async def event_handler(partition_context, event): if partition_context.partition_id == "1": raise RuntimeError("processing events error") else: @@ -183,7 +180,7 @@ async def event_handler(partition_context, events): async def error_handler(partition_context, error): if partition_context.partition_id == "1": - assert isinstance(error, RuntimeError) + error_handler.error = error else: raise RuntimeError("There shouldn't be an error for partition other than 1") @@ -196,7 +193,7 @@ async def partition_close_handler(partition_context, reason): connection_str, senders = connstr_senders for sender in senders: sender.send(EventData("EventProcessor Test")) - eventhub_client = EventHubClient.from_connection_string(connection_str, receive_timeout=3) + eventhub_client = EventHubConsumerClient.from_connection_string(connection_str) partition_manager = InMemoryPartitionManager() event_processor = EventProcessor(eventhub_client=eventhub_client, @@ -212,18 +209,19 @@ async def partition_close_handler(partition_context, reason): await event_processor.stop() # task.cancel() await eventhub_client.close() + assert isinstance(error_handler.error, RuntimeError) @pytest.mark.asyncio async def test_partition_processor_process_eventhub_consumer_error(): - async def event_handler(partition_context, events): + async def event_handler(partition_context, event): pass async def error_handler(partition_context, error): - assert isinstance(error, EventHubError) + error_handler.error = error async def partition_close_handler(partition_context, reason): - assert reason == CloseReason.OWNERSHIP_LOST + partition_close_handler.reason = reason class MockEventHubClient(object): eh_name = "test_eh_name" @@ -232,14 +230,19 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) async def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + async def receive(self): raise EventHubError("Mock EventHubConsumer EventHubError") + async def close(self): pass @@ -258,22 +261,29 @@ async def close(self): await asyncio.sleep(5) await event_processor.stop() task.cancel() + assert isinstance(error_handler.error, EventHubError) + assert partition_close_handler.reason == CloseReason.OWNERSHIP_LOST + @pytest.mark.asyncio async def test_partition_processor_process_error_close_error(): async def partition_initialize_handler(partition_context): + partition_initialize_handler.called = True raise RuntimeError("initialize error") - async def event_handler(partition_context, events): + async def event_handler(partition_context, event): + event_handler.called = True raise RuntimeError("process_events error") async def error_handler(partition_context, error): assert isinstance(error, RuntimeError) + error_handler.called = True raise RuntimeError("process_error error") async def partition_close_handler(partition_context, reason): - assert reason == CloseReason.OWNERSHIP_LOST + assert reason == CloseReason.SHUTDOWN + partition_close_handler.called = True raise RuntimeError("close error") class MockEventHubClient(object): @@ -283,14 +293,20 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) async def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + async def receive(self): - return [EventData("mock events")] + await asyncio.sleep(0.1) + await self._on_event_received(EventData("mock events")) + async def close(self): pass @@ -309,6 +325,10 @@ async def close(self): await asyncio.sleep(5) await event_processor.stop() # task.cancel() + assert partition_initialize_handler.called + assert event_handler.called + assert error_handler.called + # assert partition_close_handler.called @pytest.mark.liveTest @@ -321,23 +341,21 @@ async def update_checkpoint( if partition_id == "1": raise OwnershipLostError("Mocked ownership lost") - async def event_handler(partition_context, events): - if events: - await partition_context.update_checkpoint(events[-1]) + async def event_handler(partition_context, event): + await partition_context.update_checkpoint(event) async def error_handler(partition_context, error): assert isinstance(error, OwnershipLostError) async def partition_close_handler(partition_context, reason): if partition_context.partition_id == "1": - assert reason == CloseReason.OWNERSHIP_LOST - else: assert reason == CloseReason.SHUTDOWN + partition_close_handler.called = True connection_str, senders = connstr_senders for sender in senders: sender.send(EventData("EventProcessor Test")) - eventhub_client = EventHubClient.from_connection_string(connection_str, receive_timeout=3) + eventhub_client = EventHubConsumerClient.from_connection_string(connection_str) partition_manager = ErrorPartitionManager() event_processor = EventProcessor(eventhub_client=eventhub_client, @@ -354,3 +372,4 @@ async def partition_close_handler(partition_context, reason): # task.cancel() await asyncio.sleep(1) await eventhub_client.close() + assert partition_close_handler.called diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_negative_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_negative_async.py index cce75cac93af..55eff8239192 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_negative_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_negative_async.py @@ -16,126 +16,78 @@ AuthenticationError, EventDataSendError, ) -from azure.eventhub.aio.client_async import EventHubClient +from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_with_invalid_hostname_async(invalid_hostname, connstr_receivers): + if sys.platform.startswith('darwin'): + pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " + "and blocking other tests") _, receivers = connstr_receivers - client = EventHubClient.from_connection_string(invalid_hostname) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - await sender.send(EventData("test data")) - await sender.close() - await client.close() + client = EventHubProducerClient.from_connection_string(invalid_hostname) + async with client: + with pytest.raises(ConnectError): + await client.send(EventData("test data")) +@pytest.mark.parametrize("invalid_place", + ["hostname", "key_name", "access_key", "event_hub", "partition"]) @pytest.mark.liveTest @pytest.mark.asyncio -async def test_receive_with_invalid_hostname_async(invalid_hostname): - client = EventHubClient.from_connection_string(invalid_hostname) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - await receiver.receive(timeout=3) - await receiver.close() - await client.close() - +async def test_receive_with_invalid_param_async(live_eventhub_config, invalid_place): + eventhub_config = live_eventhub_config.copy() + if invalid_place != "partition": + eventhub_config[invalid_place] = "invalid " + invalid_place + conn_str = live_eventhub_config["connection_str"].format( + eventhub_config['hostname'], + eventhub_config['key_name'], + eventhub_config['access_key'], + eventhub_config['event_hub']) -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_with_invalid_key_async(invalid_key, connstr_receivers): - _, receivers = connstr_receivers - client = EventHubClient.from_connection_string(invalid_key) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - await sender.send(EventData("test data")) - await sender.close() - await client.close() + client = EventHubConsumerClient.from_connection_string(conn_str, retry_total=0) + async def on_event(partition_context, event): + pass -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_invalid_key_async(invalid_key): - client = EventHubClient.from_connection_string(invalid_key) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - await receiver.receive(timeout=3) - await receiver.close() - await client.close() + async with client: + if invalid_place == "partition": + task = asyncio.ensure_future(client.receive(on_event, "$default", partition_id=invalid_place, + initial_event_position=EventPosition("-1"))) + else: + task = asyncio.ensure_future(client.receive(on_event, "$default", partition_id="0", + initial_event_position=EventPosition("-1"))) + await asyncio.sleep(10) + assert len(client._event_processors) == 1 + await task @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_with_invalid_policy_async(invalid_policy, connstr_receivers): - _, receivers = connstr_receivers - client = EventHubClient.from_connection_string(invalid_policy) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - await sender.send(EventData("test data")) - await sender.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_invalid_policy_async(invalid_policy): - client = EventHubClient.from_connection_string(invalid_policy) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - await receiver.receive(timeout=3) - await receiver.close() - await client.close() +async def test_send_with_invalid_key_async(invalid_key): + client = EventHubProducerClient.from_connection_string(invalid_key) + async with client: + with pytest.raises(ConnectError): + await client.send(EventData("test data")) @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_partition_key_with_partition_async(connection_str): - pytest.skip("No longer raise value error. EventData will be sent to partition_id") - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - try: - data = EventData(b"Data") - with pytest.raises(ValueError): - await sender.send(EventData("test data")) - finally: - await sender.close() - await client.close() +async def test_send_with_invalid_policy_async(invalid_policy): + client = EventHubProducerClient.from_connection_string(invalid_policy) + async with client: + with pytest.raises(ConnectError): + await client.send(EventData("test data")) @pytest.mark.liveTest @pytest.mark.asyncio async def test_non_existing_entity_sender_async(connection_str): - client = EventHubClient.from_connection_string(connection_str, event_hub_path="nemo") - sender = client._create_producer(partition_id="1") - with pytest.raises(AuthenticationError): - await sender.send(EventData("test data")) - await sender.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_non_existing_entity_receiver_async(connection_str): - client = EventHubClient.from_connection_string(connection_str, event_hub_path="nemo") - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - await receiver.receive(timeout=5) - await receiver.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_from_invalid_partitions_async(connection_str): - partitions = ["XYZ", "-1", "1000", "-"] - for p in partitions: - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id=p, event_position=EventPosition("-1")) + client = EventHubProducerClient.from_connection_string(connection_str, event_hub_path="nemo") + async with client: with pytest.raises(ConnectError): - await receiver.receive(timeout=5) - await receiver.close() - await client.close() + await client.send(EventData("test data")) @pytest.mark.liveTest @@ -143,12 +95,12 @@ async def test_receive_from_invalid_partitions_async(connection_str): async def test_send_to_invalid_partitions_async(connection_str): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id=p) - with pytest.raises(ConnectError): - await sender.send(EventData("test data")) - await sender.close() - await client.close() + client = EventHubProducerClient.from_connection_string(connection_str) + try: + with pytest.raises(ConnectError): + await client.send(EventData("test data"), partition_id=p) + finally: + await client.close() @pytest.mark.liveTest @@ -156,88 +108,43 @@ async def test_send_to_invalid_partitions_async(connection_str): async def test_send_too_large_message_async(connection_str): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() + client = EventHubProducerClient.from_connection_string(connection_str) try: data = EventData(b"A" * 1100000) with pytest.raises(EventDataSendError): - await sender.send(data) + await client.send(data) finally: - await sender.close() await client.close() @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_null_body_async(connection_str): - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() + client = EventHubProducerClient.from_connection_string(connection_str) try: with pytest.raises(ValueError): data = EventData(None) - await sender.send(data) + await client.send(data) finally: - await sender.close() await client.close() -async def pump(receiver): - async with receiver: - messages = 0 - count = 0 - batch = await receiver.receive(timeout=10) - while batch and count <= 5: - count += 1 - messages += len(batch) - batch = await receiver.receive(timeout=10) - return messages - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_max_receivers_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receivers = [] - for i in range(6): - receivers.append(client._create_consumer(consumer_group="$default", partition_id="0", prefetch=1000, event_position=EventPosition('@latest'))) - - outputs = await asyncio.gather( - pump(receivers[0]), - pump(receivers[1]), - pump(receivers[2]), - pump(receivers[3]), - pump(receivers[4]), - pump(receivers[5]), - return_exceptions=True) - print(outputs) - failed = [o for o in outputs if isinstance(o, EventHubError)] - assert len(failed) == 1 - print(failed[0].message) - await client.close() - - @pytest.mark.liveTest @pytest.mark.asyncio async def test_create_batch_with_invalid_hostname_async(invalid_hostname): - client = EventHubClient.from_connection_string(invalid_hostname) - sender = client._create_producer() - try: - with pytest.raises(AuthenticationError): - batch_event_data = await sender.create_batch(max_size=300) - finally: - await sender.close() - await client.close() + if sys.platform.startswith('darwin'): + pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " + "and blocking other tests") + client = EventHubProducerClient.from_connection_string(invalid_hostname) + async with client: + with pytest.raises(ConnectError): + await client.create_batch(max_size=300) @pytest.mark.liveTest @pytest.mark.asyncio async def test_create_batch_with_too_large_size_async(connection_str): - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - try: + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: with pytest.raises(ValueError): - batch_event_data = await sender.create_batch(max_size=5 * 1024 * 1024) - finally: - await sender.close() - await client.close() + await client.create_batch(max_size=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_producer_client_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_producer_client_async.py deleted file mode 100644 index 71879dad6e63..000000000000 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_producer_client_async.py +++ /dev/null @@ -1,77 +0,0 @@ -import pytest -import asyncio -from azure.eventhub import EventData -from azure.eventhub.aio import EventHubProducerClient - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_with_partition_key_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - - async with client: - data_val = 0 - for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: - partition_key = b"test_partition_" + partition - for i in range(50): - data = EventData(str(data_val)) - data_val += 1 - await client.send(data, partition_key=partition_key) - - found_partition_keys = {} - for index, partition in enumerate(receivers): - received = partition.receive(timeout=5) - for message in received: - try: - existing = found_partition_keys[message.partition_key] - assert existing == index - except KeyError: - found_partition_keys[message.partition_key] = index - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_partition_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - async with client: - await client.send(EventData(b"Data"), partition_id="1") - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 1 - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_partitio_concurrent_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - async with client: - await asyncio.gather(client.send(EventData(b"Data"), partition_id="1"), - client.send(EventData(b"Data"), partition_id="1")) - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 2 - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_no_partition_batch_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - async with client: - event_batch = await client.create_batch() - try: - while True: - event_batch.try_add(EventData(b"Data")) - except ValueError: - await client.send(event_batch) - - partition_0 = receivers[0].receive(timeout=2) - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_0) + len(partition_1) > 10 diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_properties_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_properties_async.py index 475640ed9b28..d40613be1c7c 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_properties_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_properties_async.py @@ -6,13 +6,13 @@ import pytest from azure.eventhub import EventHubSharedKeyCredential -from azure.eventhub.aio.client_async import EventHubClient +from azure.eventhub.aio import EventHubConsumerClient, EventHubProducerClient @pytest.mark.liveTest @pytest.mark.asyncio async def test_get_properties(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) ) properties = await client.get_properties() @@ -22,7 +22,7 @@ async def test_get_properties(live_eventhub): @pytest.mark.liveTest @pytest.mark.asyncio async def test_get_partition_ids(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) ) partition_ids = await client.get_partition_ids() @@ -33,7 +33,7 @@ async def test_get_partition_ids(live_eventhub): @pytest.mark.liveTest @pytest.mark.asyncio async def test_get_partition_properties(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key']) ) properties = await client.get_partition_properties('0') diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receive_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receive_async.py index 2584e505da8a..15baaa504ee9 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receive_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receive_async.py @@ -8,362 +8,142 @@ import pytest import time -from azure.eventhub import EventData, EventPosition, TransportType, ConnectionLostError -from azure.eventhub.aio.client_async import EventHubClient +from azure.eventhub import EventData, EventPosition, TransportType, EventHubError +from azure.eventhub.aio import EventHubConsumerClient @pytest.mark.liveTest @pytest.mark.asyncio async def test_receive_end_of_stream_async(connstr_senders): + async def on_event(partition_context, event): + if partition_context.partition_id == "0": + assert event.body_as_str() == "Receiving only a single event" + assert list(event.body)[0] == b"Receiving only a single event" + on_event.called = True + on_event.called = False connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 + client = EventHubConsumerClient.from_connection_string(connection_str) + async with client: + task = asyncio.ensure_future(client.receive(on_event, "$default", partition_id="0", initial_event_position="-1")) + await asyncio.sleep(10) + assert on_event.called is False senders[0].send(EventData(b"Receiving only a single event")) - received = await receiver.receive(timeout=5) - assert len(received) == 1 - - assert list(received[-1].body)[0] == b"Receiving only a single event" - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_offset_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - time.sleep(1) - received = await receiver.receive(timeout=3) - assert len(received) == 1 - offset = received[0].offset - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=False)) - async with offset_receiver: - received = await offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message after offset")) - received = await offset_receiver.receive(timeout=5) - assert len(received) == 1 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_inclusive_offset_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - time.sleep(1) - received = await receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].offset - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=True)) - async with offset_receiver: - received = await offset_receiver.receive(timeout=5) - assert len(received) == 1 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_datetime_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = await receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].enqueued_time - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset)) - async with offset_receiver: - received = await offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message after timestamp")) - time.sleep(1) - received = await offset_receiver.receive(timeout=5) - assert len(received) == 1 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_sequence_no_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = await receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].sequence_number - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset)) - async with offset_receiver: - received = await offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message next in sequence")) - time.sleep(1) - received = await offset_receiver.receive(timeout=5) - assert len(received) == 1 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_with_inclusive_sequence_no_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = await receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].sequence_number - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=True)) - async with offset_receiver: - received = await offset_receiver.receive(timeout=5) - assert len(received) == 1 - await client.close() + await asyncio.sleep(10) + assert on_event.called is True + task.cancel() +@pytest.mark.parametrize("position, inclusive, expected_result", + [("offset", False, "Exclusive"), + ("offset", True, "Inclusive"), + ("sequence", False, "Exclusive"), + ("sequence", True, "Inclusive"), + ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest @pytest.mark.asyncio -async def test_receive_batch_async(connstr_senders): +async def test_receive_with_event_position_async(connstr_senders, position, inclusive, expected_result): + async def on_event(partition_context, event): + assert event.last_enqueued_event_properties.get('sequence_number') == event.sequence_number + assert event.last_enqueued_event_properties.get('offset') == event.offset + assert event.last_enqueued_event_properties.get('enqueued_time') == event.enqueued_time + assert event.last_enqueued_event_properties.get('retrieval_time') is not None + + if position == "offset": + on_event.event_position = event.offset + elif position == "sequence": + on_event.event_position = event.sequence_number + else: + on_event.event_position = event.enqueued_time + on_event.event = event + + on_event.event_position = None connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - for i in range(10): - senders[0].send(EventData(b"Data")) - received = await receiver.receive(max_batch_size=5, timeout=5) - assert len(received) == 5 - - for event in received: - assert event.system_properties - assert event.sequence_number is not None - assert event.offset - assert event.enqueued_time - await client.close() - - -async def pump(receiver, sleep=None): - messages = 0 - count = 0 - if sleep: - await asyncio.sleep(sleep) - batch = await receiver.receive(timeout=10) - while batch: - count += 1 - if count >= 10: - break - messages += len(batch) - batch = await receiver.receive(timeout=10) - return messages + senders[0].send(EventData(b"Inclusive")) + client = EventHubConsumerClient.from_connection_string(connection_str) + async with client: + task = asyncio.ensure_future(client.receive(on_event, "$default", + initial_event_position="-1", + track_last_enqueued_event_properties=True)) + await asyncio.sleep(10) + assert on_event.event_position is not None + task.cancel() + senders[0].send(EventData(expected_result)) + client2 = EventHubConsumerClient.from_connection_string(connection_str) + async with client2: + task = asyncio.ensure_future( + client2.receive(on_event, "$default", + initial_event_position= EventPosition(on_event.event_position, inclusive), + track_last_enqueued_event_properties=True)) + await asyncio.sleep(10) + assert on_event.event.body_as_str() == expected_result + task.cancel() @pytest.mark.liveTest @pytest.mark.asyncio -async def test_exclusive_receiver_async(connstr_senders): - connection_str, senders = connstr_senders - senders[0].send(EventData(b"Receiving only a single event")) +async def test_receive_owner_level_async(connstr_senders): + app_prop = {"raw_prop": "raw_value"} - client = EventHubClient.from_connection_string(connection_str) - receiver1 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), owner_level=10, prefetch=5) - receiver2 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), owner_level=20, prefetch=10) - try: - await pump(receiver1) - output2 = await pump(receiver2) - with pytest.raises(ConnectionLostError): - await receiver1.receive(timeout=3) - assert output2 == 1 - finally: - await receiver1.close() - await receiver2.close() - await client.close() + async def on_event(partition_context, event): + pass + async def on_error(partition_context, error): + on_error.error = error - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_multiple_receiver_async(connstr_senders): + on_error.error = None connection_str, senders = connstr_senders - senders[0].send(EventData(b"Receiving only a single event")) - - client = EventHubClient.from_connection_string(connection_str) - partitions = await client.get_properties() - assert partitions["partition_ids"] == ["0", "1"] - receivers = [] - for i in range(2): - receivers.append(client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), prefetch=10)) - try: - more_partitions = await client.get_properties() - assert more_partitions["partition_ids"] == ["0", "1"] - outputs = [0, 0] - outputs[0] = await pump(receivers[0]) - outputs[1] = await pump(receivers[1]) - assert isinstance(outputs[0], int) and outputs[0] == 1 - assert isinstance(outputs[1], int) and outputs[1] == 1 - finally: - for r in receivers: - await r.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_exclusive_receiver_after_non_exclusive_receiver_async(connstr_senders): - connection_str, senders = connstr_senders - senders[0].send(EventData(b"Receiving only a single event")) - - client = EventHubClient.from_connection_string(connection_str) - receiver1 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), prefetch=10) - receiver2 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), owner_level=15, prefetch=10) - try: - await pump(receiver1) - output2 = await pump(receiver2) - with pytest.raises(ConnectionLostError): - await receiver1.receive(timeout=3) - assert output2 == 1 - finally: - await receiver1.close() - await receiver2.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_non_exclusive_receiver_after_exclusive_receiver_async(connstr_senders): - connection_str, senders = connstr_senders - senders[0].send(EventData(b"Receiving only a single event")) - - client = EventHubClient.from_connection_string(connection_str) - receiver1 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), owner_level=15, prefetch=10) - receiver2 = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1"), prefetch=10) - try: - output1 = await pump(receiver1) - with pytest.raises(ConnectionLostError): - await pump(receiver2) - assert output1 == 1 - finally: - await receiver1.close() - await receiver2.close() - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_batch_with_app_prop_async(connstr_senders): - connection_str, senders = connstr_senders - app_prop_key = "raw_prop" - app_prop_value = "raw_value" - app_prop = {app_prop_key: app_prop_value} - - def batched(): - for i in range(10): - ed = EventData("Event Data {}".format(i)) - ed.application_properties = app_prop - yield ed - for i in range(10, 20): - ed = EventData("Event Data {}".format(i)) - ed.application_properties = app_prop - yield ed - - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(batched()) - - await asyncio.sleep(1) - - received = await receiver.receive(max_batch_size=15, timeout=5) - assert len(received) == 15 - - for index, message in enumerate(received): - assert list(message.body)[0] == "Event Data {}".format(index).encode('utf-8') - assert (app_prop_key.encode('utf-8') in message.application_properties) \ - and (dict(message.application_properties)[app_prop_key.encode('utf-8')] == app_prop_value.encode('utf-8')) - await client.close() + client1 = EventHubConsumerClient.from_connection_string(connection_str) + client2 = EventHubConsumerClient.from_connection_string(connection_str) + async with client1, client2: + task1 = asyncio.ensure_future(client1.receive(on_event, "$default", + partition_id="0", initial_event_position="-1", + on_error=on_error)) + event_list = [] + for i in range(5): + ed = EventData("Event Number {}".format(i)) + event_list.append(ed) + senders[0].send(event_list) + await asyncio.sleep(10) + task2 = asyncio.ensure_future(client2.receive(on_event, "$default", + partition_id="0", initial_event_position="-1", + owner_level=1)) + event_list = [] + for i in range(5): + ed = EventData("Event Number {}".format(i)) + event_list.append(ed) + senders[0].send(event_list) + await asyncio.sleep(10) + task1.cancel() + task2.cancel() + assert isinstance(on_error.error, EventHubError) @pytest.mark.liveTest @pytest.mark.asyncio async def test_receive_over_websocket_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - - event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) - - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(event_list) + app_prop = {"raw_prop": "raw_value"} - time.sleep(1) + async def on_event(partition_context, event): + on_event.received.append(event) + on_event.app_prop = event.application_properties - received = await receiver.receive(max_batch_size=50, timeout=5) - assert len(received) == 20 - await client.close() - - -@pytest.mark.asyncio -@pytest.mark.liveTest -async def test_receive_run_time_metric_async(connstr_senders): - from uamqp import __version__ as uamqp_version - from distutils.version import StrictVersion - if StrictVersion(uamqp_version) < StrictVersion('1.2.3'): - pytest.skip("Disabled for uamqp 1.2.2. Will enable after uamqp 1.2.3 is released.") + on_event.received = [] + on_event.app_prop = None connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", - event_position=EventPosition('@latest'), prefetch=500, - track_last_enqueued_event_properties=True) + client = EventHubConsumerClient.from_connection_string(connection_str, + transport_type=TransportType.AmqpOverWebsocket) event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) - - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(event_list) - - await asyncio.sleep(1) - - received = await receiver.receive(max_batch_size=50, timeout=5) - assert len(received) == 20 - assert receiver.last_enqueued_event_properties - assert receiver.last_enqueued_event_properties.get('sequence_number', None) - assert receiver.last_enqueued_event_properties.get('offset', None) - assert receiver.last_enqueued_event_properties.get('enqueued_time', None) - assert receiver.last_enqueued_event_properties.get('retrieval_time', None) - await client.close() + for i in range(5): + ed = EventData("Event Number {}".format(i)) + ed.application_properties = app_prop + event_list.append(ed) + senders[0].send(event_list) + + async with client: + task = asyncio.ensure_future(client.receive(on_event, "$default", + partition_id="0", initial_event_position="-1")) + await asyncio.sleep(10) + task.cancel() + assert len(on_event.received) == 5 + for ed in on_event.received: + assert ed.application_properties[b"raw_prop"] == b"raw_value" diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receiver_iterator_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receiver_iterator_async.py deleted file mode 100644 index e10fc60db09c..000000000000 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_receiver_iterator_async.py +++ /dev/null @@ -1,28 +0,0 @@ -#------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -#-------------------------------------------------------------------------- - -import pytest - -from azure.eventhub import EventData, EventPosition, EventHubError, TransportType -from azure.eventhub.aio.client_async import EventHubClient - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_receive_iterator_async(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - async with receiver: - received = await receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Receiving only a single event")) - async for item in receiver: - received.append(item) - break - assert len(received) == 1 - assert list(received[-1].body)[0] == b"Receiving only a single event" - await client.close() diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_reconnect_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_reconnect_async.py index b260af3c6b07..9f850f2080c9 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_reconnect_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_reconnect_async.py @@ -8,32 +8,45 @@ import asyncio import pytest -from azure.eventhub import EventData -from azure.eventhub.aio.client_async import EventHubClient +from azure.eventhub import EventData, EventHubSharedKeyCredential +from azure.eventhub.aio import EventHubProducerClient + +import uamqp +from uamqp import authentication @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_with_long_interval_async(connstr_receivers, sleep): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - try: +async def test_send_with_long_interval_async(live_eventhub, sleep): + sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) + async with sender: await sender.send(EventData(b"A single event")) for _ in range(1): if sleep: await asyncio.sleep(300) else: - sender._handler._connection._conn.destroy() + await sender._producers[-1]._handler._connection._conn.destroy() await sender.send(EventData(b"A single event")) - finally: - await sender.close() + partition_ids = await sender.get_partition_ids() received = [] - for r in receivers: - if not sleep: # if sender sleeps, the receivers will be disconnected. destroy connection to simulate - r._handler._connection._conn.destroy() - received.extend(r.receive(timeout=5)) + for p in partition_ids: + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + p) + receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) + try: + receiver.open() + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(timeout=5000)]) + finally: + receiver.close() + assert len(received) == 2 assert list(received[0].body)[0] == b"A single event" - await client.close() diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_send_async.py b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_send_async.py index a4ca028f4f16..8f6f75e67c54 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_send_async.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/asynctests/test_send_async.py @@ -12,276 +12,143 @@ import json from azure.eventhub import EventData, TransportType -from azure.eventhub.aio.client_async import EventHubClient +from azure.eventhub.aio import EventHubProducerClient @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_with_partition_key_async(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - - async with sender: + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: partition_key = b"test_partition_" + partition for i in range(50): data = EventData(str(data_val)) - # data.partition_key = partition_key data_val += 1 - await sender.send(data, partition_key=partition_key) + await client.send(data, partition_key=partition_key) found_partition_keys = {} for index, partition in enumerate(receivers): - received = partition.receive(timeout=5) + received = partition.receive_message_batch(timeout=5000) for message in received: try: - existing = found_partition_keys[message.partition_key] + event_data = EventData._from_message(message) + existing = found_partition_keys[event_data.partition_key] assert existing == index except KeyError: - found_partition_keys[message.partition_key] = index - await client.close() + found_partition_keys[event_data.partition_key] = index +@pytest.mark.parametrize("payload", [b"", b"A single event"]) @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_and_receive_zero_length_body_async(connstr_receivers): +async def test_send_and_receive_small_body_async(connstr_receivers, payload): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - async with sender: - await sender.send(EventData("")) - + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: + await client.send(EventData(payload)) received = [] for r in receivers: - received.extend(r.receive(timeout=1)) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) assert len(received) == 1 - assert list(received[0].body)[0] == b"" - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_single_event_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - async with sender: - await sender.send(EventData(b"A single event")) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=1)) - - assert len(received) == 1 - assert list(received[0].body)[0] == b"A single event" - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_batch_async(connstr_receivers): - connection_str, receivers = connstr_receivers - - def batched(): - for i in range(10): - yield EventData("Event number {}".format(i)) - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - async with sender: - await sender.send(batched()) - - time.sleep(1) - received = [] - for r in receivers: - received.extend(r.receive(timeout=3)) - - assert len(received) == 10 - for index, message in enumerate(received): - assert list(message.body)[0] == "Event number {}".format(index).encode('utf-8') - await client.close() + assert list(received[0].body)[0] == payload @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_partition_async(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - async with sender: - await sender.send(EventData(b"Data")) + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: + await client.send(EventData(b"Data"), partition_id="1") - partition_0 = receivers[0].receive(timeout=2) + partition_0 = receivers[0].receive_message_batch(timeout=5000) assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) + partition_1 = receivers[1].receive_message_batch(timeout=5000) assert len(partition_1) == 1 - await client.close() @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_non_ascii_async(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="0") - async with sender: - await sender.send(EventData("é,è,à,ù,â,ê,î,ô,û")) - await sender.send(EventData(json.dumps({"foo": "漢字"}))) + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: + await client.send(EventData(u"é,è,à,ù,â,ê,î,ô,û"), partition_id="0") + await client.send(EventData(json.dumps({"foo": u"漢字"})), partition_id="0") await asyncio.sleep(1) - partition_0 = receivers[0].receive(timeout=2) + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] assert len(partition_0) == 2 - assert partition_0[0].body_as_str() == "é,è,à,ù,â,ê,î,ô,û" - assert partition_0[1].body_as_json() == {"foo": "漢字"} - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_partition_batch_async(connstr_receivers): - connection_str, receivers = connstr_receivers - - def batched(): - for i in range(10): - yield EventData("Event number {}".format(i)) - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - async with sender: - await sender.send(batched()) - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 10 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_array_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - async with sender: - await sender.send(EventData([b"A", b"B", b"C"])) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=1)) - - assert len(received) == 1 - assert list(received[0].body) == [b"A", b"B", b"C"] - await client.close() + assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" + assert partition_0[1].body_as_json() == {"foo": u"漢字"} @pytest.mark.liveTest @pytest.mark.asyncio -async def test_send_multiple_clients_async(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender_0 = client._create_producer(partition_id="0") - sender_1 = client._create_producer(partition_id="1") - async with sender_0: - await sender_0.send(EventData(b"Message 0")) - async with sender_1: - await sender_1.send(EventData(b"Message 1")) - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 1 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 1 - await client.close() - - -@pytest.mark.liveTest -@pytest.mark.asyncio -async def test_send_batch_with_app_prop_async(connstr_receivers): +async def test_send_multiple_partition_with_app_prop_async(connstr_receivers): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - - def batched(): - for i in range(10): - ed = EventData("Event number {}".format(i)) - ed.application_properties = app_prop - yield ed - for i in range(10, 20): - ed = EventData("Event number {}".format(i)) - ed.application_properties = app_prop - yield ed - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - async with sender: - await sender.send(batched()) - - time.sleep(1) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=3)) - - assert len(received) == 20 - for index, message in enumerate(received): - assert list(message.body)[0] == "Event number {}".format(index).encode('utf-8') - assert (app_prop_key.encode('utf-8') in message.application_properties) \ - and (dict(message.application_properties)[app_prop_key.encode('utf-8')] == app_prop_value.encode('utf-8')) - await client.close() + client = EventHubProducerClient.from_connection_string(connection_str) + async with client: + ed0 = EventData(b"Message 0") + ed0.application_properties = app_prop + await client.send(ed0, partition_id="0") + ed1 = EventData(b"Message 1") + ed1.application_properties = app_prop + await client.send(ed1, partition_id="1") + + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + assert len(partition_0) == 1 + assert partition_0[0].application_properties[b"raw_prop"] == b"raw_value" + partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] + assert len(partition_1) == 1 + assert partition_1[0].application_properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_over_websocket_async(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - sender = client._create_producer() - - event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) + client = EventHubProducerClient.from_connection_string(connection_str, + transport_type=TransportType.AmqpOverWebsocket) - async with sender: - await sender.send(event_list) + async with client: + for i in range(20): + await client.send(EventData("Event Number {}".format(i))) time.sleep(1) received = [] for r in receivers: - received.extend(r.receive(timeout=3)) - + received.extend(r.receive_message_batch(timeout=5000)) assert len(received) == 20 - for r in receivers: - r.close() - await client.close() - @pytest.mark.liveTest @pytest.mark.asyncio async def test_send_with_create_event_batch_async(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - sender = client._create_producer() - - event_data_batch = await sender.create_batch(max_size=100000) - while True: - try: - event_data_batch.try_add(EventData('A single event data')) - except ValueError: - break - - await sender.send(event_data_batch) - - event_data_batch = await sender.create_batch(max_size=100000) - while True: - try: - event_data_batch.try_add(EventData('A single event data')) - except ValueError: - break - - await sender.send(event_data_batch) - await sender.close() - await client.close() + app_prop_key = "raw_prop" + app_prop_value = "raw_value" + app_prop = {app_prop_key: app_prop_value} + client = EventHubProducerClient.from_connection_string(connection_str, + transport_type=TransportType.AmqpOverWebsocket) + async with client: + event_data_batch = await client.create_batch(max_size=100000) + while True: + try: + ed = EventData('A single event data') + ed.application_properties = app_prop + event_data_batch.try_add(ed) + except ValueError: + break + await client.send(event_data_batch) + received = [] + for r in receivers: + received.extend(r.receive_message_batch(timeout=5000)) + assert len(received) > 1 + assert EventData._from_message(received[0]).application_properties[b"raw_prop"] == b"raw_value" diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_auth.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_auth.py index eb2d028c3964..ef9094c15703 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_auth.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_auth.py @@ -14,7 +14,7 @@ def test_client_secret_credential(aad_credential, live_eventhub): try: from azure.identity import EnvironmentCredential - except ImportError: + except ImportError as e: pytest.skip("No azure identity library") credential = EnvironmentCredential() producer_client = EventHubProducerClient(host=live_eventhub['hostname'], @@ -26,15 +26,21 @@ def test_client_secret_credential(aad_credential, live_eventhub): credential=credential, user_agent='customized information') with producer_client: - producer_client.send(EventData(body='A single message')) + producer_client.send(EventData(body='A single message'), partition_id="0") - def on_events(partition_context, events): - assert partition_context.partition_id == '0' - assert len(events) == 1 - assert list(events[0].body)[0] == 'A single message'.encode('utf-8') + def on_event(partition_context, event): + on_event.called = True + on_event.partition_id = partition_context.partition_id + on_event.event = event + on_event.called = False with consumer_client: - worker = threading.Thread(target=consumer_client.receive, args=(on_events,), + worker = threading.Thread(target=consumer_client.receive, args=(on_event,), kwargs={"consumer_group": '$default', "partition_id": '0'}) worker.start() - time.sleep(2) + time.sleep(6) + + worker.join() + assert on_event.called is True + assert on_event.partition_id == "0" + assert list(on_event.event.body)[0] == 'A single message'.encode('utf-8') diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_consumer_client.py index ee007256fb72..fc790f1c2eff 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_consumer_client.py @@ -5,6 +5,7 @@ from azure.eventhub import EventData from azure.eventhub import EventHubConsumerClient from azure.eventhub._eventprocessor.local_partition_manager import InMemoryPartitionManager +from azure.eventhub._constants import ALL_PARTITIONS @pytest.mark.liveTest @@ -14,18 +15,17 @@ def test_receive_no_partition(connstr_senders): senders[1].send(EventData("Test EventData")) client = EventHubConsumerClient.from_connection_string(connection_str, receive_timeout=1) - recv_cnt = {"received": 0} # substitution for nonlocal variable, 2.7 compatible - - def on_events(partition_context, events): - recv_cnt["received"] += len(events) + def on_event(partition_context, event): + on_event.received += 1 + on_event.received = 0 with client: worker = threading.Thread(target=client.receive, - args=(on_events,), + args=(on_event,), kwargs={"consumer_group": "$default", "initial_event_position": "-1"}) worker.start() time.sleep(10) - assert recv_cnt["received"] == 2 + assert on_event.received == 2 @pytest.mark.liveTest @@ -34,23 +34,26 @@ def test_receive_partition(connstr_senders): senders[0].send(EventData("Test EventData")) client = EventHubConsumerClient.from_connection_string(connection_str) - recv_cnt = {"received": 0} # substitution for nonlocal variable, 2.7 compatible - - def on_events(partition_context, events): - recv_cnt["received"] += len(events) - assert partition_context.partition_id == "0" - assert partition_context.consumer_group_name == "$default" - assert partition_context.fully_qualified_namespace in connection_str - assert partition_context.eventhub_name == senders[0]._client.eh_name + def on_event(partition_context, event): + on_event.received += 1 + on_event.partition_id = partition_context.partition_id + on_event.consumer_group_name = partition_context.consumer_group_name + on_event.fully_qualified_namespace = partition_context.fully_qualified_namespace + on_event.eventhub_name = partition_context.eventhub_name + on_event.received = 0 with client: worker = threading.Thread(target=client.receive, - args=(on_events,), + args=(on_event,), kwargs={"consumer_group": "$default", "initial_event_position": "-1", "partition_id": "0"}) worker.start() time.sleep(10) - assert recv_cnt["received"] == 1 + assert on_event.received == 1 + assert on_event.partition_id == "0" + assert on_event.consumer_group_name == "$default" + assert on_event.fully_qualified_namespace in connection_str + assert on_event.eventhub_name == senders[0]._client.eh_name @pytest.mark.liveTest @@ -65,20 +68,20 @@ def test_receive_load_balancing(connstr_senders): client2 = EventHubConsumerClient.from_connection_string( connection_str, partition_manager=pm, load_balancing_interval=1) - def on_events(partition_context, events): + def on_event(partition_context, event): pass with client1, client2: worker1 = threading.Thread(target=client1.receive, - args=(on_events,), + args=(on_event,), kwargs={"consumer_group": "$default", "initial_event_position": "-1"}) worker2 = threading.Thread(target=client2.receive, - args=(on_events,), + args=(on_event,), kwargs={"consumer_group": "$default", "initial_event_position": "-1"}) worker1.start() worker2.start() - time.sleep(20) - assert len(client1._event_processors[("$default", "-1")]._working_threads) == 1 - assert len(client2._event_processors[("$default", "-1")]._working_threads) == 1 + time.sleep(10) + assert len(client1._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 + assert len(client2._event_processors[("$default", ALL_PARTITIONS)]._consumers) == 1 diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_eventprocessor.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_eventprocessor.py index 1927a4ed865a..74b9acd12871 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_eventprocessor.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_eventprocessor.py @@ -9,12 +9,10 @@ import time from azure.eventhub import EventData, EventHubError -from azure.eventhub.client import EventHubClient from azure.eventhub._eventprocessor.event_processor import EventProcessor from azure.eventhub import CloseReason from azure.eventhub._eventprocessor.local_partition_manager import InMemoryPartitionManager -from azure.eventhub._eventprocessor.common import OwnershipLostError -from azure.eventhub.common import _Address +from azure.eventhub._client_base import _Address def event_handler(partition_context, events): @@ -30,15 +28,21 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + consumer = MockEventhubConsumer(**kwargs) + return consumer def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.1) - return [] + self._on_event_received(EventData("")) + def close(self): pass @@ -57,7 +61,7 @@ def close(self): threads.append(thread1) time.sleep(2) - ep1_after_start = len(event_processor1._working_threads) + ep1_after_start = len(event_processor1._consumers) event_processor2 = EventProcessor(eventhub_client=eventhub_client, consumer_group_name='$default', partition_manager=partition_manager, @@ -69,12 +73,12 @@ def close(self): thread2.start() threads.append(thread2) time.sleep(10) - ep2_after_start = len(event_processor2._working_threads) + ep2_after_start = len(event_processor2._consumers) event_processor1.stop() thread1.join() time.sleep(10) - ep2_after_ep1_stopped = len(event_processor2._working_threads) + ep2_after_ep1_stopped = len(event_processor2._consumers) event_processor2.stop() assert ep1_after_start == 2 @@ -94,15 +98,18 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): - time.sleep(0.5) - return [] + time.sleep(0.1) def close(self): pass @@ -120,7 +127,7 @@ def close(self): thread.start() time.sleep(2) event_processor_running = event_processor._running - event_processor_partitions = len(event_processor._working_threads) + event_processor_partitions = len(event_processor._consumers) event_processor.stop() thread.join() assert event_processor_running is True @@ -132,12 +139,12 @@ def test_partition_processor(): event_map = {} def partition_initialize_handler(partition_context): - assert_map["initialize"] = "called" assert partition_context + assert_map["initialize"] = "called" - def event_handler(partition_context, events): - event_map[partition_context.partition_id] = event_map.get(partition_context.partition_id, 0) + len(events) - partition_context.update_checkpoint(events[-1]) + def event_handler(partition_context, event): + event_map[partition_context.partition_id] = event_map.get(partition_context.partition_id, 0) + 1 + partition_context.update_checkpoint(event) assert_map["checkpoint"] = "checkpoint called" def partition_close_handler(partition_context, reason): @@ -153,15 +160,19 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.5) - return [EventData("test data")] + self._on_event_received(EventData("test data")) def close(self): pass @@ -182,7 +193,7 @@ def close(self): thread = threading.Thread(target=event_processor.start) thread.start() time.sleep(2) - ep_partitions = len(event_processor._working_threads) + ep_partitions = len(event_processor._consumers) event_processor.stop() time.sleep(2) assert ep_partitions == 2 @@ -195,7 +206,7 @@ def close(self): def test_partition_processor_process_events_error(): assert_result = {} - def event_handler(partition_context, events): + def event_handler(partition_context, event): if partition_context.partition_id == "1": raise RuntimeError("processing events error") else: @@ -203,15 +214,12 @@ def event_handler(partition_context, events): def error_handler(partition_context, error): if partition_context.partition_id == "1": - assert_result["error"] = "runtime error" + assert_result["error"] = error else: assert_result["error"] = "not an error" def partition_close_handler(partition_context, reason): - if partition_context.partition_id == "1": - assert reason == CloseReason.OWNERSHIP_LOST - else: - assert reason == CloseReason.SHUTDOWN + pass class MockEventHubClient(object): eh_name = "test_eh_name" @@ -220,15 +228,19 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.5) - return [EventData("test data")] + self._on_event_received(EventData("test data")) def close(self): pass @@ -248,7 +260,7 @@ def close(self): time.sleep(2) event_processor.stop() thread.join() - assert assert_result["error"] == "runtime error" + assert isinstance(assert_result["error"], RuntimeError) def test_partition_processor_process_eventhub_consumer_error(): @@ -269,12 +281,16 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.5) raise EventHubError("Mock EventHubConsumer EventHubError") @@ -301,18 +317,23 @@ def close(self): def test_partition_processor_process_error_close_error(): + def partition_initialize_handler(partition_context): + partition_initialize_handler.called = True raise RuntimeError("initialize error") - def event_handler(partition_context, events): + def event_handler(partition_context, event): + event_handler.called = True raise RuntimeError("process_events error") def error_handler(partition_context, error): assert isinstance(error, RuntimeError) + error_handler.called = True raise RuntimeError("process_error error") def partition_close_handler(partition_context, reason): - assert reason == CloseReason.OWNERSHIP_LOST + assert reason == CloseReason.SHUTDOWN + partition_close_handler.called = True raise RuntimeError("close error") class MockEventHubClient(object): @@ -322,15 +343,20 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.5) - return [EventData("mock events")] + self._on_event_received(EventData("test data")) + def close(self): pass @@ -351,6 +377,11 @@ def close(self): event_processor.stop() thread.join() + assert partition_initialize_handler.called + assert event_handler.called + assert error_handler.called + assert partition_close_handler.called + def test_partition_processor_process_update_checkpoint_error(): assert_map = {} @@ -361,9 +392,9 @@ def update_checkpoint( if partition_id == "1": raise ValueError("Mocked error") - def event_handler(partition_context, events): - if events: - partition_context.update_checkpoint(events[-1]) + def event_handler(partition_context, event): + if event: + partition_context.update_checkpoint(event) def error_handler(partition_context, error): assert_map["error"] = error @@ -378,15 +409,19 @@ def __init__(self): self._address = _Address(hostname="test", path=MockEventHubClient.eh_name) def _create_consumer(self, consumer_group_name, partition_id, event_position, **kwargs): - return MockEventhubConsumer() + return MockEventhubConsumer(**kwargs) def get_partition_ids(self): return ["0", "1"] class MockEventhubConsumer(object): + def __init__(self, **kwargs): + self.stop = False + self._on_event_received = kwargs.get("on_event_received") + def receive(self): time.sleep(0.5) - return [EventData("test data")] + self._on_event_received(EventData("test data")) def close(self): pass diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_negative.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_negative.py index 21e434c0b9be..4f849f5bf33a 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_negative.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_negative.py @@ -7,6 +7,7 @@ import pytest import time import sys +import threading from azure.eventhub import ( EventData, @@ -15,130 +16,59 @@ ConnectError, EventDataSendError) -from azure.eventhub.client import EventHubClient +from azure.eventhub import EventHubConsumerClient +from azure.eventhub import EventHubProducerClient @pytest.mark.liveTest def test_send_with_invalid_hostname(invalid_hostname): - client = EventHubClient.from_connection_string(invalid_hostname) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - sender.send(EventData("test data")) - sender.close() - client.close() + if sys.platform.startswith('darwin'): + pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " + "and blocking other tests") + client = EventHubProducerClient.from_connection_string(invalid_hostname) + with client: + with pytest.raises(ConnectError): + client.send(EventData("test data")) @pytest.mark.liveTest def test_receive_with_invalid_hostname_sync(invalid_hostname): - client = EventHubClient.from_connection_string(invalid_hostname) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - receiver.receive(timeout=5) - receiver.close() - client.close() - - -@pytest.mark.liveTest -def test_send_with_invalid_key(invalid_key): - client = EventHubClient.from_connection_string(invalid_key) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - sender.send(EventData("test data")) - sender.close() - client.close() - - -@pytest.mark.liveTest -def test_receive_with_invalid_key_sync(invalid_key): - client = EventHubClient.from_connection_string(invalid_key) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - - with pytest.raises(AuthenticationError): - receiver.receive(timeout=10) - receiver.close() + def on_event(partition_context, event): + pass + + client = EventHubConsumerClient.from_connection_string(invalid_hostname) + with client: + thread = threading.Thread(target=client.receive, + args=(on_event, ), + kwargs={"consumer_group": '$default'}) + thread.start() + time.sleep(2) + assert len(client._event_processors) == 1 + thread.join() + + +@pytest.mark.liveTest +def test_send_with_invalid_key(live_eventhub): + conn_str = live_eventhub["connection_str"].format( + live_eventhub['hostname'], + live_eventhub['key_name'], + 'invalid', + live_eventhub['event_hub']) + client = EventHubProducerClient.from_connection_string(conn_str) + with pytest.raises(ConnectError): + client.send(EventData("test data")) client.close() -@pytest.mark.liveTest -def test_send_with_invalid_policy(invalid_policy): - client = EventHubClient.from_connection_string(invalid_policy) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - sender.send(EventData("test data")) - sender.close() - client.close() - - -@pytest.mark.liveTest -def test_receive_with_invalid_policy_sync(invalid_policy): - client = EventHubClient.from_connection_string(invalid_policy) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - receiver.receive(timeout=5) - receiver.close() - client.close() - - -@pytest.mark.liveTest -def test_send_partition_key_with_partition_sync(connection_str): - pytest.skip("Skipped tentatively. Confirm whether to throw ValueError or just warn users") - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - try: - data = EventData(b"Data") - data._set_partition_key(b"PKey") - with pytest.raises(ValueError): - sender.send(data) - finally: - sender.close() - client.close() - - -@pytest.mark.liveTest -def test_non_existing_entity_sender(connection_str): - client = EventHubClient.from_connection_string(connection_str, event_hub_path="nemo") - sender = client._create_producer(partition_id="1") - with pytest.raises(AuthenticationError): - sender.send(EventData("test data")) - sender.close() - client.close() - - -@pytest.mark.liveTest -def test_non_existing_entity_receiver(connection_str): - client = EventHubClient.from_connection_string(connection_str, event_hub_path="nemo") - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition("-1")) - with pytest.raises(AuthenticationError): - receiver.receive(timeout=5) - receiver.close() - client.close() - - -@pytest.mark.liveTest -def test_receive_from_invalid_partitions_sync(connection_str): - partitions = ["XYZ", "-1", "1000", "-"] - for p in partitions: - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id=p, event_position=EventPosition("-1")) - try: - with pytest.raises(ConnectError): - receiver.receive(timeout=5) - finally: - receiver.close() - client.close() - - @pytest.mark.liveTest def test_send_to_invalid_partitions(connection_str): partitions = ["XYZ", "-1", "1000", "-"] for p in partitions: - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id=p) + client = EventHubProducerClient.from_connection_string(connection_str) try: with pytest.raises(ConnectError): - sender.send(EventData("test data")) + client.send(EventData("test data"), partition_id=p) finally: - sender.close() client.close() @@ -146,100 +76,40 @@ def test_send_to_invalid_partitions(connection_str): def test_send_too_large_message(connection_str): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() + client = EventHubProducerClient.from_connection_string(connection_str) try: data = EventData(b"A" * 1100000) with pytest.raises(EventDataSendError): - sender.send(data) + client.send(data) finally: - sender.close() client.close() @pytest.mark.liveTest def test_send_null_body(connection_str): - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() + client = EventHubProducerClient.from_connection_string(connection_str) try: with pytest.raises(ValueError): data = EventData(None) - sender.send(data) + client.send(data) finally: - sender.close() - client.close() - - -@pytest.mark.liveTest -def test_message_body_types(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - try: - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Bytes Data")) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - assert list(received[0].body) == [b'Bytes Data'] - assert received[0].body_as_str() == "Bytes Data" - with pytest.raises(TypeError): - received[0].body_as_json() - - senders[0].send(EventData("Str Data")) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - assert list(received[0].body) == [b'Str Data'] - assert received[0].body_as_str() == "Str Data" - with pytest.raises(TypeError): - received[0].body_as_json() - - senders[0].send(EventData(b'{"test_value": "JSON bytes data", "key1": true, "key2": 42}')) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - assert list(received[0].body) == [b'{"test_value": "JSON bytes data", "key1": true, "key2": 42}'] - assert received[0].body_as_str() == '{"test_value": "JSON bytes data", "key1": true, "key2": 42}' - assert received[0].body_as_json() == {"test_value": "JSON bytes data", "key1": True, "key2": 42} - - senders[0].send(EventData('{"test_value": "JSON str data", "key1": true, "key2": 42}')) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - assert list(received[0].body) == [b'{"test_value": "JSON str data", "key1": true, "key2": 42}'] - assert received[0].body_as_str() == '{"test_value": "JSON str data", "key1": true, "key2": 42}' - assert received[0].body_as_json() == {"test_value": "JSON str data", "key1": True, "key2": 42} - - senders[0].send(EventData(42)) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - assert received[0].body_as_str() == "42" - assert received[0].body == 42 - except: - raise - finally: - receiver.close() client.close() @pytest.mark.liveTest def test_create_batch_with_invalid_hostname_sync(invalid_hostname): - client = EventHubClient.from_connection_string(invalid_hostname) - sender = client._create_producer() - with pytest.raises(AuthenticationError): - sender.create_batch(max_size=300) - sender.close() - client.close() + if sys.platform.startswith('darwin'): + pytest.skip("Skipping on OSX - it keeps reporting 'Unable to set external certificates' " + "and blocking other tests") + client = EventHubProducerClient.from_connection_string(invalid_hostname) + with client: + with pytest.raises(ConnectError): + client.create_batch(max_size=300) @pytest.mark.liveTest def test_create_batch_with_too_large_size_sync(connection_str): - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with pytest.raises(ValueError): - sender.create_batch(max_size=5 * 1024 * 1024) - sender.close() - client.close() + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + with pytest.raises(ValueError): + client.create_batch(max_size=5 * 1024 * 1024) diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_producer_client.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_producer_client.py deleted file mode 100644 index 5e3a9b7ba7f4..000000000000 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_producer_client.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from azure.eventhub import EventData -from azure.eventhub import EventHubProducerClient - - -@pytest.mark.liveTest -def test_send_with_partition_key(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - - with client: - data_val = 0 - for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: - partition_key = b"test_partition_" + partition - for i in range(50): - data = EventData(str(data_val)) - data_val += 1 - client.send(data, partition_key=partition_key) - - found_partition_keys = {} - for index, partition in enumerate(receivers): - received = partition.receive(timeout=5) - for message in received: - try: - existing = found_partition_keys[message.partition_key] - assert existing == index - except KeyError: - found_partition_keys[message.partition_key] = index - - -@pytest.mark.liveTest -def test_send_partition(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - with client: - client.send(EventData(b"Data"), partition_id="1") - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 1 - client.close() - - -@pytest.mark.liveTest -def test_send_no_partition_batch(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubProducerClient.from_connection_string(connection_str) - with client: - event_batch = client.create_batch() - try: - while True: - event_batch.try_add(EventData(b"Data")) - except ValueError: - client.send(event_batch) - - partition_0 = receivers[0].receive(timeout=2) - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_0) + len(partition_1) > 10 diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_properties.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_properties.py index 9711c878afd3..bdbed7f99579 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_properties.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_properties.py @@ -6,12 +6,12 @@ import pytest from azure.eventhub import EventHubSharedKeyCredential -from azure.eventhub.client import EventHubClient +from azure.eventhub import EventHubConsumerClient @pytest.mark.liveTest def test_get_properties(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) properties = client.get_properties() assert properties['path'] == live_eventhub['event_hub'] and properties['partition_ids'] == ['0', '1'] @@ -20,7 +20,7 @@ def test_get_properties(live_eventhub): @pytest.mark.liveTest def test_get_partition_ids(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) partition_ids = client.get_partition_ids() assert partition_ids == ['0', '1'] @@ -29,7 +29,7 @@ def test_get_partition_ids(live_eventhub): @pytest.mark.liveTest def test_get_partition_properties(live_eventhub): - client = EventHubClient(live_eventhub['hostname'], live_eventhub['event_hub'], + client = EventHubConsumerClient(live_eventhub['hostname'], live_eventhub['event_hub'], EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) properties = client.get_partition_properties('0') assert properties['event_hub_path'] == live_eventhub['event_hub'] \ diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receive.py index 6ca06d8b148b..ca3fc981213d 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receive.py @@ -5,301 +5,149 @@ #-------------------------------------------------------------------------- import os +import threading import pytest import time import datetime -from azure.eventhub import EventData, EventPosition, TransportType -from azure.eventhub.client import EventHubClient +from azure.eventhub import EventData, EventPosition, TransportType, EventHubError +from azure.eventhub import EventHubConsumerClient @pytest.mark.liveTest def test_receive_end_of_stream(connstr_senders): + def on_event(partition_context, event): + if partition_context.partition_id == "0": + assert event.body_as_str() == "Receiving only a single event" + assert list(event.body)[0] == b"Receiving only a single event" + on_event.called = True + on_event.called = False connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 + client = EventHubConsumerClient.from_connection_string(connection_str) + with client: + thread = threading.Thread(target=client.receive, args=(on_event, "$default"), + kwargs={"partition_id": "0", "initial_event_position": "@latest"}) + thread.daemon = True + thread.start() + time.sleep(10) + assert on_event.called is False senders[0].send(EventData(b"Receiving only a single event")) - received = receiver.receive(timeout=5) - assert len(received) == 1 - - assert received[0].body_as_str() == "Receiving only a single event" - assert list(received[-1].body)[0] == b"Receiving only a single event" - client.close() - - -@pytest.mark.liveTest -def test_receive_with_offset_sync(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - partitions = client.get_properties() - assert partitions["partition_ids"] == ["0", "1"] - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - with receiver: - more_partitions = client.get_properties() - assert more_partitions["partition_ids"] == ["0", "1"] - - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].offset - - assert list(received[0].body) == [b'Data'] - assert received[0].body_as_str() == "Data" - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=False)) - with offset_receiver: - received = offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message after offset")) - received = offset_receiver.receive(timeout=5) - assert len(received) == 1 - client.close() - - -@pytest.mark.liveTest -def test_receive_with_inclusive_offset(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].offset - - assert list(received[0].body) == [b'Data'] - assert received[0].body_as_str() == "Data" - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=True)) - with offset_receiver: - received = offset_receiver.receive(timeout=5) - assert len(received) == 1 - client.close() - - -@pytest.mark.liveTest -def test_receive_with_datetime_sync(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - partitions = client.get_properties() - assert partitions["partition_ids"] == ["0", "1"] - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - with receiver: - more_partitions = client.get_properties() - assert more_partitions["partition_ids"] == ["0", "1"] - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].enqueued_time - - assert list(received[0].body) == [b'Data'] - assert received[0].body_as_str() == "Data" - - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset)) - with offset_receiver: - received = offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message after timestamp")) - received = offset_receiver.receive(timeout=5) - assert len(received) == 1 - client.close() + time.sleep(10) + assert on_event.called is True + thread.join() +@pytest.mark.parametrize("position, inclusive, expected_result", + [("offset", False, "Exclusive"), + ("offset", True, "Inclusive"), + ("sequence", False, "Exclusive"), + ("sequence", True, "Inclusive"), + ("enqueued_time", False, "Exclusive")]) @pytest.mark.liveTest -def test_receive_with_custom_datetime_sync(connstr_senders): +def test_receive_with_event_position_sync(connstr_senders, position, inclusive, expected_result): + def on_event(partition_context, event): + assert event.last_enqueued_event_properties.get('sequence_number') == event.sequence_number + assert event.last_enqueued_event_properties.get('offset') == event.offset + assert event.last_enqueued_event_properties.get('enqueued_time') == event.enqueued_time + assert event.last_enqueued_event_properties.get('retrieval_time') is not None + + if position == "offset": + on_event.event_position = event.offset + elif position == "sequence": + on_event.event_position = event.sequence_number + else: + on_event.event_position = event.enqueued_time + on_event.event = event + + on_event.event_position = None connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - for i in range(5): - senders[0].send(EventData(b"Message before timestamp")) - time.sleep(65) - - now = datetime.datetime.utcnow() - offset = datetime.datetime(now.year, now.month, now.day, now.hour, now.minute) - for i in range(5): - senders[0].send(EventData(b"Message after timestamp")) - - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset)) - with receiver: - all_received = [] - received = receiver.receive(timeout=5) - while received: - all_received.extend(received) - received = receiver.receive(timeout=5) - - assert len(all_received) == 5 - for received_event in all_received: - assert received_event.body_as_str() == "Message after timestamp" - assert received_event.enqueued_time > offset - client.close() + senders[0].send(EventData(b"Inclusive")) + senders[1].send(EventData(b"Inclusive")) + client = EventHubConsumerClient.from_connection_string(connection_str) + with client: + thread = threading.Thread(target=client.receive, args=(on_event, "$default"), + kwargs={"initial_event_position": "-1", + "track_last_enqueued_event_properties": True}) + thread.daemon = True + thread.start() + time.sleep(10) + assert on_event.event_position is not None + thread.join() + senders[0].send(EventData(expected_result)) + senders[1].send(EventData(expected_result)) + client2 = EventHubConsumerClient.from_connection_string(connection_str) + with client2: + thread = threading.Thread(target=client2.receive, args=(on_event, "$default"), + kwargs={"initial_event_position": EventPosition(on_event.event_position, inclusive), + "track_last_enqueued_event_properties": True}) + thread.daemon = True + thread.start() + time.sleep(10) + assert on_event.event.body_as_str() == expected_result + thread.join() @pytest.mark.liveTest -def test_receive_with_sequence_no(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - time.sleep(1) - received = receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].sequence_number +def test_receive_owner_level(connstr_senders): + def on_event(partition_context, event): + pass + def on_error(partition_context, error): + on_error.error = error - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, False)) - with offset_receiver: - received = offset_receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Message next in sequence")) - time.sleep(1) - received = offset_receiver.receive(timeout=5) - assert len(received) == 1 - client.close() - - -@pytest.mark.liveTest -def test_receive_with_inclusive_sequence_no(connstr_senders): + on_error.error = None connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Data")) - received = receiver.receive(timeout=5) - assert len(received) == 1 - offset = received[0].sequence_number - offset_receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition(offset, inclusive=True)) - with offset_receiver: - received = offset_receiver.receive(timeout=5) - assert len(received) == 1 - client.close() - - -@pytest.mark.liveTest -def test_receive_batch(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - for i in range(10): - senders[0].send(EventData(b"Data")) - received = receiver.receive(max_batch_size=5, timeout=5) - assert len(received) == 5 - - for event in received: - assert event.system_properties - assert event.sequence_number is not None - assert event.offset - assert event.enqueued_time - client.close() - - -@pytest.mark.liveTest -def test_receive_batch_with_app_prop_sync(connstr_senders): - #pytest.skip("Waiting on uAMQP release") - connection_str, senders = connstr_senders - app_prop_key = "raw_prop" - app_prop_value = "raw_value" - batch_app_prop = {app_prop_key: app_prop_value} - - def batched(): - for i in range(10): - ed = EventData("Event Data {}".format(i)) - ed.application_properties = batch_app_prop - yield ed - for i in range(10, 20): - ed = EventData("Event Data {}".format(i)) - ed.application_properties = batch_app_prop - yield ed - - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(batched()) - - time.sleep(1) - - received = receiver.receive(max_batch_size=15, timeout=5) - assert len(received) == 15 - - for index, message in enumerate(received): - assert list(message.body)[0] == "Event Data {}".format(index).encode('utf-8') - assert (app_prop_key.encode('utf-8') in message.application_properties) \ - and (dict(message.application_properties)[app_prop_key.encode('utf-8')] == app_prop_value.encode('utf-8')) - client.close() + client1 = EventHubConsumerClient.from_connection_string(connection_str) + client2 = EventHubConsumerClient.from_connection_string(connection_str) + with client1, client2: + thread1 = threading.Thread(target=client1.receive, args=(on_event, "$default"), + kwargs={"partition_id": "0", "initial_event_position": "-1", + "on_error": on_error}) + thread1.start() + event_list = [] + for i in range(5): + ed = EventData("Event Number {}".format(i)) + event_list.append(ed) + senders[0].send(event_list) + time.sleep(10) + thread2 = threading.Thread(target=client2.receive, args=(on_event, "$default"), + kwargs = {"partition_id": "0", "initial_event_position": "-1", "owner_level": 1}) + thread2.start() + event_list = [] + for i in range(5): + ed = EventData("Event Number {}".format(i)) + event_list.append(ed) + senders[0].send(event_list) + time.sleep(20) + thread1.join() + thread2.join() + assert isinstance(on_error.error, EventHubError) @pytest.mark.liveTest def test_receive_over_websocket_sync(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest'), prefetch=500) - - event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) - - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(event_list) - - time.sleep(1) - - received = receiver.receive(max_batch_size=50, timeout=5) - assert len(received) == 20 - client.close() + app_prop = {"raw_prop": "raw_value"} + def on_event(partition_context, event): + on_event.received.append(event) + on_event.app_prop = event.application_properties -@pytest.mark.liveTest -def test_receive_run_time_metric(connstr_senders): - from uamqp import __version__ as uamqp_version - from distutils.version import StrictVersion - if StrictVersion(uamqp_version) < StrictVersion('1.2.3'): - pytest.skip("Disabled for uamqp 1.2.2. Will enable after uamqp 1.2.3 is released.") + on_event.received = [] + on_event.app_prop = None connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", - event_position=EventPosition('@latest'), prefetch=500, - track_last_enqueued_event_properties=True) + client = EventHubConsumerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) - - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - - senders[0].send(event_list) - - time.sleep(1) - - received = receiver.receive(max_batch_size=50, timeout=5) - assert len(received) == 20 - assert receiver.last_enqueued_event_properties - assert receiver.last_enqueued_event_properties.get('sequence_number', None) - assert receiver.last_enqueued_event_properties.get('offset', None) - assert receiver.last_enqueued_event_properties.get('enqueued_time', None) - assert receiver.last_enqueued_event_properties.get('retrieval_time', None) - client.close() + for i in range(5): + ed = EventData("Event Number {}".format(i)) + ed.application_properties = app_prop + event_list.append(ed) + senders[0].send(event_list) + + with client: + thread = threading.Thread(target=client.receive, args=(on_event, "$default"), + kwargs={"partition_id": "0", "initial_event_position": "-1"}) + thread.start() + time.sleep(10) + assert len(on_event.received) == 5 + for ed in on_event.received: + assert ed.application_properties[b"raw_prop"] == b"raw_value" diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receiver_iterator.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receiver_iterator.py deleted file mode 100644 index 0e30aed027e3..000000000000 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_receiver_iterator.py +++ /dev/null @@ -1,30 +0,0 @@ -#------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -#-------------------------------------------------------------------------- - -import pytest - -from azure.eventhub import EventData, EventPosition -from azure.eventhub.client import EventHubClient - - -@pytest.mark.liveTest -def test_receive_iterator(connstr_senders): - connection_str, senders = connstr_senders - client = EventHubClient.from_connection_string(connection_str) - receiver = client._create_consumer(consumer_group="$default", partition_id="0", event_position=EventPosition('@latest')) - with receiver: - received = receiver.receive(timeout=5) - assert len(received) == 0 - senders[0].send(EventData(b"Receiving only a single event")) - - for item in receiver: - received.append(item) - break - - assert len(received) == 1 - assert received[0].body_as_str() == "Receiving only a single event" - assert list(received[-1].body)[0] == b"Receiving only a single event" - client.close() diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_reconnect.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_reconnect.py index 08281996c197..a604dbcc14e0 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_reconnect.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_reconnect.py @@ -7,29 +7,42 @@ import time import pytest -from azure.eventhub import EventData -from azure.eventhub.client import EventHubClient +import uamqp +from uamqp import authentication +from azure.eventhub import EventData, EventHubSharedKeyCredential +from azure.eventhub import EventHubProducerClient @pytest.mark.liveTest -def test_send_with_long_interval_sync(connstr_receivers, sleep): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() +def test_send_with_long_interval_sync(live_eventhub, sleep): + sender = EventHubProducerClient(live_eventhub['hostname'], live_eventhub['event_hub'], + EventHubSharedKeyCredential(live_eventhub['key_name'], live_eventhub['access_key'])) with sender: sender.send(EventData(b"A single event")) for _ in range(1): if sleep: time.sleep(300) else: - sender._handler._connection._conn.destroy() + sender._producers[-1]._handler._connection._conn.destroy() sender.send(EventData(b"A single event")) + partition_ids = sender.get_partition_ids() received = [] - for r in receivers: - if not sleep: - r._handler._connection._conn.destroy() - received.extend(r.receive(timeout=5)) + for p in partition_ids: + uri = "sb://{}/{}".format(live_eventhub['hostname'], live_eventhub['event_hub']) + sas_auth = authentication.SASTokenAuth.from_shared_access_key( + uri, live_eventhub['key_name'], live_eventhub['access_key']) + + source = "amqps://{}/{}/ConsumerGroups/{}/Partitions/{}".format( + live_eventhub['hostname'], + live_eventhub['event_hub'], + live_eventhub['consumer_group'], + p) + receiver = uamqp.ReceiveClient(source, auth=sas_auth, debug=False, timeout=5000, prefetch=500) + try: + receiver.open() + received.extend([EventData._from_message(x) for x in receiver.receive_message_batch(timeout=5000)]) + finally: + receiver.close() assert len(received) == 2 assert list(received[0].body)[0] == b"A single event" - client.close() diff --git a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_send.py b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_send.py index 568bf3405e51..d18c559c8c93 100644 --- a/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_send.py +++ b/sdk/eventhub/azure-eventhubs/tests/livetest/synctests/test_send.py @@ -11,34 +11,32 @@ import sys from azure.eventhub import EventData, TransportType -from azure.eventhub.client import EventHubClient +from azure.eventhub import EventHubProducerClient @pytest.mark.liveTest def test_send_with_partition_key(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: + client = EventHubProducerClient.from_connection_string(connection_str) + with client: data_val = 0 for partition in [b"a", b"b", b"c", b"d", b"e", b"f"]: partition_key = b"test_partition_" + partition for i in range(50): data = EventData(str(data_val)) - #data.partition_key = partition_key data_val += 1 - sender.send(data, partition_key=partition_key) + client.send(data, partition_key=partition_key) found_partition_keys = {} for index, partition in enumerate(receivers): - received = partition.receive(timeout=5) + received = partition.receive_message_batch(timeout=5000) for message in received: try: - existing = found_partition_keys[message.partition_key] + event_data = EventData._from_message(message) + existing = found_partition_keys[event_data.partition_key] assert existing == index except KeyError: - found_partition_keys[message.partition_key] = index - client.close() + found_partition_keys[event_data.partition_key] = index @pytest.mark.liveTest @@ -46,246 +44,120 @@ def test_send_and_receive_large_body_size(connstr_receivers): if sys.platform.startswith('darwin'): pytest.skip("Skipping on OSX - open issue regarding message size") connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: + client = EventHubProducerClient.from_connection_string(connection_str) + with client: payload = 250 * 1024 - sender.send(EventData("A" * payload)) + client.send(EventData("A" * payload)) received = [] for r in receivers: - received.extend(r.receive(timeout=10)) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=10000)]) assert len(received) == 1 assert len(list(received[0].body)[0]) == payload - client.close() +@pytest.mark.parametrize("payload", + [b"", b"A single event"]) @pytest.mark.liveTest -def test_send_and_receive_zero_length_body(connstr_receivers): +def test_send_and_receive_small_body(connstr_receivers, payload): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: - sender.send(EventData("")) - + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + client.send(EventData(payload)) received = [] for r in receivers: - received.extend(r.receive(timeout=1)) + received.extend([EventData._from_message(x) for x in r.receive_message_batch(timeout=5000)]) assert len(received) == 1 - assert list(received[0].body)[0] == b"" - client.close() - - -@pytest.mark.liveTest -def test_send_single_event(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: - sender.send(EventData(b"A single event")) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=1)) - - assert len(received) == 1 - assert list(received[0].body)[0] == b"A single event" - client.close() - - -@pytest.mark.liveTest -def test_send_batch_sync(connstr_receivers): - connection_str, receivers = connstr_receivers - - def batched(): - for i in range(10): - yield EventData("Event number {}".format(i)) - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: - sender.send(batched()) - - time.sleep(1) - received = [] - for r in receivers: - received.extend(r.receive(timeout=3)) - - assert len(received) == 10 - for index, message in enumerate(received): - assert list(message.body)[0] == "Event number {}".format(index).encode('utf-8') - client.close() + assert list(received[0].body)[0] == payload @pytest.mark.liveTest def test_send_partition(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - with sender: - sender.send(EventData(b"Data")) + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + client.send(EventData(b"Data"), partition_id="1") - partition_0 = receivers[0].receive(timeout=2) + partition_0 = receivers[0].receive_message_batch(timeout=5000) assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) + partition_1 = receivers[1].receive_message_batch(timeout=5000) assert len(partition_1) == 1 - client.close() @pytest.mark.liveTest def test_send_non_ascii(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="0") - with sender: - sender.send(EventData(u"é,è,à,ù,â,ê,î,ô,û")) - sender.send(EventData(json.dumps({"foo": u"漢字"}))) + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + client.send(EventData(u"é,è,à,ù,â,ê,î,ô,û"), partition_id="0") + client.send(EventData(json.dumps({"foo": u"漢字"})), partition_id="0") time.sleep(1) - partition_0 = receivers[0].receive(timeout=2) + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] assert len(partition_0) == 2 assert partition_0[0].body_as_str() == u"é,è,à,ù,â,ê,î,ô,û" assert partition_0[1].body_as_json() == {"foo": u"漢字"} - client.close() - - -@pytest.mark.liveTest -def test_send_partition_batch(connstr_receivers): - connection_str, receivers = connstr_receivers - - def batched(): - for i in range(10): - yield EventData("Event number {}".format(i)) - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer(partition_id="1") - with sender: - sender.send(batched()) - time.sleep(1) - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 0 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 10 - client.close() - - -@pytest.mark.liveTest -def test_send_array_sync(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: - sender.send(EventData([b"A", b"B", b"C"])) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=1)) - - assert len(received) == 1 - assert list(received[0].body) == [b"A", b"B", b"C"] - client.close() @pytest.mark.liveTest -def test_send_multiple_clients(connstr_receivers): - connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str) - sender_0 = client._create_producer(partition_id="0") - sender_1 = client._create_producer(partition_id="1") - with sender_0: - sender_0.send(EventData(b"Message 0")) - with sender_1: - sender_1.send(EventData(b"Message 1")) - - partition_0 = receivers[0].receive(timeout=2) - assert len(partition_0) == 1 - partition_1 = receivers[1].receive(timeout=2) - assert len(partition_1) == 1 - client.close() - - -@pytest.mark.liveTest -def test_send_batch_with_app_prop_sync(connstr_receivers): +def test_send_multiple_partitions_with_app_prop(connstr_receivers): connection_str, receivers = connstr_receivers app_prop_key = "raw_prop" app_prop_value = "raw_value" app_prop = {app_prop_key: app_prop_value} - - def batched(): - for i in range(10): - ed = EventData("Event number {}".format(i)) - ed.application_properties = app_prop - yield ed - for i in range(10, 20): - ed = EventData("Event number {}".format(i)) - ed.application_properties = app_prop - yield ed - - client = EventHubClient.from_connection_string(connection_str) - sender = client._create_producer() - with sender: - sender.send(batched()) - - time.sleep(1) - - received = [] - for r in receivers: - received.extend(r.receive(timeout=3)) - - assert len(received) == 20 - for index, message in enumerate(received): - assert list(message.body)[0] == "Event number {}".format(index).encode('utf-8') - assert (app_prop_key.encode('utf-8') in message.application_properties) \ - and (dict(message.application_properties)[app_prop_key.encode('utf-8')] == app_prop_value.encode('utf-8')) - client.close() + client = EventHubProducerClient.from_connection_string(connection_str) + with client: + ed0 = EventData(b"Message 0") + ed0.application_properties = app_prop + client.send(ed0, partition_id="0") + ed1 = EventData(b"Message 1") + ed1.application_properties = app_prop + client.send(ed1, partition_id="1") + + partition_0 = [EventData._from_message(x) for x in receivers[0].receive_message_batch(timeout=5000)] + assert len(partition_0) == 1 + assert partition_0[0].application_properties[b"raw_prop"] == b"raw_value" + partition_1 = [EventData._from_message(x) for x in receivers[1].receive_message_batch(timeout=5000)] + assert len(partition_1) == 1 + assert partition_1[0].application_properties[b"raw_prop"] == b"raw_value" @pytest.mark.liveTest def test_send_over_websocket_sync(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - sender = client._create_producer() + client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - event_list = [] - for i in range(20): - event_list.append(EventData("Event Number {}".format(i))) - - with sender: - sender.send(event_list) + with client: + for i in range(20): + client.send(EventData("Event Number {}".format(i))) time.sleep(1) received = [] for r in receivers: - received.extend(r.receive(timeout=3)) - + received.extend(r.receive_message_batch(timeout=5000)) assert len(received) == 20 - client.close() @pytest.mark.liveTest -def test_send_with_create_event_batch_sync(connstr_receivers): +def test_send_with_create_event_batch_with_app_prop_sync(connstr_receivers): connection_str, receivers = connstr_receivers - client = EventHubClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) - sender = client._create_producer() - - event_data_batch = sender.create_batch(max_size=100000) - while True: - try: - event_data_batch.try_add(EventData('A single event data')) - except ValueError: - break - - sender.send(event_data_batch) - - event_data_batch = sender.create_batch(max_size=100000) - while True: - try: - event_data_batch.try_add(EventData('A single event data')) - except ValueError: - break - - sender.send(event_data_batch) - sender.close() - client.close() + app_prop_key = "raw_prop" + app_prop_value = "raw_value" + app_prop = {app_prop_key: app_prop_value} + client = EventHubProducerClient.from_connection_string(connection_str, transport_type=TransportType.AmqpOverWebsocket) + with client: + event_data_batch = client.create_batch(max_size=100000) + while True: + try: + ed = EventData('A single event data') + ed.application_properties = app_prop + event_data_batch.try_add(ed) + except ValueError: + break + client.send(event_data_batch) + received = [] + for r in receivers: + received.extend(r.receive_message_batch(timeout=5000)) + assert len(received) > 1 + assert EventData._from_message(received[0]).application_properties[b"raw_prop"] == b"raw_value" diff --git a/sdk/eventhub/azure-eventhubs/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhubs/tests/unittest/test_event_data.py index 44612759109f..c7cce9cb3730 100644 --- a/sdk/eventhub/azure-eventhubs/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhubs/tests/unittest/test_event_data.py @@ -33,6 +33,12 @@ def test_body_json(): assert jo["a"] == "b" +def test_body_wrong_json(): + event_data = EventData('aaa') + with pytest.raises(TypeError): + event_data.body_as_json() + + def test_app_properties(): app_props = {"a": "b"} event_data = EventData("")