Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid dropping of messages after breaking from Select blocks #42

Merged
5 changes: 1 addition & 4 deletions benchmarks/benchmark_anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ async def benchmark_anycast(
recv_trackers = [0]

async def update_tracker_on_receive(chan: Receiver[int]) -> None:
while True:
msg = await chan.receive()
if msg is None:
return
async for _ in chan:
shsms marked this conversation as resolved.
Show resolved Hide resolved
recv_trackers[0] += 1

receivers = []
Expand Down
5 changes: 1 addition & 4 deletions benchmarks/benchmark_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ async def benchmark_broadcast(
recv_trackers = [0]

async def update_tracker_on_receive(chan: Receiver[int]) -> None:
while True:
msg = await chan.receive()
if msg is None:
return
async for _ in chan:
recv_trackers[0] += 1

receivers = []
Expand Down
36 changes: 25 additions & 11 deletions src/frequenz/channels/anycast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from collections import deque
from typing import Deque, Generic, Optional

from frequenz.channels.base_classes import ChannelClosedError
from frequenz.channels.base_classes import Receiver as BaseReceiver
from frequenz.channels.base_classes import Sender as BaseSender
from frequenz.channels.base_classes import T
Expand Down Expand Up @@ -162,23 +163,36 @@ def __init__(self, chan: Anycast[T]) -> None:
chan: A reference to the channel that this receiver belongs to.
"""
self._chan = chan
self._next: Optional[T] = None

async def receive(self) -> Optional[T]:
"""Receive a message from the channel.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value.

Waits for an message to become available, and returns that message.
When there are multiple receivers for the channel, only one receiver
will receive each message.

Returns:
`None`, if the channel is closed, a message otherwise.
Raises:
ChannelClosedError: if the underlying channel is closed.
"""
# if a message is already ready, then return immediately.
if self._next is not None:
return

while len(self._chan.deque) == 0:
if self._chan.closed:
return None
raise ChannelClosedError()
async with self._chan.recv_cv:
await self._chan.recv_cv.wait()
ret = self._chan.deque.popleft()
self._next = self._chan.deque.popleft()
async with self._chan.send_cv:
self._chan.send_cv.notify(1)
return ret

def consume(self) -> T:
"""Return the latest value once `ready()` is complete.

Returns:
The next value that was received.
"""
assert (
self._next is not None
), "calls to `consume()` must be follow a call to `ready()`"
next_val = self._next
self._next = None
return next_val
103 changes: 82 additions & 21 deletions src/frequenz/channels/base_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,41 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable, Generic, Optional, TypeVar
from typing import Any, Callable, Generic, Optional, TypeVar

T = TypeVar("T")
U = TypeVar("U")


class ChannelError(RuntimeError):
"""Base channel error.

All exceptions generated by channels inherit from this exception.
"""

def __init__(self, message: Any, channel: Any = None):
"""Create a ChannelError instance.

Args:
message: An error message.
channel: A reference to the channel that encountered the error.
"""
super().__init__(message)
self.channel: Any = channel


class ChannelClosedError(ChannelError):
"""Error raised when trying to operate on a closed channel."""

def __init__(self, channel: Any = None):
"""Create a `ChannelClosedError` instance.

Args:
channel: A reference to the channel that was closed.
"""
super().__init__(f"Channel {channel} was closed", channel)


class Sender(ABC, Generic[T]):
"""A channel Sender."""

Expand All @@ -31,12 +60,43 @@ async def send(self, msg: T) -> bool:
class Receiver(ABC, Generic[T]):
"""A channel Receiver."""

async def __anext__(self) -> T:
"""Await the next value in the async iteration over received values.

Returns:
The next value received.

Raises:
StopAsyncIteration: if the underlying channel is closed.
"""
try:
await self.ready()
return self.consume()
except ChannelClosedError as exc:
raise StopAsyncIteration() from exc

@abstractmethod
async def receive(self) -> Optional[T]:
"""Receive a message from the channel.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value.

Once a call to `ready()` has finished, the value should be read with a call to
`consume()`.

Raises:
ChannelClosedError: if the underlying channel is closed.
"""

@abstractmethod
def consume(self) -> T:
"""Return the latest value once `ready()` is complete.

`ready()` must be called before each call to `consume()`.

Returns:
`None`, if the channel is closed, a message otherwise.
The next value received.

Raises:
ChannelClosedError: if the underlying channel is closed.
"""

def __aiter__(self) -> Receiver[T]:
Expand All @@ -47,19 +107,19 @@ def __aiter__(self) -> Receiver[T]:
"""
return self

async def __anext__(self) -> T:
"""Await the next value in the async iteration over received values.

Returns:
The next value received.
async def receive(self) -> T:
"""Receive a message from the channel.

Raises:
StopAsyncIteration: if we receive `None`, i.e. if the underlying
channel is closed.
ChannelClosedError: if the underlying channel is closed.

Returns:
The received message.
"""
received = await self.receive()
if received is None:
raise StopAsyncIteration
try:
received = await self.__anext__() # pylint: disable=unnecessary-dunder-call
except StopAsyncIteration as exc:
raise ChannelClosedError() from exc
return received

def map(self, call: Callable[[T], U]) -> Receiver[U]:
Expand Down Expand Up @@ -136,13 +196,14 @@ def __init__(self, recv: Receiver[T], transform: Callable[[T], U]) -> None:
self._recv = recv
self._transform = transform

async def receive(self) -> Optional[U]:
"""Return a transformed message received from the input channel.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value."""
await self._recv.ready() # pylint: disable=protected-access

def consume(self) -> U:
"""Return a transformed value once `ready()` is complete.

Returns:
`None`, if the channel is closed, a message otherwise.
The next value that was received.
"""
msg = await self._recv.receive()
if msg is None:
return None
return self._transform(msg)
return self._transform(self._recv.consume()) # pylint: disable=protected-access
14 changes: 9 additions & 5 deletions src/frequenz/channels/bidirectional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from __future__ import annotations

from typing import Generic, Optional
from typing import Generic

from frequenz.channels.base_classes import Receiver, Sender, T, U
from frequenz.channels.broadcast import Broadcast
Expand Down Expand Up @@ -82,10 +82,14 @@ async def send(self, msg: T) -> bool:
"""
return await self._sender.send(msg)

async def receive(self) -> Optional[U]:
"""Receive a value from the other side.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value."""
await self._receiver.ready() # pylint: disable=protected-access

def consume(self) -> U:
"""Return the latest value once `_ready` is complete.

Returns:
Received value, or `None` if the channels are closed.
The next value that was received.
"""
return await self._receiver.receive()
return self._receiver.consume() # pylint: disable=protected-access
34 changes: 18 additions & 16 deletions src/frequenz/channels/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Deque, Dict, Generic, Optional
from uuid import UUID, uuid4

from frequenz.channels.base_classes import BufferedReceiver
from frequenz.channels.base_classes import BufferedReceiver, ChannelClosedError
from frequenz.channels.base_classes import Peekable as BasePeekable
from frequenz.channels.base_classes import Sender as BaseSender
from frequenz.channels.base_classes import T
Expand Down Expand Up @@ -249,31 +249,33 @@ def __len__(self) -> int:
"""
return len(self._q)

async def receive(self) -> Optional[T]:
"""Receive a message from the Broadcast channel.

Waits until there are messages available in the channel and returns
them. If there are no remaining messages in the buffer and the channel
is closed, returns `None` immediately.

If [into_peekable()][frequenz.channels.Receiver.into_peekable] is called
on a broadcast `Receiver`, further calls to `receive`, will raise an
`EOFError`.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value.

Raises:
EOFError: when the receiver has been converted into a `Peekable`.

Returns:
`None`, if the channel is closed, a message otherwise.
EOFError: if this receiver is no longer active.
ChannelClosedError: if the underlying channel is closed.
"""
if not self._active:
raise EOFError("This receiver is no longer active.")
leandro-lucarella-frequenz marked this conversation as resolved.
Show resolved Hide resolved

# Use a while loop here, to handle spurious wakeups of condition variables.
#
# The condition also makes sure that if there are already messages ready to be
# consumed, then we return immediately.
while len(self._q) == 0:
if self._chan.closed:
return None
raise ChannelClosedError()
async with self._chan.recv_cv:
await self._chan.recv_cv.wait()

def consume(self) -> T:
"""Return the latest value once `ready` is complete.

Returns:
The next value that was received.
"""
leandro-lucarella-frequenz marked this conversation as resolved.
Show resolved Hide resolved
assert self._q, "calls to `consume()` must be follow a call to `ready()`"
ret = self._q.popleft()
return ret

Expand Down
40 changes: 27 additions & 13 deletions src/frequenz/channels/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import asyncio
from collections import deque
from typing import Any, Deque, Optional, Set
from typing import Any, Deque, Set

from frequenz.channels.base_classes import Receiver, T
from frequenz.channels.base_classes import ChannelClosedError, Receiver, T


class Merge(Receiver[T]):
Expand All @@ -34,7 +34,7 @@ def __init__(self, *args: Receiver[T]) -> None:
"""
self._receivers = {str(id): recv for id, recv in enumerate(args)}
self._pending: Set[asyncio.Task[Any]] = {
asyncio.create_task(recv.receive(), name=name)
asyncio.create_task(recv.__anext__(), name=name)
for name, recv in self._receivers.items()
}
self._results: Deque[T] = deque(maxlen=len(self._receivers))
Expand All @@ -44,31 +44,45 @@ def __del__(self) -> None:
for task in self._pending:
task.cancel()

async def receive(self) -> Optional[T]:
"""Wait until there's a message in any of the channels.
async def ready(self) -> None:
"""Wait until the receiver is ready with a value.

Returns:
The next message that was received, or `None`, if all channels have
closed.
Raises:
ChannelClosedError: if the underlying channel is closed.
"""
# we use a while loop to continue to wait for new data, in case the
# previous `wait` completed because a channel was closed.
while True:
# if there are messages waiting to be consumed, return immediately.
if len(self._results) > 0:
return self._results.popleft()
return

if len(self._pending) == 0:
return None
raise ChannelClosedError()
done, self._pending = await asyncio.wait(
self._pending, return_when=asyncio.FIRST_COMPLETED
)
for item in done:
name = item.get_name()
result = item.result()
# if channel is closed, don't add a task for it again.
if result is None:
if isinstance(item.exception(), StopAsyncIteration):
continue
result = item.result()
self._results.append(result)
self._pending.add(
asyncio.create_task(self._receivers[name].receive(), name=name)
# pylint: disable=unnecessary-dunder-call
asyncio.create_task(self._receivers[name].__anext__(), name=name)
)

def consume(self) -> T:
"""Return the latest value once `ready` is complete.

Raises:
EOFError: When called before a call to `ready()` finishes.

Returns:
The next value that was received.
"""
assert self._results, "calls to `consume()` must be follow a call to `ready()`"

return self._results.popleft()
Loading