Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing A2DP #335

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 83 additions & 68 deletions bumble/a2dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# -----------------------------------------------------------------------------
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import dataclasses
import struct
import logging
from collections import namedtuple
from collections.abc import AsyncGenerator
from typing import List, Callable, Awaitable

from .company_ids import COMPANY_IDENTIFIERS
from .sdp import (
Expand Down Expand Up @@ -239,24 +243,20 @@ def make_audio_sink_service_sdp_records(service_record_handle, version=(1, 3)):


# -----------------------------------------------------------------------------
class SbcMediaCodecInformation(
namedtuple(
'SbcMediaCodecInformation',
[
'sampling_frequency',
'channel_mode',
'block_length',
'subbands',
'allocation_method',
'minimum_bitpool_value',
'maximum_bitpool_value',
],
)
):
@dataclasses.dataclass
class SbcMediaCodecInformation:
'''
A2DP spec - 4.3.2 Codec Specific Information Elements
'''

sampling_frequency: int
channel_mode: int
block_length: int
subbands: int
allocation_method: int
minimum_bitpool_value: int
maximum_bitpool_value: int

SAMPLING_FREQUENCY_BITS = {16000: 1 << 3, 32000: 1 << 2, 44100: 1 << 1, 48000: 1}
CHANNEL_MODE_BITS = {
SBC_MONO_CHANNEL_MODE: 1 << 3,
Expand All @@ -272,7 +272,7 @@ class SbcMediaCodecInformation(
}

@staticmethod
def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
def from_bytes(data: bytes) -> SbcMediaCodecInformation:
sampling_frequency = (data[0] >> 4) & 0x0F
channel_mode = (data[0] >> 0) & 0x0F
block_length = (data[1] >> 4) & 0x0F
Expand All @@ -293,14 +293,14 @@ def from_bytes(data: bytes) -> 'SbcMediaCodecInformation':
@classmethod
def from_discrete_values(
cls,
sampling_frequency,
channel_mode,
block_length,
subbands,
allocation_method,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequency: int,
channel_mode: int,
block_length: int,
subbands: int,
allocation_method: int,
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
channel_mode=cls.CHANNEL_MODE_BITS[channel_mode],
Expand All @@ -314,14 +314,14 @@ def from_discrete_values(
@classmethod
def from_lists(
cls,
sampling_frequencies,
channel_modes,
block_lengths,
subbands,
allocation_methods,
minimum_bitpool_value,
maximum_bitpool_value,
):
sampling_frequencies: List[int],
channel_modes: List[int],
block_lengths: List[int],
subbands: List[int],
allocation_methods: List[int],
minimum_bitpool_value: int,
maximum_bitpool_value: int,
) -> SbcMediaCodecInformation:
return SbcMediaCodecInformation(
sampling_frequency=sum(
cls.SAMPLING_FREQUENCY_BITS[x] for x in sampling_frequencies
Expand All @@ -348,7 +348,7 @@ def __bytes__(self) -> bytes:
]
)

def __str__(self):
def __str__(self) -> str:
channel_modes = ['MONO', 'DUAL_CHANNEL', 'STEREO', 'JOINT_STEREO']
allocation_methods = ['SNR', 'Loudness']
return '\n'.join(
Expand All @@ -367,16 +367,19 @@ def __str__(self):


# -----------------------------------------------------------------------------
class AacMediaCodecInformation(
namedtuple(
'AacMediaCodecInformation',
['object_type', 'sampling_frequency', 'channels', 'rfa', 'vbr', 'bitrate'],
)
):
@dataclasses.dataclass
class AacMediaCodecInformation:
'''
A2DP spec - 4.5.2 Codec Specific Information Elements
'''

object_type: int
sampling_frequency: int
channels: int
rfa: int
vbr: int
bitrate: int

OBJECT_TYPE_BITS = {
MPEG_2_AAC_LC_OBJECT_TYPE: 1 << 7,
MPEG_4_AAC_LC_OBJECT_TYPE: 1 << 6,
Expand All @@ -400,7 +403,7 @@ class AacMediaCodecInformation(
CHANNELS_BITS = {1: 1 << 1, 2: 1}

@staticmethod
def from_bytes(data: bytes) -> 'AacMediaCodecInformation':
def from_bytes(data: bytes) -> AacMediaCodecInformation:
object_type = data[0]
sampling_frequency = (data[1] << 4) | ((data[2] >> 4) & 0x0F)
channels = (data[2] >> 2) & 0x03
Expand All @@ -413,8 +416,13 @@ def from_bytes(data: bytes) -> 'AacMediaCodecInformation':

@classmethod
def from_discrete_values(
cls, object_type, sampling_frequency, channels, vbr, bitrate
):
cls,
object_type: int,
sampling_frequency: int,
channels: int,
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=cls.OBJECT_TYPE_BITS[object_type],
sampling_frequency=cls.SAMPLING_FREQUENCY_BITS[sampling_frequency],
Expand All @@ -425,7 +433,14 @@ def from_discrete_values(
)

@classmethod
def from_lists(cls, object_types, sampling_frequencies, channels, vbr, bitrate):
def from_lists(
cls,
object_types: List[int],
sampling_frequencies: List[int],
channels: List[int],
vbr: int,
bitrate: int,
) -> AacMediaCodecInformation:
return AacMediaCodecInformation(
object_type=sum(cls.OBJECT_TYPE_BITS[x] for x in object_types),
sampling_frequency=sum(
Expand All @@ -449,7 +464,7 @@ def __bytes__(self) -> bytes:
]
)

def __str__(self):
def __str__(self) -> str:
object_types = [
'MPEG_2_AAC_LC',
'MPEG_4_AAC_LC',
Expand All @@ -474,26 +489,26 @@ def __str__(self):
)


@dataclasses.dataclass
# -----------------------------------------------------------------------------
class VendorSpecificMediaCodecInformation:
'''
A2DP spec - 4.7.2 Codec Specific Information Elements
'''

vendor_id: int
codec_id: int
value: bytes

@staticmethod
def from_bytes(data):
def from_bytes(data: bytes) -> VendorSpecificMediaCodecInformation:
(vendor_id, codec_id) = struct.unpack_from('<IH', data, 0)
return VendorSpecificMediaCodecInformation(vendor_id, codec_id, data[6:])

def __init__(self, vendor_id, codec_id, value):
self.vendor_id = vendor_id
self.codec_id = codec_id
self.value = value

def __bytes__(self):
def __bytes__(self) -> bytes:
return struct.pack('<IH', self.vendor_id, self.codec_id, self.value)

def __str__(self):
def __str__(self) -> str:
# pylint: disable=line-too-long
return '\n'.join(
[
Expand All @@ -506,29 +521,27 @@ def __str__(self):


# -----------------------------------------------------------------------------
@dataclasses.dataclass
class SbcFrame:
def __init__(
self, sampling_frequency, block_count, channel_mode, subband_count, payload
):
self.sampling_frequency = sampling_frequency
self.block_count = block_count
self.channel_mode = channel_mode
self.subband_count = subband_count
self.payload = payload
sampling_frequency: int
block_count: int
channel_mode: int
subband_count: int
payload: bytes

@property
def sample_count(self):
def sample_count(self) -> int:
return self.subband_count * self.block_count

@property
def bitrate(self):
def bitrate(self) -> int:
return 8 * ((len(self.payload) * self.sampling_frequency) // self.sample_count)

@property
def duration(self):
def duration(self) -> float:
return self.sample_count / self.sampling_frequency

def __str__(self):
def __str__(self) -> str:
return (
f'SBC(sf={self.sampling_frequency},'
f'cm={self.channel_mode},'
Expand All @@ -540,12 +553,12 @@ def __str__(self):

# -----------------------------------------------------------------------------
class SbcParser:
def __init__(self, read):
def __init__(self, read: Callable[[int], Awaitable[bytes]]) -> None:
self.read = read

@property
def frames(self):
async def generate_frames():
def frames(self) -> AsyncGenerator[SbcFrame, None]:
async def generate_frames() -> AsyncGenerator[SbcFrame, None]:
while True:
# Read 4 bytes of header
header = await self.read(4)
Expand Down Expand Up @@ -589,7 +602,9 @@ async def generate_frames():

# -----------------------------------------------------------------------------
class SbcPacketSource:
def __init__(self, read, mtu, codec_capabilities):
def __init__(
self, read: Callable[[int], Awaitable[bytes]], mtu: int, codec_capabilities
) -> None:
self.read = read
self.mtu = mtu
self.codec_capabilities = codec_capabilities
Expand Down