diff --git a/apps/auracast.py b/apps/auracast.py index 463fc9b5..96f2a23e 100644 --- a/apps/auracast.py +++ b/apps/auracast.py @@ -17,10 +17,11 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio +import contextlib import dataclasses import logging import os -from typing import cast, Dict, Optional, Tuple +from typing import cast, Any, AsyncGenerator, Coroutine, Dict, Optional, Tuple import click import pyee @@ -32,6 +33,7 @@ import bumble.gatt import bumble.hci import bumble.profiles.bap +import bumble.profiles.bass import bumble.profiles.pbp import bumble.transport import bumble.utils @@ -46,14 +48,16 @@ # ----------------------------------------------------------------------------- # Constants # ----------------------------------------------------------------------------- -AURACAST_DEFAULT_DEVICE_NAME = "Bumble Auracast" -AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address("F0:F1:F2:F3:F4:F5") +AURACAST_DEFAULT_DEVICE_NAME = 'Bumble Auracast' +AURACAST_DEFAULT_DEVICE_ADDRESS = bumble.hci.Address('F0:F1:F2:F3:F4:F5') +AURACAST_DEFAULT_SYNC_TIMEOUT = 5.0 +AURACAST_DEFAULT_ATT_MTU = 256 # ----------------------------------------------------------------------------- -# Discover Broadcasts +# Scan For Broadcasts # ----------------------------------------------------------------------------- -class BroadcastDiscoverer: +class BroadcastScanner(pyee.EventEmitter): @dataclasses.dataclass class Broadcast(pyee.EventEmitter): name: str @@ -79,22 +83,6 @@ def __post_init__(self) -> None: self.sync.on('periodic_advertisement', self.on_periodic_advertisement) self.sync.on('biginfo_advertisement', self.on_biginfo_advertisement) - self.establishment_timeout_task = asyncio.create_task( - self.wait_for_establishment() - ) - - async def wait_for_establishment(self) -> None: - await asyncio.sleep(5.0) - if self.sync.state == bumble.device.PeriodicAdvertisingSync.State.PENDING: - print( - color( - '!!! Periodic advertisement sync not established in time, ' - 'canceling', - 'red', - ) - ) - await self.sync.terminate() - def update(self, advertisement: bumble.device.Advertisement) -> None: self.rssi = advertisement.rssi for service_data in advertisement.data.get_all( @@ -139,6 +127,8 @@ def update(self, advertisement: bumble.device.Advertisement) -> None: data, ) + self.emit('update') + def print(self) -> None: print( color('Broadcast:', 'yellow'), @@ -227,13 +217,12 @@ def print(self) -> None: ) def on_sync_establishment(self) -> None: - self.establishment_timeout_task.cancel() - self.emit('change') + self.emit('sync_establishment') def on_sync_loss(self) -> None: self.basic_audio_announcement = None self.biginfo = None - self.emit('change') + self.emit('sync_loss') def on_periodic_advertisement( self, advertisement: bumble.device.PeriodicAdvertisement @@ -268,37 +257,21 @@ def __init__( filter_duplicates: bool, sync_timeout: float, ): + super().__init__() self.device = device self.filter_duplicates = filter_duplicates self.sync_timeout = sync_timeout - self.broadcasts: Dict[bumble.hci.Address, BroadcastDiscoverer.Broadcast] = {} - self.status_message = '' + self.broadcasts: Dict[bumble.hci.Address, BroadcastScanner.Broadcast] = {} device.on('advertisement', self.on_advertisement) - async def run(self) -> None: - self.status_message = color('Scanning...', 'green') + async def start(self) -> None: await self.device.start_scanning( active=False, filter_duplicates=False, ) - def refresh(self) -> None: - # Clear the screen from the top - print('\033[H') - print('\033[0J') - print('\033[H') - - # Print the status message - print(self.status_message) - print("==========================================") - - # Print all broadcasts - for broadcast in self.broadcasts.values(): - broadcast.print() - print('------------------------------------------') - - # Clear the screen to the bottom - print('\033[0J') + async def stop(self) -> None: + await self.device.stop_scanning() def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: if ( @@ -311,7 +284,6 @@ def on_advertisement(self, advertisement: bumble.device.Advertisement) -> None: if broadcast := self.broadcasts.get(advertisement.address): broadcast.update(advertisement) - self.refresh() return bumble.utils.AsyncRunner.spawn( @@ -331,41 +303,318 @@ async def on_new_broadcast( name, periodic_advertising_sync, ) - broadcast.on('change', self.refresh) broadcast.update(advertisement) self.broadcasts[advertisement.address] = broadcast periodic_advertising_sync.on('loss', lambda: self.on_broadcast_loss(broadcast)) - self.status_message = color( - f'+Found {len(self.broadcasts)} broadcasts', 'green' - ) - self.refresh() + self.emit('new_broadcast', broadcast) def on_broadcast_loss(self, broadcast: Broadcast) -> None: del self.broadcasts[broadcast.sync.advertiser_address] bumble.utils.AsyncRunner.spawn(broadcast.sync.terminate()) + self.emit('broadcast_loss', broadcast) + + +class PrintingBroadcastScanner: + def __init__( + self, device: bumble.device.Device, filter_duplicates: bool, sync_timeout: float + ) -> None: + self.scanner = BroadcastScanner(device, filter_duplicates, sync_timeout) + self.scanner.on('new_broadcast', self.on_new_broadcast) + self.scanner.on('broadcast_loss', self.on_broadcast_loss) + self.scanner.on('update', self.refresh) + self.status_message = '' + + async def start(self) -> None: + self.status_message = color('Scanning...', 'green') + await self.scanner.start() + + def on_new_broadcast(self, broadcast: BroadcastScanner.Broadcast) -> None: self.status_message = color( - f'-Found {len(self.broadcasts)} broadcasts', 'green' + f'+Found {len(self.scanner.broadcasts)} broadcasts', 'green' ) + broadcast.on('change', self.refresh) + broadcast.on('update', self.refresh) self.refresh() + def on_broadcast_loss(self, broadcast: BroadcastScanner.Broadcast) -> None: + self.status_message = color( + f'-Found {len(self.scanner.broadcasts)} broadcasts', 'green' + ) + self.refresh() -async def run_discover_broadcasts( - filter_duplicates: bool, sync_timeout: float, transport: str -) -> None: + def refresh(self) -> None: + # Clear the screen from the top + print('\033[H') + print('\033[0J') + print('\033[H') + + # Print the status message + print(self.status_message) + print("==========================================") + + # Print all broadcasts + for broadcast in self.scanner.broadcasts.values(): + broadcast.print() + print('------------------------------------------') + + # Clear the screen to the bottom + print('\033[0J') + + +@contextlib.asynccontextmanager +async def create_device(transport: str) -> AsyncGenerator[bumble.device.Device, Any]: async with await bumble.transport.open_transport(transport) as ( hci_source, hci_sink, ): - device = bumble.device.Device.with_hci( - AURACAST_DEFAULT_DEVICE_NAME, - AURACAST_DEFAULT_DEVICE_ADDRESS, + device_config = bumble.device.DeviceConfiguration( + name=AURACAST_DEFAULT_DEVICE_NAME, + address=AURACAST_DEFAULT_DEVICE_ADDRESS, + keystore='JsonKeyStore', + ) + + device = bumble.device.Device.from_config_with_hci( + device_config, hci_source, hci_sink, ) await device.power_on() - discoverer = BroadcastDiscoverer(device, filter_duplicates, sync_timeout) - await discoverer.run() - await hci_source.terminated + + yield device + + +async def find_broadcast_by_name( + device: bumble.device.Device, name: Optional[str] +) -> BroadcastScanner.Broadcast: + result = asyncio.get_running_loop().create_future() + + def on_broadcast_change(broadcast: BroadcastScanner.Broadcast) -> None: + if broadcast.basic_audio_announcement and not result.done(): + print(color('Broadcast basic audio announcement received', 'green')) + result.set_result(broadcast) + + def on_new_broadcast(broadcast: BroadcastScanner.Broadcast) -> None: + if name is None or broadcast.name == name: + print(color('Broadcast found:', 'green'), broadcast.name) + broadcast.on('change', lambda: on_broadcast_change(broadcast)) + return + + print(color(f'Skipping broadcast {broadcast.name}')) + + scanner = BroadcastScanner(device, False, AURACAST_DEFAULT_SYNC_TIMEOUT) + scanner.on('new_broadcast', on_new_broadcast) + await scanner.start() + + broadcast = await result + await scanner.stop() + + return broadcast + + +async def run_scan( + filter_duplicates: bool, sync_timeout: float, transport: str +) -> None: + async with create_device(transport) as device: + if not device.supports_le_periodic_advertising: + print(color('Periodic advertising not supported', 'red')) + return + + scanner = PrintingBroadcastScanner(device, filter_duplicates, sync_timeout) + await scanner.start() + await asyncio.get_running_loop().create_future() + + +async def run_assist( + broadcast_name: Optional[str], + source_id: Optional[int], + command: str, + transport: str, + address: str, +) -> None: + async with create_device(transport) as device: + if not device.supports_le_periodic_advertising: + print(color('Periodic advertising not supported', 'red')) + return + + # Connect to the server + print(f'=== Connecting to {address}...') + connection = await device.connect(address) + peer = bumble.device.Peer(connection) + print(f'=== Connected to {peer}') + + print("+++ Encrypting connection...") + await peer.connection.encrypt() + print("+++ Connection encrypted") + + # Request a larger MTU + mtu = AURACAST_DEFAULT_ATT_MTU + print(color(f'$$$ Requesting MTU={mtu}', 'yellow')) + await peer.request_mtu(mtu) + + # Get the BASS service + bass = await peer.discover_service_and_create_proxy( + bumble.profiles.bass.BroadcastAudioScanServiceProxy + ) + + # Check that the service was found + if not bass: + print(color('!!! Broadcast Audio Scan Service not found', 'red')) + return + + # Subscribe to and read the broadcast receive state characteristics + for i, broadcast_receive_state in enumerate(bass.broadcast_receive_states): + try: + await broadcast_receive_state.subscribe( + lambda value, i=i: print( + f"{color(f'Broadcast Receive State Update [{i}]:', 'green')} {value}" + ) + ) + except bumble.core.ProtocolError as error: + print( + color( + f'!!! Failed to subscribe to Broadcast Receive State characteristic:', + 'red', + ), + error, + ) + value = await broadcast_receive_state.read_value() + print( + f'{color(f"Initial Broadcast Receive State [{i}]:", "green")} {value}' + ) + + if command == 'monitor-state': + await peer.sustain() + return + + if command == 'add-source': + # Find the requested broadcast + await bass.remote_scan_started() + if broadcast_name: + print(color('Scanning for broadcast:', 'cyan'), broadcast_name) + else: + print(color('Scanning for any broadcast', 'cyan')) + broadcast = await find_broadcast_by_name(device, broadcast_name) + + if broadcast.broadcast_audio_announcement is None: + print(color('No broadcast audio announcement found', 'red')) + return + + if ( + broadcast.basic_audio_announcement is None + or not broadcast.basic_audio_announcement.subgroups + ): + print(color('No subgroups found', 'red')) + return + + # Add the source + print(color('Adding source:', 'blue'), broadcast.sync.advertiser_address) + await bass.add_source( + broadcast.sync.advertiser_address, + broadcast.sync.sid, + broadcast.broadcast_audio_announcement.broadcast_id, + bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_AVAILABLE, + 0xFFFF, + [ + bumble.profiles.bass.SubgroupInfo( + bumble.profiles.bass.SubgroupInfo.ANY_BIS, + bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), + ) + ], + ) + + # Initiate a PA Sync Transfer + await broadcast.sync.transfer(peer.connection) + + # Notify the sink that we're done scanning. + await bass.remote_scan_stopped() + + await peer.sustain() + return + + if command == 'modify-source': + if source_id is None: + print(color('!!! modify-source requires --source-id')) + return + + # Find the requested broadcast + await bass.remote_scan_started() + if broadcast_name: + print(color('Scanning for broadcast:', 'cyan'), broadcast_name) + else: + print(color('Scanning for any broadcast', 'cyan')) + broadcast = await find_broadcast_by_name(device, broadcast_name) + + if broadcast.broadcast_audio_announcement is None: + print(color('No broadcast audio announcement found', 'red')) + return + + if ( + broadcast.basic_audio_announcement is None + or not broadcast.basic_audio_announcement.subgroups + ): + print(color('No subgroups found', 'red')) + return + + # Modify the source + print( + color('Modifying source:', 'blue'), + source_id, + ) + await bass.modify_source( + source_id, + bumble.profiles.bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 0xFFFF, + [ + bumble.profiles.bass.SubgroupInfo( + bumble.profiles.bass.SubgroupInfo.ANY_BIS, + bytes(broadcast.basic_audio_announcement.subgroups[0].metadata), + ) + ], + ) + await peer.sustain() + return + + if command == 'remove-source': + if source_id is None: + print(color('!!! remove-source requires --source-id')) + return + + # Remove the source + print(color('Removing source:', 'blue'), source_id) + await bass.remove_source(source_id) + await peer.sustain() + return + + print(color(f'!!! invalid command {command}')) + + +async def run_pair(transport: str, address: str) -> None: + async with create_device(transport) as device: + + # Connect to the server + print(f'=== Connecting to {address}...') + async with device.connect_as_gatt(address) as peer: + print(f'=== Connected to {peer}') + + print("+++ Initiating pairing...") + await peer.connection.pair() + print("+++ Paired") + + +def run_async(async_command: Coroutine) -> None: + try: + asyncio.run(async_command) + except bumble.core.ProtocolError as error: + if error.error_namespace == 'att' and error.error_code in list( + bumble.profiles.bass.ApplicationError + ): + message = bumble.profiles.bass.ApplicationError(error.error_code).name + else: + message = str(error) + + print( + color('!!! An error occurred while executing the command:', 'red'), message + ) # ----------------------------------------------------------------------------- @@ -379,7 +628,7 @@ def auracast( ctx.ensure_object(dict) -@auracast.command('discover-broadcasts') +@auracast.command('scan') @click.option( '--filter-duplicates', is_flag=True, default=False, help='Filter duplicates' ) @@ -387,14 +636,50 @@ def auracast( '--sync-timeout', metavar='SYNC_TIMEOUT', type=float, - default=5.0, + default=AURACAST_DEFAULT_SYNC_TIMEOUT, help='Sync timeout (in seconds)', ) @click.argument('transport') @click.pass_context -def discover_broadcasts(ctx, filter_duplicates, sync_timeout, transport): - """Discover public broadcasts""" - asyncio.run(run_discover_broadcasts(filter_duplicates, sync_timeout, transport)) +def scan(ctx, filter_duplicates, sync_timeout, transport): + """Scan for public broadcasts""" + run_async(run_scan(filter_duplicates, sync_timeout, transport)) + + +@auracast.command('assist') +@click.option( + '--broadcast-name', + metavar='BROADCAST_NAME', + help='Broadcast Name to tune to', +) +@click.option( + '--source-id', + metavar='SOURCE_ID', + type=int, + help='Source ID (for remove-source command)', +) +@click.option( + '--command', + type=click.Choice( + ['monitor-state', 'add-source', 'modify-source', 'remove-source'] + ), + required=True, +) +@click.argument('transport') +@click.argument('address') +@click.pass_context +def assist(ctx, broadcast_name, source_id, command, transport, address): + """Scan for broadcasts on behalf of a audio server""" + run_async(run_assist(broadcast_name, source_id, command, transport, address)) + + +@auracast.command('pair') +@click.argument('transport') +@click.argument('address') +@click.pass_context +def pair(ctx, transport, address): + """Pair with an audio server""" + run_async(run_pair(transport, address)) def main(): diff --git a/apps/device_info.py b/apps/device_info.py new file mode 100644 index 00000000..df18c65d --- /dev/null +++ b/apps/device_info.py @@ -0,0 +1,230 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import os +import logging +from typing import Callable, Iterable, Optional + +import click + +from bumble.core import ProtocolError +from bumble.colors import color +from bumble.device import Device, Peer +from bumble.gatt import Service +from bumble.profiles.device_information_service import DeviceInformationServiceProxy +from bumble.profiles.battery_service import BatteryServiceProxy +from bumble.profiles.gap import GenericAccessServiceProxy +from bumble.profiles.tmap import TelephonyAndMediaAudioServiceProxy +from bumble.transport import open_transport_or_link + + +# ----------------------------------------------------------------------------- +async def try_show(function: Callable, *args, **kwargs) -> None: + try: + await function(*args, **kwargs) + except ProtocolError as error: + print(color('ERROR:', 'red'), error) + + +# ----------------------------------------------------------------------------- +def show_services(services: Iterable[Service]) -> None: + for service in services: + print(color(str(service), 'cyan')) + + for characteristic in service.characteristics: + print(color(' ' + str(characteristic), 'magenta')) + + +# ----------------------------------------------------------------------------- +async def show_gap_information( + gap_service: GenericAccessServiceProxy, +): + print(color('### Generic Access Profile', 'yellow')) + + if gap_service.device_name: + print( + color(' Device Name:', 'green'), + await gap_service.device_name.read_value(), + ) + + if gap_service.appearance: + print( + color(' Appearance: ', 'green'), + await gap_service.appearance.read_value(), + ) + + print() + + +# ----------------------------------------------------------------------------- +async def show_device_information( + device_information_service: DeviceInformationServiceProxy, +): + print(color('### Device Information', 'yellow')) + + if device_information_service.manufacturer_name: + print( + color(' Manufacturer Name:', 'green'), + await device_information_service.manufacturer_name.read_value(), + ) + + if device_information_service.model_number: + print( + color(' Model Number: ', 'green'), + await device_information_service.model_number.read_value(), + ) + + if device_information_service.serial_number: + print( + color(' Serial Number: ', 'green'), + await device_information_service.serial_number.read_value(), + ) + + if device_information_service.firmware_revision: + print( + color(' Firmware Revision:', 'green'), + await device_information_service.firmware_revision.read_value(), + ) + + print() + + +# ----------------------------------------------------------------------------- +async def show_battery_level( + battery_service: BatteryServiceProxy, +): + print(color('### Battery Information', 'yellow')) + + if battery_service.battery_level: + print( + color(' Battery Level:', 'green'), + await battery_service.battery_level.read_value(), + ) + + print() + + +# ----------------------------------------------------------------------------- +async def show_tmas( + tmas: TelephonyAndMediaAudioServiceProxy, +): + print(color('### Telephony And Media Audio Service', 'yellow')) + + if tmas.role: + print( + color(' Role:', 'green'), + await tmas.role.read_value(), + ) + + print() + + +# ----------------------------------------------------------------------------- +async def show_device_info(peer, done: Optional[asyncio.Future]) -> None: + try: + # Discover all services + print(color('### Discovering Services and Characteristics', 'magenta')) + await peer.discover_services() + for service in peer.services: + await service.discover_characteristics() + + print(color('=== Services ===', 'yellow')) + show_services(peer.services) + print() + + if gap_service := peer.create_service_proxy(GenericAccessServiceProxy): + await try_show(show_gap_information, gap_service) + + if device_information_service := peer.create_service_proxy( + DeviceInformationServiceProxy + ): + await try_show(show_device_information, device_information_service) + + if battery_service := peer.create_service_proxy(BatteryServiceProxy): + await try_show(show_battery_level, battery_service) + + if tmas := peer.create_service_proxy(TelephonyAndMediaAudioServiceProxy): + await try_show(show_tmas, tmas) + + if done is not None: + done.set_result(None) + except asyncio.CancelledError: + print(color('!!! Operation canceled', 'red')) + + +# ----------------------------------------------------------------------------- +async def async_main(device_config, encrypt, transport, address_or_name): + async with await open_transport_or_link(transport) as (hci_source, hci_sink): + + # Create a device + if device_config: + device = Device.from_config_file_with_hci( + device_config, hci_source, hci_sink + ) + else: + device = Device.with_hci( + 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink + ) + await device.power_on() + + if address_or_name: + # Connect to the target peer + print(color('>>> Connecting...', 'green')) + connection = await device.connect(address_or_name) + print(color('>>> Connected', 'green')) + + # Encrypt the connection if required + if encrypt: + print(color('+++ Encrypting connection...', 'blue')) + await connection.encrypt() + print(color('+++ Encryption established', 'blue')) + + await show_device_info(Peer(connection), None) + else: + # Wait for a connection + done = asyncio.get_running_loop().create_future() + device.on( + 'connection', + lambda connection: asyncio.create_task( + show_device_info(Peer(connection), done) + ), + ) + await device.start_advertising(auto_restart=True) + + print(color('### Waiting for connection...', 'blue')) + await done + + +# ----------------------------------------------------------------------------- +@click.command() +@click.option('--device-config', help='Device configuration', type=click.Path()) +@click.option('--encrypt', help='Encrypt the connection', is_flag=True, default=False) +@click.argument('transport') +@click.argument('address-or-name', required=False) +def main(device_config, encrypt, transport, address_or_name): + """ + Dump the GATT database on a remote device. If ADDRESS_OR_NAME is not specified, + wait for an incoming connection. + """ + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(async_main(device_config, encrypt, transport, address_or_name)) + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + main() diff --git a/apps/gatt_dump.py b/apps/gatt_dump.py index a3205c00..3b3e874e 100644 --- a/apps/gatt_dump.py +++ b/apps/gatt_dump.py @@ -75,11 +75,15 @@ async def async_main(device_config, encrypt, transport, address_or_name): if address_or_name: # Connect to the target peer + print(color('>>> Connecting...', 'green')) connection = await device.connect(address_or_name) + print(color('>>> Connected', 'green')) # Encrypt the connection if required if encrypt: + print(color('+++ Encrypting connection...', 'blue')) await connection.encrypt() + print(color('+++ Encryption established', 'blue')) await dump_gatt_db(Peer(connection), None) else: diff --git a/apps/lea_unicast/app.py b/apps/lea_unicast/app.py index ae3b4422..5885dab3 100644 --- a/apps/lea_unicast/app.py +++ b/apps/lea_unicast/app.py @@ -33,7 +33,6 @@ import wasmtime import wasmtime.loader import liblc3 # type: ignore -import logging import click import aiohttp.web @@ -43,7 +42,7 @@ from bumble.colors import color from bumble.device import Device, DeviceConfiguration, AdvertisingParameters from bumble.transport import open_transport -from bumble.profiles import bap +from bumble.profiles import ascs, bap, pacs from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket # ----------------------------------------------------------------------------- @@ -57,8 +56,8 @@ DEFAULT_UI_PORT = 7654 -def _sink_pac_record() -> bap.PacRecord: - return bap.PacRecord( +def _sink_pac_record() -> pacs.PacRecord: + return pacs.PacRecord( coding_format=CodingFormat(CodecID.LC3), codec_specific_capabilities=bap.CodecSpecificCapabilities( supported_sampling_frequencies=( @@ -79,8 +78,8 @@ def _sink_pac_record() -> bap.PacRecord: ) -def _source_pac_record() -> bap.PacRecord: - return bap.PacRecord( +def _source_pac_record() -> pacs.PacRecord: + return pacs.PacRecord( coding_format=CodingFormat(CodecID.LC3), codec_specific_capabilities=bap.CodecSpecificCapabilities( supported_sampling_frequencies=( @@ -447,7 +446,7 @@ async def run(self) -> None: ) self.device.add_service( - bap.PublishedAudioCapabilitiesService( + pacs.PublishedAudioCapabilitiesService( supported_source_context=bap.ContextType(0xFFFF), available_source_context=bap.ContextType(0xFFFF), supported_sink_context=bap.ContextType(0xFFFF), # All context types @@ -461,10 +460,10 @@ async def run(self) -> None: ) ) - ascs = bap.AudioStreamControlService( + ascs_service = ascs.AudioStreamControlService( self.device, sink_ase_id=[1], source_ase_id=[2] ) - self.device.add_service(ascs) + self.device.add_service(ascs_service) advertising_data = bytes( AdvertisingData( @@ -479,13 +478,13 @@ async def run(self) -> None: ), ( AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, - bytes(bap.PublishedAudioCapabilitiesService.UUID), + bytes(pacs.PublishedAudioCapabilitiesService.UUID), ), ] ) ) + bytes(bap.UnicastServerAdvertisingData()) - def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): + def on_pdu(pdu: HCI_IsoDataPacket, ase: ascs.AseStateMachine): codec_config = ase.codec_specific_configuration assert isinstance(codec_config, bap.CodecSpecificConfiguration) pcm = decode( @@ -495,12 +494,12 @@ def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): ) self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) - def on_ase_state_change(ase: bap.AseStateMachine) -> None: - if ase.state == bap.AseStateMachine.State.STREAMING: + def on_ase_state_change(ase: ascs.AseStateMachine) -> None: + if ase.state == ascs.AseStateMachine.State.STREAMING: codec_config = ase.codec_specific_configuration assert isinstance(codec_config, bap.CodecSpecificConfiguration) assert ase.cis_link - if ase.role == bap.AudioRole.SOURCE: + if ase.role == ascs.AudioRole.SOURCE: ase.cis_link.abort_on( 'disconnection', lc3_source_task( @@ -516,10 +515,10 @@ def on_ase_state_change(ase: bap.AseStateMachine) -> None: ) else: ase.cis_link.sink = functools.partial(on_pdu, ase=ase) - elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED: + elif ase.state == ascs.AseStateMachine.State.CODEC_CONFIGURED: codec_config = ase.codec_specific_configuration assert isinstance(codec_config, bap.CodecSpecificConfiguration) - if ase.role == bap.AudioRole.SOURCE: + if ase.role == ascs.AudioRole.SOURCE: setup_encoders( codec_config.sampling_frequency.hz, codec_config.frame_duration.us, @@ -532,7 +531,7 @@ def on_ase_state_change(ase: bap.AseStateMachine) -> None: codec_config.audio_channel_allocation.channel_count, ) - for ase in ascs.ase_state_machines.values(): + for ase in ascs_service.ase_state_machines.values(): ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) await self.device.power_on() diff --git a/bumble/core.py b/bumble/core.py index a566650e..f6d42dd5 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -148,6 +148,10 @@ class InvalidOperationError(BaseBumbleError, RuntimeError): """Invalid Operation Error""" +class NotSupportedError(BaseBumbleError, RuntimeError): + """Not Supported""" + + class OutOfResourcesError(BaseBumbleError, RuntimeError): """Out of Resources Error""" diff --git a/bumble/device.py b/bumble/device.py index 9a9c4dbc..b04f5ddc 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -113,6 +113,7 @@ HCI_LE_Periodic_Advertising_Create_Sync_Command, HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command, HCI_LE_Periodic_Advertising_Report_Event, + HCI_LE_Periodic_Advertising_Sync_Transfer_Command, HCI_LE_Periodic_Advertising_Terminate_Sync_Command, HCI_LE_Enable_Encryption_Command, HCI_LE_Extended_Advertising_Report_Event, @@ -174,7 +175,7 @@ phy_list_to_bits, ) from .host import Host -from .gap import GenericAccessService +from .profiles.gap import GenericAccessService from .core import ( BT_BR_EDR_TRANSPORT, BT_CENTRAL_ROLE, @@ -189,6 +190,7 @@ InvalidArgumentError, InvalidOperationError, InvalidStateError, + NotSupportedError, OutOfResourcesError, UnreachableError, ) @@ -970,20 +972,25 @@ async def terminate(self) -> None: response = await self.device.send_command( HCI_LE_Periodic_Advertising_Create_Sync_Cancel_Command(), ) - if response.status == HCI_SUCCESS: + if response.return_parameters == HCI_SUCCESS: if self in self.device.periodic_advertising_syncs: self.device.periodic_advertising_syncs.remove(self) return if self.state in (self.State.ESTABLISHED, self.State.ERROR, self.State.LOST): self.state = self.State.TERMINATED - await self.device.send_command( - HCI_LE_Periodic_Advertising_Terminate_Sync_Command( - sync_handle=self.sync_handle + if self.sync_handle is not None: + await self.device.send_command( + HCI_LE_Periodic_Advertising_Terminate_Sync_Command( + sync_handle=self.sync_handle + ) ) - ) self.device.periodic_advertising_syncs.remove(self) + async def transfer(self, connection: Connection, service_data: int = 0) -> None: + if self.sync_handle is not None: + await connection.transfer_periodic_sync(self.sync_handle, service_data) + def on_establishment( self, status, @@ -1209,8 +1216,13 @@ def get_characteristics_by_uuid( return self.gatt_client.get_characteristics_by_uuid(uuid, service) - def create_service_proxy(self, proxy_class: Type[_PROXY_CLASS]) -> _PROXY_CLASS: - return cast(_PROXY_CLASS, proxy_class.from_client(self.gatt_client)) + def create_service_proxy( + self, proxy_class: Type[_PROXY_CLASS] + ) -> Optional[_PROXY_CLASS]: + if proxy := proxy_class.from_client(self.gatt_client): + return cast(_PROXY_CLASS, proxy) + + return None async def discover_service_and_create_proxy( self, proxy_class: Type[_PROXY_CLASS] @@ -1495,11 +1507,9 @@ async def sustain(self, timeout: Optional[float] = None) -> None: try: await asyncio.wait_for(self.device.abort_on('flush', abort), timeout) - except asyncio.TimeoutError: - pass - - self.remove_listener('disconnection', abort.set_result) - self.remove_listener('disconnection_failure', abort.set_exception) + finally: + self.remove_listener('disconnection', abort.set_result) + self.remove_listener('disconnection_failure', abort.set_exception) async def set_data_length(self, tx_octets, tx_time) -> None: return await self.device.set_data_length(self, tx_octets, tx_time) @@ -1530,6 +1540,11 @@ async def get_rssi(self): async def get_phy(self): return await self.device.get_connection_phy(self) + async def transfer_periodic_sync( + self, sync_handle: int, service_data: int = 0 + ) -> None: + await self.device.transfer_periodic_sync(self, sync_handle, service_data) + # [Classic only] async def request_remote_name(self): return await self.device.request_remote_name(self) @@ -2384,6 +2399,10 @@ def supports_le_phy(self, phy: int) -> bool: def supports_le_extended_advertising(self): return self.supports_le_features(LeFeatureMask.LE_EXTENDED_ADVERTISING) + @property + def supports_le_periodic_advertising(self): + return self.supports_le_features(LeFeatureMask.LE_PERIODIC_ADVERTISING) + async def start_advertising( self, advertising_type: AdvertisingType = AdvertisingType.UNDIRECTED_CONNECTABLE_SCANNABLE, @@ -2786,6 +2805,10 @@ async def create_periodic_advertising_sync( sync_timeout: float = DEVICE_DEFAULT_PERIODIC_ADVERTISING_SYNC_TIMEOUT, filter_duplicates: bool = False, ) -> PeriodicAdvertisingSync: + # Check that the controller supports the feature. + if not self.supports_le_periodic_advertising: + raise NotSupportedError() + # Check that there isn't already an equivalent entry if any( sync.advertiser_address == advertiser_address and sync.sid == sid @@ -2983,18 +3006,47 @@ async def connect( ] = None, own_address_type: int = OwnAddressType.RANDOM, timeout: Optional[float] = DEVICE_DEFAULT_CONNECT_TIMEOUT, + always_resolve: bool = False, ) -> Connection: ''' Request a connection to a peer. - When transport is BLE, this method cannot be called if there is already a + + When the transport is BLE, this method cannot be called if there is already a pending connection. - connection_parameters_preferences: (BLE only, ignored for BR/EDR) - * None: use the 1M PHY with default parameters - * map: each entry has a PHY as key and a ConnectionParametersPreferences - object as value + Args: + peer_address: + Address or name of the device to connect to. + If a string is passed: + If the string is an address followed by a `@` suffix, the `always_resolve` + argument is implicitly set to True, so the connection is made to the + address after resolution. + If the string is any other address, the connection is made to that + address (with or without address resolution, depending on the + `always_resolve` argument). + For any other string, a scan for devices using that string as their name + is initiated, and a connection to the first matching device's address + is made. In that case, `always_resolve` is ignored. + + connection_parameters_preferences: + (BLE only, ignored for BR/EDR) + * None: use the 1M PHY with default parameters + * map: each entry has a PHY as key and a ConnectionParametersPreferences + object as value - own_address_type: (BLE only) + own_address_type: + (BLE only, ignored for BR/EDR) + OwnAddressType.RANDOM to use this device's random address, or + OwnAddressType.PUBLIC to use this device's public address. + + timeout: + Maximum time to wait for a connection to be established, in seconds. + Pass None for an unlimited time. + + always_resolve: + (BLE only, ignored for BR/EDR) + If True, always initiate a scan, resolving addresses, and connect to the + address that resolves to `peer_address`. ''' # Check parameters @@ -3013,11 +3065,19 @@ async def connect( if isinstance(peer_address, str): try: - peer_address = Address.from_string_for_transport( - peer_address, transport - ) - except InvalidArgumentError: + if transport == BT_LE_TRANSPORT and peer_address.endswith('@'): + peer_address = Address.from_string_for_transport( + peer_address[:-1], transport + ) + always_resolve = True + logger.debug('forcing address resolution') + else: + peer_address = Address.from_string_for_transport( + peer_address, transport + ) + except (InvalidArgumentError, ValueError): # If the address is not parsable, assume it is a name instead + always_resolve = False logger.debug('looking for peer by name') peer_address = await self.find_peer_by_name( peer_address, transport @@ -3032,6 +3092,12 @@ async def connect( assert isinstance(peer_address, Address) + if transport == BT_LE_TRANSPORT and always_resolve: + logger.debug('resolving address') + peer_address = await self.find_peer_by_identity_address( + peer_address + ) # TODO: timeout + def on_connection(connection): if transport == BT_LE_TRANSPORT or ( # match BR/EDR connection event against peer address @@ -3533,15 +3599,26 @@ async def set_default_phy(self, tx_phys=None, rx_phys=None): check_result=True, ) + async def transfer_periodic_sync( + self, connection: Connection, sync_handle: int, service_data: int = 0 + ) -> None: + return await self.send_command( + HCI_LE_Periodic_Advertising_Sync_Transfer_Command( + connection_handle=connection.handle, + service_data=service_data, + sync_handle=sync_handle, + ), + check_result=True, + ) + async def find_peer_by_name(self, name, transport=BT_LE_TRANSPORT): """ - Scan for a peer with a give name and return its address and transport + Scan for a peer with a given name and return its address. """ # Create a future to wait for an address to be found peer_address = asyncio.get_running_loop().create_future() - # Scan/inquire with event handlers to handle scan/inquiry results def on_peer_found(address, ad_data): local_name = ad_data.get(AdvertisingData.COMPLETE_LOCAL_NAME, raw=True) if local_name is None: @@ -3550,13 +3627,13 @@ def on_peer_found(address, ad_data): if local_name.decode('utf-8') == name: peer_address.set_result(address) - handler = None + listener = None was_scanning = self.scanning was_discovering = self.discovering try: if transport == BT_LE_TRANSPORT: event_name = 'advertisement' - handler = self.on( + listener = self.on( event_name, lambda advertisement: on_peer_found( advertisement.address, advertisement.data @@ -3568,7 +3645,7 @@ def on_peer_found(address, ad_data): elif transport == BT_BR_EDR_TRANSPORT: event_name = 'inquiry_result' - handler = self.on( + listener = self.on( event_name, lambda address, class_of_device, eir_data, rssi: on_peer_found( address, eir_data @@ -3582,14 +3659,60 @@ def on_peer_found(address, ad_data): return await self.abort_on('flush', peer_address) finally: - if handler is not None: - self.remove_listener(event_name, handler) + if listener is not None: + self.remove_listener(event_name, listener) if transport == BT_LE_TRANSPORT and not was_scanning: await self.stop_scanning() elif transport == BT_BR_EDR_TRANSPORT and not was_discovering: await self.stop_discovery() + async def find_peer_by_identity_address(self, identity_address: Address) -> Address: + """ + Scan for a peer with a resolvable address that can be resolved to a given + identity address. + """ + + # Create a future to wait for an address to be found + peer_address = asyncio.get_running_loop().create_future() + + def on_peer_found(address, _): + if address == identity_address: + if not peer_address.done(): + logger.debug(f'*** Matching public address found for {address}') + peer_address.set_result(address) + return + + if address.is_resolvable: + resolved_address = self.address_resolver.resolve(address) + if resolved_address == identity_address: + if not peer_address.done(): + logger.debug(f'*** Matching identity found for {address}') + peer_address.set_result(address) + return + + was_scanning = self.scanning + event_name = 'advertisement' + listener = None + try: + listener = self.on( + event_name, + lambda advertisement: on_peer_found( + advertisement.address, advertisement.data + ), + ) + + if not self.scanning: + await self.start_scanning(filter_duplicates=True) + + return await self.abort_on('flush', peer_address) + finally: + if listener is not None: + self.remove_listener(event_name, listener) + + if not was_scanning: + await self.stop_scanning() + @property def pairing_config_factory(self) -> Callable[[Connection], PairingConfig]: return self.smp_manager.pairing_config_factory @@ -3716,6 +3839,7 @@ def on_encryption_failure(error_code): if self.keystore is None: raise InvalidOperationError('no key store') + logger.debug(f'Looking up key for {connection.peer_address}') keys = await self.keystore.get(str(connection.peer_address)) if keys is None: raise InvalidOperationError('keys not found in key store') @@ -4194,6 +4318,12 @@ def on_connection( role: int, connection_parameters: ConnectionParameters, ) -> None: + # Convert all-zeros addresses into None. + if self_resolvable_address == Address.ANY_RANDOM: + self_resolvable_address = None + if peer_resolvable_address == Address.ANY_RANDOM: + peer_resolvable_address = None + logger.debug( f'*** Connection: [0x{connection_handle:04X}] ' f'{peer_address} {"" if role is None else HCI_Constant.role_name(role)}' @@ -4253,12 +4383,6 @@ def on_connection( else self.random_address ) - # Convert all-zeros addresses into None. - if self_resolvable_address == Address.ANY_RANDOM: - self_resolvable_address = None - if peer_resolvable_address == Address.ANY_RANDOM: - peer_resolvable_address = None - # Create a connection. connection = Connection( self, diff --git a/bumble/gatt.py b/bumble/gatt.py index 896cec01..438c17cf 100644 --- a/bumble/gatt.py +++ b/bumble/gatt.py @@ -39,7 +39,7 @@ ) from bumble.colors import color -from bumble.core import UUID +from bumble.core import BaseBumbleError, UUID from bumble.att import Attribute, AttributeValue if TYPE_CHECKING: @@ -320,6 +320,11 @@ def show_services(services: Iterable[Service]) -> None: print(color(' ' + str(descriptor), 'green')) +# ----------------------------------------------------------------------------- +class InvalidServiceError(BaseBumbleError): + """The service is not compliant with the spec/profile""" + + # ----------------------------------------------------------------------------- class Service(Attribute): ''' diff --git a/bumble/gatt_client.py b/bumble/gatt_client.py index 6d4dcf6a..f2b8df65 100644 --- a/bumble/gatt_client.py +++ b/bumble/gatt_client.py @@ -253,7 +253,7 @@ class ProfileServiceProxy: SERVICE_CLASS: Type[TemplateService] @classmethod - def from_client(cls, client: Client) -> ProfileServiceProxy: + def from_client(cls, client: Client) -> Optional[ProfileServiceProxy]: return ServiceProxy.from_client(cls, client, cls.SERVICE_CLASS.UUID) @@ -283,6 +283,8 @@ def __init__(self, connection: Connection) -> None: self.services = [] self.cached_values = {} + connection.on('disconnection', self.on_disconnection) + def send_gatt_pdu(self, pdu: bytes) -> None: self.connection.send_l2cap_pdu(ATT_CID, pdu) @@ -405,7 +407,7 @@ def on_service_discovered(self, service): if not already_known: self.services.append(service) - async def discover_services(self, uuids: Iterable[UUID] = []) -> List[ServiceProxy]: + async def discover_services(self, uuids: Iterable[UUID] = ()) -> List[ServiceProxy]: ''' See Vol 3, Part G - 4.4.1 Discover All Primary Services ''' @@ -1072,6 +1074,10 @@ async def write_value( ) ) + def on_disconnection(self, _) -> None: + if self.pending_response and not self.pending_response.done(): + self.pending_response.cancel() + def on_gatt_pdu(self, att_pdu: ATT_PDU) -> None: logger.debug( f'GATT Response to client: [0x{self.connection.handle:04X}] {att_pdu}' diff --git a/bumble/hci.py b/bumble/hci.py index 7e83f2ff..af39976c 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -4529,18 +4529,6 @@ class HCI_LE_Periodic_Advertising_Terminate_Sync_Command(HCI_Command): ''' -# ----------------------------------------------------------------------------- -@HCI_Command.command([('sync_handle', 2), ('enable', 1)]) -class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command): - ''' - See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command - ''' - - class Enable(enum.IntFlag): - REPORTING_ENABLED = 1 << 0 - DUPLICATE_FILTERING_ENABLED = 1 << 1 - - # ----------------------------------------------------------------------------- @HCI_Command.command( [ @@ -4576,6 +4564,32 @@ def privacy_mode_name(cls, privacy_mode): return name_or_number(cls.PRIVACY_MODE_NAMES, privacy_mode) +# ----------------------------------------------------------------------------- +@HCI_Command.command([('sync_handle', 2), ('enable', 1)]) +class HCI_LE_Set_Periodic_Advertising_Receive_Enable_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.8.88 LE Set Periodic Advertising Receive Enable Command + ''' + + class Enable(enum.IntFlag): + REPORTING_ENABLED = 1 << 0 + DUPLICATE_FILTERING_ENABLED = 1 << 1 + + +# ----------------------------------------------------------------------------- +@HCI_Command.command( + fields=[('connection_handle', 2), ('service_data', 2), ('sync_handle', 2)], + return_parameters_fields=[ + ('status', STATUS_SPEC), + ('connection_handle', 2), + ], +) +class HCI_LE_Periodic_Advertising_Sync_Transfer_Command(HCI_Command): + ''' + See Bluetooth spec @ 7.8.89 LE Periodic Advertising Sync Transfer Command + ''' + + # ----------------------------------------------------------------------------- @HCI_Command.command( fields=[ diff --git a/bumble/profiles/ascs.py b/bumble/profiles/ascs.py new file mode 100644 index 00000000..35f45941 --- /dev/null +++ b/bumble/profiles/ascs.py @@ -0,0 +1,739 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for + +"""LE Audio - Audio Stream Control Service""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import enum +import logging +import struct +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union + +from bumble import colors +from bumble.profiles.bap import CodecSpecificConfiguration +from bumble.profiles import le_audio +from bumble import device +from bumble import gatt +from bumble import gatt_client +from bumble import hci + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# ASE Operations +# ----------------------------------------------------------------------------- + + +class ASE_Operation: + ''' + See Audio Stream Control Service - 5 ASE Control operations. + ''' + + classes: Dict[int, Type[ASE_Operation]] = {} + op_code: int + name: str + fields: Optional[Sequence[Any]] = None + ase_id: List[int] + + class Opcode(enum.IntEnum): + # fmt: off + CONFIG_CODEC = 0x01 + CONFIG_QOS = 0x02 + ENABLE = 0x03 + RECEIVER_START_READY = 0x04 + DISABLE = 0x05 + RECEIVER_STOP_READY = 0x06 + UPDATE_METADATA = 0x07 + RELEASE = 0x08 + + @staticmethod + def from_bytes(pdu: bytes) -> ASE_Operation: + op_code = pdu[0] + + cls = ASE_Operation.classes.get(op_code) + if cls is None: + instance = ASE_Operation(pdu) + instance.name = ASE_Operation.Opcode(op_code).name + instance.op_code = op_code + return instance + self = cls.__new__(cls) + ASE_Operation.__init__(self, pdu) + if self.fields is not None: + self.init_from_bytes(pdu, 1) + return self + + @staticmethod + def subclass(fields): + def inner(cls: Type[ASE_Operation]): + try: + operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] + cls.name = operation.name + cls.op_code = operation + except: + raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') + cls.fields = fields + + # Register a factory for this class + ASE_Operation.classes[cls.op_code] = cls + + return cls + + return inner + + def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: + if self.fields is not None and kwargs: + hci.HCI_Object.init_from_fields(self, self.fields, kwargs) + if pdu is None: + pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( + kwargs, self.fields + ) + self.pdu = pdu + + def init_from_bytes(self, pdu: bytes, offset: int): + return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) + + def __bytes__(self) -> bytes: + return self.pdu + + def __str__(self) -> str: + result = f'{colors.color(self.name, "yellow")} ' + if fields := getattr(self, 'fields', None): + result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') + else: + if len(self.pdu) > 1: + result += f': {self.pdu.hex()}' + return result + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('target_latency', 1), + ('target_phy', 1), + ('codec_id', hci.CodingFormat.parse_from_bytes), + ('codec_specific_configuration', 'v'), + ], + ] +) +class ASE_Config_Codec(ASE_Operation): + ''' + See Audio Stream Control Service 5.1 - Config Codec Operation + ''' + + target_latency: List[int] + target_phy: List[int] + codec_id: List[hci.CodingFormat] + codec_specific_configuration: List[bytes] + + +@ASE_Operation.subclass( + [ + [ + ('ase_id', 1), + ('cig_id', 1), + ('cis_id', 1), + ('sdu_interval', 3), + ('framing', 1), + ('phy', 1), + ('max_sdu', 2), + ('retransmission_number', 1), + ('max_transport_latency', 2), + ('presentation_delay', 3), + ], + ] +) +class ASE_Config_QOS(ASE_Operation): + ''' + See Audio Stream Control Service 5.2 - Config Qos Operation + ''' + + cig_id: List[int] + cis_id: List[int] + sdu_interval: List[int] + framing: List[int] + phy: List[int] + max_sdu: List[int] + retransmission_number: List[int] + max_transport_latency: List[int] + presentation_delay: List[int] + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Enable(ASE_Operation): + ''' + See Audio Stream Control Service 5.3 - Enable Operation + ''' + + metadata: bytes + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Start_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.4 - Receiver Start Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Disable(ASE_Operation): + ''' + See Audio Stream Control Service 5.5 - Disable Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Receiver_Stop_Ready(ASE_Operation): + ''' + See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation + ''' + + +@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) +class ASE_Update_Metadata(ASE_Operation): + ''' + See Audio Stream Control Service 5.7 - Update Metadata Operation + ''' + + metadata: List[bytes] + + +@ASE_Operation.subclass([[('ase_id', 1)]]) +class ASE_Release(ASE_Operation): + ''' + See Audio Stream Control Service 5.8 - Release Operation + ''' + + +class AseResponseCode(enum.IntEnum): + # fmt: off + SUCCESS = 0x00 + UNSUPPORTED_OPCODE = 0x01 + INVALID_LENGTH = 0x02 + INVALID_ASE_ID = 0x03 + INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 + INVALID_ASE_DIRECTION = 0x05 + UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 + UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 + REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 + INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 + UNSUPPORTED_METADATA = 0x0A + REJECTED_METADATA = 0x0B + INVALID_METADATA = 0x0C + INSUFFICIENT_RESOURCES = 0x0D + UNSPECIFIED_ERROR = 0x0E + + +class AseReasonCode(enum.IntEnum): + # fmt: off + NONE = 0x00 + CODEC_ID = 0x01 + CODEC_SPECIFIC_CONFIGURATION = 0x02 + SDU_INTERVAL = 0x03 + FRAMING = 0x04 + PHY = 0x05 + MAXIMUM_SDU_SIZE = 0x06 + RETRANSMISSION_NUMBER = 0x07 + MAX_TRANSPORT_LATENCY = 0x08 + PRESENTATION_DELAY = 0x09 + INVALID_ASE_CIS_MAPPING = 0x0A + + +# ----------------------------------------------------------------------------- +class AudioRole(enum.IntEnum): + SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST + SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER + + +# ----------------------------------------------------------------------------- +class AseStateMachine(gatt.Characteristic): + class State(enum.IntEnum): + # fmt: off + IDLE = 0x00 + CODEC_CONFIGURED = 0x01 + QOS_CONFIGURED = 0x02 + ENABLING = 0x03 + STREAMING = 0x04 + DISABLING = 0x05 + RELEASING = 0x06 + + cis_link: Optional[device.CisLink] = None + + # Additional parameters in CODEC_CONFIGURED State + preferred_framing = 0 # Unframed PDU supported + preferred_phy = 0 + preferred_retransmission_number = 13 + preferred_max_transport_latency = 100 + supported_presentation_delay_min = 0 + supported_presentation_delay_max = 0 + preferred_presentation_delay_min = 0 + preferred_presentation_delay_max = 0 + codec_id = hci.CodingFormat(hci.CodecID.LC3) + codec_specific_configuration: Union[CodecSpecificConfiguration, bytes] = b'' + + # Additional parameters in QOS_CONFIGURED State + cig_id = 0 + cis_id = 0 + sdu_interval = 0 + framing = 0 + phy = 0 + max_sdu = 0 + retransmission_number = 0 + max_transport_latency = 0 + presentation_delay = 0 + + # Additional parameters in ENABLING, STREAMING, DISABLING State + metadata = le_audio.Metadata() + + def __init__( + self, + role: AudioRole, + ase_id: int, + service: AudioStreamControlService, + ) -> None: + self.service = service + self.ase_id = ase_id + self._state = AseStateMachine.State.IDLE + self.role = role + + uuid = ( + gatt.GATT_SINK_ASE_CHARACTERISTIC + if role == AudioRole.SINK + else gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + super().__init__( + uuid=uuid, + properties=gatt.Characteristic.Properties.READ + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.READABLE, + value=gatt.CharacteristicValue(read=self.on_read), + ) + + self.service.device.on('cis_request', self.on_cis_request) + self.service.device.on('cis_establishment', self.on_cis_establishment) + + def on_cis_request( + self, + acl_connection: device.Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ) -> None: + if ( + cig_id == self.cig_id + and cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + acl_connection.abort_on( + 'flush', self.service.device.accept_cis_request(cis_handle) + ) + + def on_cis_establishment(self, cis_link: device.CisLink) -> None: + if ( + cis_link.cig_id == self.cig_id + and cis_link.cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + cis_link.on('disconnection', self.on_cis_disconnection) + + async def post_cis_established(): + await self.service.device.send_command( + hci.HCI_LE_Setup_ISO_Data_Path_Command( + connection_handle=cis_link.handle, + data_path_direction=self.role, + data_path_id=0x00, # Fixed HCI + codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), + controller_delay=0, + codec_configuration=b'', + ) + ) + if self.role == AudioRole.SINK: + self.state = self.State.STREAMING + await self.service.device.notify_subscribers(self, self.value) + + cis_link.acl_connection.abort_on('flush', post_cis_established()) + self.cis_link = cis_link + + def on_cis_disconnection(self, _reason) -> None: + self.cis_link = None + + def on_config_codec( + self, + target_latency: int, + target_phy: int, + codec_id: hci.CodingFormat, + codec_specific_configuration: bytes, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + self.State.IDLE, + self.State.CODEC_CONFIGURED, + self.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.max_transport_latency = target_latency + self.phy = target_phy + self.codec_id = codec_id + if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC: + self.codec_specific_configuration = codec_specific_configuration + else: + self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes( + codec_specific_configuration + ) + + self.state = self.State.CODEC_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_config_qos( + self, + cig_id: int, + cis_id: int, + sdu_interval: int, + framing: int, + phy: int, + max_sdu: int, + retransmission_number: int, + max_transport_latency: int, + presentation_delay: int, + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.CODEC_CONFIGURED, + AseStateMachine.State.QOS_CONFIGURED, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.cig_id = cig_id + self.cis_id = cis_id + self.sdu_interval = sdu_interval + self.framing = framing + self.phy = phy + self.max_sdu = max_sdu + self.retransmission_number = retransmission_number + self.max_transport_latency = max_transport_latency + self.presentation_delay = presentation_delay + + self.state = self.State.QOS_CONFIGURED + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.QOS_CONFIGURED: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + + self.metadata = le_audio.Metadata.from_bytes(metadata) + self.state = self.State.ENABLING + + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state != AseStateMachine.State.ENABLING: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.STREAMING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + if self.role == AudioRole.SINK: + self.state = self.State.QOS_CONFIGURED + else: + self.state = self.State.DISABLING + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: + if ( + self.role != AudioRole.SOURCE + or self.state != AseStateMachine.State.DISABLING + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.QOS_CONFIGURED + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_update_metadata( + self, metadata: bytes + ) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state not in ( + AseStateMachine.State.ENABLING, + AseStateMachine.State.STREAMING, + ): + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.metadata = le_audio.Metadata.from_bytes(metadata) + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: + if self.state == AseStateMachine.State.IDLE: + return ( + AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, + AseReasonCode.NONE, + ) + self.state = self.State.RELEASING + + async def remove_cis_async(): + await self.service.device.send_command( + hci.HCI_LE_Remove_ISO_Data_Path_Command( + connection_handle=self.cis_link.handle, + data_path_direction=self.role, + ) + ) + self.state = self.State.IDLE + await self.service.device.notify_subscribers(self, self.value) + + self.service.device.abort_on('flush', remove_cis_async()) + return (AseResponseCode.SUCCESS, AseReasonCode.NONE) + + @property + def state(self) -> State: + return self._state + + @state.setter + def state(self, new_state: State) -> None: + logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') + self._state = new_state + self.emit('state_change') + + @property + def value(self): + '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' + + if self.state == self.State.CODEC_CONFIGURED: + codec_specific_configuration_bytes = bytes( + self.codec_specific_configuration + ) + additional_parameters = ( + struct.pack( + ' bytes: + return self.value + + def __str__(self) -> str: + return ( + f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' + f'state={self._state.name})' + ) + + +# ----------------------------------------------------------------------------- +class AudioStreamControlService(gatt.TemplateService): + UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE + + ase_state_machines: Dict[int, AseStateMachine] + ase_control_point: gatt.Characteristic + _active_client: Optional[device.Connection] = None + + def __init__( + self, + device: device.Device, + source_ase_id: Sequence[int] = (), + sink_ase_id: Sequence[int] = (), + ) -> None: + self.device = device + self.ase_state_machines = { + **{ + id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) + for id in sink_ase_id + }, + **{ + id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) + for id in source_ase_id + }, + } # ASE state machines, by ASE ID + + self.ase_control_point = gatt.Characteristic( + uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.WRITE + | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE + | gatt.Characteristic.Properties.NOTIFY, + permissions=gatt.Characteristic.Permissions.WRITEABLE, + value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), + ) + + super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) + + def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): + if ase := self.ase_state_machines.get(ase_id): + handler = getattr(ase, 'on_' + opcode.name.lower()) + return (ase_id, *handler(*args)) + else: + return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + + def _on_client_disconnected(self, _reason: int) -> None: + for ase in self.ase_state_machines.values(): + ase.state = AseStateMachine.State.IDLE + self._active_client = None + + def on_write_ase_control_point(self, connection, data): + if not self._active_client and connection: + self._active_client = connection + connection.once('disconnection', self._on_client_disconnected) + + operation = ASE_Operation.from_bytes(data) + responses = [] + logger.debug(f'*** ASCS Write {operation} ***') + + if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: + for ase_id, *args in zip( + operation.ase_id, + operation.target_latency, + operation.target_phy, + operation.codec_id, + operation.codec_specific_configuration, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: + for ase_id, *args in zip( + operation.ase_id, + operation.cig_id, + operation.cis_id, + operation.sdu_interval, + operation.framing, + operation.phy, + operation.max_sdu, + operation.retransmission_number, + operation.max_transport_latency, + operation.presentation_delay, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.ENABLE, + ASE_Operation.Opcode.UPDATE_METADATA, + ): + for ase_id, *args in zip( + operation.ase_id, + operation.metadata, + ): + responses.append(self.on_operation(operation.op_code, ase_id, args)) + elif operation.op_code in ( + ASE_Operation.Opcode.RECEIVER_START_READY, + ASE_Operation.Opcode.DISABLE, + ASE_Operation.Opcode.RECEIVER_STOP_READY, + ASE_Operation.Opcode.RELEASE, + ): + for ase_id in operation.ase_id: + responses.append(self.on_operation(operation.op_code, ase_id, [])) + + control_point_notification = bytes( + [operation.op_code, len(responses)] + ) + b''.join(map(bytes, responses)) + self.device.abort_on( + 'flush', + self.device.notify_subscribers( + self.ase_control_point, control_point_notification + ), + ) + + for ase_id, *_ in responses: + if ase := self.ase_state_machines.get(ase_id): + self.device.abort_on( + 'flush', + self.device.notify_subscribers(ase, ase.value), + ) + + +# ----------------------------------------------------------------------------- +class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = AudioStreamControlService + + sink_ase: List[gatt_client.CharacteristicProxy] + source_ase: List[gatt_client.CharacteristicProxy] + ase_control_point: gatt_client.CharacteristicProxy + + def __init__(self, service_proxy: gatt_client.ServiceProxy): + self.service_proxy = service_proxy + + self.sink_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SINK_ASE_CHARACTERISTIC + ) + self.source_ase = service_proxy.get_characteristics_by_uuid( + gatt.GATT_SOURCE_ASE_CHARACTERISTIC + ) + self.ase_control_point = service_proxy.get_characteristics_by_uuid( + gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC + )[0] diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index 117e95e6..8a00eafe 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -24,15 +24,12 @@ import struct import functools import logging -from typing import Optional, List, Union, Type, Dict, Any, Tuple +from typing import List from typing_extensions import Self from bumble import core -from bumble import colors -from bumble import device from bumble import hci from bumble import gatt -from bumble import gatt_client from bumble import utils from bumble.profiles import le_audio @@ -251,231 +248,6 @@ class AnnouncementType(utils.OpenIntEnum): TARGETED = 0x01 -# ----------------------------------------------------------------------------- -# ASE Operations -# ----------------------------------------------------------------------------- - - -class ASE_Operation: - ''' - See Audio Stream Control Service - 5 ASE Control operations. - ''' - - classes: Dict[int, Type[ASE_Operation]] = {} - op_code: int - name: str - fields: Optional[Sequence[Any]] = None - ase_id: List[int] - - class Opcode(enum.IntEnum): - # fmt: off - CONFIG_CODEC = 0x01 - CONFIG_QOS = 0x02 - ENABLE = 0x03 - RECEIVER_START_READY = 0x04 - DISABLE = 0x05 - RECEIVER_STOP_READY = 0x06 - UPDATE_METADATA = 0x07 - RELEASE = 0x08 - - @staticmethod - def from_bytes(pdu: bytes) -> ASE_Operation: - op_code = pdu[0] - - cls = ASE_Operation.classes.get(op_code) - if cls is None: - instance = ASE_Operation(pdu) - instance.name = ASE_Operation.Opcode(op_code).name - instance.op_code = op_code - return instance - self = cls.__new__(cls) - ASE_Operation.__init__(self, pdu) - if self.fields is not None: - self.init_from_bytes(pdu, 1) - return self - - @staticmethod - def subclass(fields): - def inner(cls: Type[ASE_Operation]): - try: - operation = ASE_Operation.Opcode[cls.__name__[4:].upper()] - cls.name = operation.name - cls.op_code = operation - except: - raise KeyError(f'PDU name {cls.name} not found in Ase_Operation.Opcode') - cls.fields = fields - - # Register a factory for this class - ASE_Operation.classes[cls.op_code] = cls - - return cls - - return inner - - def __init__(self, pdu: Optional[bytes] = None, **kwargs) -> None: - if self.fields is not None and kwargs: - hci.HCI_Object.init_from_fields(self, self.fields, kwargs) - if pdu is None: - pdu = bytes([self.op_code]) + hci.HCI_Object.dict_to_bytes( - kwargs, self.fields - ) - self.pdu = pdu - - def init_from_bytes(self, pdu: bytes, offset: int): - return hci.HCI_Object.init_from_bytes(self, pdu, offset, self.fields) - - def __bytes__(self) -> bytes: - return self.pdu - - def __str__(self) -> str: - result = f'{colors.color(self.name, "yellow")} ' - if fields := getattr(self, 'fields', None): - result += ':\n' + hci.HCI_Object.format_fields(self.__dict__, fields, ' ') - else: - if len(self.pdu) > 1: - result += f': {self.pdu.hex()}' - return result - - -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('target_latency', 1), - ('target_phy', 1), - ('codec_id', hci.CodingFormat.parse_from_bytes), - ('codec_specific_configuration', 'v'), - ], - ] -) -class ASE_Config_Codec(ASE_Operation): - ''' - See Audio Stream Control Service 5.1 - Config Codec Operation - ''' - - target_latency: List[int] - target_phy: List[int] - codec_id: List[hci.CodingFormat] - codec_specific_configuration: List[bytes] - - -@ASE_Operation.subclass( - [ - [ - ('ase_id', 1), - ('cig_id', 1), - ('cis_id', 1), - ('sdu_interval', 3), - ('framing', 1), - ('phy', 1), - ('max_sdu', 2), - ('retransmission_number', 1), - ('max_transport_latency', 2), - ('presentation_delay', 3), - ], - ] -) -class ASE_Config_QOS(ASE_Operation): - ''' - See Audio Stream Control Service 5.2 - Config Qos Operation - ''' - - cig_id: List[int] - cis_id: List[int] - sdu_interval: List[int] - framing: List[int] - phy: List[int] - max_sdu: List[int] - retransmission_number: List[int] - max_transport_latency: List[int] - presentation_delay: List[int] - - -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) -class ASE_Enable(ASE_Operation): - ''' - See Audio Stream Control Service 5.3 - Enable Operation - ''' - - metadata: bytes - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Receiver_Start_Ready(ASE_Operation): - ''' - See Audio Stream Control Service 5.4 - Receiver Start Ready Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Disable(ASE_Operation): - ''' - See Audio Stream Control Service 5.5 - Disable Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Receiver_Stop_Ready(ASE_Operation): - ''' - See Audio Stream Control Service 5.6 - Receiver Stop Ready Operation - ''' - - -@ASE_Operation.subclass([[('ase_id', 1), ('metadata', 'v')]]) -class ASE_Update_Metadata(ASE_Operation): - ''' - See Audio Stream Control Service 5.7 - Update Metadata Operation - ''' - - metadata: List[bytes] - - -@ASE_Operation.subclass([[('ase_id', 1)]]) -class ASE_Release(ASE_Operation): - ''' - See Audio Stream Control Service 5.8 - Release Operation - ''' - - -class AseResponseCode(enum.IntEnum): - # fmt: off - SUCCESS = 0x00 - UNSUPPORTED_OPCODE = 0x01 - INVALID_LENGTH = 0x02 - INVALID_ASE_ID = 0x03 - INVALID_ASE_STATE_MACHINE_TRANSITION = 0x04 - INVALID_ASE_DIRECTION = 0x05 - UNSUPPORTED_AUDIO_CAPABILITIES = 0x06 - UNSUPPORTED_CONFIGURATION_PARAMETER_VALUE = 0x07 - REJECTED_CONFIGURATION_PARAMETER_VALUE = 0x08 - INVALID_CONFIGURATION_PARAMETER_VALUE = 0x09 - UNSUPPORTED_METADATA = 0x0A - REJECTED_METADATA = 0x0B - INVALID_METADATA = 0x0C - INSUFFICIENT_RESOURCES = 0x0D - UNSPECIFIED_ERROR = 0x0E - - -class AseReasonCode(enum.IntEnum): - # fmt: off - NONE = 0x00 - CODEC_ID = 0x01 - CODEC_SPECIFIC_CONFIGURATION = 0x02 - SDU_INTERVAL = 0x03 - FRAMING = 0x04 - PHY = 0x05 - MAXIMUM_SDU_SIZE = 0x06 - RETRANSMISSION_NUMBER = 0x07 - MAX_TRANSPORT_LATENCY = 0x08 - PRESENTATION_DELAY = 0x09 - INVALID_ASE_CIS_MAPPING = 0x0A - - -class AudioRole(enum.IntEnum): - SINK = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.CONTROLLER_TO_HOST - SOURCE = hci.HCI_LE_Setup_ISO_Data_Path_Command.Direction.HOST_TO_CONTROLLER - - @dataclasses.dataclass class UnicastServerAdvertisingData: """Advertising Data for ASCS.""" @@ -683,54 +455,6 @@ def __bytes__(self) -> bytes: ) -@dataclasses.dataclass -class PacRecord: - '''Published Audio Capabilities Service, Table 3.2/3.4.''' - - coding_format: hci.CodingFormat - codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] - metadata: le_audio.Metadata = dataclasses.field(default_factory=le_audio.Metadata) - - @classmethod - def from_bytes(cls, data: bytes) -> PacRecord: - offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0) - codec_specific_capabilities_size = data[offset] - - offset += 1 - codec_specific_capabilities_bytes = data[ - offset : offset + codec_specific_capabilities_size - ] - offset += codec_specific_capabilities_size - metadata_size = data[offset] - offset += 1 - metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size]) - - codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] - if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: - codec_specific_capabilities = codec_specific_capabilities_bytes - else: - codec_specific_capabilities = CodecSpecificCapabilities.from_bytes( - codec_specific_capabilities_bytes - ) - - return PacRecord( - coding_format=coding_format, - codec_specific_capabilities=codec_specific_capabilities, - metadata=metadata, - ) - - def __bytes__(self) -> bytes: - capabilities_bytes = bytes(self.codec_specific_capabilities) - metadata_bytes = bytes(self.metadata) - return ( - bytes(self.coding_format) - + bytes([len(capabilities_bytes)]) - + capabilities_bytes - + bytes([len(metadata_bytes)]) - + metadata_bytes - ) - - @dataclasses.dataclass class BroadcastAudioAnnouncement: broadcast_id: int @@ -822,603 +546,3 @@ def from_bytes(cls, data: bytes) -> Self: ) return cls(presentation_delay, subgroups) - - -# ----------------------------------------------------------------------------- -# Server -# ----------------------------------------------------------------------------- -class PublishedAudioCapabilitiesService(gatt.TemplateService): - UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE - - sink_pac: Optional[gatt.Characteristic] - sink_audio_locations: Optional[gatt.Characteristic] - source_pac: Optional[gatt.Characteristic] - source_audio_locations: Optional[gatt.Characteristic] - available_audio_contexts: gatt.Characteristic - supported_audio_contexts: gatt.Characteristic - - def __init__( - self, - supported_source_context: ContextType, - supported_sink_context: ContextType, - available_source_context: ContextType, - available_sink_context: ContextType, - sink_pac: Sequence[PacRecord] = (), - sink_audio_locations: Optional[AudioLocation] = None, - source_pac: Sequence[PacRecord] = (), - source_audio_locations: Optional[AudioLocation] = None, - ) -> None: - characteristics = [] - - self.supported_audio_contexts = gatt.Characteristic( - uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC, - properties=gatt.Characteristic.Properties.READ, - permissions=gatt.Characteristic.Permissions.READABLE, - value=struct.pack(' None: - self.service = service - self.ase_id = ase_id - self._state = AseStateMachine.State.IDLE - self.role = role - - uuid = ( - gatt.GATT_SINK_ASE_CHARACTERISTIC - if role == AudioRole.SINK - else gatt.GATT_SOURCE_ASE_CHARACTERISTIC - ) - super().__init__( - uuid=uuid, - properties=gatt.Characteristic.Properties.READ - | gatt.Characteristic.Properties.NOTIFY, - permissions=gatt.Characteristic.Permissions.READABLE, - value=gatt.CharacteristicValue(read=self.on_read), - ) - - self.service.device.on('cis_request', self.on_cis_request) - self.service.device.on('cis_establishment', self.on_cis_establishment) - - def on_cis_request( - self, - acl_connection: device.Connection, - cis_handle: int, - cig_id: int, - cis_id: int, - ) -> None: - if ( - cig_id == self.cig_id - and cis_id == self.cis_id - and self.state == self.State.ENABLING - ): - acl_connection.abort_on( - 'flush', self.service.device.accept_cis_request(cis_handle) - ) - - def on_cis_establishment(self, cis_link: device.CisLink) -> None: - if ( - cis_link.cig_id == self.cig_id - and cis_link.cis_id == self.cis_id - and self.state == self.State.ENABLING - ): - cis_link.on('disconnection', self.on_cis_disconnection) - - async def post_cis_established(): - await self.service.device.send_command( - hci.HCI_LE_Setup_ISO_Data_Path_Command( - connection_handle=cis_link.handle, - data_path_direction=self.role, - data_path_id=0x00, # Fixed HCI - codec_id=hci.CodingFormat(hci.CodecID.TRANSPARENT), - controller_delay=0, - codec_configuration=b'', - ) - ) - if self.role == AudioRole.SINK: - self.state = self.State.STREAMING - await self.service.device.notify_subscribers(self, self.value) - - cis_link.acl_connection.abort_on('flush', post_cis_established()) - self.cis_link = cis_link - - def on_cis_disconnection(self, _reason) -> None: - self.cis_link = None - - def on_config_codec( - self, - target_latency: int, - target_phy: int, - codec_id: hci.CodingFormat, - codec_specific_configuration: bytes, - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - self.State.IDLE, - self.State.CODEC_CONFIGURED, - self.State.QOS_CONFIGURED, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.max_transport_latency = target_latency - self.phy = target_phy - self.codec_id = codec_id - if codec_id.codec_id == hci.CodecID.VENDOR_SPECIFIC: - self.codec_specific_configuration = codec_specific_configuration - else: - self.codec_specific_configuration = CodecSpecificConfiguration.from_bytes( - codec_specific_configuration - ) - - self.state = self.State.CODEC_CONFIGURED - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_config_qos( - self, - cig_id: int, - cis_id: int, - sdu_interval: int, - framing: int, - phy: int, - max_sdu: int, - retransmission_number: int, - max_transport_latency: int, - presentation_delay: int, - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.CODEC_CONFIGURED, - AseStateMachine.State.QOS_CONFIGURED, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.cig_id = cig_id - self.cis_id = cis_id - self.sdu_interval = sdu_interval - self.framing = framing - self.phy = phy - self.max_sdu = max_sdu - self.retransmission_number = retransmission_number - self.max_transport_latency = max_transport_latency - self.presentation_delay = presentation_delay - - self.state = self.State.QOS_CONFIGURED - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_enable(self, metadata: bytes) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.QOS_CONFIGURED: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - - self.metadata = le_audio.Metadata.from_bytes(metadata) - self.state = self.State.ENABLING - - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_receiver_start_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.ENABLING: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.STREAMING - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.ENABLING, - AseStateMachine.State.STREAMING, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - if self.role == AudioRole.SINK: - self.state = self.State.QOS_CONFIGURED - else: - self.state = self.State.DISABLING - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if ( - self.role != AudioRole.SOURCE - or self.state != AseStateMachine.State.DISABLING - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.QOS_CONFIGURED - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_update_metadata( - self, metadata: bytes - ) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state not in ( - AseStateMachine.State.ENABLING, - AseStateMachine.State.STREAMING, - ): - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.metadata = le_audio.Metadata.from_bytes(metadata) - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - def on_release(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state == AseStateMachine.State.IDLE: - return ( - AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, - AseReasonCode.NONE, - ) - self.state = self.State.RELEASING - - async def remove_cis_async(): - await self.service.device.send_command( - hci.HCI_LE_Remove_ISO_Data_Path_Command( - connection_handle=self.cis_link.handle, - data_path_direction=self.role, - ) - ) - self.state = self.State.IDLE - await self.service.device.notify_subscribers(self, self.value) - - self.service.device.abort_on('flush', remove_cis_async()) - return (AseResponseCode.SUCCESS, AseReasonCode.NONE) - - @property - def state(self) -> State: - return self._state - - @state.setter - def state(self, new_state: State) -> None: - logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') - self._state = new_state - self.emit('state_change') - - @property - def value(self): - '''Returns ASE_ID, ASE_STATE, and ASE Additional Parameters.''' - - if self.state == self.State.CODEC_CONFIGURED: - codec_specific_configuration_bytes = bytes( - self.codec_specific_configuration - ) - additional_parameters = ( - struct.pack( - ' bytes: - return self.value - - def __str__(self) -> str: - return ( - f'AseStateMachine(id={self.ase_id}, role={self.role.name} ' - f'state={self._state.name})' - ) - - -class AudioStreamControlService(gatt.TemplateService): - UUID = gatt.GATT_AUDIO_STREAM_CONTROL_SERVICE - - ase_state_machines: Dict[int, AseStateMachine] - ase_control_point: gatt.Characteristic - _active_client: Optional[device.Connection] = None - - def __init__( - self, - device: device.Device, - source_ase_id: Sequence[int] = [], - sink_ase_id: Sequence[int] = [], - ) -> None: - self.device = device - self.ase_state_machines = { - **{ - id: AseStateMachine(role=AudioRole.SINK, ase_id=id, service=self) - for id in sink_ase_id - }, - **{ - id: AseStateMachine(role=AudioRole.SOURCE, ase_id=id, service=self) - for id in source_ase_id - }, - } # ASE state machines, by ASE ID - - self.ase_control_point = gatt.Characteristic( - uuid=gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC, - properties=gatt.Characteristic.Properties.WRITE - | gatt.Characteristic.Properties.WRITE_WITHOUT_RESPONSE - | gatt.Characteristic.Properties.NOTIFY, - permissions=gatt.Characteristic.Permissions.WRITEABLE, - value=gatt.CharacteristicValue(write=self.on_write_ase_control_point), - ) - - super().__init__([self.ase_control_point, *self.ase_state_machines.values()]) - - def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): - if ase := self.ase_state_machines.get(ase_id): - handler = getattr(ase, 'on_' + opcode.name.lower()) - return (ase_id, *handler(*args)) - else: - return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) - - def _on_client_disconnected(self, _reason: int) -> None: - for ase in self.ase_state_machines.values(): - ase.state = AseStateMachine.State.IDLE - self._active_client = None - - def on_write_ase_control_point(self, connection, data): - if not self._active_client and connection: - self._active_client = connection - connection.once('disconnection', self._on_client_disconnected) - - operation = ASE_Operation.from_bytes(data) - responses = [] - logger.debug(f'*** ASCS Write {operation} ***') - - if operation.op_code == ASE_Operation.Opcode.CONFIG_CODEC: - for ase_id, *args in zip( - operation.ase_id, - operation.target_latency, - operation.target_phy, - operation.codec_id, - operation.codec_specific_configuration, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code == ASE_Operation.Opcode.CONFIG_QOS: - for ase_id, *args in zip( - operation.ase_id, - operation.cig_id, - operation.cis_id, - operation.sdu_interval, - operation.framing, - operation.phy, - operation.max_sdu, - operation.retransmission_number, - operation.max_transport_latency, - operation.presentation_delay, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.ENABLE, - ASE_Operation.Opcode.UPDATE_METADATA, - ): - for ase_id, *args in zip( - operation.ase_id, - operation.metadata, - ): - responses.append(self.on_operation(operation.op_code, ase_id, args)) - elif operation.op_code in ( - ASE_Operation.Opcode.RECEIVER_START_READY, - ASE_Operation.Opcode.DISABLE, - ASE_Operation.Opcode.RECEIVER_STOP_READY, - ASE_Operation.Opcode.RELEASE, - ): - for ase_id in operation.ase_id: - responses.append(self.on_operation(operation.op_code, ase_id, [])) - - control_point_notification = bytes( - [operation.op_code, len(responses)] - ) + b''.join(map(bytes, responses)) - self.device.abort_on( - 'flush', - self.device.notify_subscribers( - self.ase_control_point, control_point_notification - ), - ) - - for ase_id, *_ in responses: - if ase := self.ase_state_machines.get(ase_id): - self.device.abort_on( - 'flush', - self.device.notify_subscribers(ase, ase.value), - ) - - -# ----------------------------------------------------------------------------- -# Client -# ----------------------------------------------------------------------------- -class PublishedAudioCapabilitiesServiceProxy(gatt_client.ProfileServiceProxy): - SERVICE_CLASS = PublishedAudioCapabilitiesService - - sink_pac: Optional[gatt_client.CharacteristicProxy] = None - sink_audio_locations: Optional[gatt_client.CharacteristicProxy] = None - source_pac: Optional[gatt_client.CharacteristicProxy] = None - source_audio_locations: Optional[gatt_client.CharacteristicProxy] = None - available_audio_contexts: gatt_client.CharacteristicProxy - supported_audio_contexts: gatt_client.CharacteristicProxy - - def __init__(self, service_proxy: gatt_client.ServiceProxy): - self.service_proxy = service_proxy - - self.available_audio_contexts = service_proxy.get_characteristics_by_uuid( - gatt.GATT_AVAILABLE_AUDIO_CONTEXTS_CHARACTERISTIC - )[0] - self.supported_audio_contexts = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC - )[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_PAC_CHARACTERISTIC - ): - self.sink_pac = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_PAC_CHARACTERISTIC - ): - self.source_pac = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_AUDIO_LOCATION_CHARACTERISTIC - ): - self.sink_audio_locations = characteristics[0] - - if characteristics := service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_AUDIO_LOCATION_CHARACTERISTIC - ): - self.source_audio_locations = characteristics[0] - - -class AudioStreamControlServiceProxy(gatt_client.ProfileServiceProxy): - SERVICE_CLASS = AudioStreamControlService - - sink_ase: List[gatt_client.CharacteristicProxy] - source_ase: List[gatt_client.CharacteristicProxy] - ase_control_point: gatt_client.CharacteristicProxy - - def __init__(self, service_proxy: gatt_client.ServiceProxy): - self.service_proxy = service_proxy - - self.sink_ase = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SINK_ASE_CHARACTERISTIC - ) - self.source_ase = service_proxy.get_characteristics_by_uuid( - gatt.GATT_SOURCE_ASE_CHARACTERISTIC - ) - self.ase_control_point = service_proxy.get_characteristics_by_uuid( - gatt.GATT_ASE_CONTROL_POINT_CHARACTERISTIC - )[0] diff --git a/bumble/profiles/bass.py b/bumble/profiles/bass.py new file mode 100644 index 00000000..57531dbd --- /dev/null +++ b/bumble/profiles/bass.py @@ -0,0 +1,440 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for + +"""LE Audio - Broadcast Audio Scan Service""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import dataclasses +import logging +import struct +from typing import ClassVar, List, Optional, Sequence + +from bumble import core +from bumble import device +from bumble import gatt +from bumble import gatt_client +from bumble import hci +from bumble import utils + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +class ApplicationError(utils.OpenIntEnum): + OPCODE_NOT_SUPPORTED = 0x80 + INVALID_SOURCE_ID = 0x81 + + +# ----------------------------------------------------------------------------- +def encode_subgroups(subgroups: Sequence[SubgroupInfo]) -> bytes: + return bytes([len(subgroups)]) + b"".join( + struct.pack(" List[SubgroupInfo]: + num_subgroups = data[0] + offset = 1 + subgroups = [] + for _ in range(num_subgroups): + bis_sync = struct.unpack(" ControlPointOperation: + op_code = data[0] + + if op_code == cls.OpCode.REMOTE_SCAN_STOPPED: + return RemoteScanStoppedOperation() + + if op_code == cls.OpCode.REMOTE_SCAN_STARTED: + return RemoteScanStartedOperation() + + if op_code == cls.OpCode.ADD_SOURCE: + return AddSourceOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.MODIFY_SOURCE: + return ModifySourceOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.SET_BROADCAST_CODE: + return SetBroadcastCodeOperation.from_parameters(data[1:]) + + if op_code == cls.OpCode.REMOVE_SOURCE: + return RemoveSourceOperation.from_parameters(data[1:]) + + raise core.InvalidArgumentError("invalid op code") + + def __init__(self, op_code: OpCode, parameters: bytes = b"") -> None: + self.op_code = op_code + self.parameters = parameters + + def __bytes__(self) -> bytes: + return bytes([self.op_code]) + self.parameters + + +class RemoteScanStoppedOperation(ControlPointOperation): + def __init__(self) -> None: + super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STOPPED) + + +class RemoteScanStartedOperation(ControlPointOperation): + def __init__(self) -> None: + super().__init__(ControlPointOperation.OpCode.REMOTE_SCAN_STARTED) + + +class AddSourceOperation(ControlPointOperation): + @classmethod + def from_parameters(cls, parameters: bytes) -> AddSourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.ADD_SOURCE + instance.parameters = parameters + instance.advertiser_address = hci.Address.parse_address_preceded_by_type( + parameters, 1 + )[1] + instance.advertising_sid = parameters[7] + instance.broadcast_id = int.from_bytes(parameters[8:11], "little") + instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[11]) + instance.pa_interval = struct.unpack(" None: + super().__init__( + ControlPointOperation.OpCode.ADD_SOURCE, + struct.pack( + " ModifySourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.MODIFY_SOURCE + instance.parameters = parameters + instance.source_id = parameters[0] + instance.pa_sync = PeriodicAdvertisingSyncParams(parameters[1]) + instance.pa_interval = struct.unpack(" None: + super().__init__( + ControlPointOperation.OpCode.MODIFY_SOURCE, + struct.pack(" SetBroadcastCodeOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.SET_BROADCAST_CODE + instance.parameters = parameters + instance.source_id = parameters[0] + instance.broadcast_code = parameters[1:17] + return instance + + def __init__( + self, + source_id: int, + broadcast_code: bytes, + ) -> None: + super().__init__( + ControlPointOperation.OpCode.SET_BROADCAST_CODE, + bytes([source_id]) + broadcast_code, + ) + self.source_id = source_id + self.broadcast_code = broadcast_code + + if len(self.broadcast_code) != 16: + raise core.InvalidArgumentError("broadcast_code must be 16 bytes") + + +class RemoveSourceOperation(ControlPointOperation): + @classmethod + def from_parameters(cls, parameters: bytes) -> RemoveSourceOperation: + instance = cls.__new__(cls) + instance.op_code = ControlPointOperation.OpCode.REMOVE_SOURCE + instance.parameters = parameters + instance.source_id = parameters[0] + return instance + + def __init__(self, source_id: int) -> None: + super().__init__(ControlPointOperation.OpCode.REMOVE_SOURCE, bytes([source_id])) + self.source_id = source_id + + +@dataclasses.dataclass +class BroadcastReceiveState: + class PeriodicAdvertisingSyncState(utils.OpenIntEnum): + NOT_SYNCHRONIZED_TO_PA = 0x00 + SYNCINFO_REQUEST = 0x01 + SYNCHRONIZED_TO_PA = 0x02 + FAILED_TO_SYNCHRONIZE_TO_PA = 0x03 + NO_PAST = 0x04 + + class BigEncryption(utils.OpenIntEnum): + NOT_ENCRYPTED = 0x00 + BROADCAST_CODE_REQUIRED = 0x01 + DECRYPTING = 0x02 + BAD_CODE = 0x03 + + source_id: int + source_address: hci.Address + source_adv_sid: int + broadcast_id: int + pa_sync_state: PeriodicAdvertisingSyncState + big_encryption: BigEncryption + bad_code: bytes + subgroups: List[SubgroupInfo] + + @classmethod + def from_bytes(cls, data: bytes) -> Optional[BroadcastReceiveState]: + if not data: + return None + + source_id = data[0] + _, source_address = hci.Address.parse_address_preceded_by_type(data, 2) + source_adv_sid = data[8] + broadcast_id = int.from_bytes(data[9:12], "little") + pa_sync_state = cls.PeriodicAdvertisingSyncState(data[12]) + big_encryption = cls.BigEncryption(data[13]) + if big_encryption == cls.BigEncryption.BAD_CODE: + bad_code = data[14:30] + subgroups = decode_subgroups(data[30:]) + else: + bad_code = b"" + subgroups = decode_subgroups(data[14:]) + + return cls( + source_id, + source_address, + source_adv_sid, + broadcast_id, + pa_sync_state, + big_encryption, + bad_code, + subgroups, + ) + + def __bytes__(self) -> bytes: + return ( + struct.pack( + " None: + pass + + +# ----------------------------------------------------------------------------- +class BroadcastAudioScanServiceProxy(gatt_client.ProfileServiceProxy): + SERVICE_CLASS = BroadcastAudioScanService + + broadcast_audio_scan_control_point: gatt_client.CharacteristicProxy + broadcast_receive_states: List[gatt.DelegatedCharacteristicAdapter] + + def __init__(self, service_proxy: gatt_client.ServiceProxy): + self.service_proxy = service_proxy + + if not ( + characteristics := service_proxy.get_characteristics_by_uuid( + gatt.GATT_BROADCAST_AUDIO_SCAN_CONTROL_POINT_CHARACTERISTIC + ) + ): + raise gatt.InvalidServiceError( + "Broadcast Audio Scan Control Point characteristic not found" + ) + self.broadcast_audio_scan_control_point = characteristics[0] + + if not ( + characteristics := service_proxy.get_characteristics_by_uuid( + gatt.GATT_BROADCAST_RECEIVE_STATE_CHARACTERISTIC + ) + ): + raise gatt.InvalidServiceError( + "Broadcast Receive State characteristic not found" + ) + self.broadcast_receive_states = [ + gatt.DelegatedCharacteristicAdapter( + characteristic, decode=BroadcastReceiveState.from_bytes + ) + for characteristic in characteristics + ] + + async def send_control_point_operation( + self, operation: ControlPointOperation + ) -> None: + await self.broadcast_audio_scan_control_point.write_value( + bytes(operation), with_response=True + ) + + async def remote_scan_started(self) -> None: + await self.send_control_point_operation(RemoteScanStartedOperation()) + + async def remote_scan_stopped(self) -> None: + await self.send_control_point_operation(RemoteScanStoppedOperation()) + + async def add_source( + self, + advertiser_address: hci.Address, + advertising_sid: int, + broadcast_id: int, + pa_sync: PeriodicAdvertisingSyncParams, + pa_interval: int, + subgroups: Sequence[SubgroupInfo], + ) -> None: + await self.send_control_point_operation( + AddSourceOperation( + advertiser_address, + advertising_sid, + broadcast_id, + pa_sync, + pa_interval, + subgroups, + ) + ) + + async def modify_source( + self, + source_id: int, + pa_sync: PeriodicAdvertisingSyncParams, + pa_interval: int, + subgroups: Sequence[SubgroupInfo], + ) -> None: + await self.send_control_point_operation( + ModifySourceOperation( + source_id, + pa_sync, + pa_interval, + subgroups, + ) + ) + + async def remove_source(self, source_id: int) -> None: + await self.send_control_point_operation(RemoveSourceOperation(source_id)) diff --git a/bumble/profiles/gap.py b/bumble/profiles/gap.py new file mode 100644 index 00000000..0dd6e512 --- /dev/null +++ b/bumble/profiles/gap.py @@ -0,0 +1,110 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic Access Profile""" + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import logging +import struct +from typing import Optional, Tuple, Union + +from bumble.core import Appearance +from bumble.gatt import ( + TemplateService, + Characteristic, + CharacteristicAdapter, + DelegatedCharacteristicAdapter, + UTF8CharacteristicAdapter, + GATT_GENERIC_ACCESS_SERVICE, + GATT_DEVICE_NAME_CHARACTERISTIC, + GATT_APPEARANCE_CHARACTERISTIC, +) +from bumble.gatt_client import ProfileServiceProxy, ServiceProxy + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Classes +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +class GenericAccessService(TemplateService): + UUID = GATT_GENERIC_ACCESS_SERVICE + + def __init__( + self, device_name: str, appearance: Union[Appearance, Tuple[int, int], int] = 0 + ): + if isinstance(appearance, int): + appearance_int = appearance + elif isinstance(appearance, tuple): + appearance_int = (appearance[0] << 6) | appearance[1] + elif isinstance(appearance, Appearance): + appearance_int = int(appearance) + else: + raise TypeError() + + self.device_name_characteristic = Characteristic( + GATT_DEVICE_NAME_CHARACTERISTIC, + Characteristic.Properties.READ, + Characteristic.READABLE, + device_name.encode('utf-8')[:248], + ) + + self.appearance_characteristic = Characteristic( + GATT_APPEARANCE_CHARACTERISTIC, + Characteristic.Properties.READ, + Characteristic.READABLE, + struct.pack(' PacRecord: + offset, coding_format = hci.CodingFormat.parse_from_bytes(data, 0) + codec_specific_capabilities_size = data[offset] + + offset += 1 + codec_specific_capabilities_bytes = data[ + offset : offset + codec_specific_capabilities_size + ] + offset += codec_specific_capabilities_size + metadata_size = data[offset] + offset += 1 + metadata = le_audio.Metadata.from_bytes(data[offset : offset + metadata_size]) + + codec_specific_capabilities: Union[CodecSpecificCapabilities, bytes] + if coding_format.codec_id == hci.CodecID.VENDOR_SPECIFIC: + codec_specific_capabilities = codec_specific_capabilities_bytes + else: + codec_specific_capabilities = CodecSpecificCapabilities.from_bytes( + codec_specific_capabilities_bytes + ) + + return PacRecord( + coding_format=coding_format, + codec_specific_capabilities=codec_specific_capabilities, + metadata=metadata, + ) + + def __bytes__(self) -> bytes: + capabilities_bytes = bytes(self.codec_specific_capabilities) + metadata_bytes = bytes(self.metadata) + return ( + bytes(self.coding_format) + + bytes([len(capabilities_bytes)]) + + capabilities_bytes + + bytes([len(metadata_bytes)]) + + metadata_bytes + ) + + +# ----------------------------------------------------------------------------- +# Server +# ----------------------------------------------------------------------------- +class PublishedAudioCapabilitiesService(gatt.TemplateService): + UUID = gatt.GATT_PUBLISHED_AUDIO_CAPABILITIES_SERVICE + + sink_pac: Optional[gatt.Characteristic] + sink_audio_locations: Optional[gatt.Characteristic] + source_pac: Optional[gatt.Characteristic] + source_audio_locations: Optional[gatt.Characteristic] + available_audio_contexts: gatt.Characteristic + supported_audio_contexts: gatt.Characteristic + + def __init__( + self, + supported_source_context: ContextType, + supported_sink_context: ContextType, + available_source_context: ContextType, + available_sink_context: ContextType, + sink_pac: Sequence[PacRecord] = (), + sink_audio_locations: Optional[AudioLocation] = None, + source_pac: Sequence[PacRecord] = (), + source_audio_locations: Optional[AudioLocation] = None, + ) -> None: + characteristics = [] + + self.supported_audio_contexts = gatt.Characteristic( + uuid=gatt.GATT_SUPPORTED_AUDIO_CONTEXTS_CHARACTERISTIC, + properties=gatt.Characteristic.Properties.READ, + permissions=gatt.Characteristic.Permissions.READABLE, + value=struct.pack(' None: self.sink = sink def on_transport_lost(self) -> None: - self.terminated.set_result(None) + if not self.terminated.done(): + self.terminated.set_result(None) + if self.sink: if hasattr(self.sink, 'on_transport_lost'): self.sink.on_transport_lost() diff --git a/examples/run_mcp_client.py b/examples/run_mcp_client.py index f0c8162b..83dad5b5 100644 --- a/examples/run_mcp_client.py +++ b/examples/run_mcp_client.py @@ -35,15 +35,13 @@ CodingFormat, OwnAddressType, ) +from bumble.profiles.ascs import AudioStreamControlService from bumble.profiles.bap import ( CodecSpecificCapabilities, ContextType, AudioLocation, SupportedSamplingFrequency, SupportedFrameDuration, - PacRecord, - PublishedAudioCapabilitiesService, - AudioStreamControlService, UnicastServerAdvertisingData, ) from bumble.profiles.mcp import ( @@ -52,7 +50,7 @@ MediaState, MediaControlPointOpcode, ) - +from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService from bumble.transport import open_transport_or_link from typing import Optional diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 95ae5510..3ff1c965 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -34,8 +34,8 @@ CodingFormat, HCI_IsoDataPacket, ) +from bumble.profiles.ascs import AseStateMachine, AudioStreamControlService from bumble.profiles.bap import ( - AseStateMachine, UnicastServerAdvertisingData, CodecSpecificConfiguration, CodecSpecificCapabilities, @@ -43,13 +43,10 @@ AudioLocation, SupportedSamplingFrequency, SupportedFrameDuration, - PacRecord, - PublishedAudioCapabilitiesService, - AudioStreamControlService, ) from bumble.profiles.cap import CommonAudioServiceService from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType - +from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService from bumble.transport import open_transport_or_link diff --git a/examples/run_vcp_renderer.py b/examples/run_vcp_renderer.py index 0cffbae5..ba9c8404 100644 --- a/examples/run_vcp_renderer.py +++ b/examples/run_vcp_renderer.py @@ -30,6 +30,7 @@ CodingFormat, OwnAddressType, ) +from bumble.profiles.ascs import AudioStreamControlService from bumble.profiles.bap import ( UnicastServerAdvertisingData, CodecSpecificCapabilities, @@ -37,10 +38,8 @@ AudioLocation, SupportedSamplingFrequency, SupportedFrameDuration, - PacRecord, - PublishedAudioCapabilitiesService, - AudioStreamControlService, ) +from bumble.profiles.pacs import PacRecord, PublishedAudioCapabilitiesService from bumble.profiles.cap import CommonAudioServiceService from bumble.profiles.csip import CoordinatedSetIdentificationService, SirkType from bumble.profiles.vcp import VolumeControlService diff --git a/tests/bap_test.py b/tests/bap_test.py index e276790c..0b57fcd2 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -23,8 +23,9 @@ from bumble import device from bumble.hci import CodecID, CodingFormat -from bumble.profiles.bap import ( - AudioLocation, +from bumble.profiles.ascs import ( + AudioStreamControlService, + AudioStreamControlServiceProxy, AseStateMachine, ASE_Operation, ASE_Config_Codec, @@ -35,6 +36,9 @@ ASE_Receiver_Stop_Ready, ASE_Release, ASE_Update_Metadata, +) +from bumble.profiles.bap import ( + AudioLocation, SupportedFrameDuration, SupportedSamplingFrequency, SamplingFrequency, @@ -42,9 +46,9 @@ CodecSpecificCapabilities, CodecSpecificConfiguration, ContextType, +) +from bumble.profiles.pacs import ( PacRecord, - AudioStreamControlService, - AudioStreamControlServiceProxy, PublishedAudioCapabilitiesService, PublishedAudioCapabilitiesServiceProxy, ) diff --git a/tests/bass_test.py b/tests/bass_test.py new file mode 100644 index 00000000..b893555f --- /dev/null +++ b/tests/bass_test.py @@ -0,0 +1,146 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import os +import logging + +from bumble import hci +from bumble.profiles import bass + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def basic_operation_check(operation: bass.ControlPointOperation) -> None: + serialized = bytes(operation) + parsed = bass.ControlPointOperation.from_bytes(serialized) + assert bytes(parsed) == serialized + + +# ----------------------------------------------------------------------------- +def test_operations() -> None: + op1 = bass.RemoteScanStoppedOperation() + basic_operation_check(op1) + + op2 = bass.RemoteScanStartedOperation() + basic_operation_check(op2) + + op3 = bass.AddSourceOperation( + hci.Address("AA:BB:CC:DD:EE:FF"), + 34, + 123456, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 456, + (), + ) + basic_operation_check(op3) + + op4 = bass.AddSourceOperation( + hci.Address("AA:BB:CC:DD:EE:FF"), + 34, + 123456, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 456, + ( + bass.SubgroupInfo(6677, bytes.fromhex('aabbcc')), + bass.SubgroupInfo(8899, bytes.fromhex('ddeeff')), + ), + ) + basic_operation_check(op4) + + op5 = bass.ModifySourceOperation( + 12, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 567, + (), + ) + basic_operation_check(op5) + + op6 = bass.ModifySourceOperation( + 12, + bass.PeriodicAdvertisingSyncParams.SYNCHRONIZE_TO_PA_PAST_NOT_AVAILABLE, + 567, + ( + bass.SubgroupInfo(6677, bytes.fromhex('112233')), + bass.SubgroupInfo(8899, bytes.fromhex('4567')), + ), + ) + basic_operation_check(op6) + + op7 = bass.SetBroadcastCodeOperation( + 7, bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf') + ) + basic_operation_check(op7) + + op8 = bass.RemoveSourceOperation(7) + basic_operation_check(op8) + + +# ----------------------------------------------------------------------------- +def basic_broadcast_receive_state_check(brs: bass.BroadcastReceiveState) -> None: + serialized = bytes(brs) + parsed = bass.BroadcastReceiveState.from_bytes(serialized) + assert parsed is not None + assert bytes(parsed) == serialized + + +def test_broadcast_receive_state() -> None: + subgroups = [ + bass.SubgroupInfo(6677, bytes.fromhex('112233')), + bass.SubgroupInfo(8899, bytes.fromhex('4567')), + ] + + brs1 = bass.BroadcastReceiveState( + 12, + hci.Address("AA:BB:CC:DD:EE:FF"), + 123, + 123456, + bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA, + bass.BroadcastReceiveState.BigEncryption.DECRYPTING, + b'', + subgroups, + ) + basic_broadcast_receive_state_check(brs1) + + brs2 = bass.BroadcastReceiveState( + 12, + hci.Address("AA:BB:CC:DD:EE:FF"), + 123, + 123456, + bass.BroadcastReceiveState.PeriodicAdvertisingSyncState.SYNCHRONIZED_TO_PA, + bass.BroadcastReceiveState.BigEncryption.BAD_CODE, + bytes.fromhex('a0a1a2a3a4a5a6a7a8a9aaabacadaeaf'), + subgroups, + ) + basic_broadcast_receive_state_check(brs2) + + +# ----------------------------------------------------------------------------- +async def run(): + test_operations() + test_broadcast_receive_state() + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) + asyncio.run(run()) diff --git a/tests/import_test.py b/tests/import_test.py index e0b6e3ca..95425112 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -27,7 +27,6 @@ def test_import(): core, crypto, device, - gap, hci, hfp, host, @@ -41,6 +40,22 @@ def test_import(): utils, ) + from bumble.profiles import ( + ascs, + bap, + bass, + battery_service, + cap, + csip, + device_information_service, + gap, + heart_rate_service, + le_audio, + pacs, + pbp, + vcp, + ) + assert att assert bridge assert company_ids @@ -48,7 +63,6 @@ def test_import(): assert core assert crypto assert device - assert gap assert hci assert hfp assert host @@ -61,6 +75,20 @@ def test_import(): assert transport assert utils + assert ascs + assert bap + assert bass + assert battery_service + assert cap + assert csip + assert device_information_service + assert gap + assert heart_rate_service + assert le_audio + assert pacs + assert pbp + assert vcp + # ----------------------------------------------------------------------------- def test_app_imports():