Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
barbibulle committed Nov 11, 2023
1 parent 43b4d66 commit d8bfbab
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 58 deletions.
2 changes: 1 addition & 1 deletion bumble/avctp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class PacketType(IntEnum):
CONTINUE = 0b10
END = 0b11

def __init__(self, l2cap_channel: l2cap.Channel) -> None:
def __init__(self, l2cap_channel: l2cap.ClassicChannel) -> None:
self.command_handlers = {}
self.response_handlers = {}
self.l2cap_channel = l2cap_channel
Expand Down
82 changes: 63 additions & 19 deletions bumble/avrcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import logging
import struct
from typing import (
Any,
Awaitable,
Callable,
ClassVar,
Dict,
List,
Optional,
Expand All @@ -47,7 +47,7 @@
DataElement,
ServiceAttribute,
)
from bumble.utils import OpenIntEnum
from bumble.utils import AsyncRunner, OpenIntEnum
from bumble.core import (
ProtocolError,
BT_L2CAP_PROTOCOL_ID,
Expand Down Expand Up @@ -258,6 +258,7 @@ def on_pdu(self, pdu: bytes) -> None:
self.on_pdu_complete()

def on_pdu_complete(self) -> None:
assert self.pdu_id is not None
try:
self.callback(self.pdu_id, self.parameter)
except Exception as error:
Expand Down Expand Up @@ -394,7 +395,9 @@ def __str__(self) -> str:

# -----------------------------------------------------------------------------
class NotImplementedResponse(Response):
pass
@classmethod
def from_bytes(cls, pdu_id: Protocol.PduId, pdu: bytes) -> NotImplementedResponse:
return cls(pdu_id, pdu[1:])


# -----------------------------------------------------------------------------
Expand All @@ -413,6 +416,7 @@ def from_bytes(cls, pdu: bytes) -> GetCapabilitiesResponse:
capability_id = GetCapabilitiesCommand.CapabilityId(pdu[0])
capability_count = pdu[1]

capabilities: List[SupportsBytes]
if capability_id == GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED:
capabilities = [EventId(pdu[2 + x]) for x in range(capability_count)]
else:
Expand Down Expand Up @@ -480,7 +484,7 @@ def __init__(self, event: Event) -> None:
def __str__(self) -> str:
return self.to_string(
{
"event": self.event,
"event": str(self.event),
}
)

Expand All @@ -506,12 +510,14 @@ def __bytes__(self) -> bytes:


# -----------------------------------------------------------------------------
@dataclass
class Event:
event_id: EventId

@staticmethod
def from_bytes(pdu: bytes) -> Event:
event_id = pdu[0]
subclass: Any
if event_id == EventId.PLAYBACK_STATUS_CHANGED:
subclass = PlaybackStatusChangedEvent
elif event_id == EventId.VOLUME_CHANGED:
Expand All @@ -528,12 +534,11 @@ def __bytes__(self) -> bytes:
# -----------------------------------------------------------------------------
@dataclass
class GenericEvent(Event):
event_id: EventId
data: bytes

@classmethod
def from_bytes(cls, pdu: bytes) -> GenericEvent:
return cls(event_id=pdu[0], data=pdu[1:])
return cls(event_id=EventId(pdu[0]), data=pdu[1:])

def __bytes__(self) -> bytes:
return bytes(self.event_id) + self.data
Expand All @@ -550,27 +555,33 @@ class PlayStatus(OpenIntEnum):
REV_SEEK = 0x04
ERROR = 0xFF

event_id: ClassVar[EventId] = EventId.PLAYBACK_STATUS_CHANGED
play_status: PlayStatus

@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent:
return cls(play_status=cls.PlayStatus(pdu[1]))

def __init__(self, play_status: PlayStatus) -> None:
super().__init__(EventId.PLAYBACK_STATUS_CHANGED)
self.play_status = play_status

def __bytes__(self) -> bytes:
return bytes(self.event_id) + bytes([self.play_status])


# -----------------------------------------------------------------------------
@dataclass
class VolumeChangedEvent(Event):
event_id: ClassVar[EventId] = EventId.VOLUME_CHANGED
volume: int

@classmethod
def from_bytes(cls, pdu: bytes) -> PlaybackStatusChangedEvent:
def from_bytes(cls, pdu: bytes) -> VolumeChangedEvent:
return cls(volume=pdu[1])

def __init__(self, volume: int) -> None:
super().__init__(EventId.VOLUME_CHANGED)
self.volume = volume

def __bytes__(self) -> bytes:
return bytes(self.event_id) + bytes([self.volume])

Expand Down Expand Up @@ -621,7 +632,7 @@ class NotPendingError(Exception):
"""There is no pending command for a transaction label."""

class PendingCommand:
response: Awaitable
response: asyncio.Future

def __init__(self, transaction_label: int) -> None:
self.transaction_label = transaction_label
Expand Down Expand Up @@ -692,7 +703,9 @@ def listen(self, device: Device) -> None:
device.register_l2cap_server(avctp.AVCTP_PSM, self.on_avctp_connection)

async def connect(self, connection: Connection) -> None:
avctp_channel = await connection.create_l2cap_connector(avctp.AVCTP_PSM)()
avctp_channel = await connection.create_l2cap_channel(
l2cap.ClassicChannelSpec(psm=avctp.AVCTP_PSM)
)
self.on_avctp_channel_open(avctp_channel)

async def obtain_pending_command(self) -> PendingCommand:
Expand All @@ -718,16 +731,16 @@ def recycle_pending_command(self, pending_command: PendingCommand) -> None:

# return await pending_command.response

def on_avctp_connection(self, l2cap_channel: l2cap.Channel) -> None:
def on_avctp_connection(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug("AVCTP connection established")
l2cap_channel.on("open", lambda: self.on_avctp_channel_open(l2cap_channel))

def on_avctp_channel_open(self, l2cap_channel: l2cap.Channel) -> None:
def on_avctp_channel_open(self, l2cap_channel: l2cap.ClassicChannel) -> None:
logger.debug("AVCTP channel open")
if self.avctp_protocol is not None:
# TODO: find a better strategy instead of just closing
logger.warning("AVCTP protocol already active, closing connection")
l2cap_channel.disconnect()
AsyncRunner.spawn(l2cap_channel.disconnect())
return

self.avctp_protocol = avctp.Protocol(l2cap_channel)
Expand Down Expand Up @@ -812,7 +825,7 @@ def on_avctp_response(

# A None response means an invalid PID was used in the request.
if response is None:
pending_command.set_exception(self.InvalidPidError())
pending_command.response.set_exception(self.InvalidPidError())

if isinstance(response, avc.VendorDependentResponseFrame):
if not self.check_vendor_dependent_frame(response):
Expand Down Expand Up @@ -848,6 +861,8 @@ def on_avctp_response(
def on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
logger.debug(f"<<< AVRCP command PDU [pdu_id={pdu_id.name}]: {pdu.hex()}")

assert self.receive_command_state is not None

# Dispatch the command.
# NOTE: with a small number of supported commands, a manual dispatch like this
# is Ok, but if/when more commands are supported, a lookup dispatch mechanism
Expand All @@ -872,16 +887,20 @@ def on_command_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
# Not supported.
# TODO: check that this is the right way to respond in this case.
logger.debug("unsupported PDU ID")
self.send_rejected_response_pdu(self.StatusCode.INVALID_PARAMETER)
self.send_rejected_avrcp_response(
pdu_id, self.StatusCode.INVALID_PARAMETER
)
else:
logger.debug("unsupported command type")
self.send_rejected_response_pdu(self.StatusCode.INVALID_COMMAND)
self.send_rejected_avrcp_response(pdu_id, self.StatusCode.INVALID_COMMAND)

self.receive_command_state = None

def on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
logger.debug(f"<<< AVRCP response PDU [pdu_id={pdu_id.name}]: {pdu.hex()}")

assert self.receive_response_state is not None

transaction_label = self.receive_response_state.transaction_label
response_code = self.receive_response_state.response_code
self.receive_response_state = None
Expand Down Expand Up @@ -959,12 +978,15 @@ def on_response_pdu(self, pdu_id: PduId, pdu: bytes) -> None:
def send_command(self, transaction_label: int, command: avc.CommandFrame) -> None:
logger.debug(f">>> AVRCP command: {command}")

if self.avctp_protocol is None:
logger.warning("trying to send command while avctp_protocol is None")
return

self.avctp_protocol.send_command(transaction_label, AVRCP_PID, bytes(command))

async def send_passthrough_command(
self, command: avc.PassThroughCommandFrame
) -> avc.PassThroughResponseFrame:
# TODO
# Wait for a free command slot.
pending_command = await self.obtain_pending_command()

Expand All @@ -975,6 +997,22 @@ async def send_passthrough_command(
response = await pending_command.response
return response

async def send_key_event(
self, key: avc.PassThroughCommandFrame.OperationId, pressed: bool
) -> avc.PassThroughResponseFrame:
return await self.send_passthrough_command(
avc.PassThroughCommandFrame(
avc.CommandFrame.CommandType.CONTROL,
avc.Frame.SubunitType.PANEL,
0,
avc.PassThroughFrame.StateFlag.PRESSED
if pressed
else avc.PassThroughFrame.StateFlag.RELEASED,
key,
b'',
)
)

async def send_avrcp_command(
self, command_type: avc.CommandFrame.CommandType, command: Command
) -> avc.VendorDependentResponseFrame:
Expand Down Expand Up @@ -1003,6 +1041,7 @@ async def send_avrcp_command(
def send_response(
self, transaction_label: int, response: avc.ResponseFrame
) -> None:
assert self.avctp_protocol is not None
logger.debug(f">>> AVRCP response: {response}")
self.avctp_protocol.send_response(transaction_label, AVRCP_PID, bytes(response))

Expand All @@ -1025,6 +1064,8 @@ def send_passthrough_response(
def send_avrcp_response(
self, response_code: avc.ResponseFrame.ResponseCode, response: Response
) -> None:
assert self.receive_command_state is not None

# TODO: fragmentation
logger.debug(f">>> AVRCP response PDU: {response}")
pdu = (
Expand Down Expand Up @@ -1068,7 +1109,10 @@ def on_get_capabilities_command(self, command: GetCapabilitiesCommand) -> None:
== GetCapabilitiesCommand.CapabilityId.EVENTS_SUPPORTED
):
# TEST: hardcoded values for testing only
supported_events = [EventId.VOLUME_CHANGED, EventId.PLAYBACK_STATUS_CHANGED]
supported_events: List[SupportsBytes] = [
EventId.VOLUME_CHANGED,
EventId.PLAYBACK_STATUS_CHANGED,
]
self.send_avrcp_response(
avc.ResponseFrame.ResponseCode.IMPLEMENTED_OR_STABLE,
GetCapabilitiesResponse(command.capability_id, supported_events),
Expand Down
1 change: 0 additions & 1 deletion bumble/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
HCI_LE_1M_PHY_BIT,
HCI_LE_2M_PHY,
HCI_LE_2M_PHY_LE_SUPPORTED_FEATURE,
HCI_LE_CLEAR_RESOLVING_LIST_COMMAND,
HCI_LE_CODED_PHY,
HCI_LE_CODED_PHY_BIT,
HCI_LE_CODED_PHY_LE_SUPPORTED_FEATURE,
Expand Down
2 changes: 1 addition & 1 deletion bumble/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def deprecated(msg: str):
"""

def wrapper(function):
@wraps(function)
@functools.wraps(function)
def inner(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return function(*args, **kwargs)
Expand Down
Loading

0 comments on commit d8bfbab

Please sign in to comment.