diff --git a/.vscode/settings.json b/.vscode/settings.json index b535ada8..777c47b4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,7 @@ { "cSpell.words": [ "Abortable", + "aiohttp", "altsetting", "ansiblue", "ansicyan", @@ -9,6 +10,7 @@ "ansired", "ansiyellow", "appendleft", + "ascs", "ASHA", "asyncio", "ATRAC", @@ -43,6 +45,7 @@ "keyup", "levelname", "libc", + "liblc", "libusb", "MITM", "MSBC", @@ -78,6 +81,7 @@ "unmuted", "usbmodem", "vhci", + "wasmtime", "websockets", "xcursor", "ycursor" diff --git a/apps/lea_unicast/app.py b/apps/lea_unicast/app.py new file mode 100644 index 00000000..ae3b4422 --- /dev/null +++ b/apps/lea_unicast/app.py @@ -0,0 +1,577 @@ +# Copyright 2021-2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +import datetime +import enum +import functools +from importlib import resources +import json +import os +import logging +import pathlib +from typing import Optional, List, cast +import weakref +import struct + +import ctypes +import wasmtime +import wasmtime.loader +import liblc3 # type: ignore +import logging + +import click +import aiohttp.web + +import bumble +from bumble.core import AdvertisingData +from bumble.colors import color +from bumble.device import Device, DeviceConfiguration, AdvertisingParameters +from bumble.transport import open_transport +from bumble.profiles import bap +from bumble.hci import Address, CodecID, CodingFormat, HCI_IsoDataPacket + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +DEFAULT_UI_PORT = 7654 + + +def _sink_pac_record() -> bap.PacRecord: + return bap.PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=bap.CodecSpecificCapabilities( + supported_sampling_frequencies=( + bap.SupportedSamplingFrequency.FREQ_8000 + | bap.SupportedSamplingFrequency.FREQ_16000 + | bap.SupportedSamplingFrequency.FREQ_24000 + | bap.SupportedSamplingFrequency.FREQ_32000 + | bap.SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_count=[1, 2], + min_octets_per_codec_frame=26, + max_octets_per_codec_frame=240, + supported_max_codec_frames_per_sdu=2, + ), + ) + + +def _source_pac_record() -> bap.PacRecord: + return bap.PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=bap.CodecSpecificCapabilities( + supported_sampling_frequencies=( + bap.SupportedSamplingFrequency.FREQ_8000 + | bap.SupportedSamplingFrequency.FREQ_16000 + | bap.SupportedSamplingFrequency.FREQ_24000 + | bap.SupportedSamplingFrequency.FREQ_32000 + | bap.SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + bap.SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_count=[1], + min_octets_per_codec_frame=30, + max_octets_per_codec_frame=100, + supported_max_codec_frames_per_sdu=1, + ), + ) + + +# ----------------------------------------------------------------------------- +# WASM - liblc3 +# ----------------------------------------------------------------------------- +store = wasmtime.loader.store +_memory = cast(wasmtime.Memory, liblc3.memory) +STACK_POINTER = _memory.data_len(store) +_memory.grow(store, 1) +# Mapping wasmtime memory to linear address +memory = (ctypes.c_ubyte * _memory.data_len(store)).from_address( + ctypes.addressof(_memory.data_ptr(store).contents) # type: ignore +) + + +class Liblc3PcmFormat(enum.IntEnum): + S16 = 0 + S24 = 1 + S24_3LE = 2 + FLOAT = 3 + + +MAX_DECODER_SIZE = liblc3.lc3_decoder_size(10000, 48000) +MAX_ENCODER_SIZE = liblc3.lc3_encoder_size(10000, 48000) + +DECODER_STACK_POINTER = STACK_POINTER +ENCODER_STACK_POINTER = DECODER_STACK_POINTER + MAX_DECODER_SIZE * 2 +DECODE_BUFFER_STACK_POINTER = ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * 2 +ENCODE_BUFFER_STACK_POINTER = DECODE_BUFFER_STACK_POINTER + 8192 +DEFAULT_PCM_SAMPLE_RATE = 48000 +DEFAULT_PCM_FORMAT = Liblc3PcmFormat.S16 +DEFAULT_PCM_BYTES_PER_SAMPLE = 2 + + +encoders: List[int] = [] +decoders: List[int] = [] + + +def setup_encoders( + sample_rate_hz: int, frame_duration_us: int, num_channels: int +) -> None: + logger.info( + f"setup_encoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" + ) + encoders[:num_channels] = [ + liblc3.lc3_setup_encoder( + frame_duration_us, + sample_rate_hz, + DEFAULT_PCM_SAMPLE_RATE, # Input sample rate + ENCODER_STACK_POINTER + MAX_ENCODER_SIZE * i, + ) + for i in range(num_channels) + ] + + +def setup_decoders( + sample_rate_hz: int, frame_duration_us: int, num_channels: int +) -> None: + logger.info( + f"setup_decoders {sample_rate_hz}Hz {frame_duration_us}us {num_channels}channels" + ) + decoders[:num_channels] = [ + liblc3.lc3_setup_decoder( + frame_duration_us, + sample_rate_hz, + DEFAULT_PCM_SAMPLE_RATE, # Output sample rate + DECODER_STACK_POINTER + MAX_DECODER_SIZE * i, + ) + for i in range(num_channels) + ] + + +def decode( + frame_duration_us: int, + num_channels: int, + input_bytes: bytes, +) -> bytes: + if not input_bytes: + return b'' + + input_buffer_offset = DECODE_BUFFER_STACK_POINTER + input_buffer_size = len(input_bytes) + input_bytes_per_frame = input_buffer_size // num_channels + + # Copy into wasm + memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore + + output_buffer_offset = input_buffer_offset + input_buffer_size + output_buffer_size = ( + liblc3.lc3_frame_samples(frame_duration_us, DEFAULT_PCM_SAMPLE_RATE) + * DEFAULT_PCM_BYTES_PER_SAMPLE + * num_channels + ) + + for i in range(num_channels): + res = liblc3.lc3_decode( + decoders[i], + input_buffer_offset + input_bytes_per_frame * i, + input_bytes_per_frame, + DEFAULT_PCM_FORMAT, + output_buffer_offset + i * DEFAULT_PCM_BYTES_PER_SAMPLE, + num_channels, # Stride + ) + + if res != 0: + logging.error(f"Parsing failed, res={res}") + + # Extract decoded data from the output buffer + return bytes( + memory[output_buffer_offset : output_buffer_offset + output_buffer_size] + ) + + +def encode( + sdu_length: int, + num_channels: int, + stride: int, + input_bytes: bytes, +) -> bytes: + if not input_bytes: + return b'' + + input_buffer_offset = ENCODE_BUFFER_STACK_POINTER + input_buffer_size = len(input_bytes) + + # Copy into wasm + memory[input_buffer_offset : input_buffer_offset + input_buffer_size] = input_bytes # type: ignore + + output_buffer_offset = input_buffer_offset + input_buffer_size + output_buffer_size = sdu_length + output_frame_size = output_buffer_size // num_channels + + for i in range(num_channels): + res = liblc3.lc3_encode( + encoders[i], + DEFAULT_PCM_FORMAT, + input_buffer_offset + DEFAULT_PCM_BYTES_PER_SAMPLE * i, + stride, + output_frame_size, + output_buffer_offset + output_frame_size * i, + ) + + if res != 0: + logging.error(f"Parsing failed, res={res}") + + # Extract decoded data from the output buffer + return bytes( + memory[output_buffer_offset : output_buffer_offset + output_buffer_size] + ) + + +async def lc3_source_task( + filename: str, + sdu_length: int, + frame_duration_us: int, + device: Device, + cis_handle: int, +) -> None: + with open(filename, 'rb') as f: + header = f.read(44) + assert header[8:12] == b'WAVE' + + pcm_num_channel, pcm_sample_rate, _byte_rate, _block_align, bits_per_sample = ( + struct.unpack(" None: + self.speaker = weakref.ref(speaker) + self.port = port + self.channel_socket = None + + async def start_http(self) -> None: + """Start the UI HTTP server.""" + + app = aiohttp.web.Application() + app.add_routes( + [ + aiohttp.web.get('/', self.get_static), + aiohttp.web.get('/index.html', self.get_static), + aiohttp.web.get('/channel', self.get_channel), + ] + ) + + runner = aiohttp.web.AppRunner(app) + await runner.setup() + site = aiohttp.web.TCPSite(runner, 'localhost', self.port) + print('UI HTTP server at ' + color(f'http://127.0.0.1:{self.port}', 'green')) + await site.start() + + async def get_static(self, request): + path = request.path + if path == '/': + path = '/index.html' + if path.endswith('.html'): + content_type = 'text/html' + elif path.endswith('.js'): + content_type = 'text/javascript' + elif path.endswith('.css'): + content_type = 'text/css' + elif path.endswith('.svg'): + content_type = 'image/svg+xml' + else: + content_type = 'text/plain' + text = ( + resources.files("bumble.apps.lea_unicast") + .joinpath(pathlib.Path(path).relative_to('/')) + .read_text(encoding="utf-8") + ) + return aiohttp.web.Response(text=text, content_type=content_type) + + async def get_channel(self, request): + ws = aiohttp.web.WebSocketResponse() + await ws.prepare(request) + + # Process messages until the socket is closed. + self.channel_socket = ws + async for message in ws: + if message.type == aiohttp.WSMsgType.TEXT: + logger.debug(f'<<< received message: {message.data}') + await self.on_message(message.data) + elif message.type == aiohttp.WSMsgType.ERROR: + logger.debug( + f'channel connection closed with exception {ws.exception()}' + ) + + self.channel_socket = None + logger.debug('--- channel connection closed') + + return ws + + async def on_message(self, message_str: str): + # Parse the message as JSON + message = json.loads(message_str) + + # Dispatch the message + message_type = message['type'] + message_params = message.get('params', {}) + handler = getattr(self, f'on_{message_type}_message') + if handler: + await handler(**message_params) + + async def on_hello_message(self): + await self.send_message( + 'hello', + bumble_version=bumble.__version__, + codec=self.speaker().codec, + streamState=self.speaker().stream_state.name, + ) + if connection := self.speaker().connection: + await self.send_message( + 'connection', + peer_address=connection.peer_address.to_string(False), + peer_name=connection.peer_name, + ) + + async def send_message(self, message_type: str, **kwargs) -> None: + if self.channel_socket is None: + return + + message = {'type': message_type, 'params': kwargs} + await self.channel_socket.send_json(message) + + async def send_audio(self, data: bytes) -> None: + if self.channel_socket is None: + return + + try: + await self.channel_socket.send_bytes(data) + except Exception as error: + logger.warning(f'exception while sending audio packet: {error}') + + +# ----------------------------------------------------------------------------- +class Speaker: + + def __init__( + self, + device_config_path: Optional[str], + ui_port: int, + transport: str, + lc3_input_file_path: str, + ): + self.device_config_path = device_config_path + self.transport = transport + self.lc3_input_file_path = lc3_input_file_path + + # Create an HTTP server for the UI + self.ui_server = UiServer(speaker=self, port=ui_port) + + async def run(self) -> None: + await self.ui_server.start_http() + + async with await open_transport(self.transport) as hci_transport: + # Create a device + if self.device_config_path: + device_config = DeviceConfiguration.from_file(self.device_config_path) + else: + device_config = DeviceConfiguration( + name="Bumble LE Headphone", + class_of_device=0x244418, + keystore="JsonKeyStore", + advertising_interval_min=25, + advertising_interval_max=25, + address=Address('F1:F2:F3:F4:F5:F6'), + ) + + device_config.le_enabled = True + device_config.cis_enabled = True + self.device = Device.from_config_with_hci( + device_config, hci_transport.source, hci_transport.sink + ) + + self.device.add_service( + bap.PublishedAudioCapabilitiesService( + supported_source_context=bap.ContextType(0xFFFF), + available_source_context=bap.ContextType(0xFFFF), + supported_sink_context=bap.ContextType(0xFFFF), # All context types + available_sink_context=bap.ContextType(0xFFFF), # All context types + sink_audio_locations=( + bap.AudioLocation.FRONT_LEFT | bap.AudioLocation.FRONT_RIGHT + ), + sink_pac=[_sink_pac_record()], + source_audio_locations=bap.AudioLocation.FRONT_LEFT, + source_pac=[_source_pac_record()], + ) + ) + + ascs = bap.AudioStreamControlService( + self.device, sink_ase_id=[1], source_ase_id=[2] + ) + self.device.add_service(ascs) + + advertising_data = bytes( + AdvertisingData( + [ + ( + AdvertisingData.COMPLETE_LOCAL_NAME, + bytes(device_config.name, 'utf-8'), + ), + ( + AdvertisingData.FLAGS, + bytes([AdvertisingData.LE_GENERAL_DISCOVERABLE_MODE_FLAG]), + ), + ( + AdvertisingData.INCOMPLETE_LIST_OF_16_BIT_SERVICE_CLASS_UUIDS, + bytes(bap.PublishedAudioCapabilitiesService.UUID), + ), + ] + ) + ) + bytes(bap.UnicastServerAdvertisingData()) + + def on_pdu(pdu: HCI_IsoDataPacket, ase: bap.AseStateMachine): + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + pcm = decode( + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.channel_count, + pdu.iso_sdu_fragment, + ) + self.device.abort_on('disconnection', self.ui_server.send_audio(pcm)) + + def on_ase_state_change(ase: bap.AseStateMachine) -> None: + if ase.state == bap.AseStateMachine.State.STREAMING: + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + assert ase.cis_link + if ase.role == bap.AudioRole.SOURCE: + ase.cis_link.abort_on( + 'disconnection', + lc3_source_task( + filename=self.lc3_input_file_path, + sdu_length=( + codec_config.codec_frames_per_sdu + * codec_config.octets_per_codec_frame + ), + frame_duration_us=codec_config.frame_duration.us, + device=self.device, + cis_handle=ase.cis_link.handle, + ), + ) + else: + ase.cis_link.sink = functools.partial(on_pdu, ase=ase) + elif ase.state == bap.AseStateMachine.State.CODEC_CONFIGURED: + codec_config = ase.codec_specific_configuration + assert isinstance(codec_config, bap.CodecSpecificConfiguration) + if ase.role == bap.AudioRole.SOURCE: + setup_encoders( + codec_config.sampling_frequency.hz, + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.channel_count, + ) + else: + setup_decoders( + codec_config.sampling_frequency.hz, + codec_config.frame_duration.us, + codec_config.audio_channel_allocation.channel_count, + ) + + for ase in ascs.ase_state_machines.values(): + ase.on('state_change', functools.partial(on_ase_state_change, ase=ase)) + + await self.device.power_on() + await self.device.create_advertising_set( + advertising_data=advertising_data, + auto_restart=True, + advertising_parameters=AdvertisingParameters( + primary_advertising_interval_min=100, + primary_advertising_interval_max=100, + ), + ) + + await hci_transport.source.terminated + + +@click.command() +@click.option( + '--ui-port', + 'ui_port', + metavar='HTTP_PORT', + default=DEFAULT_UI_PORT, + show_default=True, + help='HTTP port for the UI server', +) +@click.option('--device-config', metavar='FILENAME', help='Device configuration file') +@click.argument('transport') +@click.argument('lc3_file') +def speaker(ui_port: int, device_config: str, transport: str, lc3_file: str) -> None: + """Run the speaker.""" + + asyncio.run(Speaker(device_config, ui_port, transport, lc3_file).run()) + + +# ----------------------------------------------------------------------------- +def main(): + logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'WARNING').upper()) + speaker() + + +# ----------------------------------------------------------------------------- +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/apps/lea_unicast/index.html b/apps/lea_unicast/index.html new file mode 100644 index 00000000..fb1e61c9 --- /dev/null +++ b/apps/lea_unicast/index.html @@ -0,0 +1,68 @@ + + + + + + + + + +
+ +
+ + +
+ + + + + + + \ No newline at end of file diff --git a/apps/lea_unicast/liblc3.wasm b/apps/lea_unicast/liblc3.wasm new file mode 100755 index 00000000..e9051058 Binary files /dev/null and b/apps/lea_unicast/liblc3.wasm differ diff --git a/bumble/device.py b/bumble/device.py index 6bc945a3..8d6b8831 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -23,7 +23,13 @@ import asyncio import logging import secrets -from contextlib import asynccontextmanager, AsyncExitStack, closing +import sys +from contextlib import ( + asynccontextmanager, + AsyncExitStack, + closing, + AbstractAsyncContextManager, +) from dataclasses import dataclass, field from collections.abc import Iterable from typing import ( @@ -961,8 +967,9 @@ class ScoLink(CompositeEventEmitter): acl_connection: Connection handle: int link_type: int + sink: Optional[Callable[[HCI_SynchronousDataPacket], Any]] = None - def __post_init__(self): + def __post_init__(self) -> None: super().__init__() async def disconnect( @@ -984,8 +991,9 @@ class State(IntEnum): cis_id: int # CIS ID assigned by Central device cig_id: int # CIG ID assigned by Central device state: State = State.PENDING + sink: Optional[Callable[[HCI_IsoDataPacket], Any]] = None - def __post_init__(self): + def __post_init__(self) -> None: super().__init__() async def disconnect( @@ -1533,6 +1541,12 @@ def __init__( Address.ANY: [] } # Futures, by BD address OR [Futures] for Address.ANY + # In Python <= 3.9 + Rust Runtime, asyncio.Lock cannot be properly initiated. + if sys.version_info >= (3, 10): + self._cis_lock = asyncio.Lock() + else: + self._cis_lock = AsyncExitStack() + # Own address type cache self.connect_own_address_type = None @@ -3406,49 +3420,71 @@ async def create_cis(self, cis_acl_pairs: List[Tuple[int, int]]) -> List[CisLink for cis_handle, _ in cis_acl_pairs } - @watcher.on(self, 'cis_establishment') def on_cis_establishment(cis_link: CisLink) -> None: if pending_future := pending_cis_establishments.get(cis_link.handle): pending_future.set_result(cis_link) - result = await self.send_command( + def on_cis_establishment_failure(cis_handle: int, status: int) -> None: + if pending_future := pending_cis_establishments.get(cis_handle): + pending_future.set_exception(HCI_Error(status)) + + watcher.on(self, 'cis_establishment', on_cis_establishment) + watcher.on(self, 'cis_establishment_failure', on_cis_establishment_failure) + await self.send_command( HCI_LE_Create_CIS_Command( cis_connection_handle=[p[0] for p in cis_acl_pairs], acl_connection_handle=[p[1] for p in cis_acl_pairs], ), + check_result=True, ) - if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Create_CIS_Command failed: ' - f'{HCI_Constant.error_name(result.status)}' - ) - raise HCI_StatusError(result) return await asyncio.gather(*pending_cis_establishments.values()) # [LE only] @experimental('Only for testing.') async def accept_cis_request(self, handle: int) -> CisLink: - result = await self.send_command( - HCI_LE_Accept_CIS_Request_Command(connection_handle=handle), - ) - if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Accept_CIS_Request_Command failed: ' - f'{HCI_Constant.error_name(result.status)}' - ) - raise HCI_StatusError(result) + """[LE Only] Accepts an incoming CIS request. - pending_cis_establishment = asyncio.get_running_loop().create_future() + When the specified CIS handle is already created, this method returns the + existed CIS link object immediately. - with closing(EventWatcher()) as watcher: + Args: + handle: CIS handle to accept. - @watcher.on(self, 'cis_establishment') - def on_cis_establishment(cis_link: CisLink) -> None: - if cis_link.handle == handle: - pending_cis_establishment.set_result(cis_link) + Returns: + CIS link object on the given handle. + """ + if not (cis_link := self.cis_links.get(handle)): + raise InvalidStateError(f'No pending CIS request of handle {handle}') + + # There might be multiple ASE sharing a CIS channel. + # If one of them has accepted the request, the others should just leverage it. + async with self._cis_lock: + if cis_link.state == CisLink.State.ESTABLISHED: + return cis_link + + with closing(EventWatcher()) as watcher: + pending_establishment = asyncio.get_running_loop().create_future() + + def on_establishment() -> None: + pending_establishment.set_result(None) + + def on_establishment_failure(status: int) -> None: + pending_establishment.set_exception(HCI_Error(status)) + + watcher.on(cis_link, 'establishment', on_establishment) + watcher.on(cis_link, 'establishment_failure', on_establishment_failure) + + await self.send_command( + HCI_LE_Accept_CIS_Request_Command(connection_handle=handle), + check_result=True, + ) + + await pending_establishment + return cis_link - return await pending_cis_establishment + # Mypy believes this is reachable when context is an ExitStack. + raise InvalidStateError('Unreachable') # [LE only] @experimental('Only for testing.') @@ -3457,15 +3493,10 @@ async def reject_cis_request( handle: int, reason: int = HCI_REMOTE_USER_TERMINATED_CONNECTION_ERROR, ) -> None: - result = await self.send_command( + await self.send_command( HCI_LE_Reject_CIS_Request_Command(connection_handle=handle, reason=reason), + check_result=True, ) - if result.status != HCI_COMMAND_STATUS_PENDING: - logger.warning( - 'HCI_LE_Reject_CIS_Request_Command failed: ' - f'{HCI_Constant.error_name(result.status)}' - ) - raise HCI_StatusError(result) async def get_remote_le_features(self, connection: Connection) -> LeFeatureMask: """[LE Only] Reads remote LE supported features. @@ -3485,11 +3516,17 @@ def on_le_remote_features(handle: int, features: int): if handle == connection.handle: read_feature_future.set_result(LeFeatureMask(features)) + def on_failure(handle: int, status: int): + if handle == connection.handle: + read_feature_future.set_exception(HCI_Error(status)) + watcher.on(self.host, 'le_remote_features', on_le_remote_features) + watcher.on(self.host, 'le_remote_features_failure', on_failure) await self.send_command( HCI_LE_Read_Remote_Features_Command( connection_handle=connection.handle ), + check_result=True, ) return await read_feature_future @@ -4111,8 +4148,8 @@ def on_sco_connection_failure( @host_event_handler @experimental('Only for testing') def on_sco_packet(self, sco_handle: int, packet: HCI_SynchronousDataPacket) -> None: - if sco_link := self.sco_links.get(sco_handle): - sco_link.emit('pdu', packet) + if (sco_link := self.sco_links.get(sco_handle)) and sco_link.sink: + sco_link.sink(packet) # [LE only] @host_event_handler @@ -4168,15 +4205,15 @@ def on_cis_establishment(self, cis_handle: int) -> None: def on_cis_establishment_failure(self, cis_handle: int, status: int) -> None: logger.debug(f'*** CIS Establishment Failure: cis=[0x{cis_handle:04X}] ***') if cis_link := self.cis_links.pop(cis_handle): - cis_link.emit('establishment_failure') + cis_link.emit('establishment_failure', status) self.emit('cis_establishment_failure', cis_handle, status) # [LE only] @host_event_handler @experimental('Only for testing') def on_iso_packet(self, handle: int, packet: HCI_IsoDataPacket) -> None: - if cis_link := self.cis_links.get(handle): - cis_link.emit('pdu', packet) + if (cis_link := self.cis_links.get(handle)) and cis_link.sink: + cis_link.sink(packet) @host_event_handler @with_connection_from_handle diff --git a/bumble/hci.py b/bumble/hci.py index fba89515..9ef40bf2 100644 --- a/bumble/hci.py +++ b/bumble/hci.py @@ -23,7 +23,7 @@ import logging import secrets import struct -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, ClassVar from bumble import crypto from .colors import color @@ -2003,7 +2003,7 @@ class HCI_Packet: Abstract Base class for HCI packets ''' - hci_packet_type: int + hci_packet_type: ClassVar[int] @staticmethod def from_bytes(packet: bytes) -> HCI_Packet: @@ -6192,12 +6192,23 @@ def __str__(self) -> str: # ----------------------------------------------------------------------------- +@dataclasses.dataclass class HCI_IsoDataPacket(HCI_Packet): ''' See Bluetooth spec @ 5.4.5 HCI ISO Data Packets ''' - hci_packet_type = HCI_ISO_DATA_PACKET + hci_packet_type: ClassVar[int] = HCI_ISO_DATA_PACKET + + connection_handle: int + data_total_length: int + iso_sdu_fragment: bytes + pb_flag: int + ts_flag: int = 0 + time_stamp: Optional[int] = None + packet_sequence_number: Optional[int] = None + iso_sdu_length: Optional[int] = None + packet_status_flag: Optional[int] = None @staticmethod def from_bytes(packet: bytes) -> HCI_IsoDataPacket: @@ -6241,28 +6252,6 @@ def from_bytes(packet: bytes) -> HCI_IsoDataPacket: iso_sdu_fragment=iso_sdu_fragment, ) - def __init__( - self, - connection_handle: int, - pb_flag: int, - ts_flag: int, - data_total_length: int, - time_stamp: Optional[int], - packet_sequence_number: Optional[int], - iso_sdu_length: Optional[int], - packet_status_flag: Optional[int], - iso_sdu_fragment: bytes, - ) -> None: - self.connection_handle = connection_handle - self.pb_flag = pb_flag - self.ts_flag = ts_flag - self.data_total_length = data_total_length - self.time_stamp = time_stamp - self.packet_sequence_number = packet_sequence_number - self.iso_sdu_length = iso_sdu_length - self.packet_status_flag = packet_status_flag - self.iso_sdu_fragment = iso_sdu_fragment - def __bytes__(self) -> bytes: return self.to_bytes() diff --git a/bumble/host.py b/bumble/host.py index 4223d0c4..64b66688 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -721,14 +721,16 @@ def on_hci_number_of_completed_packets_event(self, event): for connection_handle, num_completed_packets in zip( event.connection_handles, event.num_completed_packets ): - if not (connection := self.connections.get(connection_handle)): + if connection := self.connections.get(connection_handle): + connection.acl_packet_queue.on_packets_completed(num_completed_packets) + elif not ( + self.cis_links.get(connection_handle) + or self.sco_links.get(connection_handle) + ): logger.warning( 'received packet completion event for unknown handle ' f'0x{connection_handle:04X}' ) - continue - - connection.acl_packet_queue.on_packets_completed(num_completed_packets) # Classic only def on_hci_connection_request_event(self, event): diff --git a/bumble/profiles/bap.py b/bumble/profiles/bap.py index b54ad1dc..c0123b11 100644 --- a/bumble/profiles/bap.py +++ b/bumble/profiles/bap.py @@ -78,6 +78,10 @@ class AudioLocation(enum.IntFlag): LEFT_SURROUND = 0x04000000 RIGHT_SURROUND = 0x08000000 + @property + def channel_count(self) -> int: + return bin(self.value).count('1') + class AudioInputType(enum.IntEnum): '''Bluetooth Assigned Numbers, Section 6.12.2 - Audio Input Type''' @@ -218,6 +222,13 @@ class FrameDuration(enum.IntEnum): DURATION_7500_US = 0x00 DURATION_10000_US = 0x01 + @property + def us(self) -> int: + return { + FrameDuration.DURATION_7500_US: 7500, + FrameDuration.DURATION_10000_US: 10000, + }[self] + class SupportedFrameDuration(enum.IntFlag): '''Bluetooth Assigned Numbers, Section 6.12.4.2 - Frame Duration''' @@ -534,7 +545,7 @@ class Type(enum.IntEnum): supported_sampling_frequencies: SupportedSamplingFrequency supported_frame_durations: SupportedFrameDuration - supported_audio_channel_counts: Sequence[int] + supported_audio_channel_count: Sequence[int] min_octets_per_codec_frame: int max_octets_per_codec_frame: int supported_max_codec_frames_per_sdu: int @@ -543,7 +554,7 @@ class Type(enum.IntEnum): def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: offset = 0 # Allowed default values. - supported_audio_channel_counts = [1] + supported_audio_channel_count = [1] supported_max_codec_frames_per_sdu = 1 while offset < len(data): length, type = struct.unpack_from('BB', data, offset) @@ -556,7 +567,7 @@ def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: elif type == CodecSpecificCapabilities.Type.FRAME_DURATION: supported_frame_durations = SupportedFrameDuration(value) elif type == CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT: - supported_audio_channel_counts = bits_to_channel_counts(value) + supported_audio_channel_count = bits_to_channel_counts(value) elif type == CodecSpecificCapabilities.Type.OCTETS_PER_FRAME: min_octets_per_sample = value & 0xFFFF max_octets_per_sample = value >> 16 @@ -567,7 +578,7 @@ def from_bytes(cls, data: bytes) -> CodecSpecificCapabilities: return CodecSpecificCapabilities( supported_sampling_frequencies=supported_sampling_frequencies, supported_frame_durations=supported_frame_durations, - supported_audio_channel_counts=supported_audio_channel_counts, + supported_audio_channel_count=supported_audio_channel_count, min_octets_per_codec_frame=min_octets_per_sample, max_octets_per_codec_frame=max_octets_per_sample, supported_max_codec_frames_per_sdu=supported_max_codec_frames_per_sdu, @@ -584,7 +595,7 @@ def __bytes__(self) -> bytes: self.supported_frame_durations, 2, CodecSpecificCapabilities.Type.AUDIO_CHANNEL_COUNT, - channel_counts_to_bits(self.supported_audio_channel_counts), + channel_counts_to_bits(self.supported_audio_channel_count), 5, CodecSpecificCapabilities.Type.OCTETS_PER_FRAME, self.min_octets_per_codec_frame, @@ -870,15 +881,22 @@ def on_cis_request( cig_id: int, cis_id: int, ) -> None: - if cis_id == self.cis_id and self.state == self.State.ENABLING: + if ( + cig_id == self.cig_id + and cis_id == self.cis_id + and self.state == self.State.ENABLING + ): acl_connection.abort_on( 'flush', self.service.device.accept_cis_request(cis_handle) ) def on_cis_establishment(self, cis_link: device.CisLink) -> None: - if cis_link.cis_id == self.cis_id and self.state == self.State.ENABLING: - self.state = self.State.STREAMING - self.cis_link = cis_link + if ( + cis_link.cig_id == self.cig_id + and cis_link.cis_id == self.cis_id + and self.state == self.State.ENABLING + ): + cis_link.on('disconnection', self.on_cis_disconnection) async def post_cis_established(): await self.service.device.send_command( @@ -891,9 +909,15 @@ async def post_cis_established(): codec_configuration=b'', ) ) + if self.role == AudioRole.SINK: + self.state = self.State.STREAMING await self.service.device.notify_subscribers(self, self.value) cis_link.acl_connection.abort_on('flush', post_cis_established()) + self.cis_link = cis_link + + def on_cis_disconnection(self, _reason) -> None: + self.cis_link = None def on_config_codec( self, @@ -991,11 +1015,17 @@ def on_disable(self) -> Tuple[AseResponseCode, AseReasonCode]: AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseReasonCode.NONE, ) - self.state = self.State.DISABLING + if self.role == AudioRole.SINK: + self.state = self.State.QOS_CONFIGURED + else: + self.state = self.State.DISABLING return (AseResponseCode.SUCCESS, AseReasonCode.NONE) def on_receiver_stop_ready(self) -> Tuple[AseResponseCode, AseReasonCode]: - if self.state != AseStateMachine.State.DISABLING: + if ( + self.role != AudioRole.SOURCE + or self.state != AseStateMachine.State.DISABLING + ): return ( AseResponseCode.INVALID_ASE_STATE_MACHINE_TRANSITION, AseReasonCode.NONE, @@ -1046,6 +1076,7 @@ def state(self) -> State: def state(self, new_state: State) -> None: logger.debug(f'{self} state change -> {colors.color(new_state.name, "cyan")}') self._state = new_state + self.emit('state_change') @property def value(self): @@ -1118,6 +1149,7 @@ class AudioStreamControlService(gatt.TemplateService): ase_state_machines: Dict[int, AseStateMachine] ase_control_point: gatt.Characteristic + _active_client: Optional[device.Connection] = None def __init__( self, @@ -1155,7 +1187,16 @@ def on_operation(self, opcode: ASE_Operation.Opcode, ase_id: int, args): else: return (ase_id, AseResponseCode.INVALID_ASE_ID, AseReasonCode.NONE) + def _on_client_disconnected(self, _reason: int) -> None: + for ase in self.ase_state_machines.values(): + ase.state = AseStateMachine.State.IDLE + self._active_client = None + def on_write_ase_control_point(self, connection, data): + if not self._active_client and connection: + self._active_client = connection + connection.once('disconnection', self._on_client_disconnected) + operation = ASE_Operation.from_bytes(data) responses = [] logger.debug(f'*** ASCS Write {operation} ***') diff --git a/examples/run_hfp_gateway.py b/examples/run_hfp_gateway.py index 8e596f80..851f97cf 100644 --- a/examples/run_hfp_gateway.py +++ b/examples/run_hfp_gateway.py @@ -26,7 +26,7 @@ from typing import Optional import bumble.core -from bumble.device import Device +from bumble.device import Device, ScoLink from bumble.transport import open_transport_or_link from bumble.core import ( BT_BR_EDR_TRANSPORT, @@ -217,11 +217,11 @@ def on_dlc(dlc: rfcomm.DLC): 1: hfp.make_ag_sdp_records(1, channel, configuration) } - def on_sco_connection(sco_link): + def on_sco_connection(sco_link: ScoLink): assert ag_protocol on_sco_state_change(ag_protocol.active_codec) sco_link.on('disconnection', lambda _: on_sco_state_change(0)) - sco_link.on('pdu', on_sco_packet) + sco_link.sink = on_sco_packet device.on('sco_connection', on_sco_connection) if len(sys.argv) >= 4: diff --git a/examples/run_unicast_server.py b/examples/run_unicast_server.py index 60d2f4ae..95ae5510 100644 --- a/examples/run_unicast_server.py +++ b/examples/run_unicast_server.py @@ -16,20 +16,28 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import datetime +import functools import logging import sys import os +import io import struct import secrets + +from typing import Dict + from bumble.core import AdvertisingData -from bumble.device import Device, CisLink +from bumble.device import Device from bumble.hci import ( CodecID, CodingFormat, HCI_IsoDataPacket, ) from bumble.profiles.bap import ( + AseStateMachine, UnicastServerAdvertisingData, + CodecSpecificConfiguration, CodecSpecificCapabilities, ContextType, AudioLocation, @@ -45,6 +53,32 @@ from bumble.transport import open_transport_or_link +def _sink_pac_record() -> PacRecord: + return PacRecord( + coding_format=CodingFormat(CodecID.LC3), + codec_specific_capabilities=CodecSpecificCapabilities( + supported_sampling_frequencies=( + SupportedSamplingFrequency.FREQ_8000 + | SupportedSamplingFrequency.FREQ_16000 + | SupportedSamplingFrequency.FREQ_24000 + | SupportedSamplingFrequency.FREQ_32000 + | SupportedSamplingFrequency.FREQ_48000 + ), + supported_frame_durations=( + SupportedFrameDuration.DURATION_7500_US_SUPPORTED + | SupportedFrameDuration.DURATION_10000_US_SUPPORTED + ), + supported_audio_channel_count=[1, 2], + min_octets_per_codec_frame=26, + max_octets_per_codec_frame=240, + supported_max_codec_frames_per_sdu=2, + ), + ) + + +file_outputs: Dict[AseStateMachine, io.BufferedWriter] = {} + + # ----------------------------------------------------------------------------- async def main() -> None: if len(sys.argv) < 3: @@ -71,49 +105,17 @@ async def main() -> None: PublishedAudioCapabilitiesService( supported_source_context=ContextType.PROHIBITED, available_source_context=ContextType.PROHIBITED, - supported_sink_context=ContextType.MEDIA, - available_sink_context=ContextType.MEDIA, + supported_sink_context=ContextType(0xFF), # All context types + available_sink_context=ContextType(0xFF), # All context types sink_audio_locations=( AudioLocation.FRONT_LEFT | AudioLocation.FRONT_RIGHT ), - sink_pac=[ - # Codec Capability Setting 16_2 - PacRecord( - coding_format=CodingFormat(CodecID.LC3), - codec_specific_capabilities=CodecSpecificCapabilities( - supported_sampling_frequencies=( - SupportedSamplingFrequency.FREQ_16000 - ), - supported_frame_durations=( - SupportedFrameDuration.DURATION_10000_US_SUPPORTED - ), - supported_audio_channel_counts=[1], - min_octets_per_codec_frame=40, - max_octets_per_codec_frame=40, - supported_max_codec_frames_per_sdu=1, - ), - ), - # Codec Capability Setting 24_2 - PacRecord( - coding_format=CodingFormat(CodecID.LC3), - codec_specific_capabilities=CodecSpecificCapabilities( - supported_sampling_frequencies=( - SupportedSamplingFrequency.FREQ_48000 - ), - supported_frame_durations=( - SupportedFrameDuration.DURATION_10000_US_SUPPORTED - ), - supported_audio_channel_counts=[1], - min_octets_per_codec_frame=120, - max_octets_per_codec_frame=120, - supported_max_codec_frames_per_sdu=1, - ), - ), - ], + sink_pac=[_sink_pac_record()], ) ) - device.add_service(AudioStreamControlService(device, sink_ase_id=[1, 2])) + ascs = AudioStreamControlService(device, sink_ase_id=[1], source_ase_id=[2]) + device.add_service(ascs) advertising_data = ( bytes( @@ -143,44 +145,57 @@ async def main() -> None: + csis.get_advertising_data() + bytes(UnicastServerAdvertisingData()) ) - subprocess = await asyncio.create_subprocess_shell( - f'dlc3 | ffplay pipe:0', - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - stdin = subprocess.stdin - assert stdin - - # Write a fake LC3 header to dlc3. - stdin.write( - bytes([0x1C, 0xCC]) # Header. - + struct.pack( - ' None: + if state != AseStateMachine.State.STREAMING: + if file_output := file_outputs.pop(ase): + file_output.close() + else: + file_output = open(f'{datetime.datetime.now().isoformat()}.lc3', 'wb') + codec_configuration = ase.codec_specific_configuration + assert isinstance(codec_configuration, CodecSpecificConfiguration) + # Write a LC3 header. + file_output.write( + bytes([0x1C, 0xCC]) # Header. + + struct.pack( + ' None: supported_frame_durations=( SupportedFrameDuration.DURATION_10000_US_SUPPORTED ), - supported_audio_channel_counts=[1], + supported_audio_channel_count=[1], min_octets_per_codec_frame=120, max_octets_per_codec_frame=120, supported_max_codec_frames_per_sdu=1, diff --git a/setup.cfg b/setup.cfg index 129ae51f..4dd551e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -96,6 +96,7 @@ development = types-appdirs >= 1.4.3 types-invoke >= 1.7.3 types-protobuf >= 4.21.0 + wasmtime == 20.0.0 avatar = pandora-avatar == 0.0.9 rootcanal == 1.10.0 ; python_version>='3.10' diff --git a/tests/bap_test.py b/tests/bap_test.py index bc223c14..0b6db1a7 100644 --- a/tests/bap_test.py +++ b/tests/bap_test.py @@ -72,7 +72,7 @@ def test_codec_specific_capabilities() -> None: cap = CodecSpecificCapabilities( supported_sampling_frequencies=SAMPLE_FREQUENCY, supported_frame_durations=FRAME_SURATION, - supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, + supported_audio_channel_count=AUDIO_CHANNEL_COUNTS, min_octets_per_codec_frame=40, max_octets_per_codec_frame=40, supported_max_codec_frames_per_sdu=1, @@ -88,7 +88,7 @@ def test_pac_record() -> None: cap = CodecSpecificCapabilities( supported_sampling_frequencies=SAMPLE_FREQUENCY, supported_frame_durations=FRAME_SURATION, - supported_audio_channel_counts=AUDIO_CHANNEL_COUNTS, + supported_audio_channel_count=AUDIO_CHANNEL_COUNTS, min_octets_per_codec_frame=40, max_octets_per_codec_frame=40, supported_max_codec_frames_per_sdu=1, @@ -216,7 +216,7 @@ async def test_pacs(): supported_frame_durations=( SupportedFrameDuration.DURATION_10000_US_SUPPORTED ), - supported_audio_channel_counts=[1], + supported_audio_channel_count=[1], min_octets_per_codec_frame=40, max_octets_per_codec_frame=40, supported_max_codec_frames_per_sdu=1, @@ -232,7 +232,7 @@ async def test_pacs(): supported_frame_durations=( SupportedFrameDuration.DURATION_10000_US_SUPPORTED ), - supported_audio_channel_counts=[1], + supported_audio_channel_count=[1], min_octets_per_codec_frame=60, max_octets_per_codec_frame=60, supported_max_codec_frames_per_sdu=1, diff --git a/tests/device_test.py b/tests/device_test.py index 5d872826..814fed6f 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -16,6 +16,7 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import functools import logging import os from types import LambdaType @@ -35,12 +36,14 @@ HCI_COMMAND_STATUS_PENDING, HCI_CREATE_CONNECTION_COMMAND, HCI_SUCCESS, + HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, Address, OwnAddressType, HCI_Command_Complete_Event, HCI_Command_Status_Event, HCI_Connection_Complete_Event, HCI_Connection_Request_Event, + HCI_Error, HCI_Packet, ) from bumble.gatt import ( @@ -52,6 +55,10 @@ from .test_utils import TwoDevices, async_barrier +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +_TIMEOUT = 0.1 # ----------------------------------------------------------------------------- # Logging @@ -385,6 +392,29 @@ async def test_get_remote_le_features(): assert (await devices.connections[0].get_remote_le_features()) is not None +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_get_remote_le_features_failed(): + devices = TwoDevices() + await devices.setup_connection() + + def on_hci_le_read_remote_features_complete_event(event): + devices[0].host.emit( + 'le_remote_features_failure', + event.connection_handle, + HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, + ) + + devices[0].host.on_hci_le_read_remote_features_complete_event = ( + on_hci_le_read_remote_features_complete_event + ) + + with pytest.raises(HCI_Error): + await asyncio.wait_for( + devices.connections[0].get_remote_le_features(), _TIMEOUT + ) + + # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_cis(): @@ -433,6 +463,65 @@ def on_cis_request( await cis_links[1].disconnect() +# ----------------------------------------------------------------------------- +@pytest.mark.asyncio +async def test_cis_setup_failure(): + devices = TwoDevices() + await devices.setup_connection() + + cis_requests = asyncio.Queue() + + def on_cis_request( + acl_connection: Connection, + cis_handle: int, + cig_id: int, + cis_id: int, + ): + del acl_connection, cig_id, cis_id + cis_requests.put_nowait(cis_handle) + + devices[1].on('cis_request', on_cis_request) + + cis_handles = await devices[0].setup_cig( + cig_id=1, + cis_id=[2], + sdu_interval=(0, 0), + framing=0, + max_sdu=(0, 0), + retransmission_number=0, + max_transport_latency=(0, 0), + ) + assert len(cis_handles) == 1 + + cis_create_task = asyncio.create_task( + devices[0].create_cis( + [ + (cis_handles[0], devices.connections[0].handle), + ] + ) + ) + + def on_hci_le_cis_established_event(host, event): + host.emit( + 'cis_establishment_failure', + event.connection_handle, + HCI_CONNECTION_FAILED_TO_BE_ESTABLISHED_ERROR, + ) + + for device in devices: + device.host.on_hci_le_cis_established_event = functools.partial( + on_hci_le_cis_established_event, device.host + ) + + cis_request = await asyncio.wait_for(cis_requests.get(), _TIMEOUT) + + with pytest.raises(HCI_Error): + await asyncio.wait_for(devices[1].accept_cis_request(cis_request), _TIMEOUT) + + with pytest.raises(HCI_Error): + await asyncio.wait_for(cis_create_task, _TIMEOUT) + + # ----------------------------------------------------------------------------- def test_gatt_services_with_gas(): device = Device(host=Host(None, None))