Skip to content
68 changes: 1 addition & 67 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from uamqp import BatchMessage, Message, constants

from ._mixin import DictMixin
from ._utils import (
set_message_partition_key,
trace_message,
Expand Down Expand Up @@ -585,70 +586,3 @@ def add(self, event_data):
self.message._body_gen.append(outgoing_event_data) # pylint: disable=protected-access
self._size = size_after_add
self._count += 1


class DictMixin(object):
def __setitem__(self, key, item):
# type: (Any, Any) -> None
self.__dict__[key] = item

def __getitem__(self, key):
# type: (Any) -> Any
return self.__dict__[key]

def __contains__(self, key):
return key in self.__dict__

def __repr__(self):
# type: () -> str
return str(self)

def __len__(self):
# type: () -> int
return len(self.keys())

def __delitem__(self, key):
# type: (Any) -> None
self.__dict__[key] = None

def __eq__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False

def __ne__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
return not self.__eq__(other)

def __str__(self):
# type: () -> str
return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")})

def has_key(self, k):
# type: (Any) -> bool
return k in self.__dict__

def update(self, *args, **kwargs):
# type: (Any, Any) -> None
return self.__dict__.update(*args, **kwargs)

def keys(self):
# type: () -> list
return [k for k in self.__dict__ if not k.startswith("_")]

def values(self):
# type: () -> list
return [v for k, v in self.__dict__.items() if not k.startswith("_")]

def items(self):
# type: () -> list
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")]

def get(self, key, default=None):
# type: (Any, Optional[Any]) -> Any
if key in self.__dict__:
return self.__dict__[key]
return default
Original file line number Diff line number Diff line change
Expand Up @@ -29,68 +29,6 @@ def reset_connection_if_broken(self):
pass


class _ConnectionMode(Enum):
ShareConnection = 1
SeparateConnection = 2


class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes
def __init__(self, **kwargs):
self._lock = Lock()
self._conn = None # type: Connection

self._container_id = kwargs.get("container_id")
self._debug = kwargs.get("debug")
self._error_policy = kwargs.get("error_policy")
self._properties = kwargs.get("properties")
self._encoding = kwargs.get("encoding") or "UTF-8"
self._transport_type = kwargs.get("transport_type") or TransportType.Amqp
self._http_proxy = kwargs.get("http_proxy")
self._max_frame_size = kwargs.get("max_frame_size")
self._channel_max = kwargs.get("channel_max")
self._idle_timeout = kwargs.get("idle_timeout")
self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get(
"remote_idle_timeout_empty_frame_send_ratio"
)

def get_connection(self, host, auth):
# type: (str, JWTTokenAuth) -> Connection
with self._lock:
if self._conn is None:
self._conn = Connection(
host,
auth,
container_id=self._container_id,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio,
error_policy=self._error_policy,
debug=self._debug,
encoding=self._encoding,
)
return self._conn

def close_connection(self):
# type: () -> None
with self._lock:
if self._conn:
self._conn.destroy()
self._conn = None

def reset_connection_if_broken(self):
# type: () -> None
with self._lock:
if self._conn and self._conn._state in ( # pylint:disable=protected-access
c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member
):
self._conn = None


class _SeparateConnectionManager(object):
def __init__(self, **kwargs):
pass
Expand All @@ -110,7 +48,4 @@ def reset_connection_if_broken(self):

def get_connection_manager(**kwargs):
# type: (...) -> 'ConnectionManager'
connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection)
if connection_mode == _ConnectionMode.ShareConnection:
return _SharedConnectionManager(**kwargs)
return _SeparateConnectionManager(**kwargs)
68 changes: 68 additions & 0 deletions sdk/eventhub/azure-eventhub/azure/eventhub/_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# -------------------------------------------------------------------------

class DictMixin(object):
def __setitem__(self, key, item):
# type: (Any, Any) -> None
self.__dict__[key] = item

def __getitem__(self, key):
# type: (Any) -> Any
return self.__dict__[key]

def __repr__(self):
# type: () -> str
return str(self)

def __len__(self):
# type: () -> int
return len(self.keys())

def __delitem__(self, key):
# type: (Any) -> None
self.__dict__[key] = None

def __eq__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False

def __ne__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
return not self.__eq__(other)

def __str__(self):
# type: () -> str
return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")})

def has_key(self, k):
# type: (Any) -> bool
return k in self.__dict__

def update(self, *args, **kwargs):
# type: (Any, Any) -> None
return self.__dict__.update(*args, **kwargs)

def keys(self):
# type: () -> list
return [k for k in self.__dict__ if not k.startswith("_")]

def values(self):
# type: () -> list
return [v for k, v in self.__dict__.items() if not k.startswith("_")]

def items(self):
# type: () -> list
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")]

def get(self, key, default=None):
# type: (Any, Optional[Any]) -> Any
if key in self.__dict__:
return self.__dict__[key]
return default
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from uamqp import TransportType, c_uamqp
from uamqp.async_ops import ConnectionAsync

from .._connection_manager import _ConnectionMode

if TYPE_CHECKING:
from uamqp.authentication import JWTTokenAsync

Expand All @@ -32,62 +30,6 @@ async def reset_connection_if_broken(self) -> None:
pass


class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes
def __init__(self, **kwargs) -> None:
self._loop = kwargs.get("loop")
self._lock = Lock(loop=self._loop)
self._conn = None

self._container_id = kwargs.get("container_id")
self._debug = kwargs.get("debug")
self._error_policy = kwargs.get("error_policy")
self._properties = kwargs.get("properties")
self._encoding = kwargs.get("encoding") or "UTF-8"
self._transport_type = kwargs.get("transport_type") or TransportType.Amqp
self._http_proxy = kwargs.get("http_proxy")
self._max_frame_size = kwargs.get("max_frame_size")
self._channel_max = kwargs.get("channel_max")
self._idle_timeout = kwargs.get("idle_timeout")
self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get(
"remote_idle_timeout_empty_frame_send_ratio"
)

async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync:
async with self._lock:
if self._conn is None:
self._conn = ConnectionAsync(
host,
auth,
container_id=self._container_id,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio,
error_policy=self._error_policy,
debug=self._debug,
loop=self._loop,
encoding=self._encoding,
)
return self._conn

async def close_connection(self) -> None:
async with self._lock:
if self._conn:
await self._conn.destroy_async()
self._conn = None

async def reset_connection_if_broken(self) -> None:
async with self._lock:
if self._conn and self._conn._state in ( # pylint:disable=protected-access
c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member
):
self._conn = None


class _SeparateConnectionManager(object):
def __init__(self, **kwargs) -> None:
pass
Expand All @@ -103,7 +45,4 @@ async def reset_connection_if_broken(self) -> None:


def get_connection_manager(**kwargs) -> "ConnectionManager":
connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection)
if connection_mode == _ConnectionMode.ShareConnection:
return _SharedConnectionManager(**kwargs)
return _SeparateConnectionManager(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -9,70 +9,7 @@
import uamqp

from ._constants import AMQP_MESSAGE_BODY_TYPE_MAP, AmqpMessageBodyType


class DictMixin(object):
def __setitem__(self, key, item):
# type: (Any, Any) -> None
self.__dict__[key] = item

def __getitem__(self, key):
# type: (Any) -> Any
return self.__dict__[key]

def __repr__(self):
# type: () -> str
return str(self)

def __len__(self):
# type: () -> int
return len(self.keys())

def __delitem__(self, key):
# type: (Any) -> None
self.__dict__[key] = None

def __eq__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return False

def __ne__(self, other):
# type: (Any) -> bool
"""Compare objects by comparing all attributes."""
return not self.__eq__(other)

def __str__(self):
# type: () -> str
return str({k: v for k, v in self.__dict__.items() if not k.startswith("_")})

def has_key(self, k):
# type: (Any) -> bool
return k in self.__dict__

def update(self, *args, **kwargs):
# type: (Any, Any) -> None
return self.__dict__.update(*args, **kwargs)

def keys(self):
# type: () -> list
return [k for k in self.__dict__ if not k.startswith("_")]

def values(self):
# type: () -> list
return [v for k, v in self.__dict__.items() if not k.startswith("_")]

def items(self):
# type: () -> list
return [(k, v) for k, v in self.__dict__.items() if not k.startswith("_")]

def get(self, key, default=None):
# type: (Any, Optional[Any]) -> Any
if key in self.__dict__:
return self.__dict__[key]
return default
from .._mixin import DictMixin


class AmqpAnnotatedMessage(object):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def test_receive_no_partition_async(connstr_senders):

async def on_event(partition_context, event):
on_event.received += 1
await partition_context.update_checkpoint(event)
await partition_context.update_checkpoint(event, fake_kwarg="arg") # ignores fake_kwarg
on_event.namespace = partition_context.fully_qualified_namespace
on_event.eventhub_name = partition_context.eventhub_name
on_event.consumer_group = partition_context.consumer_group
Expand Down
Loading