Skip to content

Commit

Permalink
RFCOMM: Avoid receive packets before DLC sink set
Browse files Browse the repository at this point in the history
  • Loading branch information
zxzxwu committed May 7, 2024
1 parent 593c619 commit 25dedf4
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
16 changes: 14 additions & 2 deletions bumble/rfcomm.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def __init__(
self.connection_result = None
self.drained = asyncio.Event()
self.drained.set()
self.opened = asyncio.Event()

# Compute the MTU
max_overhead = 4 + 1 # header with 2-byte length + fcs
Expand Down Expand Up @@ -505,6 +506,8 @@ def on_sabm_frame(self, _frame: RFCOMM_Frame) -> None:

self.change_state(DLC.State.CONNECTED)
self.emit('open')
# Opened as responder
self.opened.set()

def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:
if self.state != DLC.State.CONNECTING:
Expand All @@ -521,6 +524,8 @@ def on_ua_frame(self, _frame: RFCOMM_Frame) -> None:

self.change_state(DLC.State.CONNECTED)
self.multiplexer.on_dlc_open_complete(self)
# Opened as initiator
self.opened.set()

def on_dm_frame(self, frame: RFCOMM_Frame) -> None:
# TODO: handle all states
Expand Down Expand Up @@ -549,8 +554,7 @@ def on_uih_frame(self, frame: RFCOMM_Frame) -> None:
f'rx_credits={self.rx_credits}: {data.hex()}'
)
if data:
if self.sink:
self.sink(data) # pylint: disable=not-callable
asyncio.create_task(self._dispatch_data(data))

# Update the credits
if self.rx_credits > 0:
Expand Down Expand Up @@ -664,6 +668,14 @@ def write(self, data: Union[bytes, str]) -> None:
async def drain(self) -> None:
await self.drained.wait()

async def _dispatch_data(self, data: bytes) -> None:
# Make sure the DLC has been opened, otherwise the sink may not be set when the
# PDU is dispatched.
await self.opened.wait()

if self.sink:
self.sink(data) # pylint: disable=not-callable

def __str__(self) -> str:
return f'DLC(dlci={self.dlci},state={self.state.name})'

Expand Down
25 changes: 25 additions & 0 deletions tests/rfcomm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
RFCOMM_PSM,
)

_TIMEOUT = 0.1


# -----------------------------------------------------------------------------
def basic_frame_check(x):
Expand Down Expand Up @@ -82,6 +84,29 @@ async def test_basic_connection() -> None:
assert await queues[0].get() == b'Lorem ipsum dolor sit amet'


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_receive_pdu_before_open_dlc_returns() -> None:
devices = await test_utils.TwoDevices.create_with_connection()
DATA = b'123'

accept_future: asyncio.Future[DLC] = asyncio.get_running_loop().create_future()
channel = Server(devices[0]).listen(acceptor=accept_future.set_result)

assert devices.connections[1]
multiplexer = await Client(devices.connections[1]).start()
open_dlc_task = asyncio.create_task(multiplexer.open_dlc(channel))

dlc_responder = await accept_future
dlc_responder.write(DATA)

dlc_initiator = await open_dlc_task
dlc_initiator_queue = asyncio.Queue() # type: ignore[var-annotated]
dlc_initiator.sink = dlc_initiator_queue.put_nowait

assert await asyncio.wait_for(dlc_initiator_queue.get(), timeout=_TIMEOUT) == DATA


# -----------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_service_record():
Expand Down
9 changes: 8 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# Imports
# -----------------------------------------------------------------------------
import asyncio
from typing import List, Optional
from typing import List, Optional, Type
from typing_extensions import Self

from bumble.controller import Controller
from bumble.link import LocalLink
Expand Down Expand Up @@ -81,6 +82,12 @@ async def setup_connection(self) -> None:
def __getitem__(self, index: int) -> Device:
return self.devices[index]

@classmethod
async def create_with_connection(cls: Type[Self]) -> Self:
devices = cls()
await devices.setup_connection()
return devices


# -----------------------------------------------------------------------------
async def async_barrier():
Expand Down

0 comments on commit 25dedf4

Please sign in to comment.