From fa4df6e3a2b040cca7d0495c952c743a492cf115 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Tue, 9 Jan 2024 13:54:55 +0800 Subject: [PATCH] Controller: CIS implementation --- bumble/controller.py | 232 ++++++++++++++++++++++++++++++++++++++++++- bumble/hci.py | 3 +- bumble/link.py | 54 ++++++++++ tests/device_test.py | 49 +++++++++ 4 files changed, 336 insertions(+), 2 deletions(-) diff --git a/bumble/controller.py b/bumble/controller.py index db8f0e65..374f138d 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -57,6 +57,8 @@ HCI_Encryption_Change_Event, HCI_Synchronous_Connection_Complete_Event, HCI_LE_Advertising_Report_Event, + HCI_LE_CIS_Established_Event, + HCI_LE_CIS_Request_Event, HCI_LE_Connection_Complete_Event, HCI_LE_Read_Remote_Features_Complete_Event, HCI_Number_Of_Completed_Packets_Event, @@ -82,6 +84,15 @@ class DataObject: pass +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class CisLink: + handle: int + cis_id: int + cig_id: int + acl_connection: Optional[Connection] = None + + # ----------------------------------------------------------------------------- @dataclasses.dataclass class Connection: @@ -132,6 +143,8 @@ def __init__( self.classic_connections: Dict[ Address, Connection ] = {} # Connections in BR/EDR + self.central_cis_links: Dict[int, CisLink] = {} # CIS links by handle + self.peripheral_cis_links: Dict[int, CisLink] = {} # CIS links by handle self.hci_version = HCI_VERSION_BLUETOOTH_CORE_5_0 self.hci_revision = 0 @@ -310,7 +323,7 @@ async def wait_for_termination(self): ############################################################ # Link connections ############################################################ - def allocate_connection_handle(self): + def allocate_connection_handle(self) -> int: handle = 0 max_handle = 0 for connection in itertools.chain( @@ -322,6 +335,13 @@ def allocate_connection_handle(self): if connection.handle == handle: # Already used, continue searching after the current max handle = max_handle + 1 + for cis_handle in itertools.chain( + self.central_cis_links.keys(), self.peripheral_cis_links.keys() + ): + max_handle = max(max_handle, cis_handle) + if cis_handle == handle: + # Already used, continue searching after the current max + handle = max_handle + 1 return handle def find_le_connection_by_address(self, address): @@ -549,6 +569,104 @@ def on_link_advertising_data(self, sender_address, data): ) self.send_hci_packet(HCI_LE_Advertising_Report_Event([report])) + def on_link_cis_request( + self, central_address: Address, cig_id: int, cis_id: int + ) -> None: + ''' + Called when an incoming CIS request occurs from a central on the link + ''' + + connection = self.peripheral_connections.get(central_address) + assert connection + + pending_cis_link = CisLink( + handle=self.allocate_connection_handle(), + cis_id=cis_id, + cig_id=cig_id, + acl_connection=connection, + ) + self.peripheral_cis_links[pending_cis_link.handle] = pending_cis_link + + self.send_hci_packet( + HCI_LE_CIS_Request_Event( + acl_connection_handle=connection.handle, + cis_connection_handle=pending_cis_link.handle, + cig_id=cig_id, + cis_id=cis_id, + ) + ) + + def on_link_cis_established(self, cig_id: int, cis_id: int) -> None: + ''' + Called when an incoming CIS established. + ''' + + cis_link = next( + cis_link + for cis_link in itertools.chain( + self.central_cis_links.values(), self.peripheral_cis_links.values() + ) + if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id + ) + + self.send_hci_packet( + HCI_LE_CIS_Established_Event( + status=HCI_SUCCESS, + connection_handle=cis_link.handle, + # CIS parameters are ignored. + cig_sync_delay=0, + cis_sync_delay=0, + transport_latency_c_to_p=0, + transport_latency_p_to_c=0, + phy_c_to_p=0, + phy_p_to_c=0, + nse=0, + bn_c_to_p=0, + bn_p_to_c=0, + ft_c_to_p=0, + ft_p_to_c=0, + max_pdu_c_to_p=0, + max_pdu_p_to_c=0, + iso_interval=0, + ) + ) + + def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None: + ''' + Called when a CIS disconnected. + ''' + + if cis_link := next( + ( + cis_link + for cis_link in self.peripheral_cis_links.values() + if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id + ), + None, + ): + # Remove peripheral CIS on disconnection. + self.peripheral_cis_links.pop(cis_link.handle) + elif cis_link := next( + ( + cis_link + for cis_link in self.central_cis_links.values() + if cis_link.cis_id == cis_id and cis_link.cig_id == cig_id + ), + None, + ): + # Keep central CIS on disconnection. They should be removed by HCI_LE_Remove_CIG_Command. + cis_link.acl_connection = None + else: + return + + self.send_hci_packet( + HCI_Disconnection_Complete_Event( + status=HCI_SUCCESS, + connection_handle=cis_link.handle, + reason=HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, + ) + ) + ############################################################ # Classic link connections ############################################################ @@ -769,6 +887,17 @@ def on_hci_disconnect_command(self, command): else: # Remove the connection del self.classic_connections[connection.peer_address] + elif cis_link := ( + self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle) + ): + if self.link: + self.link.disconnect_cis( + initiator_controller=self, + peer_address=cis_link.acl_connection.peer_address, + cig_id=cis_link.cig_id, + cis_id=cis_link.cis_id, + ) + # Spec requires handle to be kept after disconnection. def on_hci_accept_connection_request_command(self, command): ''' @@ -1399,6 +1528,107 @@ def on_hci_le_read_transmit_power_command(self, _command): ''' return struct.pack(' None: + logger.debug( + f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}' + ) + if peripheral_controller := self.find_controller(peripheral_address): + asyncio.get_running_loop().call_soon( + peripheral_controller.on_link_cis_request, + central_controller.random_address, + cig_id, + cis_id, + ) + + def accept_cis( + self, + peripheral_controller: controller.Controller, + central_address: Address, + cig_id: int, + cis_id: int, + ) -> None: + logger.debug( + f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}' + ) + if central_controller := self.find_controller(central_address): + asyncio.get_running_loop().call_soon( + central_controller.on_link_cis_established, cig_id, cis_id + ) + asyncio.get_running_loop().call_soon( + peripheral_controller.on_link_cis_established, cig_id, cis_id + ) + + def disconnect_cis( + self, + initiator_controller: controller.Controller, + peer_address: Address, + cig_id: int, + cis_id: int, + ) -> None: + logger.debug( + f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}' + ) + if peer_controller := self.find_controller(peer_address): + asyncio.get_running_loop().call_soon( + initiator_controller.on_link_cis_disconnected, cig_id, cis_id + ) + asyncio.get_running_loop().call_soon( + peer_controller.on_link_cis_disconnected, cig_id, cis_id + ) + ############################################################ # Classic handlers ############################################################ diff --git a/tests/device_test.py b/tests/device_test.py index f6cd2213..d2b51d86 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -423,6 +423,55 @@ async def test_get_remote_le_features(): assert (await devices.connections[0].get_remote_le_features()) is not None +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_cis(): + devices = TwoDevices() + await devices.setup_connection() + + peripheral_cis_futures = {} + + def on_cis_request( + acl_connection: Connection, + cis_handle: int, + _cig_id: int, + _cis_id: int, + ): + acl_connection.abort_on( + 'disconnection', devices[1].accept_cis_request(cis_handle) + ) + peripheral_cis_futures[cis_handle] = asyncio.get_running_loop().create_future() + + devices[1].on('cis_request', on_cis_request) + devices[1].on( + 'cis_establishment', + lambda cis_link: peripheral_cis_futures[cis_link.handle].set_result(None), + ) + + cis_handles = await devices[0].setup_cig( + cig_id=1, + cis_id=[2, 3], + sdu_interval=(0, 0), + framing=0, + max_sdu=(0, 0), + retransmission_number=0, + max_transport_latency=(0, 0), + ) + assert len(cis_handles) == 2 + cis_links = await devices[0].create_cis( + [ + (cis_handles[0], devices.connections[0].handle), + (cis_handles[1], devices.connections[0].handle), + ] + ) + await asyncio.gather(*peripheral_cis_futures.values()) + assert len(cis_links) == 2 + + # TODO: Fix Host CIS support. + # await cis_links[0].disconnect() + # await cis_links[1].disconnect() + + # ----------------------------------------------------------------------------- def test_gatt_services_with_gas(): device = Device(host=Host(None, None))