diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py index 4ccf42e02b12..50c187482aba 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_common.py @@ -22,6 +22,7 @@ from uamqp import BatchMessage, Message, constants +from ._mixin import DictMixin from ._utils import ( set_message_partition_key, trace_message, @@ -585,70 +586,3 @@ def add(self, event_data): self.message._body_gen.append(outgoing_event_data) # pylint: disable=protected-access self._size = size_after_add self._count += 1 - - -class DictMixin(object): - def __setitem__(self, key, item): - # type: (Any, Any) -> None - self.__dict__[key] = item - - def __getitem__(self, key): - # type: (Any) -> Any - return self.__dict__[key] - - def __contains__(self, key): - return key in self.__dict__ - - def __repr__(self): - # type: () -> str - return str(self) - - def __len__(self): - # type: () -> int - return len(self.keys()) - - def __delitem__(self, key): - # type: (Any) -> None - self.__dict__[key] = None - - def __eq__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - return not self.__eq__(other) - - def __str__(self): - # type: () -> str - return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) - - def has_key(self, k): - # type: (Any) -> bool - return k in self.__dict__ - - def update(self, *args, **kwargs): - # type: (Any, Any) -> None - return self.__dict__.update(*args, **kwargs) - - def keys(self): - # type: () -> list - return [k for k in self.__dict__ if not k.startswith("_")] - - def values(self): - # type: () -> list - return [v for k, v in self.__dict__.items() if not k.startswith("_")] - - def items(self): - # type: () -> list - return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] - - def get(self, key, default=None): - # type: (Any, Optional[Any]) -> Any - if key in self.__dict__: - return self.__dict__[key] - return default diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py index 81fe47d36b31..58eede3b3212 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_connection_manager.py @@ -29,68 +29,6 @@ def reset_connection_if_broken(self): pass -class _ConnectionMode(Enum): - ShareConnection = 1 - SeparateConnection = 2 - - -class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes - def __init__(self, **kwargs): - self._lock = Lock() - self._conn = None # type: Connection - - self._container_id = kwargs.get("container_id") - self._debug = kwargs.get("debug") - self._error_policy = kwargs.get("error_policy") - self._properties = kwargs.get("properties") - self._encoding = kwargs.get("encoding") or "UTF-8" - self._transport_type = kwargs.get("transport_type") or TransportType.Amqp - self._http_proxy = kwargs.get("http_proxy") - self._max_frame_size = kwargs.get("max_frame_size") - self._channel_max = kwargs.get("channel_max") - self._idle_timeout = kwargs.get("idle_timeout") - self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( - "remote_idle_timeout_empty_frame_send_ratio" - ) - - def get_connection(self, host, auth): - # type: (str, JWTTokenAuth) -> Connection - with self._lock: - if self._conn is None: - self._conn = Connection( - host, - auth, - container_id=self._container_id, - max_frame_size=self._max_frame_size, - channel_max=self._channel_max, - idle_timeout=self._idle_timeout, - properties=self._properties, - remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, - error_policy=self._error_policy, - debug=self._debug, - encoding=self._encoding, - ) - return self._conn - - def close_connection(self): - # type: () -> None - with self._lock: - if self._conn: - self._conn.destroy() - self._conn = None - - def reset_connection_if_broken(self): - # type: () -> None - with self._lock: - if self._conn and self._conn._state in ( # pylint:disable=protected-access - c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member - ): - self._conn = None - - class _SeparateConnectionManager(object): def __init__(self, **kwargs): pass @@ -110,7 +48,4 @@ def reset_connection_if_broken(self): def get_connection_manager(**kwargs): # type: (...) -> 'ConnectionManager' - connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) - if connection_mode == _ConnectionMode.ShareConnection: - return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/_mixin.py b/sdk/eventhub/azure-eventhub/azure/eventhub/_mixin.py new file mode 100644 index 000000000000..f1db10965826 --- /dev/null +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/_mixin.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from typing import ( + Any, + Optional, +) + +class DictMixin(object): + def __setitem__(self, key, item): + # type: (Any, Any) -> None + self.__dict__[key] = item + + def __getitem__(self, key): + # type: (Any) -> Any + return self.__dict__[key] + + def __repr__(self): + # type: () -> str + return str(self) + + def __len__(self): + # type: () -> int + return len(self.keys()) + + def __delitem__(self, key): + # type: (Any) -> None + self.__dict__[key] = None + + def __eq__(self, other): + # type: (Any) -> bool + """Compare objects by comparing all attributes.""" + if isinstance(other, self.__class__): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + # type: (Any) -> bool + """Compare objects by comparing all attributes.""" + return not self.__eq__(other) + + def __str__(self): + # type: () -> str + return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) + + def has_key(self, k): + # type: (Any) -> bool + return k in self.__dict__ + + def update(self, *args, **kwargs): + # type: (Any, Any) -> None + return self.__dict__.update(*args, **kwargs) + + def keys(self): + # type: () -> list + return [k for k in self.__dict__ if not k.startswith("_")] + + def values(self): + # type: () -> list + return [v for k, v in self.__dict__.items() if not k.startswith("_")] + + def items(self): + # type: () -> list + return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] + + def get(self, key, default=None): + # type: (Any, Optional[Any]) -> Any + if key in self.__dict__: + return self.__dict__[key] + return default diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py index 5eeb7b83bb80..5c62f03cf77b 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/aio/_connection_manager_async.py @@ -9,8 +9,6 @@ from uamqp import TransportType, c_uamqp from uamqp.async_ops import ConnectionAsync -from .._connection_manager import _ConnectionMode - if TYPE_CHECKING: from uamqp.authentication import JWTTokenAsync @@ -32,62 +30,6 @@ async def reset_connection_if_broken(self) -> None: pass -class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes - def __init__(self, **kwargs) -> None: - self._loop = kwargs.get("loop") - self._lock = Lock(loop=self._loop) - self._conn = None - - self._container_id = kwargs.get("container_id") - self._debug = kwargs.get("debug") - self._error_policy = kwargs.get("error_policy") - self._properties = kwargs.get("properties") - self._encoding = kwargs.get("encoding") or "UTF-8" - self._transport_type = kwargs.get("transport_type") or TransportType.Amqp - self._http_proxy = kwargs.get("http_proxy") - self._max_frame_size = kwargs.get("max_frame_size") - self._channel_max = kwargs.get("channel_max") - self._idle_timeout = kwargs.get("idle_timeout") - self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get( - "remote_idle_timeout_empty_frame_send_ratio" - ) - - async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync: - async with self._lock: - if self._conn is None: - self._conn = ConnectionAsync( - host, - auth, - container_id=self._container_id, - max_frame_size=self._max_frame_size, - channel_max=self._channel_max, - idle_timeout=self._idle_timeout, - properties=self._properties, - remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio, - error_policy=self._error_policy, - debug=self._debug, - loop=self._loop, - encoding=self._encoding, - ) - return self._conn - - async def close_connection(self) -> None: - async with self._lock: - if self._conn: - await self._conn.destroy_async() - self._conn = None - - async def reset_connection_if_broken(self) -> None: - async with self._lock: - if self._conn and self._conn._state in ( # pylint:disable=protected-access - c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member - c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member - ): - self._conn = None - - class _SeparateConnectionManager(object): def __init__(self, **kwargs) -> None: pass @@ -103,7 +45,4 @@ async def reset_connection_if_broken(self) -> None: def get_connection_manager(**kwargs) -> "ConnectionManager": - connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection) - if connection_mode == _ConnectionMode.ShareConnection: - return _SharedConnectionManager(**kwargs) return _SeparateConnectionManager(**kwargs) diff --git a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py index 500d9797fbb6..2d0bad7551f8 100644 --- a/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py +++ b/sdk/eventhub/azure-eventhub/azure/eventhub/amqp/_amqp_message.py @@ -9,70 +9,7 @@ import uamqp from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType - - -class DictMixin(object): - def __setitem__(self, key, item): - # type: (Any, Any) -> None - self.__dict__[key] = item - - def __getitem__(self, key): - # type: (Any) -> Any - return self.__dict__[key] - - def __repr__(self): - # type: () -> str - return str(self) - - def __len__(self): - # type: () -> int - return len(self.keys()) - - def __delitem__(self, key): - # type: (Any) -> None - self.__dict__[key] = None - - def __eq__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - if isinstance(other, self.__class__): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - # type: (Any) -> bool - """Compare objects by comparing all attributes.""" - return not self.__eq__(other) - - def __str__(self): - # type: () -> str - return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")}) - - def has_key(self, k): - # type: (Any) -> bool - return k in self.__dict__ - - def update(self, *args, **kwargs): - # type: (Any, Any) -> None - return self.__dict__.update(*args, **kwargs) - - def keys(self): - # type: () -> list - return [k for k in self.__dict__ if not k.startswith("_")] - - def values(self): - # type: () -> list - return [v for k, v in self.__dict__.items() if not k.startswith("_")] - - def items(self): - # type: () -> list - return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")] - - def get(self, key, default=None): - # type: (Any, Optional[Any]) -> Any - if key in self.__dict__: - return self.__dict__[key] - return default +from .._mixin import DictMixin class AmqpAnnotatedMessage(object): diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_consumer_client_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_consumer_client_async.py index 0ff6ae711cbf..41c2a950b629 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_consumer_client_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_consumer_client_async.py @@ -16,7 +16,7 @@ async def test_receive_no_partition_async(connstr_senders): async def on_event(partition_context, event): on_event.received += 1 - await partition_context.update_checkpoint(event) + await partition_context.update_checkpoint(event, fake_kwarg="arg") # ignores fake_kwarg on_event.namespace = partition_context.fully_qualified_namespace on_event.eventhub_name = partition_context.eventhub_name on_event.consumer_group = partition_context.consumer_group diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_receive_async.py b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_receive_async.py index 3dece8daa1db..cd7de4f46037 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_receive_async.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/asynctests/test_receive_async.py @@ -21,6 +21,13 @@ async def on_event(partition_context, event): assert event.body_as_str() == "Receiving only a single event" assert list(event.body)[0] == b"Receiving only a single event" on_event.called = True + assert event.partition_key == b'0' + event_str = str(event) + assert ", offset: " in event_str + assert ", sequence_number: " in event_str + assert ", enqueued_time: " in event_str + assert ", partition_key: 0" in event_str + on_event.called = False connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') @@ -28,7 +35,7 @@ async def on_event(partition_context, event): task = asyncio.ensure_future(client.receive(on_event, partition_id="0", starting_position="@latest")) await asyncio.sleep(10) assert on_event.called is False - senders[0].send(EventData(b"Receiving only a single event")) + senders[0].send(EventData(b"Receiving only a single event"), partition_key='0') await asyncio.sleep(10) assert on_event.called is True diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py index 8da5ddeb6ead..5a2dca789f37 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_consumer_client.py @@ -17,7 +17,7 @@ def test_receive_no_partition(connstr_senders): def on_event(partition_context, event): on_event.received += 1 - partition_context.update_checkpoint(event) + partition_context.update_checkpoint(event, fake_kwarg="arg") # ignores fake_kwarg on_event.namespace = partition_context.fully_qualified_namespace on_event.eventhub_name = partition_context.eventhub_name on_event.consumer_group = partition_context.consumer_group diff --git a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py index 6152f0fff792..21d6e249581e 100644 --- a/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py +++ b/sdk/eventhub/azure-eventhub/tests/livetest/synctests/test_receive.py @@ -21,6 +21,12 @@ def on_event(partition_context, event): assert event.body_as_str() == "Receiving only a single event" assert list(event.body)[0] == b"Receiving only a single event" on_event.called = True + assert event.partition_key == b'0' + event_str = str(event) + assert ", offset: " in event_str + assert ", sequence_number: " in event_str + assert ", enqueued_time: " in event_str + assert ", partition_key: 0" in event_str on_event.called = False connection_str, senders = connstr_senders client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default') @@ -31,7 +37,7 @@ def on_event(partition_context, event): thread.start() time.sleep(10) assert on_event.called is False - senders[0].send(EventData(b"Receiving only a single event")) + senders[0].send(EventData(b"Receiving only a single event"), partition_key='0') time.sleep(10) assert on_event.called is True thread.join() @@ -86,6 +92,7 @@ def on_event(partition_context, event): thread.start() time.sleep(10) assert on_event.event.body_as_str() == expected_result + thread.join() diff --git a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py index 9f0d3a421859..ef562c2628a5 100644 --- a/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py +++ b/sdk/eventhub/azure-eventhub/tests/unittest/test_event_data.py @@ -2,6 +2,7 @@ import pytest import uamqp from packaging import version +from azure.eventhub.amqp import AmqpAnnotatedMessage from azure.eventhub import _common pytestmark = pytest.mark.skipif(platform.python_implementation() == "PyPy", reason="This is ignored for PyPy") @@ -105,3 +106,23 @@ def test_event_data_batch(): assert batch.size_in_bytes == 93 and len(batch) == 1 with pytest.raises(ValueError): batch.add(EventData("A")) + +def test_event_data_from_message(): + message = uamqp.Message('A') + event = EventData._from_message(message) + assert event.content_type is None + assert event.correlation_id is None + assert event.message_id is None + + event.content_type = 'content_type' + event.correlation_id = 'correlation_id' + event.message_id = 'message_id' + assert event.content_type == 'content_type' + assert event.correlation_id == 'correlation_id' + assert event.message_id == 'message_id' + +def test_amqp_message_str_repr(): + data_body = b'A' + message = AmqpAnnotatedMessage(data_body=data_body) + assert str(message) == 'A' + assert 'AmqpAnnotatedMessage(body=A, body_type=data' in repr(message)