Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] New Packet ABC #430

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 178 additions & 1 deletion bumble/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,28 @@
# Imports
# -----------------------------------------------------------------------------
from __future__ import annotations

import abc
import dataclasses
import enum
import struct
from typing import List, Optional, Tuple, Union, cast, Dict
from typing import (
Any,
List,
Callable,
Optional,
Tuple,
Union,
ClassVar,
Type,
TypeVar,
Literal,
Dict,
cast,
get_args,
get_origin,
)
from typing_extensions import Self, Annotated

from .company_ids import COMPANY_IDENTIFIERS

Expand Down Expand Up @@ -1066,3 +1085,161 @@ class LeRole(enum.IntEnum):
CENTRAL_ONLY = 0x01
BOTH_PERIPHERAL_PREFERRED = 0x02
BOTH_CENTRAL_PREFERRED = 0x03


# -----------------------------------------------------------------------------
# Data Unit
# -----------------------------------------------------------------------------
@dataclasses.dataclass
class FieldSpec:
format: Union[int, str, Type[List[DataUnit]], None] = None
mapper: Optional[Callable] = None
serializer: Optional[Callable[..., bytes]] = None
deserializer: Optional[Callable[[bytes, int], Tuple[int, Any]]] = None


class DataUnit(abc.ABC):
@classmethod
def init_from_bytes(cls: Type[Self], data: bytes) -> Self:
return cls.parse_from_bytes(data, 0)[0]

@classmethod
def parse_from_bytes(cls: Type[Self], data: bytes, offset: int) -> Tuple[Self, int]:
kwargs = {}
for field, annotation in cls.__annotations__.items():
if hasattr(annotation, '__metadata__'):
kwargs[field], offset = cls.parse_field(
data,
offset,
annotation.__metadata__[0],
)
elif get_origin(annotation) is List and (args := get_args(annotation)):
field_value = []
size = data[offset]
offset += 1
list_type: Type[DataUnit] = args[0]
for _ in range(size):
item, offset = list_type.parse_from_bytes(data, offset)
field_value.append(item)
kwargs[field] = field_value

return (cls(**kwargs), offset)

@classmethod
def parse_field(
cls: Type[Self],
data: bytes,
offset: int,
field_spec: FieldSpec,
) -> Tuple[Any, int]:
# Parse the field
if field_spec.format == '*':
# The rest of the bytes
field_value = data[offset:]
return (field_value, len(field_value))
if field_spec.format == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_length = data[offset]
offset += 1
field_value = data[offset : offset + field_length]
return (field_value, field_length + 1)
if field_spec.format == 1:
# 8-bit unsigned
return (data[offset], 1)
if field_spec.format == -1:
# 8-bit signed
return (struct.unpack_from('b', data, offset)[0], 1)
if field_spec.format == 2:
# 16-bit unsigned
return (struct.unpack_from('<H', data, offset)[0], 2)
if field_spec.format == '>2':
# 16-bit unsigned big-endian
return (struct.unpack_from('>H', data, offset)[0], 2)
if field_spec.format == -2:
# 16-bit signed
return (struct.unpack_from('<h', data, offset)[0], 2)
if field_spec.format == 3:
# 24-bit unsigned
padded = data[offset : offset + 3] + bytes([0])
return (struct.unpack('<I', padded)[0], 3)
if field_spec.format == 4:
# 32-bit unsigned
return (struct.unpack_from('<I', data, offset)[0], 4)
if field_spec.format == '>4':
# 32-bit unsigned big-endian
return (struct.unpack_from('>I', data, offset)[0], 4)
if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256:
# Byte array (from 5 up to 256 bytes)
return (data[offset : offset + field_spec.format], field_spec.format)
if field_spec.deserializer:
new_offset, field_value = field_spec.deserializer(data, offset)
return (field_value, new_offset - offset)

raise ValueError(f'Unknown field type {field_spec}')

@classmethod
def serialize_field(
cls: Type[Self], field_value: Any, field_spec: FieldSpec
) -> bytes:
# Serialize the field
if field_spec.serializer:
field_bytes = field_spec.serializer(field_value)
elif field_spec.format == 1:
# 8-bit unsigned
field_bytes = bytes([field_value])
elif field_spec.format == -1:
# 8-bit signed
field_bytes = struct.pack('b', field_value)
elif field_spec.format == 2:
# 16-bit unsigned
field_bytes = struct.pack('<H', field_value)
elif field_spec.format == '>2':
# 16-bit unsigned big-endian
field_bytes = struct.pack('>H', field_value)
elif field_spec.format == -2:
# 16-bit signed
field_bytes = struct.pack('<h', field_value)
elif field_spec.format == 3:
# 24-bit unsigned
field_bytes = struct.pack('<I', field_value)[0:3]
elif field_spec.format == 4:
# 32-bit unsigned
field_bytes = struct.pack('<I', field_value)
elif field_spec.format == '>4':
# 32-bit unsigned big-endian
field_bytes = struct.pack('>I', field_value)
elif field_spec.format == '*':
if isinstance(field_value, int):
if 0 <= field_value <= 255:
field_bytes = bytes([field_value])
else:
raise ValueError('value too large for *-typed field')
else:
field_bytes = bytes(field_value)
elif field_spec.format == 'v':
# Variable-length bytes field, with 1-byte length at the beginning
field_bytes = bytes(field_value)
field_length = len(field_bytes)
field_bytes = bytes([field_length]) + field_bytes
elif isinstance(field_value, (bytes, bytearray)) or hasattr(
field_value, 'to_bytes'
):
field_bytes = bytes(field_value)
if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_spec.format:
field_bytes += bytes(field_spec.format - len(field_bytes))
elif len(field_bytes) > field_spec.format:
field_bytes = field_bytes[: field_spec.format]
elif isinstance(field_value, list):
field_bytes = bytes(field_value)
if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256:
# Truncate or pad with zeros if the field is too long or too short
if len(field_bytes) < field_spec.format:
field_bytes += bytes(field_spec.format - len(field_bytes))
elif len(field_bytes) > field_spec.format:
field_bytes = field_bytes[: field_spec.format]
else:
raise ValueError(f"don't know how to serialize type {type(field_value)}")

return field_bytes
Loading