diff --git a/bumble/rfcomm.py b/bumble/rfcomm.py index 6ca0f509..0a671de4 100644 --- a/bumble/rfcomm.py +++ b/bumble/rfcomm.py @@ -445,7 +445,9 @@ class State(enum.IntEnum): RESET = 0x05 connection_result: Optional[asyncio.Future] - sink: Optional[Callable[[bytes], None]] + _sink: Optional[Callable[[bytes], None]] + if TYPE_CHECKING: + _packet_queue: asyncio.Queue[bytes] def __init__( self, @@ -466,10 +468,12 @@ def __init__( self.state = DLC.State.INIT self.role = multiplexer.role self.c_r = 1 if self.role == Multiplexer.Role.INITIATOR else 0 - self.sink = None self.connection_result = None self.drained = asyncio.Event() self.drained.set() + # Queued packets when sink is not set. + self._packet_queue = asyncio.Queue(maxsize=32) + self._sink = None # Compute the MTU max_overhead = 4 + 1 # header with 2-byte length + fcs @@ -477,6 +481,17 @@ def __init__( max_frame_size, self.multiplexer.l2cap_channel.peer_mtu - max_overhead ) + @property + def sink(self) -> Optional[Callable[[bytes], None]]: + return self._sink + + @sink.setter + def sink(self, sink: Optional[Callable[[bytes], None]]) -> None: + self._sink = sink + # Dump queued packets to sink + while sink and not self._packet_queue.empty(): + sink(self._packet_queue.get_nowait()) # pylint: disable=not-callable + def change_state(self, new_state: State) -> None: logger.debug(f'{self} state change -> {color(new_state.name, "magenta")}') self.state = new_state @@ -549,8 +564,12 @@ def on_uih_frame(self, frame: RFCOMM_Frame) -> None: f'rx_credits={self.rx_credits}: {data.hex()}' ) if data: - if self.sink: - self.sink(data) # pylint: disable=not-callable + if self._sink: + self._sink(data) # pylint: disable=not-callable + elif not self._packet_queue.full(): + self._packet_queue.put_nowait(data) + else: + logger.warning(f'DLC [{self.dlci}] packet queue is full') # Update the credits if self.rx_credits > 0: diff --git a/tests/rfcomm_test.py b/tests/rfcomm_test.py index 4ce4d116..fcd43108 100644 --- a/tests/rfcomm_test.py +++ b/tests/rfcomm_test.py @@ -32,6 +32,8 @@ RFCOMM_PSM, ) +_TIMEOUT = 0.1 + # ----------------------------------------------------------------------------- def basic_frame_check(x): @@ -82,6 +84,29 @@ async def test_basic_connection() -> None: assert await queues[0].get() == b'Lorem ipsum dolor sit amet' +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_receive_pdu_before_open_dlc_returns() -> None: + devices = await test_utils.TwoDevices.create_with_connection() + DATA = b'123' + + accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future() + channel = Server(devices[0]).listen(acceptor=accept_future.set_result) + + assert devices.connections[1] + multiplexer = await Client(devices.connections[1]).start() + open_dlc_task = asyncio.create_task(multiplexer.open_dlc(channel)) + + dlc_responder = await accept_future + dlc_responder.write(DATA) + + dlc_initiator = await open_dlc_task + dlc_initiator_queue = asyncio.Queue() # type: ignore[var-annotated] + dlc_initiator.sink = dlc_initiator_queue.put_nowait + + assert await asyncio.wait_for(dlc_initiator_queue.get(), timeout=_TIMEOUT) == DATA + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_service_record(): diff --git a/tests/test_utils.py b/tests/test_utils.py index d193d6e5..1f0b4f3b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,7 +16,8 @@ # Imports # ----------------------------------------------------------------------------- import asyncio -from typing import List, Optional +from typing import List, Optional, Type +from typing_extensions import Self from bumble.controller import Controller from bumble.link import LocalLink @@ -81,6 +82,12 @@ async def setup_connection(self) -> None: def __getitem__(self, index: int) -> Device: return self.devices[index] + @classmethod + async def create_with_connection(cls: Type[Self]) -> Self: + devices = cls() + await devices.setup_connection() + return devices + # ----------------------------------------------------------------------------- async def async_barrier():