From e1fdb126476555173329eac5d62131d74e7ff950 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Fri, 17 Nov 2023 17:29:35 +0800 Subject: [PATCH] Typing A2DP --- bumble/a2dp.py | 151 +++++++++++++++++++++++++++---------------------- 1 file changed, 83 insertions(+), 68 deletions(-) diff --git a/bumble/a2dp.py b/bumble/a2dp.py index eeecb1ee..6d8fc478 100644 --- a/bumble/a2dp.py +++ b/bumble/a2dp.py @@ -15,9 +15,13 @@ # ----------------------------------------------------------------------------- # Imports # ----------------------------------------------------------------------------- +from __future__ import annotations + +import dataclasses import struct import logging -from collections import namedtuple +from collections.abc import AsyncGenerator +from typing import List, Callable, Awaitable from .company_ids import COMPANY_IDENTIFIERS from .sdp import ( @@ -239,24 +243,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)): # ----------------------------------------------------------------------------- -class SbcMediaCodecInformation( - namedtuple( - 'SbcMediaCodecInformation', - [ - 'sampling_frequency', - 'channel_mode', - 'block_length', - 'subbands', - 'allocation_method', - 'minimum_bitpool_value', - 'maximum_bitpool_value', - ], - ) -): +@dataclasses.dataclass +class SbcMediaCodecInformation: ''' A2DP spec - 4.3.2 Codec Specific Information Elements ''' + sampling_frequency: int + channel_mode: int + block_length: int + subbands: int + allocation_method: int + minimum_bitpool_value: int + maximum_bitpool_value: int + SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1} CHANNEL_MODE_BITS = { SBC_MONO_CHANNEL_MODE: 1 << 3, @@ -272,7 +272,7 @@ class SbcMediaCodecInformation( } @staticmethod - def from_bytes(data: bytes) -> 'SbcMediaCodecInformation': + def from_bytes(data: bytes) -> SbcMediaCodecInformation: sampling_frequency = (data[0] >> 4) & 0x0F channel_mode = (data[0] >> 0) & 0x0F block_length = (data[1] >> 4) & 0x0F @@ -293,14 +293,14 @@ def from_bytes(data: bytes) -> 'SbcMediaCodecInformation': @classmethod def from_discrete_values( cls, - sampling_frequency, - channel_mode, - block_length, - subbands, - allocation_method, - minimum_bitpool_value, - maximum_bitpool_value, - ): + sampling_frequency: int, + channel_mode: int, + block_length: int, + subbands: int, + allocation_method: int, + minimum_bitpool_value: int, + maximum_bitpool_value: int, + ) -> SbcMediaCodecInformation: return SbcMediaCodecInformation( sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], channel_mode=cls.CHANNEL_MODE_BITS[channel_mode], @@ -314,14 +314,14 @@ def from_discrete_values( @classmethod def from_lists( cls, - sampling_frequencies, - channel_modes, - block_lengths, - subbands, - allocation_methods, - minimum_bitpool_value, - maximum_bitpool_value, - ): + sampling_frequencies: List[int], + channel_modes: List[int], + block_lengths: List[int], + subbands: List[int], + allocation_methods: List[int], + minimum_bitpool_value: int, + maximum_bitpool_value: int, + ) -> SbcMediaCodecInformation: return SbcMediaCodecInformation( sampling_frequency=sum( cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies @@ -348,7 +348,7 @@ def __bytes__(self) -> bytes: ] ) - def __str__(self): + def __str__(self) -> str: channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO'] allocation_methods = ['SNR', 'Loudness'] return '\n'.join( @@ -367,16 +367,19 @@ def __str__(self): # ----------------------------------------------------------------------------- -class AacMediaCodecInformation( - namedtuple( - 'AacMediaCodecInformation', - ['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'], - ) -): +@dataclasses.dataclass +class AacMediaCodecInformation: ''' A2DP spec - 4.5.2 Codec Specific Information Elements ''' + object_type: int + sampling_frequency: int + channels: int + rfa: int + vbr: int + bitrate: int + OBJECT_TYPE_BITS = { MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7, MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6, @@ -400,7 +403,7 @@ class AacMediaCodecInformation( CHANNELS_BITS = {1: 1 << 1, 2: 1} @staticmethod - def from_bytes(data: bytes) -> 'AacMediaCodecInformation': + def from_bytes(data: bytes) -> AacMediaCodecInformation: object_type = data[0] sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F) channels = (data[2] >> 2) & 0x03 @@ -413,8 +416,13 @@ def from_bytes(data: bytes) -> 'AacMediaCodecInformation': @classmethod def from_discrete_values( - cls, object_type, sampling_frequency, channels, vbr, bitrate - ): + cls, + object_type: int, + sampling_frequency: int, + channels: int, + vbr: int, + bitrate: int, + ) -> AacMediaCodecInformation: return AacMediaCodecInformation( object_type=cls.OBJECT_TYPE_BITS[object_type], sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency], @@ -425,7 +433,14 @@ def from_discrete_values( ) @classmethod - def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate): + def from_lists( + cls, + object_types: List[int], + sampling_frequencies: List[int], + channels: List[int], + vbr: int, + bitrate: int, + ) -> AacMediaCodecInformation: return AacMediaCodecInformation( object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types), sampling_frequency=sum( @@ -449,7 +464,7 @@ def __bytes__(self) -> bytes: ] ) - def __str__(self): + def __str__(self) -> str: object_types = [ 'MPEG_2_AAC_LC', 'MPEG_4_AAC_LC', @@ -474,26 +489,26 @@ def __str__(self): ) +@dataclasses.dataclass # ----------------------------------------------------------------------------- class VendorSpecificMediaCodecInformation: ''' A2DP spec - 4.7.2 Codec Specific Information Elements ''' + vendor_id: int + codec_id: int + value: bytes + @staticmethod - def from_bytes(data): + def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation: (vendor_id, codec_id) = struct.unpack_from(' bytes: return struct.pack(' str: # pylint: disable=line-too-long return '\n'.join( [ @@ -506,29 +521,27 @@ def __str__(self): # ----------------------------------------------------------------------------- +@dataclasses.dataclass class SbcFrame: - def __init__( - self, sampling_frequency, block_count, channel_mode, subband_count, payload - ): - self.sampling_frequency = sampling_frequency - self.block_count = block_count - self.channel_mode = channel_mode - self.subband_count = subband_count - self.payload = payload + sampling_frequency: int + block_count: int + channel_mode: int + subband_count: int + payload: bytes @property - def sample_count(self): + def sample_count(self) -> int: return self.subband_count * self.block_count @property - def bitrate(self): + def bitrate(self) -> int: return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count) @property - def duration(self): + def duration(self) -> float: return self.sample_count / self.sampling_frequency - def __str__(self): + def __str__(self) -> str: return ( f'SBC(sf={self.sampling_frequency},' f'cm={self.channel_mode},' @@ -540,12 +553,12 @@ def __str__(self): # ----------------------------------------------------------------------------- class SbcParser: - def __init__(self, read): + def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None: self.read = read @property - def frames(self): - async def generate_frames(): + def frames(self) -> AsyncGenerator[SbcFrame, None]: + async def generate_frames() -> AsyncGenerator[SbcFrame, None]: while True: # Read 4 bytes of header header = await self.read(4) @@ -589,7 +602,9 @@ async def generate_frames(): # ----------------------------------------------------------------------------- class SbcPacketSource: - def __init__(self, read, mtu, codec_capabilities): + def __init__( + self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities + ) -> None: self.read = read self.mtu = mtu self.codec_capabilities = codec_capabilities