diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 96a0240b..5fb36f8c 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -10,7 +10,15 @@ ## New Features - +- There is a new `Receiver.triggered` method that can be used instead of `selected_from`: + + ```python + async for selected in select(recv1, recv2): + if recv1.triggered(selected): + print('Received from recv1:', selected.message) + if recv2.triggered(selected): + print('Received from recv2:', selected.message) + ``` ## Bug Fixes diff --git a/src/frequenz/channels/_receiver.py b/src/frequenz/channels/_receiver.py index 13a79481..7b57a631 100644 --- a/src/frequenz/channels/_receiver.py +++ b/src/frequenz/channels/_receiver.py @@ -155,11 +155,14 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import Generic, Self +from typing import TYPE_CHECKING, Any, Generic, Self, TypeGuard from ._exceptions import Error from ._generic import MappedMessageT_co, ReceiverMessageT_co +if TYPE_CHECKING: + from ._select import Selected + class Receiver(ABC, Generic[ReceiverMessageT_co]): """An endpoint to receive messages.""" @@ -284,6 +287,30 @@ def filter( """ return _Filter(receiver=self, filter_function=filter_function) + def triggered( + self, selected: Selected[Any] + ) -> TypeGuard[Selected[ReceiverMessageT_co]]: + """Check whether this receiver was selected by [`select()`][frequenz.channels.select]. + + This method is used in conjunction with the + [`Selected`][frequenz.channels.Selected] class to determine which receiver was + selected in `select()` iteration. + + It also works as a [type guard][typing.TypeGuard] to narrow the type of the + `Selected` instance to the type of the receiver. + + Please see [`select()`][frequenz.channels.select] for an example. + + Args: + selected: The result of a `select()` iteration. + + Returns: + Whether this receiver was selected. + """ + if handled := selected._recv is self: # pylint: disable=protected-access + selected._handled = True # pylint: disable=protected-access + return handled + class ReceiverError(Error, Generic[ReceiverMessageT_co]): """An error that originated in a [Receiver][frequenz.channels.Receiver]. @@ -373,9 +400,7 @@ def consume(self) -> MappedMessageT_co: # noqa: DOC502 ReceiverStoppedError: If the receiver stopped producing messages. ReceiverError: If there is a problem with the receiver. """ - return self._mapping_function( - self._receiver.consume() - ) # pylint: disable=protected-access + return self._mapping_function(self._receiver.consume()) def __str__(self) -> str: """Return a string representation of the mapper.""" diff --git a/src/frequenz/channels/_select.py b/src/frequenz/channels/_select.py index 41da79ba..ccd669eb 100644 --- a/src/frequenz/channels/_select.py +++ b/src/frequenz/channels/_select.py @@ -269,9 +269,7 @@ def selected_from( Returns: Whether the given receiver was selected. """ - if handled := selected._recv is receiver: # pylint: disable=protected-access - selected._handled = True # pylint: disable=protected-access - return handled + return receiver.triggered(selected) class SelectError(Error): @@ -378,14 +376,14 @@ async def select( # noqa: DOC503 import datetime from typing import assert_never - from frequenz.channels import ReceiverStoppedError, select, selected_from + from frequenz.channels import ReceiverStoppedError, select from frequenz.channels.timer import SkipMissedAndDrift, Timer, TriggerAllMissed timer1 = Timer(datetime.timedelta(seconds=1), TriggerAllMissed()) timer2 = Timer(datetime.timedelta(seconds=0.5), SkipMissedAndDrift()) async for selected in select(timer1, timer2): - if selected_from(selected, timer1): + if timer1.triggered(selected): # Beware: `selected.message` might raise an exception, you can always # check for exceptions with `selected.exception` first or use # a try-except block. You can also quickly check if the receiver was @@ -395,7 +393,7 @@ async def select( # noqa: DOC503 continue print(f"timer1: now={datetime.datetime.now()} drift={selected.message}") timer2.stop() - elif selected_from(selected, timer2): + elif timer2.triggered(selected): # Explicitly handling of exceptions match selected.exception: case ReceiverStoppedError():