From 43b4d667ccf9922d04e9afd944dfdc1149404fab Mon Sep 17 00:00:00 2001 From: Gilles Boccon-Gibod Date: Mon, 7 Aug 2023 11:34:58 -0700 Subject: [PATCH] wip (+2 squashed commits) Squashed commits: [3dafaa8] wip [5844026] wip (+1 squashed commit) Squashed commits: [4cbb35a] wip (+1 squashed commit) Squashed commits: [4d2b6d3] wip (+4 squashed commits) Squashed commits: [f2da510] wip [318c119] wip [923b4eb] wip [9d46365] wip --- .vscode/settings.json | 2 + bumble/a2dp.py | 16 +- bumble/avc.py | 512 ++++++++++++++ bumble/avctp.py | 286 ++++++++ bumble/avdtp.py | 8 +- bumble/avrcp.py | 1114 ++++++++++++++++++++++++++++++ bumble/core.py | 18 +- bumble/helpers.py | 81 ++- bumble/sdp.py | 6 +- bumble/utils.py | 62 +- examples/run_avrcp_controller.py | 255 +++++++ examples/run_avrcp_target.py | 221 ++++++ tests/avrcp_test.py | 173 +++++ tests/utils_test.py | 36 +- 14 files changed, 2734 insertions(+), 56 deletions(-) create mode 100644 bumble/avc.py create mode 100644 bumble/avctp.py create mode 100644 bumble/avrcp.py create mode 100644 examples/run_avrcp_controller.py create mode 100644 examples/run_avrcp_target.py create mode 100644 tests/avrcp_test.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 04a7f409..e4cff0ba 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -12,7 +12,9 @@ "ASHA", "asyncio", "ATRAC", + "avctp", "avdtp", + "avrcp", "bitpool", "bitstruct", "BSCP", diff --git a/bumble/a2dp.py b/bumble/a2dp.py index eeecb1ee..f8c0df12 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -180,8 +180,12 @@ def make_audio_source_service_sdp_records(service_record_handle, version=(1, 3)) SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence( [ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int), + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ) ] ), ), @@ -230,8 +234,12 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, DataElement.sequence( [ - DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), - DataElement.unsigned_integer_16(version_int), + DataElement.sequence( + [ + DataElement.uuid(BT_ADVANCED_AUDIO_DISTRIBUTION_SERVICE), + DataElement.unsigned_integer_16(version_int), + ] + ) ] ), ), diff --git a/bumble/avc.py b/bumble/avc.py new file mode 100644 index 00000000..ab1b501b --- /dev/null +++ b/bumble/avc.py @@ -0,0 +1,512 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import enum +import struct +from typing import Dict, Type, Union, Tuple + +from bumble.utils import OpenIntEnum + + +# ----------------------------------------------------------------------------- +class Frame: + class SubunitType(enum.IntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.4 + MONITOR = 0x00 + AUDIO = 0x01 + PRINTER = 0x02 + DISC = 0x03 + TAPE_RECORDER_OR_PLAYER = 0x04 + TUNER = 0x05 + CA = 0x06 + CAMERA = 0x07 + PANEL = 0x09 + BULLETIN_BOARD = 0x0A + VENDOR_UNIQUE = 0x1C + EXTENDED = 0x1E + UNIT = 0x1F + + class OperationCode(OpenIntEnum): + # 0x00 - 0x0F: Unit and subunit commands + VENDOR_DEPENDENT = 0x00 + RESERVE = 0x01 + PLUG_INFO = 0x02 + + # 0x10 - 0x3F: Unit commands + DIGITAL_OUTPUT = 0x10 + DIGITAL_INPUT = 0x11 + CHANNEL_USAGE = 0x12 + OUTPUT_PLUG_SIGNAL_FORMAT = 0x18 + INPUT_PLUG_SIGNAL_FORMAT = 0x19 + GENERAL_BUS_SETUP = 0x1F + CONNECT_AV = 0x20 + DISCONNECT_AV = 0x21 + CONNECTIONS = 0x22 + CONNECT = 0x24 + DISCONNECT = 0x25 + UNIT_INFO = 0x30 + SUBUNIT_INFO = 0x31 + + # 0x40 - 0x7F: Subunit commands + PASS_THROUGH = 0x7C + GUI_UPDATE = 0x7D + PUSH_GUI_DATA = 0x7E + USER_ACTION = 0x7F + + # 0xA0 - 0xBF: Unit and subunit commands + VERSION = 0xB0 + POWER = 0xB2 + + subunit_type: SubunitType + subunit_id: int + opcode: OperationCode + operands: bytes + + @staticmethod + def subclass(subclass): + # Infer the opcode from the class name + if subclass.__name__.endswith("CommandFrame"): + short_name = subclass.__name__.replace("CommandFrame", "") + category_class = CommandFrame + elif subclass.__name__.endswith("ResponseFrame"): + short_name = subclass.__name__.replace("ResponseFrame", "") + category_class = ResponseFrame + else: + raise ValueError(f"invalid subclass name {subclass.__name__}") + + uppercase_indexes = [ + i for i in range(len(short_name)) if short_name[i].isupper() + ] + uppercase_indexes.append(len(short_name)) + words = [ + short_name[uppercase_indexes[i] : uppercase_indexes[i + 1]].upper() + for i in range(len(uppercase_indexes) - 1) + ] + opcode_name = "_".join(words) + opcode = Frame.OperationCode[opcode_name] + category_class.subclasses[opcode] = subclass + return subclass + + @staticmethod + def from_bytes(data: bytes) -> Frame: + if data[0] >> 4 != 0: + raise ValueError("first 4 bits must be 0s") + + ctype_or_response = data[0] & 0xF + subunit_type = Frame.SubunitType(data[1] >> 3) + subunit_id = data[1] & 7 + + if subunit_type == Frame.SubunitType.EXTENDED: + # Not supported + raise NotImplementedError("extended subunit types not supported") + + if subunit_id < 5: + opcode_offset = 2 + elif subunit_id == 5: + # Extended to the next byte + extension = data[2] + if extension == 0: + raise ValueError("extended subunit ID value reserved") + if extension == 0xFF: + subunit_id = 5 + 254 + data[3] + opcode_offset = 4 + else: + subunit_id = 5 + extension + opcode_offset = 3 + + elif subunit_id == 6: + raise ValueError("reserved subunit ID") + + opcode = Frame.OperationCode(data[opcode_offset]) + operands = data[opcode_offset + 1 :] + + # Look for a registered subclass + if ctype_or_response < 8: + # Command + ctype = CommandFrame.CommandType(ctype_or_response) + if c_subclass := CommandFrame.subclasses.get(opcode): + return c_subclass( + ctype, + subunit_type, + subunit_id, + *c_subclass.parse_operands(operands), + ) + return CommandFrame(ctype, subunit_type, subunit_id, opcode, operands) + else: + # Response + response = ResponseFrame.ResponseCode(ctype_or_response) + if r_subclass := ResponseFrame.subclasses.get(opcode): + return r_subclass( + response, + subunit_type, + subunit_id, + *r_subclass.parse_operands(operands), + ) + return ResponseFrame(response, subunit_type, subunit_id, opcode, operands) + + def to_bytes( + self, + ctype_or_response: Union[CommandFrame.CommandType, ResponseFrame.ResponseCode], + ) -> bytes: + # TODO: support extended subunit types and ids. + return ( + bytes( + [ + ctype_or_response, + self.subunit_type << 3 | self.subunit_id, + self.opcode, + ] + ) + + self.operands + ) + + def to_string(self, extra: str) -> str: + return ( + f"{self.__class__.__name__}({extra}" + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"opcode={self.opcode.name}, " + f"operands={self.operands.hex()})" + ) + + def __init__( + self, + subunit_type: SubunitType, + subunit_id: int, + opcode: OperationCode, + operands: bytes, + ) -> None: + self.subunit_type = subunit_type + self.subunit_id = subunit_id + self.opcode = opcode + self.operands = operands + + +# ----------------------------------------------------------------------------- +class CommandFrame(Frame): + class CommandType(OpenIntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.1 + CONTROL = 0x00 + STATUS = 0x01 + SPECIFIC_INQUIRY = 0x02 + NOTIFY = 0x03 + GENERAL_INQUIRY = 0x04 + + subclasses: Dict[Frame.OperationCode, Type[CommandFrame]] = {} + ctype: CommandType + + def __init__( + self, + ctype: CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + opcode: Frame.OperationCode, + operands: bytes, + ) -> None: + super().__init__(subunit_type, subunit_id, opcode, operands) + self.ctype = ctype + + def __bytes__(self): + return self.to_bytes(self.ctype) + + def __str__(self): + return self.to_string(f"ctype={self.ctype.name}, ") + + +# ----------------------------------------------------------------------------- +class ResponseFrame(Frame): + class ResponseCode(OpenIntEnum): + # AV/C Digital Interface Command Set General Specification Version 4.1 + # Table 7.2 + NOT_IMPLEMENTED = 0x08 + ACCEPTED = 0x09 + REJECTED = 0x0A + IN_TRANSITION = 0x0B + IMPLEMENTED_OR_STABLE = 0x0C + CHANGED = 0x0D + INTERIM = 0x0F + + subclasses: Dict[Frame.OperationCode, Type[ResponseFrame]] = {} + response: ResponseCode + + def __init__( + self, + response: ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + opcode: Frame.OperationCode, + operands: bytes, + ) -> None: + super().__init__(subunit_type, subunit_id, opcode, operands) + self.response = response + + def __bytes__(self): + return self.to_bytes(self.response) + + def __str__(self): + return self.to_string(f"response={self.response.name}, ") + + +# ----------------------------------------------------------------------------- +class VendorDependentFrame: + company_id: int + vendor_dependent_data: bytes + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + return ( + struct.unpack(">I", b"\x00" + operands[:3])[0], + operands[3:], + ) + + def make_operands(self) -> bytes: + return struct.pack(">I", self.company_id)[1:] + self.vendor_dependent_data + + def __init__(self, company_id: int, vendor_dependent_data: bytes): + self.company_id = company_id + self.vendor_dependent_data = vendor_dependent_data + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class VendorDependentCommandFrame(VendorDependentFrame, CommandFrame): + def __init__( + self, + ctype: CommandFrame.CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + company_id: int, + vendor_dependent_data: bytes, + ) -> None: + VendorDependentFrame.__init__(self, company_id, vendor_dependent_data) + CommandFrame.__init__( + self, + ctype, + subunit_type, + subunit_id, + Frame.OperationCode.VENDOR_DEPENDENT, + self.make_operands(), + ) + + def __str__(self): + return ( + f"VendorDependentCommandFrame(ctype={self.ctype.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"company_id=0x{self.company_id:06X}, " + f"vendor_dependent_data={self.vendor_dependent_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class VendorDependentResponseFrame(VendorDependentFrame, ResponseFrame): + def __init__( + self, + response: ResponseFrame.ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + company_id: int, + vendor_dependent_data: bytes, + ) -> None: + VendorDependentFrame.__init__(self, company_id, vendor_dependent_data) + ResponseFrame.__init__( + self, + response, + subunit_type, + subunit_id, + Frame.OperationCode.VENDOR_DEPENDENT, + self.make_operands(), + ) + + def __str__(self): + return ( + f"VendorDependentResponseFrame(response={self.response.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"company_id=0x{self.company_id:06X}, " + f"vendor_dependent_data={self.vendor_dependent_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +class PassThroughFrame: + """ + See AV/C Panel Subunit Specification 1.1 - 9.4 PASS THROUGH control command + """ + + class StateFlag(enum.IntEnum): + PRESSED = 0 + RELEASED = 1 + + class OperationId(OpenIntEnum): + SELECT = 0x00 + UP = 0x01 + DOWN = 0x01 + LEFT = 0x03 + RIGHT = 0x04 + RIGHT_UP = 0x05 + RIGHT_DOWN = 0x06 + LEFT_UP = 0x07 + LEFT_DOWN = 0x08 + ROOT_MENU = 0x09 + SETUP_MENU = 0x0A + CONTENTS_MENU = 0x0B + FAVORITE_MENU = 0x0C + EXIT = 0x0D + NUMBER_0 = 0x20 + NUMBER_1 = 0x21 + NUMBER_2 = 0x22 + NUMBER_3 = 0x23 + NUMBER_4 = 0x24 + NUMBER_5 = 0x25 + NUMBER_6 = 0x26 + NUMBER_7 = 0x27 + NUMBER_8 = 0x28 + NUMBER_9 = 0x29 + DOT = 0x2A + ENTER = 0x2B + CLEAR = 0x2C + CHANNEL_UP = 0x30 + CHANNEL_DOWN = 0x31 + PREVIOUS_CHANNEL = 0x32 + SOUND_SELECT = 0x33 + INPUT_SELECT = 0x34 + DISPLAY_INFORMATION = 0x35 + HELP = 0x36 + PAGE_UP = 0x37 + PAGE_DOWN = 0x38 + POWER = 0x40 + VOLUME_UP = 0x41 + VOLUME_DOWN = 0x42 + MUTE = 0x43 + PLAY = 0x44 + STOP = 0x45 + PAUSE = 0x46 + RECORD = 0x47 + REWIND = 0x48 + FAST_FORWARD = 0x49 + EJECT = 0x4A + FORWARD = 0x4B + BACKWARD = 0x4C + ANGLE = 0x50 + SUBPICTURE = 0x51 + F1 = 0x71 + F2 = 0x72 + F3 = 0x73 + F4 = 0x74 + F5 = 0x75 + VENDOR_UNIQUE = 0x7E + + state_flag: StateFlag + operation_id: OperationId + operation_data: bytes + + @staticmethod + def parse_operands(operands: bytes) -> Tuple: + return ( + PassThroughFrame.StateFlag(operands[0] >> 7), + PassThroughFrame.OperationId(operands[0] & 0x7F), + operands[1 : 1 + operands[1]], + ) + + def make_operands(self): + return ( + bytes([self.state_flag << 7 | self.operation_id, len(self.operation_data)]) + + self.operation_data + ) + + def __init__( + self, + state_flag: StateFlag, + operation_id: OperationId, + operation_data: bytes, + ) -> None: + if len(operation_data) > 255: + raise ValueError("operation data must be <= 255 bytes") + self.state_flag = state_flag + self.operation_id = operation_id + self.operation_data = operation_data + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class PassThroughCommandFrame(PassThroughFrame, CommandFrame): + def __init__( + self, + ctype: CommandFrame.CommandType, + subunit_type: Frame.SubunitType, + subunit_id: int, + state_flag: PassThroughFrame.StateFlag, + operation_id: PassThroughFrame.OperationId, + operation_data: bytes, + ) -> None: + PassThroughFrame.__init__(self, state_flag, operation_id, operation_data) + CommandFrame.__init__( + self, + ctype, + subunit_type, + subunit_id, + Frame.OperationCode.PASS_THROUGH, + self.make_operands(), + ) + + def __str__(self): + return ( + f"PassThroughCommandFrame(ctype={self.ctype.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"state_flag={self.state_flag.name}, " + f"operation_id={self.operation_id.name}, " + f"operation_data={self.operation_data.hex()})" + ) + + +# ----------------------------------------------------------------------------- +@Frame.subclass +class PassThroughResponseFrame(PassThroughFrame, ResponseFrame): + def __init__( + self, + response: ResponseFrame.ResponseCode, + subunit_type: Frame.SubunitType, + subunit_id: int, + state_flag: PassThroughFrame.StateFlag, + operation_id: PassThroughFrame.OperationId, + operation_data: bytes, + ) -> None: + PassThroughFrame.__init__(self, state_flag, operation_id, operation_data) + ResponseFrame.__init__( + self, + response, + subunit_type, + subunit_id, + Frame.OperationCode.PASS_THROUGH, + self.make_operands(), + ) + + def __str__(self): + return ( + f"PassThroughResponseFrame(response={self.response.name}, " + f"subunit_type={self.subunit_type.name}, " + f"subunit_id=0x{self.subunit_id:02X}, " + f"state_flag={self.state_flag.name}, " + f"operation_id={self.operation_id.name}, " + f"operation_data={self.operation_data.hex()})" + ) diff --git a/bumble/avctp.py b/bumble/avctp.py new file mode 100644 index 00000000..e4835a09 --- /dev/null +++ b/bumble/avctp.py @@ -0,0 +1,286 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +from enum import IntEnum +import logging +import struct +from typing import Callable, Dict, Optional + +from bumble.colors import color +from bumble import avc +from bumble import l2cap + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +AVCTP_PSM = 0x0017 +AVCTP_BROWSING_PSM = 0x001B + + +# ----------------------------------------------------------------------------- +class MessageAssembler: + Callback = Callable[[int, bool, bool, int, bytes], None] + + transaction_label: int + pid: int + c_r: int + ipid: int + payload: bytes + number_of_packets: int + packets_received: int + + def __init__(self, callback: Callback) -> None: + self.callback = callback + self.reset() + + def reset(self) -> None: + self.packets_received = 0 + self.transaction_label = -1 + self.pid = -1 + self.c_r = -1 + self.ipid = -1 + self.payload = b'' + self.number_of_packets = 0 + self.packet_count = 0 + + def on_pdu(self, pdu: bytes) -> None: + self.packets_received += 1 + + transaction_label = pdu[0] >> 4 + packet_type = Protocol.PacketType((pdu[0] >> 2) & 3) + c_r = (pdu[0] >> 1) & 1 + ipid = pdu[0] & 1 + + if c_r == 0 and ipid != 0: + logger.warning("invalid IPID in command frame") + self.reset() + return + + pid_offset = 1 + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START): + if self.transaction_label >= 0: + # We are already in a transaction + logger.warning("received START or SINGLE fragment while in transaction") + self.reset() + self.packets_received = 1 + + if packet_type == Protocol.PacketType.START: + self.number_of_packets = pdu[1] + pid_offset = 2 + + pid = struct.unpack_from(">H", pdu, pid_offset)[0] + self.payload += pdu[pid_offset + 2 :] + + if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END): + if transaction_label != self.transaction_label: + logger.warning("transaction label does not match") + self.reset() + return + + if pid != self.pid: + logger.warning("PID does not match") + self.reset() + return + + if c_r != self.c_r: + logger.warning("C/R does not match") + self.reset() + return + + if self.packets_received > self.number_of_packets: + logger.warning("too many fragments in transaction") + self.reset() + return + + if packet_type == Protocol.PacketType.END: + if self.packets_received != self.number_of_packets: + logger.warning("premature END") + self.reset() + return + else: + self.transaction_label = transaction_label + self.c_r = c_r + self.ipid = ipid + self.pid = pid + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END): + self.on_message_complete() + + def on_message_complete(self): + try: + self.callback( + self.transaction_label, + self.c_r == 0, + self.ipid != 0, + self.pid, + self.payload, + ) + except Exception as error: + logger.exception(color(f"!!! exception in callback: {error}", "red")) + + self.reset() + + +# ----------------------------------------------------------------------------- +class Protocol: + Handler = Callable[[int, Optional[bytes]], None] + command_handlers: Dict[int, Handler] # Command handlers, by PID + response_handlers: Dict[int, Handler] # Response handlers, by PID + next_transaction_label: int + message_assembler: MessageAssembler + + class PacketType(IntEnum): + SINGLE = 0b00 + START = 0b01 + CONTINUE = 0b10 + END = 0b11 + + def __init__(self, l2cap_channel: l2cap.Channel) -> None: + self.command_handlers = {} + self.response_handlers = {} + self.l2cap_channel = l2cap_channel + self.message_assembler = MessageAssembler(self.on_message) + + # Register to receive PDUs from the channel + l2cap_channel.sink = self.on_pdu + l2cap_channel.on("open", self.on_l2cap_channel_open) + l2cap_channel.on("close", self.on_l2cap_channel_close) + + def on_l2cap_channel_open(self): + logger.debug(color("<<< AVCTP channel open", "magenta")) + + def on_l2cap_channel_close(self): + logger.debug(color("<<< AVCTP channel closed", "magenta")) + + def on_pdu(self, pdu: bytes) -> None: + self.message_assembler.on_pdu(pdu) + + def on_message( + self, + transaction_label: int, + is_command: bool, + ipid: bool, + pid: int, + payload: bytes, + ) -> None: + logger.debug( + f"<<< AVCTP Message: pid={pid}, " + f"transaction_label={transaction_label}, " + f"is_command={is_command}, " + f"ipid={ipid}, " + f"payload={payload.hex()}" + ) + + # Find the appropriate handler. + handlers = self.command_handlers if is_command else self.response_handlers + if pid not in handlers: + logger.warning(f"no handler for PID {pid}") + if is_command: + self.send_ipid(transaction_label, pid) + return + + # Check for invalid PID responses. + if ipid: + logger.debug(f"received IPID for PID={pid}") + + # Invoke the handler. + # By convention, for an ipid, send a None payload to the response handler. + if ipid: + frame = None + else: + frame = avc.Frame.from_bytes(payload) + handlers[pid](transaction_label, frame) + + def send_message( + self, + transaction_label: int, + is_command: bool, + ipid: bool, + pid: int, + payload: bytes, + ): + # TODO: fragment large messages + packet_type = Protocol.PacketType.SINGLE + pdu = ( + struct.pack( + ">BH", + transaction_label << 4 + | packet_type << 2 + | (0 if is_command else 1) << 1 + | (1 if ipid else 0), + pid, + ) + + payload + ) + self.l2cap_channel.send_pdu(pdu) + + def send_command(self, transaction_label: int, pid: int, payload: bytes) -> None: + logger.debug( + ">>> AVCTP command: " + f"transaction_label={transaction_label}, " + f"pid={pid}, " + f"payload={payload.hex()}" + ) + self.send_message(transaction_label, True, False, pid, payload) + + def send_response(self, transaction_label: int, pid: int, payload: bytes): + logger.debug( + ">>> AVCTP response: " + f"transaction_label={transaction_label}, " + f"pid={pid}, " + f"payload={payload.hex()}" + ) + self.send_message(transaction_label, False, False, pid, payload) + + def send_ipid(self, transaction_label: int, pid: int) -> None: + logger.debug( + ">>> AVCTP ipid: " + f"transaction_label={transaction_label}, " + f"pid={pid}" + ) + self.send_message(transaction_label, False, True, pid, b'') + + def register_handler( + self, pid: int, handler: Protocol.Handler, registry: Dict[int, Handler] + ) -> None: + registry[pid] = handler + + def unregister_handler( + self, pid: int, handler: Protocol.Handler, registry: Dict[int, Handler] + ) -> None: + if pid not in registry or registry[pid] != handler: + raise ValueError("handler not registered") + del registry[pid] + + def register_command_handler(self, pid: int, handler: Protocol.Handler) -> None: + self.register_handler(pid, handler, self.command_handlers) + + def unregister_command_handler(self, pid: int, handler: Protocol.Handler) -> None: + self.unregister_handler(pid, handler, self.command_handlers) + + def register_response_handler(self, pid: int, handler: Protocol.Handler) -> None: + self.register_handler(pid, handler, self.response_handlers) + + def unregister_response_handler(self, pid: int, handler: Protocol.Handler) -> None: + self.unregister_handler(pid, handler, self.response_handlers) diff --git a/bumble/avdtp.py b/bumble/avdtp.py index 9a332f45..f46162fd 100644 --- a/bumble/avdtp.py +++ b/bumble/avdtp.py @@ -241,7 +241,10 @@ async def find_avdtp_service_with_sdp_client( ) if profile_descriptor_list: for profile_descriptor in profile_descriptor_list.value: - if len(profile_descriptor.value) >= 2: + if ( + profile_descriptor.type == sdp.DataElement.SEQUENCE + and len(profile_descriptor.value) >= 2 + ): avdtp_version_major = profile_descriptor.value[1].value >> 8 avdtp_version_minor = profile_descriptor.value[1].value & 0xFF return (avdtp_version_major, avdtp_version_minor) @@ -511,7 +514,8 @@ def on_message_complete(self) -> None: try: self.callback(self.transaction_label, message) except Exception as error: - logger.warning(color(f'!!! exception in callback: {error}')) + logger.exception(color(f'!!! exception in callback: {error}', 'red')) + self.reset() diff --git a/bumble/avrcp.py b/bumble/avrcp.py new file mode 100644 index 00000000..fa485f24 --- /dev/null +++ b/bumble/avrcp.py @@ -0,0 +1,1114 @@ +# Copyright 2021-2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations +import asyncio +from dataclasses import dataclass +import enum +import logging +import struct +from typing import ( + Awaitable, + Callable, + ClassVar, + Dict, + List, + Optional, + SupportsBytes, + Tuple, +) + +import pyee + +from bumble.colors import color +from bumble.device import Device, Connection +from bumble.sdp import ( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + SDP_PUBLIC_BROWSE_ROOT, + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement, + ServiceAttribute, +) +from bumble.utils import OpenIntEnum +from bumble.core import ( + ProtocolError, + BT_L2CAP_PROTOCOL_ID, + BT_AVCTP_PROTOCOL_ID, + BT_AV_REMOTE_CONTROL_SERVICE, + BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE, + BT_AV_REMOTE_CONTROL_TARGET_SERVICE, +) +from bumble import l2cap +from bumble import avc +from bumble import avctp + + +# ----------------------------------------------------------------------------- +# Logging +# ----------------------------------------------------------------------------- +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +# Constants +# ----------------------------------------------------------------------------- +AVRCP_PID = 0x110E +AVRCP_BLUETOOTH_SIG_COMPANY_ID = 0x001958 + + +# ----------------------------------------------------------------------------- +def make_controller_service_sdp_records( + service_record_handle: int, + avctp_version: Tuple[int, int] = (1, 4), + avrcp_version: Tuple[int, int] = (1, 6), + supported_features: int = 1, +) -> List[ServiceAttribute]: + # TODO: support a way to compute the supported features from a feature list + avctp_version_int = avctp_version[0] << 8 | avctp_version[1] + avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1] + + return [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.uuid(BT_AV_REMOTE_CONTROL_CONTROLLER_SERVICE), + ] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp.AVCTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVCTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.unsigned_integer_16(avrcp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement.unsigned_integer_16(supported_features), + ), + ] + + +# ----------------------------------------------------------------------------- +def make_target_service_sdp_records( + service_record_handle: int, + avctp_version: Tuple[int, int] = (1, 4), + avrcp_version: Tuple[int, int] = (1, 6), + supported_features: int = 0x23, +) -> List[ServiceAttribute]: + # TODO: support a way to compute the supported features from a feature list + avctp_version_int = avctp_version[0] << 8 | avctp_version[1] + avrcp_version_int = avrcp_version[0] << 8 | avrcp_version[1] + + return [ + ServiceAttribute( + SDP_SERVICE_RECORD_HANDLE_ATTRIBUTE_ID, + DataElement.unsigned_integer_32(service_record_handle), + ), + ServiceAttribute( + SDP_BROWSE_GROUP_LIST_ATTRIBUTE_ID, + DataElement.sequence([DataElement.uuid(SDP_PUBLIC_BROWSE_ROOT)]), + ), + ServiceAttribute( + SDP_SERVICE_CLASS_ID_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_TARGET_SERVICE), + ] + ), + ), + ServiceAttribute( + SDP_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_L2CAP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp.AVCTP_PSM), + ] + ), + DataElement.sequence( + [ + DataElement.uuid(BT_AVCTP_PROTOCOL_ID), + DataElement.unsigned_integer_16(avctp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_BLUETOOTH_PROFILE_DESCRIPTOR_LIST_ATTRIBUTE_ID, + DataElement.sequence( + [ + DataElement.sequence( + [ + DataElement.uuid(BT_AV_REMOTE_CONTROL_SERVICE), + DataElement.unsigned_integer_16(avrcp_version_int), + ] + ), + ] + ), + ), + ServiceAttribute( + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID, + DataElement.unsigned_integer_16(supported_features), + ), + ] + + +# ----------------------------------------------------------------------------- +class PduAssembler: + """ + PDU Assembler to support fragmented PDUs are defined in: + Audio/Video Remote Control / Profile Specification + 6.3.1 AVRCP specific AV//C commands + """ + + pdu_id: Optional[Protocol.PduId] + payload: bytes + + def __init__(self, callback: Callable[[Protocol.PduId, bytes], None]) -> None: + self.callback = callback + self.reset() + + def reset(self) -> None: + self.pdu_id = None + self.parameter = b'' + + def on_pdu(self, pdu: bytes) -> None: + pdu_id = Protocol.PduId(pdu[0]) + packet_type = Protocol.PacketType(pdu[1] & 3) + parameter_length = struct.unpack_from('>H', pdu, 2)[0] + parameter = pdu[4 : 4 + parameter_length] + if len(parameter) != parameter_length: + logger.warning("parameter length exceeds pdu size") + self.reset() + return + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.START): + if self.pdu_id is not None: + # We are already in a PDU + logger.warning("received START or SINGLE fragment while in pdu") + self.reset() + + if packet_type in (Protocol.PacketType.CONTINUE, Protocol.PacketType.END): + if pdu_id != self.pdu_id: + logger.warning("PID does not match") + self.reset() + return + else: + self.pdu_id = pdu_id + + self.parameter += parameter + + if packet_type in (Protocol.PacketType.SINGLE, Protocol.PacketType.END): + self.on_pdu_complete() + + def on_pdu_complete(self) -> None: + try: + self.callback(self.pdu_id, self.parameter) + except Exception as error: + logger.exception(color(f'!!! exception in callback: {error}', 'red')) + + self.reset() + + +# ----------------------------------------------------------------------------- +@dataclass +class Command: + pdu_id: Protocol.PduId + parameter: bytes + + def to_string(self, properties: Dict[str, str]) -> str: + properties_str = ",".join( + [f"{name}={value}" for name, value in properties.items()] + ) + return f"Command[{self.pdu_id.name}]({properties_str})" + + def __str__(self) -> str: + return self.to_string({"parameters": self.parameter.hex()}) + + def __repr__(self) -> str: + return str(self) + + +# ----------------------------------------------------------------------------- +class GetCapabilitiesCommand(Command): + class CapabilityId(OpenIntEnum): + COMPANY_ID = 0x02 + EVENTS_SUPPORTED = 0x03 + + capability_id: CapabilityId + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetCapabilitiesCommand: + return cls(cls.CapabilityId(pdu[0])) + + def __init__(self, capability_id: CapabilityId) -> None: + super().__init__(Protocol.PduId.GET_CAPABILITIES, bytes([capability_id])) + self.capability_id = capability_id + + def __str__(self) -> str: + return self.to_string({"capability_id": self.capability_id.name}) + + +# ----------------------------------------------------------------------------- +class SetAbsoluteVolumeCommand(Command): + MAXIMUM_VOLUME = 0x7F + + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeCommand: + return cls(pdu[0]) + + def __init__(self, volume: int) -> None: + super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) + self.volume = volume + + def __str__(self) -> str: + return self.to_string({"volume": str(self.volume)}) + + +# ----------------------------------------------------------------------------- +class RegisterNotificationCommand(Command): + event_id: EventId + playback_interval: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> RegisterNotificationCommand: + event_id = EventId(pdu[0]) + playback_interval = struct.unpack_from(">I", pdu, 1)[0] + return cls(event_id, playback_interval) + + def __init__(self, event_id: EventId, playback_interval: int) -> None: + super().__init__( + Protocol.PduId.REGISTER_NOTIFICATION, + struct.pack(">BI", int(event_id), playback_interval), + ) + self.event_id = event_id + self.playback_interval = playback_interval + + def __str__(self) -> str: + return self.to_string( + { + "event_id": self.event_id.name, + "playback_interval": str(self.playback_interval), + } + ) + + +# ----------------------------------------------------------------------------- +@dataclass +class Response: + pdu_id: Protocol.PduId + parameter: bytes + + def to_string(self, properties: Dict[str, str]) -> str: + properties_str = ",".join( + [f"{name}={value}" for name, value in properties.items()] + ) + return f"Response[{self.pdu_id.name}]({properties_str})" + + def __str__(self) -> str: + return self.to_string({"parameter": self.parameter.hex()}) + + def __repr__(self) -> str: + return str(self) + + +# ----------------------------------------------------------------------------- +class RejectedResponse(Response): + status_code: Protocol.StatusCode + + @classmethod + def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> RejectedResponse: + return cls(pdu_id, Protocol.StatusCode(pdu[0])) + + def __init__( + self, pdu_id: Protocol.PduId, status_code: Protocol.StatusCode + ) -> None: + super().__init__(pdu_id, bytes([int(status_code)])) + self.status_code = status_code + + def __str__(self) -> str: + return self.to_string( + { + "status_code": self.status_code.name, + } + ) + + +# ----------------------------------------------------------------------------- +class NotImplementedResponse(Response): + pass + + +# ----------------------------------------------------------------------------- +class GetCapabilitiesResponse(Response): + capability_id: GetCapabilitiesCommand.CapabilityId + capabilities: List[SupportsBytes] + + @classmethod + def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse: + if len(pdu) < 2: + # Possibly a reject response. + return cls(GetCapabilitiesCommand.CapabilityId(0), []) + + # Assume that the payloads all follow the same pattern: + # + capability_id = GetCapabilitiesCommand.CapabilityId(pdu[0]) + capability_count = pdu[1] + + if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED: + capabilities = [EventId(pdu[2 + x]) for x in range(capability_count)] + else: + capability_size = (len(pdu) - 2) // capability_count + capabilities = [ + pdu[x : x + capability_size] + for x in range(2, len(pdu), capability_size) + ] + + return cls(capability_id, capabilities) + + def __init__( + self, + capability_id: GetCapabilitiesCommand.CapabilityId, + capabilities: List[SupportsBytes], + ) -> None: + super().__init__( + Protocol.PduId.GET_CAPABILITIES, + bytes([capability_id, len(capabilities)]) + + b''.join(bytes(capability) for capability in capabilities), + ) + self.capability_id = capability_id + self.capabilities = capabilities + + def __str__(self) -> str: + return self.to_string( + { + "capability_id": self.capability_id.name, + "capabilities": str(self.capabilities), + } + ) + + +# ----------------------------------------------------------------------------- +class SetAbsoluteVolumeResponse(Response): + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> SetAbsoluteVolumeResponse: + return cls(pdu[0]) + + def __init__(self, volume: int) -> None: + super().__init__(Protocol.PduId.SET_ABSOLUTE_VOLUME, bytes([volume])) + self.volume = volume + + def __str__(self) -> str: + return self.to_string({"volume": str(self.volume)}) + + +# ----------------------------------------------------------------------------- +class RegisterNotificationResponse(Response): + event: Event + + @classmethod + def from_bytes(cls, pdu: bytes) -> RegisterNotificationResponse: + return cls(Event.from_bytes(pdu)) + + def __init__(self, event: Event) -> None: + super().__init__( + Protocol.PduId.REGISTER_NOTIFICATION, + bytes(event), + ) + self.event = event + + def __str__(self) -> str: + return self.to_string( + { + "event": self.event, + } + ) + + +# ----------------------------------------------------------------------------- +class EventId(OpenIntEnum): + PLAYBACK_STATUS_CHANGED = 0x01 + TRACK_CHANGED = 0x02 + TRACK_REACHED_END = 0x03 + TRACK_REACHED_START = 0x04 + PLAYBACK_POS_CHANGED = 0x05 + BATT_STATUS_CHANGED = 0x06 + SYSTEM_STATUS_CHANGED = 0x07 + PLAYER_APPLICATION_SETTING_CHANGED = 0x08 + NOW_PLAYING_CONTENT_CHANGED = 0x09 + AVAILABLE_PLAYERS_CHANGED = 0x0A + ADDRESSED_PLAYER_CHANGED = 0x0B + UIDS_CHANGED = 0x0C + VOLUME_CHANGED = 0x0D + + def __bytes__(self) -> bytes: + return bytes([int(self)]) + + +# ----------------------------------------------------------------------------- +class Event: + event_id: EventId + + @staticmethod + def from_bytes(pdu: bytes) -> Event: + event_id = pdu[0] + if event_id == EventId.PLAYBACK_STATUS_CHANGED: + subclass = PlaybackStatusChangedEvent + elif event_id == EventId.VOLUME_CHANGED: + subclass = VolumeChangedEvent + else: + subclass = GenericEvent + + return subclass.from_bytes(pdu) + + def __bytes__(self) -> bytes: + return bytes(self.event_id) + + +# ----------------------------------------------------------------------------- +@dataclass +class GenericEvent(Event): + event_id: EventId + data: bytes + + @classmethod + def from_bytes(cls, pdu: bytes) -> GenericEvent: + return cls(event_id=pdu[0], data=pdu[1:]) + + def __bytes__(self) -> bytes: + return bytes(self.event_id) + self.data + + +# ----------------------------------------------------------------------------- +@dataclass +class PlaybackStatusChangedEvent(Event): + class PlayStatus(OpenIntEnum): + STOPPED = 0x00 + PLAYING = 0x01 + PAUSED = 0x02 + FWD_SEEK = 0x03 + REV_SEEK = 0x04 + ERROR = 0xFF + + event_id: ClassVar[EventId] = EventId.PLAYBACK_STATUS_CHANGED + play_status: PlayStatus + + @classmethod + def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent: + return cls(play_status=cls.PlayStatus(pdu[1])) + + def __bytes__(self) -> bytes: + return bytes(self.event_id) + bytes([self.play_status]) + + +# ----------------------------------------------------------------------------- +@dataclass +class VolumeChangedEvent(Event): + event_id: ClassVar[EventId] = EventId.VOLUME_CHANGED + volume: int + + @classmethod + def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent: + return cls(volume=pdu[1]) + + def __bytes__(self) -> bytes: + return bytes(self.event_id) + bytes([self.volume]) + + +# ----------------------------------------------------------------------------- +class Protocol(pyee.EventEmitter): + class PacketType(enum.IntEnum): + SINGLE = 0b00 + START = 0b01 + CONTINUE = 0b10 + END = 0b11 + + class PduId(OpenIntEnum): + GET_CAPABILITIES = 0x10 + GET_ELEMENT_ATTRIBUTES = 0x20 + GET_PLAY_STATUS = 0x30 + REGISTER_NOTIFICATION = 0x31 + SET_ABSOLUTE_VOLUME = 0x50 + + class StatusCode(OpenIntEnum): + INVALID_COMMAND = 0x00 + INVALID_PARAMETER = 0x01 + PARAMETER_CONTENT_ERROR = 0x02 + INTERNAL_ERROR = 0x03 + OPERATION_COMPLETED = 0x04 + UID_CHANGED = 0x05 + INVALID_DIRECTION = 0x07 + NOT_A_DIRECTORY = 0x08 + DOES_NOT_EXIST = 0x09 + INVALID_SCOPE = 0x0A + RANGE_OUT_OF_BOUNDS = 0x0B + FOLDER_ITEM_IS_NOT_PLAYABLE = 0x0C + MEDIA_IN_USE = 0x0D + NOW_PLAYING_LIST_FULL = 0x0E + SEARCH_NOT_SUPPORTED = 0x0F + SEARCH_IN_PROGRESS = 0x10 + INVALID_PLAYER_ID = 0x11 + PLAYER_NOT_BROWSABLE = 0x12 + PLAYER_NOT_ADDRESSED = 0x13 + NO_VALID_SEARCH_RESULTS = 0x14 + NO_AVAILABLE_PLAYERS = 0x15 + ADDRESSED_PLAYER_CHANGED = 0x16 + + class InvalidPidError(Exception): + """A response frame with ipid==1 was received.""" + + class NotPendingError(Exception): + """There is no pending command for a transaction label.""" + + class PendingCommand: + response: Awaitable + + def __init__(self, transaction_label: int) -> None: + self.transaction_label = transaction_label + self.reset() + + def reset(self): + self.response = asyncio.get_running_loop().create_future() + + @dataclass + class ReceiveCommandState: + transaction_label: int + command_type: avc.CommandFrame.CommandType + + @dataclass + class ReceiveResponseState: + transaction_label: int + response_code: avc.ResponseFrame.ResponseCode + + @dataclass + class ResponseContext: + transaction_label: int + response: Response + + @dataclass + class FinalResponse(ResponseContext): + response_code: avc.ResponseFrame.ResponseCode + + @dataclass + class InterimResponse(ResponseContext): + final: Awaitable[Protocol.FinalResponse] + + send_transaction_label: int + command_pdu_assembler: PduAssembler + receive_command_state: Optional[ReceiveCommandState] + response_pdu_assembler: PduAssembler + receive_response_state: Optional[ReceiveResponseState] + avctp_protocol: Optional[avctp.Protocol] + free_commands: asyncio.Queue + pending_commands: Dict[int, PendingCommand] # Pending commands, by label + + @staticmethod + def check_vendor_dependent_frame(frame: avc.VendorDependentFrame) -> bool: + if frame.company_id != AVRCP_BLUETOOTH_SIG_COMPANY_ID: + logger.debug("unsupported company id, ignoring") + return False + + if frame.subunit_type != avc.Frame.SubunitType.PANEL or frame.subunit_id != 0: + logger.debug("unsupported subunit") + return False + + return True + + def __init__(self) -> None: + super().__init__() + self.command_pdu_assembler = PduAssembler(self.on_command_pdu) + self.receive_command_state = None + self.response_pdu_assembler = PduAssembler(self.on_response_pdu) + self.receive_response_state = None + self.avctp_protocol = None + + # Create an initial pool of free commands + self.pending_commands = {} + self.free_commands = asyncio.Queue() + for transaction_label in range(16): + self.free_commands.put_nowait(self.PendingCommand(transaction_label)) + + def listen(self, device: Device) -> None: + device.register_l2cap_server(avctp.AVCTP_PSM, self.on_avctp_connection) + + async def connect(self, connection: Connection) -> None: + avctp_channel = await connection.create_l2cap_connector(avctp.AVCTP_PSM)() + self.on_avctp_channel_open(avctp_channel) + + async def obtain_pending_command(self) -> PendingCommand: + pending_command = await self.free_commands.get() + self.pending_commands[pending_command.transaction_label] = pending_command + return pending_command + + def recycle_pending_command(self, pending_command: PendingCommand) -> None: + pending_command.reset() + del self.pending_commands[pending_command.transaction_label] + self.free_commands.put_nowait(pending_command) + logger.debug(f"recycled pending command, {self.free_commands.qsize()} free") + + # async def final_avrcp_response( + # self, interim_response: InterimResponse + # ) -> FinalResponse: + # if not ( + # pending_command := self.pending_commands.get( + # interim_response.transaction_label + # ) + # ): + # raise self.NotPendingError() + + # return await pending_command.response + + def on_avctp_connection(self, l2cap_channel: l2cap.Channel) -> None: + logger.debug("AVCTP connection established") + l2cap_channel.on("open", lambda: self.on_avctp_channel_open(l2cap_channel)) + + def on_avctp_channel_open(self, l2cap_channel: l2cap.Channel) -> None: + logger.debug("AVCTP channel open") + if self.avctp_protocol is not None: + # TODO: find a better strategy instead of just closing + logger.warning("AVCTP protocol already active, closing connection") + l2cap_channel.disconnect() + return + + self.avctp_protocol = avctp.Protocol(l2cap_channel) + self.avctp_protocol.register_command_handler(AVRCP_PID, self.on_avctp_command) + self.avctp_protocol.register_response_handler(AVRCP_PID, self.on_avctp_response) + l2cap_channel.on("close", self.on_avctp_channel_close) + + self.emit("start") + + def on_avctp_channel_close(self) -> None: + logger.debug("AVCTP channel closed") + self.avctp_protocol = None + self.emit("stop") + + def on_avctp_command( + self, transaction_label: int, command: avc.CommandFrame + ) -> None: + logger.debug( + f"<<< AVCTP Command, transaction_label={transaction_label}: " f"{command}" + ) + + # Only the PANEL subunit type with subunit ID 0 is supported in this profile. + if ( + command.subunit_type != avc.Frame.SubunitType.PANEL + or command.subunit_id != 0 + ): + logger.debug("subunit not supported") + self.send_not_implemented_response(transaction_label, command) + return + + if isinstance(command, avc.VendorDependentCommandFrame): + if not self.check_vendor_dependent_frame(command): + return + + if self.receive_command_state is None: + self.receive_command_state = self.ReceiveCommandState( + transaction_label=transaction_label, command_type=command.ctype + ) + elif ( + self.receive_command_state.transaction_label != transaction_label + or self.receive_command_state.command_type != command.ctype + ): + # We're in the middle of some other PDU + logger.warning("received interleaved PDU, resetting state") + self.command_pdu_assembler.reset() + self.receive_command_state = None + return + else: + self.receive_command_state.command_type = command.ctype + self.receive_command_state.transaction_label = transaction_label + + self.command_pdu_assembler.on_pdu(command.vendor_dependent_data) + return + + if isinstance(command, avc.PassThroughCommandFrame): + # TODO: delegate + response = avc.PassThroughResponseFrame( + avc.ResponseFrame.ResponseCode.ACCEPTED, + avc.Frame.SubunitType.PANEL, + 0, + command.state_flag, + command.operation_id, + command.operation_data, + ) + self.send_response(transaction_label, response) + return + + # TODO handle other types + self.send_not_implemented_response(transaction_label, command) + + def on_avctp_response( + self, transaction_label: int, response: Optional[avc.ResponseFrame] + ) -> None: + logger.debug( + f"<<< AVCTP Response, transaction_label={transaction_label}: {response}" + ) + + # Check that we have a pending command that matches this response. + if not (pending_command := self.pending_commands.get(transaction_label)): + logger.warning("no pending command with this transaction label") + return + + # A None response means an invalid PID was used in the request. + if response is None: + pending_command.set_exception(self.InvalidPidError()) + + if isinstance(response, avc.VendorDependentResponseFrame): + if not self.check_vendor_dependent_frame(response): + return + + if self.receive_response_state is None: + self.receive_response_state = self.ReceiveResponseState( + transaction_label=transaction_label, response_code=response.response + ) + elif ( + self.receive_response_state.transaction_label != transaction_label + or self.receive_response_state.response_code != response.response + ): + # We're in the middle of some other PDU + logger.warning("received interleaved PDU, resetting state") + self.response_pdu_assembler.reset() + self.receive_response_state = None + return + else: + self.receive_response_state.response_code = response.response + self.receive_response_state.transaction_label = transaction_label + + self.response_pdu_assembler.on_pdu(response.vendor_dependent_data) + return + + if isinstance(response, avc.PassThroughResponseFrame): + pending_command.response.set_result(response) + + # TODO handle other types + + self.recycle_pending_command(pending_command) + + def on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None: + logger.debug(f"<<< AVRCP command PDU [pdu_id={pdu_id.name}]: {pdu.hex()}") + + # Dispatch the command. + # NOTE: with a small number of supported commands, a manual dispatch like this + # is Ok, but if/when more commands are supported, a lookup dispatch mechanism + # would be more appropriate. + # TODO: switch on ctype + if self.receive_command_state.command_type in ( + avc.CommandFrame.CommandType.CONTROL, + avc.CommandFrame.CommandType.STATUS, + avc.CommandFrame.CommandType.NOTIFY, + ): + if pdu_id == self.PduId.GET_CAPABILITIES: + self.on_get_capabilities_command(GetCapabilitiesCommand.from_bytes(pdu)) + elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: + self.on_set_absolute_volume_command( + SetAbsoluteVolumeCommand.from_bytes(pdu) + ) + elif pdu_id == self.PduId.REGISTER_NOTIFICATION: + self.on_register_notification_command( + RegisterNotificationCommand.from_bytes(pdu) + ) + else: + # Not supported. + # TODO: check that this is the right way to respond in this case. + logger.debug("unsupported PDU ID") + self.send_rejected_response_pdu(self.StatusCode.INVALID_PARAMETER) + else: + logger.debug("unsupported command type") + self.send_rejected_response_pdu(self.StatusCode.INVALID_COMMAND) + + self.receive_command_state = None + + def on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None: + logger.debug(f"<<< AVRCP response PDU [pdu_id={pdu_id.name}]: {pdu.hex()}") + + transaction_label = self.receive_response_state.transaction_label + response_code = self.receive_response_state.response_code + self.receive_response_state = None + + # Check that we have a pending command that matches this response. + if not (pending_command := self.pending_commands.get(transaction_label)): + logger.warning("no pending command with this transaction label") + return + + # Convert the PDU bytes into a response object. + # NOTE: with a small number of supported responses, a manual switch like this + # is Ok, but if/when more responses are supported, a lookup mechanism would be + # more appropriate. + response: Optional[Response] = None + if response_code == avc.ResponseFrame.ResponseCode.REJECTED: + response = RejectedResponse.from_bytes(pdu_id, pdu) + elif response_code == avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED: + response = NotImplementedResponse.from_bytes(pdu_id, pdu) + elif response_code in ( + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + avc.ResponseFrame.ResponseCode.INTERIM, + avc.ResponseFrame.ResponseCode.CHANGED, + avc.ResponseFrame.ResponseCode.ACCEPTED, + ): + if pdu_id == self.PduId.GET_CAPABILITIES: + response = GetCapabilitiesResponse.from_bytes(pdu) + elif pdu_id == self.PduId.SET_ABSOLUTE_VOLUME: + response = SetAbsoluteVolumeResponse.from_bytes(pdu) + elif pdu_id == self.PduId.REGISTER_NOTIFICATION: + response = RegisterNotificationResponse.from_bytes(pdu) + else: + logger.debug("unexpected PDU ID") + pending_command.response.set_exception( + ProtocolError( + error_code=None, + error_namespace="avrcp", + details="unexpected PDU ID", + ) + ) + else: + logger.debug("unexpected response code") + pending_command.response.set_exception( + ProtocolError( + error_code=None, + error_namespace="avrcp", + details="unexpected response code", + ) + ) + + if response is None: + self.recycle_pending_command(pending_command) + return + + # Make the response available to the waiter. + if response_code == avc.ResponseFrame.ResponseCode.INTERIM: + pending_interim_response = pending_command.response + pending_command.reset() + pending_interim_response.set_result( + self.InterimResponse( + pending_command.transaction_label, + response, + pending_command.response, + ) + ) + else: + pending_command.response.set_result( + self.FinalResponse( + pending_command.transaction_label, + response, + response_code, + ) + ) + self.recycle_pending_command(pending_command) + + def send_command(self, transaction_label: int, command: avc.CommandFrame) -> None: + logger.debug(f">>> AVRCP command: {command}") + + self.avctp_protocol.send_command(transaction_label, AVRCP_PID, bytes(command)) + + async def send_passthrough_command( + self, command: avc.PassThroughCommandFrame + ) -> avc.PassThroughResponseFrame: + # TODO + # Wait for a free command slot. + pending_command = await self.obtain_pending_command() + + # Send the command. + self.send_command(pending_command.transaction_label, command) + + # Wait for the response. + response = await pending_command.response + return response + + async def send_avrcp_command( + self, command_type: avc.CommandFrame.CommandType, command: Command + ) -> avc.VendorDependentResponseFrame: + # Wait for a free command slot. + pending_command = await self.obtain_pending_command() + + # TODO: fragmentation + # Send the command. + logger.debug(f">>> AVRCP command PDU: {command}") + pdu = ( + struct.pack(">BBH", command.pdu_id, 0, len(command.parameter)) + + command.parameter + ) + command_frame = avc.VendorDependentCommandFrame( + command_type, + avc.Frame.SubunitType.PANEL, + 0, + AVRCP_BLUETOOTH_SIG_COMPANY_ID, + pdu, + ) + self.send_command(pending_command.transaction_label, command_frame) + + # Wait for the response. + return await pending_command.response + + def send_response( + self, transaction_label: int, response: avc.ResponseFrame + ) -> None: + logger.debug(f">>> AVRCP response: {response}") + self.avctp_protocol.send_response(transaction_label, AVRCP_PID, bytes(response)) + + def send_passthrough_response( + self, + transaction_label: int, + command: avc.PassThroughCommandFrame, + response_code: avc.ResponseFrame.ResponseCode, + ): + response = avc.PassThroughResponseFrame( + response_code, + avc.Frame.SubunitType.PANEL, + 0, + command.state_flag, + command.operation_id, + command.operation_data, + ) + self.send_response(transaction_label, response) + + def send_avrcp_response( + self, response_code: avc.ResponseFrame.ResponseCode, response: Response + ) -> None: + # TODO: fragmentation + logger.debug(f">>> AVRCP response PDU: {response}") + pdu = ( + struct.pack(">BBH", response.pdu_id, 0, len(response.parameter)) + + response.parameter + ) + response_frame = avc.VendorDependentResponseFrame( + response_code, + avc.Frame.SubunitType.PANEL, + 0, + AVRCP_BLUETOOTH_SIG_COMPANY_ID, + pdu, + ) + self.send_response(self.receive_command_state.transaction_label, response_frame) + + def send_not_implemented_response( + self, transaction_label: int, command: avc.CommandFrame + ) -> None: + response = avc.ResponseFrame( + avc.ResponseFrame.ResponseCode.NOT_IMPLEMENTED, + command.subunit_type, + command.subunit_id, + command.opcode, + command.operands, + ) + self.send_response(transaction_label, response) + + def send_rejected_avrcp_response( + self, pdu_id: Protocol.PduId, status_code: StatusCode + ) -> None: + self.send_avrcp_response( + avc.ResponseFrame.ResponseCode.REJECTED, + RejectedResponse(pdu_id, status_code), + ) + + def on_get_capabilities_command(self, command: GetCapabilitiesCommand) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + if ( + command.capability_id + == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED + ): + # TEST: hardcoded values for testing only + supported_events = [EventId.VOLUME_CHANGED, EventId.PLAYBACK_STATUS_CHANGED] + self.send_avrcp_response( + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + GetCapabilitiesResponse(command.capability_id, supported_events), + ) + return + + self.send_rejected_avrcp_response( + self.PduId.GET_CAPABILITIES, self.StatusCode.INVALID_PARAMETER + ) + + def on_set_absolute_volume_command(self, command: SetAbsoluteVolumeCommand) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + # TODO implement a delegate + self.send_avrcp_response( + avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE, + SetAbsoluteVolumeResponse(command.volume), + ) + + def on_register_notification_command( + self, command: RegisterNotificationCommand + ) -> None: + logger.debug(f"<<< AVRCP command PDU: {command}") + + if command.event_id == EventId.VOLUME_CHANGED: + # TODO: testing only + response = RegisterNotificationResponse(VolumeChangedEvent(volume=10)) + self.send_avrcp_response(avc.ResponseFrame.ResponseCode.INTERIM, response) + return + + if command.event_id == EventId.PLAYBACK_STATUS_CHANGED: + # TODO: testing only + response = RegisterNotificationResponse( + PlaybackStatusChangedEvent( + play_status=PlaybackStatusChangedEvent.PlayStatus.PLAYING + ) + ) + self.send_avrcp_response(avc.ResponseFrame.ResponseCode.INTERIM, response) + return + + self.send_rejected_avrcp_response( + self.PduId.REGISTER_NOTIFICATION, self.StatusCode.INVALID_PARAMETER + ) diff --git a/bumble/core.py b/bumble/core.py index 4a67d6ec..6ba8149b 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -96,12 +96,16 @@ def __str__(self): namespace = f'{self.error_namespace}/' else: namespace = '' - error_text = { - (True, True): f'{self.error_name} [0x{self.error_code:X}]', - (True, False): self.error_name, - (False, True): f'0x{self.error_code:X}', - (False, False): '', - }[(self.error_name != '', self.error_code is not None)] + have_name = self.error_name != '' + have_code = self.error_code is not None + if have_name and have_code: + error_text = f'{self.error_name} [0x{self.error_code:X}]' + elif have_name and not have_code: + error_text = self.error_name + elif not have_name and have_code: + error_text = f'0x{self.error_code:X}' + else: + error_text = '' return f'{type(self).__name__}({namespace}{error_text})' @@ -318,7 +322,7 @@ def __str__(self) -> str: BT_HARDCOPY_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0012, 'HardcopyControlChannel') BT_HARDCOPY_DATA_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x0014, 'HardcopyDataChannel') BT_HARDCOPY_NOTIFICATION_PROTOCOL_ID = UUID.from_16_bits(0x0016, 'HardcopyNotification') -BT_AVTCP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP') +BT_AVCTP_PROTOCOL_ID = UUID.from_16_bits(0x0017, 'AVCTP') BT_AVDTP_PROTOCOL_ID = UUID.from_16_bits(0x0019, 'AVDTP') BT_CMTP_PROTOCOL_ID = UUID.from_16_bits(0x001B, 'CMTP') BT_MCAP_CONTROL_CHANNEL_PROTOCOL_ID = UUID.from_16_bits(0x001E, 'MCAPControlChannel') diff --git a/bumble/helpers.py b/bumble/helpers.py index 83c7c6df..680c2085 100644 --- a/bumble/helpers.py +++ b/bumble/helpers.py @@ -39,6 +39,10 @@ from .rfcomm import RFCOMM_Frame, RFCOMM_PSM from .sdp import SDP_PDU, SDP_PSM from .avdtp import MessageAssembler as AVDTP_MessageAssembler, AVDTP_PSM +from .avctp import MessageAssembler as AVCTP_MessageAssembler, AVCTP_PSM +from .avrcp import AVRCP_PID +from .avc import Frame as AVC_Frame + # ----------------------------------------------------------------------------- # Logging @@ -50,10 +54,12 @@ PSM_NAMES = { RFCOMM_PSM: 'RFCOMM', SDP_PSM: 'SDP', - AVDTP_PSM: 'AVDTP' + AVDTP_PSM: 'AVDTP', + AVCTP_PSM: 'AVCTP' # TODO: add more PSM values } +AVCTP_PID_NAMES = {AVRCP_PID: 'AVRCP'} # ----------------------------------------------------------------------------- class PacketTracer: @@ -62,12 +68,15 @@ def __init__(self, analyzer): self.analyzer = analyzer self.packet_assembler = HCI_AclDataPacketAssembler(self.on_acl_pdu) self.avdtp_assemblers = {} # AVDTP assemblers, by source_cid + self.avctp_assemblers = {} # AVCTP assemblers, by source_cid + self.avrcp_assemblers = {} # AVRCP assemblers, by source_cid self.psms = {} # PSM, by source_cid self.peer = None # ACL stream in the other direction # pylint: disable=too-many-nested-blocks def on_acl_pdu(self, pdu): l2cap_pdu = L2CAP_PDU.from_bytes(pdu) + self.analyzer.emit(l2cap_pdu) if l2cap_pdu.cid == ATT_CID: att_pdu = ATT_PDU.from_bytes(l2cap_pdu.payload) @@ -82,28 +91,33 @@ def on_acl_pdu(self, pdu): # Check if this signals a new channel if control_frame.code == L2CAP_CONNECTION_REQUEST: self.psms[control_frame.source_cid] = control_frame.psm - elif control_frame.code == L2CAP_CONNECTION_RESPONSE: - if ( - control_frame.result - == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL + elif ( + control_frame.code == L2CAP_CONNECTION_RESPONSE + and control_frame.result + == L2CAP_Connection_Response.CONNECTION_SUCCESSFUL + ): + if self.peer and ( + psm := self.peer.psms.get(control_frame.source_cid) ): - if self.peer: - if psm := self.peer.psms.get(control_frame.source_cid): - # Found a pending connection - self.psms[control_frame.destination_cid] = psm - - # For AVDTP connections, create a packet assembler for - # each direction - if psm == AVDTP_PSM: - self.avdtp_assemblers[ - control_frame.source_cid - ] = AVDTP_MessageAssembler(self.on_avdtp_message) - self.peer.avdtp_assemblers[ - control_frame.destination_cid - ] = AVDTP_MessageAssembler( - self.peer.on_avdtp_message - ) - + # Found a pending connection + self.psms[control_frame.destination_cid] = psm + + # For AVDTP and AVCTP connections, create a packet + # assembler for each direction + if psm == AVDTP_PSM: + self.avdtp_assemblers[ + control_frame.source_cid + ] = AVDTP_MessageAssembler(self.on_avdtp_message) + self.peer.avdtp_assemblers[ + control_frame.destination_cid + ] = AVDTP_MessageAssembler(self.peer.on_avdtp_message) + elif psm == AVCTP_PSM: + self.avctp_assemblers[ + control_frame.source_cid + ] = AVCTP_MessageAssembler(self.on_avctp_message) + self.peer.avctp_assemblers[ + control_frame.destination_cid + ] = AVCTP_MessageAssembler(self.peer.on_avctp_message) else: # Try to find the PSM associated with this PDU if self.peer and (psm := self.peer.psms.get(l2cap_pdu.cid)): @@ -121,6 +135,14 @@ def on_acl_pdu(self, pdu): assembler = self.avdtp_assemblers.get(l2cap_pdu.cid) if assembler: assembler.on_pdu(l2cap_pdu.payload) + elif psm == AVCTP_PSM: + self.analyzer.emit( + f'{color("L2CAP", "green")} [CID={l2cap_pdu.cid}, ' + f'PSM=AVCTP]: {l2cap_pdu.payload.hex()}' + ) + assembler = self.avctp_assemblers.get(l2cap_pdu.cid) + if assembler: + assembler.on_pdu(l2cap_pdu.payload) else: psm_string = name_or_number(PSM_NAMES, psm) self.analyzer.emit( @@ -135,6 +157,21 @@ def on_avdtp_message(self, transaction_label, message): f'{color("AVDTP", "green")} [{transaction_label}] {message}' ) + def on_avctp_message(self, transaction_label, is_command, ipid, pid, payload): + if pid == AVRCP_PID: + avc_frame = AVC_Frame.from_bytes(payload) + details = str(avc_frame) + else: + details = payload.hex() + + c_r = 'Command' if is_command else 'Response' + self.analyzer.emit( + f'{color("AVCTP", "green")} ' + f'{c_r}[{transaction_label}][{name_or_number(AVCTP_PID_NAMES, pid)}] ' + f'{"#" if ipid else ""}' + f'{details}' + ) + def feed_packet(self, packet): self.packet_assembler.feed_packet(packet) diff --git a/bumble/sdp.py b/bumble/sdp.py index bc8303c8..062b45d9 100644 --- a/bumble/sdp.py +++ b/bumble/sdp.py @@ -97,7 +97,8 @@ SDP_ICON_URL_ATTRIBUTE_ID = 0X000C SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID = 0X000D -# Attribute Identifier (cf. Assigned Numbers for Service Discovery) + +# Profile-specific Attribute Identifiers (cf. Assigned Numbers for Service Discovery) # used by AVRCP, HFP and A2DP SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID = 0x0311 @@ -115,7 +116,8 @@ SDP_DOCUMENTATION_URL_ATTRIBUTE_ID: 'SDP_DOCUMENTATION_URL_ATTRIBUTE_ID', SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID: 'SDP_CLIENT_EXECUTABLE_URL_ATTRIBUTE_ID', SDP_ICON_URL_ATTRIBUTE_ID: 'SDP_ICON_URL_ATTRIBUTE_ID', - SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID' + SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID: 'SDP_ADDITIONAL_PROTOCOL_DESCRIPTOR_LIST_ATTRIBUTE_ID', + SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID: 'SDP_SUPPORTED_FEATURES_ATTRIBUTE_ID', } SDP_PUBLIC_BROWSE_ROOT = core.UUID.from_16_bits(0x1002, 'PublicBrowseRoot') diff --git a/bumble/utils.py b/bumble/utils.py index a562618f..946fac26 100644 --- a/bumble/utils.py +++ b/bumble/utils.py @@ -17,9 +17,10 @@ # ----------------------------------------------------------------------------- from __future__ import annotations import asyncio -import logging -import traceback import collections +import enum +import functools +import logging import sys import warnings from typing import ( @@ -34,7 +35,8 @@ Union, overload, ) -from functools import wraps, partial +import traceback + from pyee import EventEmitter from .colors import color @@ -131,13 +133,14 @@ def on( Args: emitter: EventEmitter to watch event: Event name - handler: (Optional) Event handler. When nothing is passed, this method works as a decorator. + handler: (Optional) Event handler. When nothing is passed, this method + works as a decorator. ''' - def wrapper(f: _Handler) -> _Handler: - self.handlers.append((emitter, event, f)) - emitter.on(event, f) - return f + def wrapper(wrapped: _Handler) -> _Handler: + self.handlers.append((emitter, event, wrapped)) + emitter.on(event, wrapped) + return wrapped return wrapper if handler is None else wrapper(handler) @@ -157,13 +160,14 @@ def once( Args: emitter: EventEmitter to watch event: Event name - handler: (Optional) Event handler. When nothing passed, this method works as a decorator. + handler: (Optional) Event handler. When nothing passed, this method works + as a decorator. ''' - def wrapper(f: _Handler) -> _Handler: - self.handlers.append((emitter, event, f)) - emitter.once(event, f) - return f + def wrapper(wrapped: _Handler) -> _Handler: + self.handlers.append((emitter, event, wrapped)) + emitter.once(event, wrapped) + return wrapped return wrapper if handler is None else wrapper(handler) @@ -276,7 +280,7 @@ def run_in_task(queue=None): """ def decorator(func): - @wraps(func) + @functools.wraps(func) def wrapper(*args, **kwargs): coroutine = func(*args, **kwargs) if queue is None: @@ -413,23 +417,28 @@ async def pump(self): self.check_pump() +# ----------------------------------------------------------------------------- async def async_call(function, *args, **kwargs): """ - Immediately calls the function with provided args and kwargs, wrapping it in an async function. - Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject a running loop. + Immediately calls the function with provided args and kwargs, wrapping it in an + async function. + Rust's `pyo3_asyncio` library needs functions to be marked async to properly inject + a running loop. result = await async_call(some_function, ...) """ return function(*args, **kwargs) +# ----------------------------------------------------------------------------- def wrap_async(function): """ Wraps the provided function in an async function. """ - return partial(async_call, function) + return functools.partial(async_call, function) +# ----------------------------------------------------------------------------- def deprecated(msg: str): """ Throw deprecation warning before execution @@ -444,3 +453,22 @@ def inner(*args, **kwargs): return inner return wrapper + + +# ----------------------------------------------------------------------------- +class OpenIntEnum(enum.IntEnum): + """ + Subclass of enum.IntEnum that can hold integer values outside the set of + predefined values. This is convenient for implementing protocols where some + integer constants may be added over time. + """ + + @classmethod + def _missing_(cls, value): + if not isinstance(value, int): + return None + + obj = int.__new__(cls, value) + obj._value_ = value + obj._name_ = f"{cls.__name__}[{value}]" + return obj diff --git a/examples/run_avrcp_controller.py b/examples/run_avrcp_controller.py new file mode 100644 index 00000000..49c85e48 --- /dev/null +++ b/examples/run_avrcp_controller.py @@ -0,0 +1,255 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import json +import sys +import os +import logging +import websockets + +from bumble.device import Device +from bumble.transport import open_transport_or_link +from bumble.core import BT_BR_EDR_TRANSPORT +from bumble import avc +from bumble import avrcp +from bumble import avdtp +from bumble import a2dp +from bumble import utils + + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def sdp_records(): + a2dp_sink_service_record_handle = 0x00010001 + avrcp_controller_service_record_handle = 0x00010002 + avrcp_target_service_record_handle = 0x00010003 + # pylint: disable=line-too-long + return { + a2dp_sink_service_record_handle: a2dp.make_audio_sink_service_sdp_records( + a2dp_sink_service_record_handle + ), + avrcp_controller_service_record_handle: avrcp.make_controller_service_sdp_records( + avrcp_controller_service_record_handle + ), + avrcp_target_service_record_handle: avrcp.make_target_service_sdp_records( + avrcp_controller_service_record_handle + ), + } + + +# ----------------------------------------------------------------------------- +def codec_capabilities(): + return avdtp.MediaCodecCapabilities( + media_type=avdtp.AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=a2dp.A2DP_SBC_CODEC_TYPE, + media_codec_information=a2dp.SbcMediaCodecInformation.from_lists( + sampling_frequencies=[48000, 44100, 32000, 16000], + channel_modes=[ + a2dp.SBC_MONO_CHANNEL_MODE, + a2dp.SBC_DUAL_CHANNEL_MODE, + a2dp.SBC_STEREO_CHANNEL_MODE, + a2dp.SBC_JOINT_STEREO_CHANNEL_MODE, + ], + block_lengths=[4, 8, 12, 16], + subbands=[4, 8], + allocation_methods=[ + a2dp.SBC_LOUDNESS_ALLOCATION_METHOD, + a2dp.SBC_SNR_ALLOCATION_METHOD, + ], + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), + ) + + +# ----------------------------------------------------------------------------- +def on_avdtp_connection(server): + # Add a sink endpoint to the server + sink = server.add_sink(codec_capabilities()) + sink.on('rtp_packet', on_rtp_packet) + + +# ----------------------------------------------------------------------------- +def on_rtp_packet(packet): + print(f'RTP: {packet}') + + +# ----------------------------------------------------------------------------- +def on_avrcp_start(avrcp_protocol): + async def get_events_supported(): + response = await avrcp_protocol.send_avrcp_command( + avc.CommandFrame.CommandType.STATUS, + avrcp.GetCapabilitiesCommand( + avrcp.GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED + ), + ) + print("EVENTS SUPPORTED:", response) + + utils.AsyncRunner.spawn(get_events_supported()) + + async def get_volume(): + await asyncio.sleep(5) + while True: + response = await avrcp_protocol.send_avrcp_command( + avc.CommandFrame.CommandType.NOTIFY, + avrcp.RegisterNotificationCommand(avrcp.EventId.VOLUME_CHANGED, 0), + ) + print("VOLUME CHANGE:", response) + if isinstance(response, avrcp.Protocol.InterimResponse): + print("INTERIM!") + response = await response.final + print("FINAL:", response) + + + utils.AsyncRunner.spawn(get_volume()) + + async def play(): + await asyncio.sleep(10) + response = await avrcp_protocol.send_passthrough_command( + avc.PassThroughCommandFrame( + avc.CommandFrame.CommandType.CONTROL, + avc.Frame.SubunitType.PANEL, + 0, + avc.PassThroughFrame.StateFlag.PRESSED, + avc.PassThroughFrame.OperationId.PLAY, + b'', + ) + ) + + print("PLAY:", response) + #utils.AsyncRunner.spawn(play()) + + # avrcp_protocol.send_command_pdu( + # avc.CommandFrame.CommandType.STATUS, + # avrcp.GetCapabilitiesCommand(avrcp.GetCapabilitiesCommand.CapabilityId.COMPANY_ID), + # ) + # self.send_command_pdu( + # avc.CommandFrame.CommandType.STATUS, + # GetCapabilitiesCommand(GetCapabilitiesCommand.CapabilityId(123)), + # ) + + # for event_id in ( + # avrcp.EventId.TRACK_REACHED_END, + # avrcp.EventId.SYSTEM_STATUS_CHANGED, + # avrcp.EventId.PLAYBACK_STATUS_CHANGED, + # avrcp.EventId.TRACK_CHANGED, + # avrcp.EventId.PLAYER_APPLICATION_SETTING_CHANGED, + # avrcp.EventId.NOW_PLAYING_CONTENT_CHANGED, + # avrcp.EventId.AVAILABLE_PLAYERS_CHANGED, + # avrcp.EventId(111), + # ): + # avrcp_protocol.send_command_pdu( + # avc.CommandFrame.CommandType.NOTIFY, + # avrcp.RegisterNotificationCommand(event_id, 0), + # ) + + +# ----------------------------------------------------------------------------- +async def serve_websocket_controller(avrcp_protocol): + # Start a Websocket server to receive events from a web page + async def serve(websocket, _path): + while True: + try: + message = await websocket.recv() + print('Received: ', str(message)) + + parsed = json.loads(message) + message_type = parsed['type'] + if message_type == 'keydown': + pass + elif message_type == 'keyup': + pass + except websockets.exceptions.ConnectionClosedOK: + pass + + # pylint: disable-next=no-member + await websockets.serve(serve, 'localhost', 8989) + await asyncio.get_event_loop().create_future() + + +# ----------------------------------------------------------------------------- +async def main(): + if len(sys.argv) < 3: + print( + 'Usage: run_avrcp_controller.py ' + ' []' + ) + print('example: run_avrcp_controller.py classic1.json usb:0') + return + + print('<<< connecting to HCI...') + async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): + print('<<< connected') + + # Create a device + device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) + device.classic_enabled = True + + # Setup the SDP to expose the sink service + device.sdp_service_records = sdp_records() + + # Start the controller + await device.power_on() + + # Create a listener to wait for AVDTP connections + listener = avdtp.Listener(avdtp.Listener.create_registrar(device)) + listener.on('connection', on_avdtp_connection) + + avrcp_protocol = avrcp.Protocol() + # avrcp_protocol.listen(device) + avrcp_protocol.on("start", lambda: on_avrcp_start(avrcp_protocol)) + + if len(sys.argv) >= 5: + # Connect to the source + target_address = sys.argv[4] + print(f'=== Connecting to {target_address}...') + connection = await device.connect( + target_address, transport=BT_BR_EDR_TRANSPORT + ) + print(f'=== Connected to {connection.peer_address}!') + + # Request authentication + print('*** Authenticating...') + await connection.authenticate() + print('*** Authenticated') + + # Enable encryption + print('*** Enabling encryption...') + await connection.encrypt() + print('*** Encryption on') + + server = await avdtp.Protocol.connect(connection) + listener.set_server(connection, server) + sink = server.add_sink(codec_capabilities()) + sink.on('rtp_packet', on_rtp_packet) + + await avrcp_protocol.connect(connection) + + else: + # Start being discoverable and connectable + await device.set_discoverable(True) + await device.set_connectable(True) + + await serve_websocket_controller(avrcp_protocol) + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/examples/run_avrcp_target.py b/examples/run_avrcp_target.py new file mode 100644 index 00000000..45dfb664 --- /dev/null +++ b/examples/run_avrcp_target.py @@ -0,0 +1,221 @@ +# Copyright 2021-2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import asyncio +import sys +import os +import logging + +from bumble.colors import color +from bumble.device import Device +from bumble.transport import open_transport_or_link +from bumble.core import BT_BR_EDR_TRANSPORT +from bumble.avdtp import ( + find_avdtp_service_with_connection, + AVDTP_AUDIO_MEDIA_TYPE, + MediaCodecCapabilities, + MediaPacketPump, + Protocol, + Listener, +) +from bumble import avc +from bumble import avctp +from bumble import avrcp +from bumble.a2dp import ( + SBC_JOINT_STEREO_CHANNEL_MODE, + SBC_LOUDNESS_ALLOCATION_METHOD, + make_audio_source_service_sdp_records, + A2DP_SBC_CODEC_TYPE, + SbcMediaCodecInformation, + SbcPacketSource, +) +from bumble import l2cap + + +logger = logging.getLogger(__name__) + + +# ----------------------------------------------------------------------------- +def sdp_records(): + service_record_handle = 0x00010001 + return { + service_record_handle: make_audio_source_service_sdp_records( + service_record_handle + ) + } + + +# ----------------------------------------------------------------------------- +def codec_capabilities(): + # NOTE: this shouldn't be hardcoded, but should be inferred from the input file + # instead + return MediaCodecCapabilities( + media_type=AVDTP_AUDIO_MEDIA_TYPE, + media_codec_type=A2DP_SBC_CODEC_TYPE, + media_codec_information=SbcMediaCodecInformation.from_discrete_values( + sampling_frequency=44100, + channel_mode=SBC_JOINT_STEREO_CHANNEL_MODE, + block_length=16, + subbands=8, + allocation_method=SBC_LOUDNESS_ALLOCATION_METHOD, + minimum_bitpool_value=2, + maximum_bitpool_value=53, + ), + ) + + +# ----------------------------------------------------------------------------- +def on_avdtp_connection(): + logger.debug("$$$ AVDTP Connection") + + +# ----------------------------------------------------------------------------- +async def stream_packets(read_function, protocol): + # Discover all endpoints on the remote device + endpoints = await protocol.discover_remote_endpoints() + for endpoint in endpoints: + print('@@@', endpoint) + + # Select a sink + sink = protocol.find_remote_sink_by_codec( + AVDTP_AUDIO_MEDIA_TYPE, A2DP_SBC_CODEC_TYPE + ) + if sink is None: + print(color('!!! no SBC sink found', 'red')) + return + print(f'### Selected sink: {sink.seid}') + + # Stream the packets + packet_source = SbcPacketSource( + read_function, protocol.l2cap_channel.mtu, codec_capabilities() + ) + packet_pump = MediaPacketPump(packet_source.packets) + source = protocol.add_source(packet_source.codec_capabilities, packet_pump) + stream = await protocol.create_stream(source, sink) + await stream.start() + await asyncio.sleep(5) + await stream.stop() + await asyncio.sleep(5) + await stream.start() + await asyncio.sleep(5) + await stream.stop() + await stream.close() + + +# ----------------------------------------------------------------------------- +def on_avctp_connection(l2cap_channel: l2cap.Channel) -> None: + logger.debug(f'+++ new L2CAP connection: {l2cap_channel}') + l2cap_channel.on('open', lambda: on_avctp_l2cap_channel_open(l2cap_channel)) + + +def on_avctp_l2cap_channel_open(l2cap_channel: l2cap.Channel) -> None: + logger.debug(f'$$$ L2CAP channel open: {l2cap_channel}') + l2cap_channel.sink = on_avctp_pdu + + +def on_avctp_pdu(pdu): + logger.debug(f'AVCTP PDU: {pdu.hex()}') + + +def on_avctp_channel_open(): + logger.debug('### AVCTP channel open') + + +def on_avctp_channel_close(): + logger.debug('&&& AVCTP channel close') + + +def on_avctp_command(transaction_label, command): + print(f"<<< AVCTP Command, transaction_label={transaction_label}: {command.hex()}") + frame = avc.Frame.from_bytes(command) + print(frame) + + +# ----------------------------------------------------------------------------- +async def main(): + if len(sys.argv) < 3: + print( + 'Usage: run_avrcp_target.py ' + '[]' + ) + print('example: run_avrcp_target.py classic1.json usb:0 E1:CA:72:48:C4:E8') + return + + print('<<< connecting to HCI...') + async with await open_transport_or_link(sys.argv[2]) as (hci_source, hci_sink): + print('<<< connected') + + # Create a device + device = Device.from_config_file_with_hci(sys.argv[1], hci_source, hci_sink) + device.classic_enabled = True + + # Setup the SDP to expose the SRC service + device.sdp_service_records = sdp_records() + + # Start + await device.power_on() + + device.register_l2cap_server(avctp.AVCTP_PSM, on_avctp_connection) + + if len(sys.argv) > 3: + # Connect to a peer + target_address = sys.argv[3] + print(f'=== Connecting to {target_address}...') + connection = await device.connect( + target_address, transport=BT_BR_EDR_TRANSPORT + ) + print(f'=== Connected to {connection.peer_address}!') + + # Request authentication + print('*** Authenticating...') + await connection.authenticate() + print('*** Authenticated') + + # Enable encryption + print('*** Enabling encryption...') + await connection.encrypt() + print('*** Encryption on') + + # Create a client to interact with the remote device + await Protocol.connect(connection, (1, 2)) + + print("------------------ connecting to AVCTP -----------") + connector = connection.create_l2cap_connector(avctp.AVCTP_PSM) + avctp_channel = await connector() + print("++++++ connected") + avctp_channel.sink = on_avctp_pdu + avctp_channel.on('open', on_avctp_channel_open) + avctp_channel.on('close', on_avctp_channel_close) + + avrcp_protocol = avrcp.Protocol() + avctp_protocol = avctp.Protocol(avctp_channel) + avctp_protocol.register_command_handler(avrcp.AVRCP_PID, on_avctp_command) + else: + # Create a listener to wait for AVDTP connections + listener = Listener(Listener.create_registrar(device), version=(1, 2)) + listener.on('connection', lambda protocol: on_avdtp_connection()) + + # Become connectable and wait for a connection + await device.set_discoverable(True) + await device.set_connectable(True) + + await hci_source.wait_for_termination() + + +# ----------------------------------------------------------------------------- +logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'DEBUG').upper()) +asyncio.run(main()) diff --git a/tests/avrcp_test.py b/tests/avrcp_test.py new file mode 100644 index 00000000..beaaa259 --- /dev/null +++ b/tests/avrcp_test.py @@ -0,0 +1,173 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +import pytest +import struct + +from bumble import avc +from bumble import avrcp +from bumble import avctp + + +# ----------------------------------------------------------------------------- +def test_frame_parser(): + with pytest.raises(ValueError) as error: + avc.Frame.from_bytes(bytes.fromhex("11480000")) + + x = bytes.fromhex("014D0208") + frame = avc.Frame.from_bytes(x) + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 7 + assert frame.opcode == 8 + + x = bytes.fromhex("014DFF0108") + frame = avc.Frame.from_bytes(x) + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 260 + assert frame.opcode == 8 + + x = bytes.fromhex("0148000019581000000103") + + frame = avc.Frame.from_bytes(x) + + assert isinstance(frame, avc.CommandFrame) + assert frame.ctype == avc.CommandFrame.CommandType.STATUS + assert frame.subunit_type == avc.Frame.SubunitType.PANEL + assert frame.subunit_id == 0 + assert frame.opcode == 0 + + +# ----------------------------------------------------------------------------- +def test_vendor_dependent_command(): + x = bytes.fromhex("0148000019581000000103") + frame = avc.Frame.from_bytes(x) + assert isinstance(frame, avc.VendorDependentCommandFrame) + assert frame.company_id == 0x1958 + assert frame.vendor_dependent_data == bytes.fromhex("1000000103") + + frame = avc.VendorDependentCommandFrame( + avc.CommandFrame.CommandType.STATUS, + avc.Frame.SubunitType.PANEL, + 0, + 0x1958, + bytes.fromhex("1000000103"), + ) + assert bytes(frame) == x + + +# ----------------------------------------------------------------------------- +def test_avctp_message_assembler(): + received_message = [] + + def on_message(transaction_label, is_response, ipid, pid, payload): + received_message.append((transaction_label, is_response, ipid, pid, payload)) + + assembler = avctp.MessageAssembler(on_message) + + payload = bytes.fromhex("01") + assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + received_message = [] + payload = bytes.fromhex("010203") + assembler.on_pdu(bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert len(received_message) == 0 + assembler.on_pdu(bytes([1 << 4 | 0b00 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + received_message = [] + payload = bytes.fromhex("010203") + assembler.on_pdu( + bytes([1 << 4 | 0b01 << 2 | 1 << 1 | 0, 3, 0x11, 0x22]) + payload[0:1] + ) + assembler.on_pdu( + bytes([1 << 4 | 0b10 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[1:2] + ) + assembler.on_pdu( + bytes([1 << 4 | 0b11 << 2 | 1 << 1 | 0, 0x11, 0x22]) + payload[2:3] + ) + assert received_message + assert received_message[0] == (1, False, False, 0x1122, payload) + + # received_message = [] + # parameter = bytes.fromhex("010203") + # assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter) + # assert len(received_message) == 0 + + +# ----------------------------------------------------------------------------- +def test_avrcp_pdu_assembler(): + received_pdus = [] + + def on_pdu(pdu_id, parameter): + received_pdus.append((pdu_id, parameter)) + + assembler = avrcp.PduAssembler(on_pdu) + + parameter = bytes.fromhex("01") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, len(parameter)) + parameter) + assert len(received_pdus) == 0 + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b00, len(parameter)) + parameter) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b01, 1) + parameter[0:1]) + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b10, 1) + parameter[1:2]) + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, 1) + parameter[2:3]) + assert received_pdus + assert received_pdus[0] == (0x10, parameter) + + received_pdus = [] + parameter = bytes.fromhex("010203") + assembler.on_pdu(struct.pack(">BBH", 0x10, 0b11, len(parameter)) + parameter) + assert len(received_pdus) == 0 + + +def test_passthrough_commands(): + play_pressed = avc.PassThroughCommandFrame( + avc.CommandFrame.CommandType.CONTROL, + avc.CommandFrame.SubunitType.PANEL, + 0, + avc.PassThroughCommandFrame.StateFlag.PRESSED, + avc.PassThroughCommandFrame.OperationId.PLAY, + b'', + ) + + play_pressed_bytes = bytes(play_pressed) + parsed = avc.Frame.from_bytes(play_pressed_bytes) + assert isinstance(parsed, avc.PassThroughCommandFrame) + assert parsed.operation_id == avc.PassThroughCommandFrame.OperationId.PLAY + assert bytes(parsed) == play_pressed_bytes + + +# ----------------------------------------------------------------------------- +if __name__ == '__main__': + test_frame_parser() + test_vendor_dependent_command() + test_avctp_message_assembler() + test_avrcp_pdu_assembler() + test_passthrough_commands() diff --git a/tests/utils_test.py b/tests/utils_test.py index d6f57807..6266f9ef 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -12,15 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- import contextlib import logging import os +from unittest.mock import MagicMock -from bumble import utils from pyee import EventEmitter -from unittest.mock import MagicMock + +from bumble import utils +# ----------------------------------------------------------------------------- def test_on() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -33,6 +38,7 @@ def test_on() -> None: assert mock.call_count == 1 +# ----------------------------------------------------------------------------- def test_on_decorator() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -48,6 +54,7 @@ def on_event(*_) -> None: assert mock.call_count == 1 +# ----------------------------------------------------------------------------- def test_multiple_handlers() -> None: emitter = EventEmitter() with contextlib.closing(utils.EventWatcher()) as context: @@ -64,6 +71,30 @@ def test_multiple_handlers() -> None: mock.assert_called_once_with('b') +# ----------------------------------------------------------------------------- +def test_open_int_enums(): + class Foo(utils.OpenIntEnum): + FOO = 1 + BAR = 2 + BLA = 3 + + x = Foo(1) + assert x.name == "FOO" + assert x.value == 1 + assert int(x) == 1 + assert x == 1 + assert x + 1 == 2 + + x = Foo(4) + assert x.name == "Foo[4]" + assert x.value == 4 + assert int(x) == 4 + assert x == 4 + assert x + 1 == 5 + + print(list(Foo)) + + # ----------------------------------------------------------------------------- def run_tests(): test_on() @@ -75,3 +106,4 @@ def run_tests(): if __name__ == '__main__': logging.basicConfig(level=os.environ.get('BUMBLE_LOGLEVEL', 'INFO').upper()) run_tests() + test_open_int_enums()