Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -34,63 +34,6 @@ class _ConnectionMode(Enum):
SeparateConnection = 2


class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes
def __init__(self, **kwargs):
self._lock = Lock()
self._conn = None # type: Connection

self._container_id = kwargs.get("container_id")
self._debug = kwargs.get("debug")
self._error_policy = kwargs.get("error_policy")
self._properties = kwargs.get("properties")
self._encoding = kwargs.get("encoding") or "UTF-8"
self._transport_type = kwargs.get("transport_type") or TransportType.Amqp
self._http_proxy = kwargs.get("http_proxy")
self._max_frame_size = kwargs.get("max_frame_size")
self._channel_max = kwargs.get("channel_max")
self._idle_timeout = kwargs.get("idle_timeout")
self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get(
"remote_idle_timeout_empty_frame_send_ratio"
)

def get_connection(self, host, auth):
# type: (str, JWTTokenAuth) -> Connection
with self._lock:
if self._conn is None:
self._conn = Connection(
host,
auth,
container_id=self._container_id,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio,
error_policy=self._error_policy,
debug=self._debug,
encoding=self._encoding,
)
return self._conn

def close_connection(self):
# type: () -> None
with self._lock:
if self._conn:
self._conn.destroy()
self._conn = None

def reset_connection_if_broken(self):
# type: () -> None
with self._lock:
if self._conn and self._conn._state in ( # pylint:disable=protected-access
c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member
):
self._conn = None


class _SeparateConnectionManager(object):
def __init__(self, **kwargs):
pass
Expand All @@ -112,5 +55,5 @@ def get_connection_manager(**kwargs):
# type: (...) -> 'ConnectionManager'
connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection)
if connection_mode == _ConnectionMode.ShareConnection:
return _SharedConnectionManager(**kwargs)
pass
return _SeparateConnectionManager(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -32,62 +32,6 @@ async def reset_connection_if_broken(self) -> None:
pass


class _SharedConnectionManager(object): # pylint:disable=too-many-instance-attributes
def __init__(self, **kwargs) -> None:
self._loop = kwargs.get("loop")
self._lock = Lock(loop=self._loop)
self._conn = None

self._container_id = kwargs.get("container_id")
self._debug = kwargs.get("debug")
self._error_policy = kwargs.get("error_policy")
self._properties = kwargs.get("properties")
self._encoding = kwargs.get("encoding") or "UTF-8"
self._transport_type = kwargs.get("transport_type") or TransportType.Amqp
self._http_proxy = kwargs.get("http_proxy")
self._max_frame_size = kwargs.get("max_frame_size")
self._channel_max = kwargs.get("channel_max")
self._idle_timeout = kwargs.get("idle_timeout")
self._remote_idle_timeout_empty_frame_send_ratio = kwargs.get(
"remote_idle_timeout_empty_frame_send_ratio"
)

async def get_connection(self, host: str, auth: "JWTTokenAsync") -> ConnectionAsync:
async with self._lock:
if self._conn is None:
self._conn = ConnectionAsync(
host,
auth,
container_id=self._container_id,
max_frame_size=self._max_frame_size,
channel_max=self._channel_max,
idle_timeout=self._idle_timeout,
properties=self._properties,
remote_idle_timeout_empty_frame_send_ratio=self._remote_idle_timeout_empty_frame_send_ratio,
error_policy=self._error_policy,
debug=self._debug,
loop=self._loop,
encoding=self._encoding,
)
return self._conn

async def close_connection(self) -> None:
async with self._lock:
if self._conn:
await self._conn.destroy_async()
self._conn = None

async def reset_connection_if_broken(self) -> None:
async with self._lock:
if self._conn and self._conn._state in ( # pylint:disable=protected-access
c_uamqp.ConnectionState.CLOSE_RCVD, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.CLOSE_SENT, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.DISCARDING, # pylint:disable=c-extension-no-member
c_uamqp.ConnectionState.END, # pylint:disable=c-extension-no-member
):
self._conn = None


class _SeparateConnectionManager(object):
def __init__(self, **kwargs) -> None:
pass
Expand All @@ -105,5 +49,5 @@ async def reset_connection_if_broken(self) -> None:
def get_connection_manager(**kwargs) -> "ConnectionManager":
connection_mode = kwargs.get("connection_mode", _ConnectionMode.SeparateConnection)
if connection_mode == _ConnectionMode.ShareConnection:
return _SharedConnectionManager(**kwargs)
pass
return _SeparateConnectionManager(**kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def test_receive_no_partition_async(connstr_senders):

async def on_event(partition_context, event):
on_event.received += 1
await partition_context.update_checkpoint(event)
await partition_context.update_checkpoint(event, fake_kwarg="arg") # ignores fake_kwarg
on_event.namespace = partition_context.fully_qualified_namespace
on_event.eventhub_name = partition_context.eventhub_name
on_event.consumer_group = partition_context.consumer_group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,21 @@ async def on_event(partition_context, event):
assert event.body_as_str() == "Receiving only a single event"
assert list(event.body)[0] == b"Receiving only a single event"
on_event.called = True
assert event.partition_key == b'0'
event_str = str(event)
assert ", offset: " in event_str
assert ", sequence_number: " in event_str
assert ", enqueued_time: " in event_str
assert ", partition_key: 0" in event_str

on_event.called = False
connection_str, senders = connstr_senders
client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default')
async with client:
task = asyncio.ensure_future(client.receive(on_event, partition_id="0", starting_position="@latest"))
await asyncio.sleep(10)
assert on_event.called is False
senders[0].send(EventData(b"Receiving only a single event"))
senders[0].send(EventData(b"Receiving only a single event"), partition_key='0')
await asyncio.sleep(10)
assert on_event.called is True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_receive_no_partition(connstr_senders):

def on_event(partition_context, event):
on_event.received += 1
partition_context.update_checkpoint(event)
partition_context.update_checkpoint(event, fake_kwarg="arg") # ignores fake_kwarg
on_event.namespace = partition_context.fully_qualified_namespace
on_event.eventhub_name = partition_context.eventhub_name
on_event.consumer_group = partition_context.consumer_group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def on_event(partition_context, event):
assert event.body_as_str() == "Receiving only a single event"
assert list(event.body)[0] == b"Receiving only a single event"
on_event.called = True
assert event.partition_key == b'0'
event_str = str(event)
assert ", offset: " in event_str
assert ", sequence_number: " in event_str
assert ", enqueued_time: " in event_str
assert ", partition_key: 0" in event_str
on_event.called = False
connection_str, senders = connstr_senders
client = EventHubConsumerClient.from_connection_string(connection_str, consumer_group='$default')
Expand All @@ -31,7 +37,7 @@ def on_event(partition_context, event):
thread.start()
time.sleep(10)
assert on_event.called is False
senders[0].send(EventData(b"Receiving only a single event"))
senders[0].send(EventData(b"Receiving only a single event"), partition_key='0')
time.sleep(10)
assert on_event.called is True
thread.join()
Expand Down Expand Up @@ -86,6 +92,7 @@ def on_event(partition_context, event):
thread.start()
time.sleep(10)
assert on_event.event.body_as_str() == expected_result

thread.join()


Expand Down