From a47fbb1104adefb792cd64edc833822bf4c15972 Mon Sep 17 00:00:00 2001 From: Josh Wu Date: Thu, 8 Feb 2024 01:11:44 +0800 Subject: [PATCH] New Packet ABC --- bumble/core.py | 179 ++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 178 insertions(+), 1 deletion(-) diff --git a/bumble/core.py b/bumble/core.py index dce721a4..aae9f3ff 100644 --- a/bumble/core.py +++ b/bumble/core.py @@ -16,9 +16,28 @@ # Imports # ----------------------------------------------------------------------------- from __future__ import annotations + +import abc +import dataclasses import enum import struct -from typing import List, Optional, Tuple, Union, cast, Dict +from typing import ( + Any, + List, + Callable, + Optional, + Tuple, + Union, + ClassVar, + Type, + TypeVar, + Literal, + Dict, + cast, + get_args, + get_origin, +) +from typing_extensions import Self, Annotated from .company_ids import COMPANY_IDENTIFIERS @@ -1066,3 +1085,161 @@ class LeRole(enum.IntEnum): CENTRAL_ONLY = 0x01 BOTH_PERIPHERAL_PREFERRED = 0x02 BOTH_CENTRAL_PREFERRED = 0x03 + + +# ----------------------------------------------------------------------------- +# Data Unit +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class FieldSpec: + format: Union[int, str, Type[List[DataUnit]], None] = None + mapper: Optional[Callable] = None + serializer: Optional[Callable[..., bytes]] = None + deserializer: Optional[Callable[[bytes, int], Tuple[int, Any]]] = None + + +class DataUnit(abc.ABC): + @classmethod + def init_from_bytes(cls: Type[Self], data: bytes) -> Self: + return cls.parse_from_bytes(data, 0)[0] + + @classmethod + def parse_from_bytes(cls: Type[Self], data: bytes, offset: int) -> Tuple[Self, int]: + kwargs = {} + for field, annotation in cls.__annotations__.items(): + if hasattr(annotation, '__metadata__'): + kwargs[field], offset = cls.parse_field( + data, + offset, + annotation.__metadata__[0], + ) + elif get_origin(annotation) is List and (args := get_args(annotation)): + field_value = [] + size = data[offset] + offset += 1 + list_type: Type[DataUnit] = args[0] + for _ in range(size): + item, offset = list_type.parse_from_bytes(data, offset) + field_value.append(item) + kwargs[field] = field_value + + return (cls(**kwargs), offset) + + @classmethod + def parse_field( + cls: Type[Self], + data: bytes, + offset: int, + field_spec: FieldSpec, + ) -> Tuple[Any, int]: + # Parse the field + if field_spec.format == '*': + # The rest of the bytes + field_value = data[offset:] + return (field_value, len(field_value)) + if field_spec.format == 'v': + # Variable-length bytes field, with 1-byte length at the beginning + field_length = data[offset] + offset += 1 + field_value = data[offset : offset + field_length] + return (field_value, field_length + 1) + if field_spec.format == 1: + # 8-bit unsigned + return (data[offset], 1) + if field_spec.format == -1: + # 8-bit signed + return (struct.unpack_from('b', data, offset)[0], 1) + if field_spec.format == 2: + # 16-bit unsigned + return (struct.unpack_from('2': + # 16-bit unsigned big-endian + return (struct.unpack_from('>H', data, offset)[0], 2) + if field_spec.format == -2: + # 16-bit signed + return (struct.unpack_from('4': + # 32-bit unsigned big-endian + return (struct.unpack_from('>I', data, offset)[0], 4) + if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256: + # Byte array (from 5 up to 256 bytes) + return (data[offset : offset + field_spec.format], field_spec.format) + if field_spec.deserializer: + new_offset, field_value = field_spec.deserializer(data, offset) + return (field_value, new_offset - offset) + + raise ValueError(f'Unknown field type {field_spec}') + + @classmethod + def serialize_field( + cls: Type[Self], field_value: Any, field_spec: FieldSpec + ) -> bytes: + # Serialize the field + if field_spec.serializer: + field_bytes = field_spec.serializer(field_value) + elif field_spec.format == 1: + # 8-bit unsigned + field_bytes = bytes([field_value]) + elif field_spec.format == -1: + # 8-bit signed + field_bytes = struct.pack('b', field_value) + elif field_spec.format == 2: + # 16-bit unsigned + field_bytes = struct.pack('2': + # 16-bit unsigned big-endian + field_bytes = struct.pack('>H', field_value) + elif field_spec.format == -2: + # 16-bit signed + field_bytes = struct.pack('4': + # 32-bit unsigned big-endian + field_bytes = struct.pack('>I', field_value) + elif field_spec.format == '*': + if isinstance(field_value, int): + if 0 <= field_value <= 255: + field_bytes = bytes([field_value]) + else: + raise ValueError('value too large for *-typed field') + else: + field_bytes = bytes(field_value) + elif field_spec.format == 'v': + # Variable-length bytes field, with 1-byte length at the beginning + field_bytes = bytes(field_value) + field_length = len(field_bytes) + field_bytes = bytes([field_length]) + field_bytes + elif isinstance(field_value, (bytes, bytearray)) or hasattr( + field_value, 'to_bytes' + ): + field_bytes = bytes(field_value) + if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256: + # Truncate or pad with zeros if the field is too long or too short + if len(field_bytes) < field_spec.format: + field_bytes += bytes(field_spec.format - len(field_bytes)) + elif len(field_bytes) > field_spec.format: + field_bytes = field_bytes[: field_spec.format] + elif isinstance(field_value, list): + field_bytes = bytes(field_value) + if isinstance(field_spec.format, int) and 4 < field_spec.format <= 256: + # Truncate or pad with zeros if the field is too long or too short + if len(field_bytes) < field_spec.format: + field_bytes += bytes(field_spec.format - len(field_bytes)) + elif len(field_bytes) > field_spec.format: + field_bytes = field_bytes[: field_spec.format] + else: + raise ValueError(f"don't know how to serialize type {type(field_value)}") + + return field_bytes