From 5eccce59f03d59c76e60c6aebb3292e3a2918b25 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 01/21] feat(compression): implement model_editor for TFLite model manipulation Implement unified module for creating, reading, and modifying TFLite models with a clean API. The module eliminates manual index tracking and buffer management through automatic bookkeeping, supporting both declarative and imperative construction styles. Core design uses first-class Buffer objects that can be shared between tensors, with automatic deduplication during build. Tensors reference Buffers directly, matching the TFLite schema structure. The compiler automatically extracts inline tensor declarations, builds operator code tables, and handles index assignment according to TFLite conventions. Supports quantization parameters (per-tensor and per-channel), metadata key-value pairs, and read-modify-write workflows. The read() function preserves the object graph structure, enabling models to be read, modified, and rebuilt. Add comprehensive test coverage for core functionality, advanced features, quantization, and modification workflows. --- tensorflow/lite/micro/compression/BUILD | 22 + .../lite/micro/compression/model_editor.py | 557 +++++++++++++ .../micro/compression/model_editor_test.py | 735 ++++++++++++++++++ 3 files changed, 1314 insertions(+) create mode 100644 tensorflow/lite/micro/compression/model_editor.py create mode 100644 tensorflow/lite/micro/compression/model_editor_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index afb19abc425..01876f98122 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -237,6 +237,28 @@ py_test( ], ) +py_library( + name = "model_editor", + srcs = ["model_editor.py"], + deps = [ + "//tensorflow/lite/python:schema_py", + requirement("flatbuffers"), + requirement("numpy"), + ], +) + +py_test( + name = "model_editor_test", + size = "small", + srcs = ["model_editor_test.py"], + deps = [ + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + requirement("tensorflow"), + ], +) + py_binary( name = "view", srcs = [ diff --git a/tensorflow/lite/micro/compression/model_editor.py b/tensorflow/lite/micro/compression/model_editor.py new file mode 100644 index 00000000000..541636b7e42 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor.py @@ -0,0 +1,557 @@ +# Copyright 2024 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unified TFLite model manipulation module. + +Provides a clean API for creating, reading, and modifying TFLite models. +""" + +from dataclasses import dataclass, field +from typing import Optional, Union, List +import numpy as np +import flatbuffers +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class _BufferList(list): + """Custom list that auto-sets buffer.index on append. + + When a buffer is appended, automatically sets buffer.index to its position. + This enables append-only workflows to work seamlessly. + """ + + def append(self, buf): + """Append buffer and auto-set its index.""" + buf.index = len(self) + super().append(buf) + + +@dataclass +class Buffer: + """Buffer holding tensor data. + + The index field indicates the buffer's position in the model's buffer array. + It is automatically populated during: + - read(): Set from flatbuffer + - build(): Set during compilation + - model.buffers.append(): Auto-set to len(model.buffers) - 1 + + The index may become stale after: + - Deleting buffers from model.buffers + - Reordering buffers in model.buffers + + For append-only workflows (the common case), buffer.index can be trusted. + """ + data: bytes + index: Optional[int] = None + + def __len__(self): + return len(self.data) + + def __bytes__(self): + return self.data + + +@dataclass +class Quantization: + """Quantization parameters helper.""" + scales: Union[float, List[float]] + zero_points: Union[int, List[int]] = 0 + axis: Optional[int] = None + + def to_tflite(self) -> tflite.QuantizationParametersT: + """Convert to TFLite schema object.""" + q = tflite.QuantizationParametersT() + + # Normalize to lists + scales = [self.scales] if isinstance(self.scales, + (int, float)) else self.scales + zeros = [self.zero_points] if isinstance(self.zero_points, + int) else self.zero_points + + q.scale = scales + q.zeroPoint = zeros + if self.axis is not None: + q.quantizedDimension = self.axis + + return q + + +@dataclass +class Tensor: + """Declarative tensor specification. + + Supports both buffer= and data= parameters for flexibility: + - buffer=: Explicitly provide a Buffer object (can be shared between tensors) + - data=: Convenience parameter that auto-creates a Buffer + + Cannot specify both buffer and data at initialization. + """ + shape: tuple + dtype: tflite.TensorType + buffer: Optional[Buffer] = None + quantization: Optional[Quantization] = None + name: Optional[str] = None + + # Internal field for data initialization only + _data_init: Optional[Union[bytes, np.ndarray]] = field(default=None, + init=False, + repr=False) + + # Auto-populated during build/read + _index: Optional[int] = field(default=None, init=False, repr=False) + + def __init__(self, + shape, + dtype, + buffer=None, + data=None, + quantization=None, + name=None): + """Initialize Tensor. + + Args: + shape: Tensor shape as tuple + dtype: TensorType enum value + buffer: Optional Buffer object (for explicit buffer sharing) + data: Optional numpy array or bytes (convenience parameter, creates Buffer) + quantization: Optional Quantization object + name: Optional tensor name + + Raises: + ValueError: If both buffer and data are specified + """ + if data is not None and buffer is not None: + raise ValueError("Cannot specify both data and buffer") + + self.shape = shape + self.dtype = dtype + self.buffer = buffer + self.quantization = quantization + self.name = name + self._index = None + + # Convert data to buffer if provided + if data is not None: + buf_data = data if isinstance(data, bytes) else data.tobytes() + self.buffer = Buffer(data=buf_data) + + @property + def array(self) -> Optional[np.ndarray]: + """Get tensor data as properly-shaped numpy array. + + Returns: + numpy array with shape matching tensor.shape and dtype matching + tensor.dtype, or None if tensor has no data. + + For low-level byte access, use tensor.buffer.data instead. + """ + if self.buffer is None: + return None + return np.frombuffer(self.buffer.data, + dtype=_dtype_to_numpy(self.dtype)).reshape(self.shape) + + @array.setter + def array(self, value: np.ndarray): + """Set tensor data from numpy array. + + Args: + value: New tensor data as numpy array. Will be converted to bytes + using tobytes() and stored in the buffer. + + Creates a new Buffer if tensor has no buffer, or updates the existing + buffer's data in place. + + For low-level byte access, use tensor.buffer.data instead. + """ + buf_data = value.tobytes() + if self.buffer is None: + self.buffer = Buffer(data=buf_data) + else: + self.buffer.data = buf_data + + @property + def index(self) -> Optional[int]: + """Tensor index in the subgraph's tensor list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + @property + def numpy_dtype(self) -> np.dtype: + """Get numpy dtype corresponding to tensor's TFLite dtype. + + Returns: + numpy dtype object for use with np.frombuffer, np.array, etc. + """ + return _dtype_to_numpy(self.dtype) + + +@dataclass +class OperatorCode: + """Operator code specification.""" + builtin_code: tflite.BuiltinOperator + custom_code: Optional[str] = None + version: int = 1 + + +@dataclass +class Operator: + """Declarative operator specification.""" + opcode: Union[tflite.BuiltinOperator, int] + inputs: List[Tensor] + outputs: List[Tensor] + custom_code: Optional[str] = None + + # Set when reading from existing model + opcode_index: Optional[int] = None + + _index: Optional[int] = field(default=None, init=False, repr=False) + + +@dataclass +class Subgraph: + """Declarative subgraph specification with imperative methods.""" + tensors: List[Tensor] = field(default_factory=list) + operators: List[Operator] = field(default_factory=list) + name: Optional[str] = None + + _index: Optional[int] = field(default=None, init=False, repr=False) + + def add_tensor(self, **kwargs) -> Tensor: + """Add tensor imperatively and return it.""" + t = Tensor(**kwargs) + t._index = len(self.tensors) + self.tensors.append(t) + return t + + def add_operator(self, **kwargs) -> Operator: + """Add operator imperatively and return it.""" + op = Operator(**kwargs) + op._index = len(self.operators) + self.operators.append(op) + return op + + @property + def index(self) -> Optional[int]: + """Subgraph index in the model's subgraph list. + + Returns index after read() or build(). May be None or stale after + modifications. Use with caution. + """ + return self._index + + +@dataclass +class Model: + """Top-level model specification.""" + subgraphs: List[Subgraph] = field(default_factory=list) + buffers: _BufferList = field( + default_factory=_BufferList) # Auto-sets buffer.index on append + operator_codes: List[OperatorCode] = field(default_factory=list) + metadata: dict = field(default_factory=dict) + description: Optional[str] = None + + def add_subgraph(self, **kwargs) -> Subgraph: + """Add subgraph imperatively and return it.""" + sg = Subgraph(**kwargs) + sg._index = len(self.subgraphs) + self.subgraphs.append(sg) + return sg + + def build(self) -> bytearray: + """Compile to flatbuffer with automatic bookkeeping.""" + compiler = _ModelCompiler(self) + return compiler.compile() + + +def read(buffer: bytes) -> Model: + """Read a TFLite flatbuffer and return a Model object.""" + fb_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) + + # Create Model with basic fields + # Decode bytes to strings where needed + description = fb_model.description + if isinstance(description, bytes): + description = description.decode('utf-8') + + model = Model(description=description) + + # Create all buffers first (so tensors can reference them) + for i, fb_buf in enumerate(fb_model.buffers): + buf_data = bytes(fb_buf.data) if fb_buf.data is not None else b'' + buf = Buffer(data=buf_data, index=i) + model.buffers.append(buf) + + # Read operator codes + for fb_opcode in fb_model.operatorCodes: + custom_code = fb_opcode.customCode + if isinstance(custom_code, bytes): + custom_code = custom_code.decode('utf-8') + + opcode = OperatorCode( + builtin_code=fb_opcode.builtinCode, + custom_code=custom_code, + version=fb_opcode.version if fb_opcode.version else 1) + model.operator_codes.append(opcode) + + # Read subgraphs + for sg_idx, fb_sg in enumerate(fb_model.subgraphs): + sg = Subgraph() + sg._index = sg_idx + + # Read tensors + for tensor_idx, fb_tensor in enumerate(fb_sg.tensors): + # Decode tensor name + name = fb_tensor.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Create tensor referencing the appropriate buffer + # Buffer 0 is the empty buffer (TFLite convention), so treat it as None + buf = None if fb_tensor.buffer == 0 else model.buffers[fb_tensor.buffer] + + # Read quantization parameters if present + quant = None + if fb_tensor.quantization: + fb_quant = fb_tensor.quantization + if fb_quant.scale is not None and len(fb_quant.scale) > 0: + # Quantization parameters present + scales = list(fb_quant.scale) + zeros = list( + fb_quant.zeroPoint + ) if fb_quant.zeroPoint is not None else [0] * len(scales) + + # Handle axis: only set if per-channel (more than one scale) + axis = None + if len(scales) > 1 and fb_quant.quantizedDimension is not None: + axis = fb_quant.quantizedDimension + + quant = Quantization(scales=scales, zero_points=zeros, axis=axis) + + tensor = Tensor(shape=tuple(fb_tensor.shape), + dtype=fb_tensor.type, + buffer=buf, + name=name, + quantization=quant) + tensor._index = tensor_idx + + sg.tensors.append(tensor) + + # Read operators + for fb_op in fb_sg.operators: + # Get operator code info + opcode_obj = model.operator_codes[fb_op.opcodeIndex] + + op = Operator(opcode=opcode_obj.builtin_code, + inputs=[sg.tensors[i] for i in fb_op.inputs], + outputs=[sg.tensors[i] for i in fb_op.outputs], + custom_code=opcode_obj.custom_code, + opcode_index=fb_op.opcodeIndex) + sg.operators.append(op) + + model.subgraphs.append(sg) + + # Read metadata + if fb_model.metadata: + for entry in fb_model.metadata: + # Decode metadata name + name = entry.name + if isinstance(name, bytes): + name = name.decode('utf-8') + + # Get metadata value from buffer + buffer = fb_model.buffers[entry.buffer] + value = bytes(buffer.data) if buffer.data is not None else b'' + + model.metadata[name] = value + + return model + + +def _dtype_to_numpy(dtype: tflite.TensorType) -> np.dtype: + """Convert TFLite dtype to numpy dtype.""" + type_map = { + tflite.TensorType.INT8: np.int8, + tflite.TensorType.INT16: np.int16, + tflite.TensorType.INT32: np.int32, + tflite.TensorType.UINT8: np.uint8, + tflite.TensorType.FLOAT32: np.float32, + } + return type_map.get(dtype, np.uint8) + + +class _ModelCompiler: + """Internal: compiles Model to flatbuffer with automatic bookkeeping.""" + + def __init__(self, model: Model): + self.model = model + self._buffers = [] + self._buffer_map = {} # Map Buffer object id to index + self._operator_codes = {} + + def compile(self) -> bytearray: + """Compile model to flatbuffer.""" + root = tflite.ModelT() + root.version = 3 + + # Set description + root.description = self.model.description + + # Initialize buffers + # If model.buffers exists (from read()), preserve those buffers + if self.model.buffers: + for buf in self.model.buffers: + fb_buf = tflite.BufferT() + fb_buf.data = list(buf.data) if buf.data else [] + self._buffers.append(fb_buf) + self._buffer_map[id(buf)] = buf.index + else: + # Creating model from scratch: initialize buffer 0 as empty (TFLite convention) + empty_buffer = tflite.BufferT() + empty_buffer.data = [] + self._buffers = [empty_buffer] + # Note: buffer 0 should not be in _buffer_map since tensors without data use it + + # Auto-collect and register operator codes + self._collect_operator_codes() + root.operatorCodes = list(self._operator_codes.values()) + + # Process subgraphs + root.subgraphs = [] + for sg in self.model.subgraphs: + root.subgraphs.append(self._compile_subgraph(sg)) + + # Process buffers + root.buffers = self._buffers + + # Process metadata + root.metadata = self._compile_metadata() + + # Pack and return + builder = flatbuffers.Builder(4 * 2**20) + builder.Finish(root.Pack(builder)) + return builder.Output() + + def _collect_operator_codes(self): + """Scan all operators and build operator code table.""" + for sg in self.model.subgraphs: + for op in sg.operators: + key = (op.opcode, op.custom_code) + if key not in self._operator_codes: + opcode = tflite.OperatorCodeT() + opcode.builtinCode = op.opcode + if op.custom_code: + opcode.customCode = op.custom_code + self._operator_codes[key] = opcode + + def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: + """Compile subgraph, extracting inline tensors from operators.""" + sg_t = tflite.SubGraphT() + sg_t.name = sg.name + + # Collect all tensors (from tensor list and inline in operators) + all_tensors = list(sg.tensors) + tensor_to_index = {} + for i, t in enumerate(all_tensors): + t._index = i + tensor_to_index[id(t)] = i + + # Extract inline tensors from operators + for op in sg.operators: + for tensor in op.inputs + op.outputs: + if id(tensor) not in tensor_to_index: + tensor._index = len(all_tensors) + tensor_to_index[id(tensor)] = tensor._index + all_tensors.append(tensor) + + # Compile all tensors + sg_t.tensors = [] + for tensor in all_tensors: + sg_t.tensors.append(self._compile_tensor(tensor)) + + # Compile operators + sg_t.operators = [] + for op in sg.operators: + sg_t.operators.append(self._compile_operator(op, tensor_to_index)) + + return sg_t + + def _compile_operator(self, op: Operator, + tensor_to_index: dict) -> tflite.OperatorT: + """Compile operator, resolving tensor references and opcodes.""" + op_t = tflite.OperatorT() + + # Get opcode index + key = (op.opcode, op.custom_code) + opcode_index = list(self._operator_codes.keys()).index(key) + op_t.opcodeIndex = opcode_index + + # Resolve tensor references to indices + op_t.inputs = [tensor_to_index[id(inp)] for inp in op.inputs] + op_t.outputs = [tensor_to_index[id(outp)] for outp in op.outputs] + + return op_t + + def _compile_tensor(self, tensor: Tensor) -> tflite.TensorT: + """Compile tensor, reusing or creating buffer as needed.""" + t = tflite.TensorT() + t.shape = list(tensor.shape) + t.type = tensor.dtype + t.name = tensor.name + + # Handle buffer assignment + if tensor.buffer is None: + # No data: use buffer 0 + t.buffer = 0 + else: + # Has buffer: get or create index for it + buf_id = id(tensor.buffer) + if buf_id not in self._buffer_map: + # First time seeing this buffer, add it + fb_buf = tflite.BufferT() + fb_buf.data = list(tensor.buffer.data) + self._buffers.append(fb_buf) + buf_index = len(self._buffers) - 1 + self._buffer_map[buf_id] = buf_index + tensor.buffer.index = buf_index + t.buffer = self._buffer_map[buf_id] + + # Handle quantization + if tensor.quantization: + t.quantization = tensor.quantization.to_tflite() + + return t + + def _compile_metadata(self): + """Compile metadata, creating buffers for metadata values.""" + if not self.model.metadata: + return [] + + metadata_entries = [] + for name, value in self.model.metadata.items(): + # Create buffer for metadata value + buf = tflite.BufferT() + buf.data = list(value) if isinstance(value, bytes) else list(value) + self._buffers.append(buf) + buf_index = len(self._buffers) - 1 + + # Create metadata entry + entry = tflite.MetadataT() + entry.name = name + entry.buffer = buf_index + metadata_entries.append(entry) + + return metadata_entries diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py new file mode 100644 index 00000000000..a6c5de56629 --- /dev/null +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -0,0 +1,735 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for model_editor module. +""" + +import numpy as np +import tensorflow as tf +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression.model_editor import ( + Buffer, Model, Operator, OperatorCode, Quantization, Subgraph, Tensor) + + +class TestBasicModel(tf.test.TestCase): + """Test basic model with tensors and operators.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5]], dtype=np.int8) + cls.weights_data = np.array([[1], [2], [3], [4], [5]], dtype=np.int8) + + cls.model = Model( + description="Test model", + subgraphs=[ + Subgraph(operators=[ + Operator(opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 5), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(5, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + # Build the model to a flatbuffer byte array. This exercises the + # model_editor's build path, which converts the high-level Model API + # representation into the binary TFLite format. + fb = cls.model.build() + + # Read the flatbuffer back through model_editor.read() to create a + # loopback model. This exercises the read path, which parses the + # flatbuffer and reconstructs a high-level Model representation. The + # loopback model should be semantically equivalent to cls.model, + # demonstrating that build() and read() are inverse operations. + cls.loopback_model = model_editor.read(fb) + + # Parse the same flatbuffer using the low-level TFLite schema interface + # (ModelT from schema_py_generated). This provides direct access to the + # raw flatbuffer structure, allowing us to verify that model_editor + # encodes data correctly at the binary level. We compare fb_model + # (low-level) against loopback_model (high-level) to ensure both + # representations are consistent. + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_description(self): + """Verify model description is preserved through loopback.""" + self.assertEqual(self.fb_model.description, b"Test model") + self.assertEqual(self.loopback_model.description, "Test model") + + def test_counts(self): + """Verify subgraph, tensor, and operator counts.""" + self.assertEqual(len(self.fb_model.subgraphs), 1) + self.assertEqual(len(self.loopback_model.subgraphs), 1) + + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.tensors), 3) + self.assertEqual(len(loopback_sg.tensors), 3) + + self.assertEqual(len(fb_sg.operators), 1) + self.assertEqual(len(loopback_sg.operators), 1) + + def test_tensor_names(self): + """Verify tensor names are preserved.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Check that all expected tensor names are present + fb_names = {t.name for t in fb_sg.tensors} + self.assertEqual(fb_names, {b"input", b"weights", b"output"}) + + loopback_names = {t.name for t in loopback_sg.tensors} + self.assertEqual(loopback_names, {"input", "weights", "output"}) + + def test_tensor_properties(self): + """Verify tensor shapes and dtypes.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor + input_fb = next(t for t in fb_sg.tensors if t.name == b"input") + input_loopback = next(t for t in loopback_sg.tensors if t.name == "input") + self.assertEqual(list(input_fb.shape), [1, 5]) + self.assertEqual(input_loopback.shape, (1, 5)) + self.assertEqual(input_fb.type, tflite.TensorType.INT8) + self.assertEqual(input_loopback.dtype, tflite.TensorType.INT8) + + # Weights tensor + weights_fb = next(t for t in fb_sg.tensors if t.name == b"weights") + weights_loopback = next(t for t in loopback_sg.tensors + if t.name == "weights") + self.assertEqual(list(weights_fb.shape), [5, 1]) + self.assertEqual(weights_loopback.shape, (5, 1)) + self.assertEqual(weights_fb.type, tflite.TensorType.INT8) + self.assertEqual(weights_loopback.dtype, tflite.TensorType.INT8) + + # Output tensor + output_fb = next(t for t in fb_sg.tensors if t.name == b"output") + output_loopback = next(t for t in loopback_sg.tensors + if t.name == "output") + self.assertEqual(list(output_fb.shape), [1, 1]) + self.assertEqual(output_loopback.shape, (1, 1)) + self.assertEqual(output_fb.type, tflite.TensorType.INT8) + self.assertEqual(output_loopback.dtype, tflite.TensorType.INT8) + + def test_tensor_data(self): + """Verify tensor data and buffer access.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Input tensor data + input_buffer = self.fb_model.buffers[fb_sg.tensors[0].buffer] + self.assertIsNotNone(input_buffer.data) + self.assertEqual(bytes(input_buffer.data), self.input_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[0].array) + self.assertAllEqual(loopback_sg.tensors[0].array, self.input_data) + + # Weights tensor data + weights_buffer = self.fb_model.buffers[fb_sg.tensors[1].buffer] + self.assertIsNotNone(weights_buffer.data) + self.assertEqual(bytes(weights_buffer.data), self.weights_data.tobytes()) + + self.assertIsNotNone(loopback_sg.tensors[1].array) + self.assertAllEqual(loopback_sg.tensors[1].array, self.weights_data) + + # Output tensor has no data + self.assertEqual(fb_sg.tensors[2].buffer, 0) + self.assertIsNone(loopback_sg.tensors[2].array) + + def test_buffer_allocation(self): + """Verify buffer allocation and zero convention.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Exact buffer count: buffer 0 (empty) + input + weights = 3 total + self.assertEqual(len(self.fb_model.buffers), 3) + self.assertEqual(len(self.loopback_model.buffers), 3) + + # Buffer 0 is empty + buffer_zero = self.fb_model.buffers[0] + self.assertTrue(buffer_zero.data is None or len(buffer_zero.data) == 0) + + # Verify each buffer is referenced by exactly the expected tensor + # Buffer 0 -> output tensor (no data) + output_tensor = next(t for t in fb_sg.tensors if t.name == b"output") + self.assertEqual(output_tensor.buffer, 0) + + # Buffer 1 and 2 -> input and weights (order may vary) + input_tensor = next(t for t in fb_sg.tensors if t.name == b"input") + weights_tensor = next(t for t in fb_sg.tensors if t.name == b"weights") + self.assertNotEqual(input_tensor.buffer, 0) + self.assertNotEqual(weights_tensor.buffer, 0) + self.assertIn(input_tensor.buffer, [1, 2]) + self.assertIn(weights_tensor.buffer, [1, 2]) + + # Tensors with data point to non-zero buffers in loopback model + loopback_input_tensor = next(t for t in loopback_sg.tensors + if t.name == "input") + self.assertIsNotNone(loopback_input_tensor.buffer) + self.assertIsNotNone(loopback_input_tensor.buffer.index) + self.assertNotEqual(loopback_input_tensor.buffer.index, 0) + self.assertEqual(len(loopback_input_tensor.buffer.data), 5) + self.assertEqual(bytes(loopback_input_tensor.buffer.data), + self.input_data.tobytes()) + + def test_operator_references(self): + """Verify operators reference correct tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Operator input/output references + self.assertEqual(len(fb_sg.operators[0].inputs), 2) + self.assertEqual([t.name for t in loopback_sg.operators[0].inputs], + ["input", "weights"]) + + self.assertEqual(len(fb_sg.operators[0].outputs), 1) + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["output"]) + + # Operator indices are in bounds + num_tensors = len(fb_sg.tensors) + for idx in list(fb_sg.operators[0].inputs) + list( + fb_sg.operators[0].outputs): + self.assertGreaterEqual(idx, 0) + self.assertLess(idx, num_tensors) + + def test_operator_codes(self): + """Verify operator code table is correctly populated.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertIsNotNone(self.fb_model.operatorCodes) + self.assertEqual(len(self.fb_model.operatorCodes), 1) + self.assertEqual(self.fb_model.operatorCodes[0].builtinCode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + self.assertEqual(len(self.loopback_model.operator_codes), 1) + self.assertIsNotNone(loopback_sg.operators[0].opcode_index) + loopback_opcode = self.loopback_model.operator_codes[ + loopback_sg.operators[0].opcode_index] + self.assertEqual(loopback_opcode.builtin_code, + tflite.BuiltinOperator.FULLY_CONNECTED) + + +class TestAdvancedModel(tf.test.TestCase): + """Test multiple operators, custom ops, shared tensors, and mixed references.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + cls.input_data = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], dtype=np.int8) + cls.weights_data = np.array( + [[1], [2], [3], [4], [5], [6], [7], [8], [9], [10]], dtype=np.int8) + cls.bias_data = np.array([10], dtype=np.int8) + # Int16 data to test endianness: values that will show byte order issues + cls.int16_data = np.array([256, 512, 1024], + dtype=np.int16) # 0x0100, 0x0200, 0x0400 + + # Pre-declare shared tensor (output of FC, input to custom op) + cls.hidden = Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="hidden") + + # Create explicit shared buffer to test buffer sharing between tensors + cls.shared_buffer_data = np.array([100, 127], dtype=np.int8) + cls.shared_buf = Buffer(data=cls.shared_buffer_data.tobytes()) + + cls.model = Model( + description="Advanced model", + metadata={ + "version": b"1.0.0", + "author": b"test_suite", + "custom_data": bytes([0xDE, 0xAD, 0xBE, 0xEF]) + }, + subgraphs=[ + Subgraph( + tensors=[ + cls.hidden, # Mixed: pre-declared shared tensor + # Int16 tensor to test endianness + Tensor(shape=(3, ), + dtype=tflite.TensorType.INT16, + data=cls.int16_data, + name="int16_tensor"), + # Two tensors sharing same buffer to test buffer deduplication + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor1"), + Tensor(shape=(2, ), + dtype=tflite.TensorType.INT8, + buffer=cls.shared_buf, + name="shared_buf_tensor2") + ], + operators=[ + # Multiple operators: FULLY_CONNECTED + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[ + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=cls.input_data, + name="input"), + Tensor(shape=(10, 1), + dtype=tflite.TensorType.INT8, + data=cls.weights_data, + name="weights") + ], + outputs=[cls.hidden + ] # Shared: reference to pre-declared + ), + # Custom operator + Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code="MyCustomOp", + inputs=[cls.hidden], # Shared: reuse hidden tensor + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed") + ]), + # Multiple operators: ADD + Operator( + opcode=tflite.BuiltinOperator.ADD, + inputs=[ + Tensor( + shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="processed_ref" # Mixed: inline tensor + ), + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + data=cls.bias_data, + name="bias") + ], + outputs=[ + Tensor(shape=(1, 1), + dtype=tflite.TensorType.INT8, + name="output") + ]) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_operator_counts(self): + """Verify correct number of operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + self.assertEqual(len(fb_sg.operators), 3) + self.assertEqual(len(loopback_sg.operators), 3) + + def test_operator_code_table(self): + """Verify operator code table contains all operator types.""" + self.assertEqual(len(self.fb_model.operatorCodes), 3) + self.assertEqual(len(self.loopback_model.operator_codes), 3) + + opcodes_fb = {op.builtinCode for op in self.fb_model.operatorCodes} + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_fb) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_fb) + + opcodes_loopback = { + op.builtin_code + for op in self.loopback_model.operator_codes + } + self.assertIn(tflite.BuiltinOperator.FULLY_CONNECTED, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.CUSTOM, opcodes_loopback) + self.assertIn(tflite.BuiltinOperator.ADD, opcodes_loopback) + + def test_custom_operator(self): + """Verify custom operator code preservation.""" + loopback_sg = self.loopback_model.subgraphs[0] + + # Custom code in operator code table + custom_opcode_fb = next(op for op in self.fb_model.operatorCodes + if op.builtinCode == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_fb.customCode, b"MyCustomOp") + + custom_opcode_loopback = next( + op for op in self.loopback_model.operator_codes + if op.builtin_code == tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_opcode_loopback.custom_code, "MyCustomOp") + + # Custom operator references custom code + custom_op_loopback = loopback_sg.operators[1] + self.assertEqual(custom_op_loopback.opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(custom_op_loopback.custom_code, "MyCustomOp") + + def test_shared_tensor_references(self): + """Verify tensors shared between operators.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Hidden tensor is at index 0 (pre-declared) + self.assertEqual(fb_sg.tensors[0].name, b"hidden") + self.assertEqual(loopback_sg.tensors[0].name, "hidden") + + # FC operator outputs to hidden + self.assertEqual([t.name for t in loopback_sg.operators[0].outputs], + ["hidden"]) + + # Custom operator inputs from hidden + self.assertEqual([t.name for t in loopback_sg.operators[1].inputs], + ["hidden"]) + + # Same Tensor object is referenced by both operators + fc_output = loopback_sg.operators[0].outputs[0] + custom_input = loopback_sg.operators[1].inputs[0] + self.assertIs(fc_output, custom_input) + + def test_mixed_tensor_references(self): + """Verify mix of pre-declared and inline tensors.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Total: hidden, int16_tensor, shared_buf_tensor1, shared_buf_tensor2 (pre-declared) + # + input, weights, processed, processed_ref, bias, output (inline from operators) + self.assertEqual(len(fb_sg.tensors), 10) + self.assertEqual(len(loopback_sg.tensors), 10) + + def test_int16_endianness(self): + """Verify int16 data is stored in little-endian byte order.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Find int16 tensor by name + int16_tensor_fb = next(t for t in fb_sg.tensors + if t.name == b"int16_tensor") + int16_tensor_loopback = next(t for t in loopback_sg.tensors + if t.name == "int16_tensor") + + # Verify dtype + self.assertEqual(int16_tensor_fb.type, tflite.TensorType.INT16) + self.assertEqual(int16_tensor_loopback.dtype, tflite.TensorType.INT16) + + # Check flatbuffer buffer has correct little-endian bytes + # For [256, 512, 1024] = [0x0100, 0x0200, 0x0400] + # Little-endian bytes: [0x00, 0x01, 0x00, 0x02, 0x00, 0x04] + int16_buffer_fb = self.fb_model.buffers[int16_tensor_fb.buffer] + self.assertIsNotNone(int16_buffer_fb.data) + expected_bytes = self.int16_data.astype(np.int16).astype('buffer mapping from flatbuffer + metadata_map_fb = {} + for entry in self.fb_model.metadata: + buffer_idx = entry.buffer + self.assertLess(buffer_idx, len(self.fb_model.buffers)) + buffer = self.fb_model.buffers[buffer_idx] + if buffer.data is not None: + metadata_map_fb[entry.name] = bytes(buffer.data) + + # Verify flatbuffer metadata values + self.assertEqual(metadata_map_fb[b"version"], b"1.0.0") + self.assertEqual(metadata_map_fb[b"author"], b"test_suite") + self.assertEqual(metadata_map_fb[b"custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + # Check loopback model metadata + self.assertIsNotNone(self.loopback_model.metadata) + self.assertEqual(len(self.loopback_model.metadata), 3) + + # Verify loopback metadata values (decoded from bytes) + self.assertEqual(self.loopback_model.metadata["version"], b"1.0.0") + self.assertEqual(self.loopback_model.metadata["author"], b"test_suite") + self.assertEqual(self.loopback_model.metadata["custom_data"], + bytes([0xDE, 0xAD, 0xBE, 0xEF])) + + def test_buffer_allocation(self): + """Verify no orphaned buffers and shared buffer deduplication.""" + fb_sg = self.fb_model.subgraphs[0] + loopback_sg = self.loopback_model.subgraphs[0] + + # Collect all buffer references (from tensors and metadata) + referenced_buffers = {0} # Buffer 0 is special (always referenced) + + # Collect buffer references from tensors + for tensor in fb_sg.tensors: + referenced_buffers.add(tensor.buffer) + + # Collect buffer references from metadata + for entry in self.fb_model.metadata: + referenced_buffers.add(entry.buffer) + + # Verify no orphaned buffers (all buffers are referenced) + for i in range(len(self.fb_model.buffers)): + self.assertIn( + i, referenced_buffers, + f"Buffer {i} is orphaned (not referenced by any tensor or metadata)") + + # Verify shared buffer deduplication: two tensors share one buffer + tensor1_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor1") + tensor2_fb = next(t for t in fb_sg.tensors + if t.name == b"shared_buf_tensor2") + + # Both tensors should point to the same buffer index + self.assertEqual(tensor1_fb.buffer, tensor2_fb.buffer) + self.assertNotEqual(tensor1_fb.buffer, 0) + + # Verify loopback preserves shared buffer (same Buffer object) + tensor1_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor1") + tensor2_loopback = next(t for t in loopback_sg.tensors + if t.name == "shared_buf_tensor2") + + self.assertIs(tensor1_loopback.buffer, tensor2_loopback.buffer) + self.assertEqual(bytes(tensor1_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + self.assertEqual(bytes(tensor2_loopback.buffer.data), + self.shared_buffer_data.tobytes()) + + +class TestQuantization(tf.test.TestCase): + """Test per-tensor and per-channel quantization parameters.""" + + @classmethod + def setUpClass(cls): + """Build model once for all tests in this class.""" + # Per-channel quantization parameters + cls.per_channel_scales = [0.1, 0.2, 0.3, 0.4] + cls.per_channel_zeros = [0, 1, 2, 3] + + cls.model = Model( + description="Quantization test model", + subgraphs=[ + Subgraph(tensors=[ + # Per-tensor quantized tensor (single scale/zero_point) + Tensor(shape=(1, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((1, 10), dtype=np.int8), + name="per_tensor", + quantization=Quantization(scales=0.5, zero_points=10)), + # Per-channel quantized tensor (array of scales/zero_points, axis) + Tensor(shape=(4, 10), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 10), dtype=np.int8), + name="per_channel", + quantization=Quantization( + scales=cls.per_channel_scales, + zero_points=cls.per_channel_zeros, + axis=0)) + ]) + ]) + + fb = cls.model.build() + cls.loopback_model = model_editor.read(fb) + cls.fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + + def test_per_tensor_quantization_flatbuffer(self): + """Verify per-tensor quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Scale and zero_point encoded as single-element arrays + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 1) + self.assertEqual(tensor.quantization.scale[0], 0.5) + + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 1) + self.assertEqual(tensor.quantization.zeroPoint[0], 10) + + def test_per_tensor_quantization_loopback(self): + """Verify per-tensor quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_tensor") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, [0.5]) + self.assertEqual(tensor.quantization.zero_points, [10]) + self.assertIsNone(tensor.quantization.axis) + + def test_per_channel_quantization_flatbuffer(self): + """Verify per-channel quantization in flatbuffer encoding.""" + fb_sg = self.fb_model.subgraphs[0] + + tensor = next(t for t in fb_sg.tensors if t.name == b"per_channel") + self.assertIsNotNone(tensor.quantization) + + # All scales encoded + self.assertIsNotNone(tensor.quantization.scale) + self.assertEqual(len(tensor.quantization.scale), 4) + self.assertEqual(list(tensor.quantization.scale), self.per_channel_scales) + + # All zero_points encoded + self.assertIsNotNone(tensor.quantization.zeroPoint) + self.assertEqual(len(tensor.quantization.zeroPoint), 4) + self.assertEqual(list(tensor.quantization.zeroPoint), + self.per_channel_zeros) + + # Axis encoded as quantizedDimension + self.assertEqual(tensor.quantization.quantizedDimension, 0) + + def test_per_channel_quantization_loopback(self): + """Verify per-channel quantization in loopback model.""" + loopback_sg = self.loopback_model.subgraphs[0] + + tensor = next(t for t in loopback_sg.tensors if t.name == "per_channel") + self.assertIsNotNone(tensor.quantization) + + # Read back as lists + self.assertEqual(tensor.quantization.scales, self.per_channel_scales) + self.assertEqual(tensor.quantization.zero_points, self.per_channel_zeros) + self.assertEqual(tensor.quantization.axis, 0) + + +class TestReadModifyWrite(tf.test.TestCase): + """Test read-modify-write workflows.""" + + @classmethod + def setUpClass(cls): + """Create a simple base model for modification tests.""" + cls.original_data = np.array([[1, 2, 3]], dtype=np.int8) + cls.model = Model( + description="Base model", + metadata={"original": b"metadata"}, + subgraphs=[ + Subgraph(tensors=[ + Tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=cls.original_data, + name="weights"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="input"), + Tensor( + shape=(1, 3), dtype=tflite.TensorType.INT8, name="output") + ]) + ]) + + cls.fb = cls.model.build() + + def test_modify_tensor_data(self): + """Read model, modify tensor data, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify tensor data using array setter (high-level API) + weights_tensor = next(t for t in model2.subgraphs[0].tensors + if t.name == "weights") + new_data = np.array([[10, 20, 30]], dtype=np.int8) + weights_tensor.array = new_data # Uses array setter + + # Build modified model + fb2 = model2.build() + + # Read back and verify modification + model3 = model_editor.read(fb2) + modified_weights = next(t for t in model3.subgraphs[0].tensors + if t.name == "weights") + self.assertAllEqual(modified_weights.array, new_data) + + # Verify other tensors unchanged + self.assertEqual(len(model3.subgraphs[0].tensors), 3) + + def test_add_tensor_and_operator(self): + """Read model, add new tensor and operator, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + sg = model2.subgraphs[0] + + # Get existing tensors + input_tensor = next(t for t in sg.tensors if t.name == "input") + output_tensor = next(t for t in sg.tensors if t.name == "output") + + # Add new tensor using imperative API + new_weights = np.array([[5, 10, 15]], dtype=np.int8) + new_weights_tensor = sg.add_tensor(shape=(1, 3), + dtype=tflite.TensorType.INT8, + data=new_weights, + name="new_weights") + + # Add new operator using imperative API + sg.add_operator(opcode=tflite.BuiltinOperator.ADD, + inputs=[input_tensor, new_weights_tensor], + outputs=[output_tensor]) + + # Build modified model + fb2 = model2.build() + + # Read back and verify additions + model3 = model_editor.read(fb2) + sg3 = model3.subgraphs[0] + + # Verify tensor was added + self.assertEqual(len(sg3.tensors), 4) + added_tensor = next(t for t in sg3.tensors if t.name == "new_weights") + self.assertIsNotNone(added_tensor) + self.assertAllEqual(added_tensor.array, new_weights) + + # Verify operator was added + self.assertEqual(len(sg3.operators), 1) + added_op = sg3.operators[0] + self.assertEqual([t.name for t in added_op.inputs], + ["input", "new_weights"]) + self.assertEqual([t.name for t in added_op.outputs], ["output"]) + + def test_modify_metadata(self): + """Read model, modify metadata, write back, verify.""" + # Read the model + model2 = model_editor.read(self.fb) + + # Modify existing metadata + model2.metadata["original"] = b"modified_metadata" + + # Add new metadata + model2.metadata["new_key"] = b"new_value" + + # Build modified model + fb2 = model2.build() + + # Read back and verify modifications + model3 = model_editor.read(fb2) + + self.assertEqual(len(model3.metadata), 2) + self.assertEqual(model3.metadata["original"], b"modified_metadata") + self.assertEqual(model3.metadata["new_key"], b"new_value") + + +if __name__ == "__main__": + tf.test.main() From 953b734e8bf1f0888fa1aa9e651be84f5c570bba Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 02/21] refactor(compression): migrate compress.py from model_facade to model_editor Replace model_facade with model_editor in compress.py and tests. model_editor provides a cleaner API with better buffer and metadata handling. Buffers appended during compression are automatically indexed, and quantization parameters are accessed through a wrapper object. Update BUILD dependencies accordingly. --- tensorflow/lite/micro/compression/BUILD | 4 +-- tensorflow/lite/micro/compression/compress.py | 36 ++++++++++--------- .../lite/micro/compression/compress_test.py | 19 +++++----- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 01876f98122..9ddbe1024ae 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -115,7 +115,7 @@ py_library( ], deps = [ ":metadata_py", - ":model_facade", + ":model_editor", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), @@ -149,7 +149,7 @@ py_test( deps = [ ":compress", ":metadata_py", - ":model_facade", + ":model_editor", ":spec", ":test_models", "//tensorflow/lite/python:schema_py", diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 79959a7f612..61d3ce8df17 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -29,7 +29,7 @@ import flatbuffers import numpy as np -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper @@ -177,7 +177,7 @@ def _check_lut_compression(compression) -> spec.LookUpTableCompression: return compression[0] -def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: +def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: """Determines the axis along which to compress. The axis along which to compress is inferred from the tensor's quantization @@ -191,16 +191,18 @@ def _identify_compression_axis(tensor: model_facade._Tensor) -> Optional[int]: CompressionError: If the axis cannot be determined. """ q = tensor.quantization - if q is not None \ - and q.scale is not None \ - and q.quantizedDimension < len(tensor.shape): - quantization_channels = len(q.scale) + if q is not None: + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + if quantization_channels == 1: # Use one value table for the entire tensor return None - if quantization_channels == tensor.shape[q.quantizedDimension]: - return q.quantizedDimension + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis raise CompressionError( f"Invalid or no quanitzation parameters from which to " @@ -300,7 +302,7 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: Returns: A compressed flatbuffer. """ - model = model_facade.read(model_in) + model = model_editor.read(model_in) metadata = _MetadataBuilder() for spec in specs: @@ -316,12 +318,14 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) # write value buffer - value_buffer = model.add_buffer() - value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, + value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, 2**spec_bitwidth) + value_buffer = model_editor.Buffer(data=value_buffer_data) + model.buffers.append(value_buffer) # Auto-sets value_buffer.index + # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) - lut_tensor.tensor = tensor.index + lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) + lut_tensor.tensor = spec.tensor lut_tensor.valueBuffer = value_buffer.index lut_tensor.indexBitwidth = spec_bitwidth @@ -329,10 +333,10 @@ def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: raise CompressionError(f"error compressing {spec}") from e # add compression metadata to model - model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) + model.metadata[TFLITE_METADATA_KEY] = metadata.compile() - # Compile the model and apply proper alignment - unaligned_model = model.compile() + # Build the model and apply proper alignment + unaligned_model = model.build() return _apply_flatbuffer_alignment(unaligned_model) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 1167c421f6e..e40a2969df3 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -19,7 +19,7 @@ from tflite_micro.tensorflow.lite.micro.compression import compress from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema -from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -368,12 +368,12 @@ class TestsCompression(tf.test.TestCase): def setUpClass(cls): super().setUpClass() cls.flatbuffer = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(cls.flatbuffer) + cls.uncompressed = model_editor.read(cls.flatbuffer) def test_compression_metadata(self): """The compressed model has compression metadata.""" compressed = compress.compress(self.flatbuffer, TEST_COMPRESSION_SPEC) - model = model_facade.read(compressed) + model = model_editor.read(compressed) self.assertIn("metadata0", self.uncompressed.metadata) self.assertIn(compress.TFLITE_METADATA_KEY, model.metadata) @@ -461,16 +461,17 @@ def setUpClass(cls): super().setUpClass() # Create a model uncompressed_fb = test_models.build(TEST_MODEL) - cls.uncompressed = model_facade.read(uncompressed_fb) + cls.uncompressed = model_editor.read(uncompressed_fb) # Compress the model compressed_fb = compress.compress(uncompressed_fb, TEST_COMPRESSION_SPEC) - cls.compressed = model_facade.read(compressed_fb) + cls.compressed = model_editor.read(compressed_fb) # Extract the compression metadata - metadata_flatbuffer = cls.compressed.metadata[compress.TFLITE_METADATA_KEY] - cls.metadata = schema.MetadataT.InitFromPackedBuf(metadata_flatbuffer.data, - 0) + metadata_flatbuffer_bytes = cls.compressed.metadata[ + compress.TFLITE_METADATA_KEY] + cls.metadata = schema.MetadataT.InitFromPackedBuf( + metadata_flatbuffer_bytes, 0) def test_uncompressed_tensors(self): """Tensors not in compression spec are not compressed. @@ -515,7 +516,7 @@ def _get_compressed( indices = indices[:n_indices * bitwidth] # trim possible padding value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.dtype) + values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) return bitwidth, indices, values From fdf9c601e8bc9276266523957328e86316f6d8be Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 03/21] chore(compression): remove model_facade.py and model_facade_test.py Remove model_facade module and its tests as they are superseded by model_editor. --- tensorflow/lite/micro/compression/BUILD | 20 -- .../lite/micro/compression/model_facade.py | 276 ------------------ .../micro/compression/model_facade_test.py | 144 --------- 3 files changed, 440 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/model_facade.py delete mode 100644 tensorflow/lite/micro/compression/model_facade_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 9ddbe1024ae..ec3c903ce22 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -159,26 +159,6 @@ py_test( ], ) -py_library( - name = "model_facade", - srcs = ["model_facade.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - ], -) - -py_test( - name = "model_facade_test", - size = "small", - srcs = ["model_facade_test.py"], - deps = [ - ":model_facade", - ":test_models", - requirement("tensorflow"), - ], -) - py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py deleted file mode 100644 index 2e58d8080f1..00000000000 --- a/tensorflow/lite/micro/compression/model_facade.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""A facade for working with tflite.Model. - -This module provides convenient navigation, data type conversions, and -utilities for working with a tflite.Model, which can be tedious and verbose to -work with directly. - -Usage: - model = model_facade.read(flatbuffer) - # manipulate - new_flatbuffer = model.compile() -""" - -from __future__ import annotations - -import flatbuffers -import numpy as np -from numpy.typing import NDArray -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from typing import ByteString, Generic, TypeVar - -_IteratorTo = TypeVar("_IteratorTo") - - -class _Iterator(Generic[_IteratorTo]): - - def __init__(self, sequence, cls, parent): - self._sequence = sequence - self._cls = cls - self._index = 0 - self._parent = parent - - def __getitem__(self, key) -> _IteratorTo: - return self._cls(self._sequence[key], key, self._parent) - - def __len__(self): - return len(self._sequence) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _IndirectIterator(Generic[_IteratorTo]): - - def __init__(self, indices, sequence): - self._indices = indices - self._index = 0 - self._sequence = sequence - - def __getitem__(self, key) -> _IteratorTo: - index = self._indices[key] - return self._sequence[index] - - def __len__(self): - return len(self._indices) - - def __iter__(self): - self._index = 0 - return self - - def __next__(self): - try: - result = self[self._index] - self._index += 1 - return result - except IndexError: - raise StopIteration - - -class _Operator: - - def __init__(self, operator, index, subgraph): - self.operator = operator - self.index = index - self.subgraph = subgraph - - @property - def opcode(self) -> tflite.OperatorCodeT: - return self.subgraph.model.operatorCodes[self.operator.opcodeIndex] - - @property - def inputs(self): - return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) - - -_NP_DTYPES = { - tflite.TensorType.FLOAT16: np.dtype(" _Buffer: - return self.subgraph.model.buffers[self._tensor_t.buffer] - - @property - def data(self) -> bytes: - return self.buffer.data - - @property - def dtype(self) -> np.dtype: - return _NP_DTYPES[self._tensor_t.type] - - @property - def array(self) -> np.ndarray: - """Returns an array created from the Tensor's data, type, and shape. - - Note the bytes in the data buffer and the Tensor's type and shape may be - inconsistent, and thus the returned array invalid, if the data buffer has - been altered according to the compression schema, in which the data buffer - is an array of fixed-width, integer fields. - """ - return np.frombuffer(self.data, - dtype=self.dtype).reshape(self._tensor_t.shape) - - @property - def quantization(self) -> tflite.QuantizationParametersT | None: - return self._tensor_t.quantization - - -class _Buffer: - - def __init__(self, buffer_t: tflite.BufferT, index, model): - self._buffer_t = buffer_t - self.index = index - self.model = model - - @property - def data(self) -> bytes: - return bytes(self._buffer_t.data) - - @data.setter - def data(self, value: ByteString): - self._buffer_t.data = list(value) - - def extend(self, values: NDArray): - self._buffer_t.data.extend(values.tobytes()) - - -class _Subgraph: - - def __init__(self, subgraph_t: tflite.SubGraphT, index: int, model: _Model): - self._subgraph_t = subgraph_t - self.index = index - self.model = model - - @property - def operators(self) -> _Iterator[_Operator]: - return _Iterator(self._subgraph_t.operators, _Operator, parent=self) - - @property - def tensors(self) -> _Iterator[_Tensor]: - return _Iterator(self._subgraph_t.tensors, _Tensor, parent=self) - - -class _Model: - """A facade for manipulating tflite.Model. - """ - - def __init__(self, model_t: tflite.ModelT): - self._model_t = model_t - - def compile(self) -> bytearray: - """Returns a tflite.Model flatbuffer. - """ - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self._model_t.Pack(builder)) - return builder.Output() - - def add_buffer(self) -> _Buffer: - """Adds a buffer to the model. - """ - buffer = tflite.BufferT() - buffer.data = [] - self._model_t.buffers.append(buffer) - index = len(self._model_t.buffers) - 1 - return _Buffer(buffer, index, self._model_t) - - def add_metadata(self, key, value): - """Adds a key-value pair, writing value to a newly created buffer. - """ - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self._model_t.metadata.append(metadata) - - @property - def metadata(self) -> dict[str, _Buffer]: - """Returns the model's metadata as a dictionary to Buffer objects. - """ - result = {} - for m in self._model_t.metadata: - name = m.name.decode("utf-8") # type: ignore (fb library is wrong) - buffer = _Buffer(self._model_t.buffers[m.buffer], m.buffer, - self._model_t) - result[name] = buffer - - return result - - @property - def operatorCodes(self): - return self._model_t.operatorCodes - - @property - def subgraphs(self) -> _Iterator[_Subgraph]: - return _Iterator(self._model_t.subgraphs, _Subgraph, parent=self) - - @property - def buffers(self) -> _Iterator[_Buffer]: - return _Iterator(self._model_t.buffers, _Buffer, parent=self) - - -def read(buffer: ByteString): - """Reads a tflite.Model and returns a model facade. - """ - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py deleted file mode 100644 index e931e578f2b..00000000000 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import tensorflow as tf -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -from tflite_micro.tensorflow.lite.micro.compression import model_facade -from tflite_micro.tensorflow.lite.micro.compression import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - 1: { - "name": "metadata1", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (8, 1), - "type": tflite.TensorType.INT16, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (4, 1), - "type": tflite.TensorType.INT32, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (2, 1), - "type": tflite.TensorType.INT64, - "buffer": 4, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 04/21] refactor(compression): replace test_models with model_editor in compress_test Replace dictionary-based test_models.build() with model_editor's declarative API. Add _build_test_model() function that uses model_editor to create the same test model more cleanly. --- tensorflow/lite/micro/compression/BUILD | 1 - .../lite/micro/compression/compress_test.py | 218 ++++++------------ 2 files changed, 66 insertions(+), 153 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index ec3c903ce22..e70c152e469 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -151,7 +151,6 @@ py_test( ":metadata_py", ":model_editor", ":spec", - ":test_models", "//tensorflow/lite/python:schema_py", requirement("bitarray"), requirement("numpy"), diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index e40a2969df3..41a40c8ae47 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -21,7 +21,6 @@ from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import test_models from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -170,153 +169,70 @@ def test_multiple_tables_with_padding(self): self.assertEqual(result, expected_output) -# yapf: disable -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 0, - "inputs": ( - 0, - 1, - ), - "outputs": (2, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 2, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT16, - "buffer": 3, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 4, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 4: { - "shape": (16, 1), - "type": tflite.TensorType.INT32, - "buffer": 5, - "quantization": { - "quantized_dimension": 1, - "scale": (1,), - "zero_point": (0,), - }, - }, - 5: { - "shape": (4, 5), - "type": tflite.TensorType.INT16, - "buffer": 6, - "quantization": { - "quantized_dimension": 1, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 6: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 7, - "quantization": { - "quantized_dimension": 0, - "scale": (1, 1, 1, 1, 1), - "zero_point": (0, 0, 0, 0, 0), - }, - }, - 7: { - "shape": (5, 4), - "type": tflite.TensorType.INT16, - "buffer": 8, - "quantization": { - "quantized_dimension": 0, - "scale": (1,), - "zero_point": (0,), - }, - }, - 8: { - "shape": (16, 1), - "type": tflite.TensorType.UINT8, - "buffer": 9, - }, - }, - }, - }, - "buffers": { - 0: None, - - 1: np.array(range(16), dtype=np.dtype(" Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 05/21] chore(compression): remove test_models.py and test_models_test.py Remove test_models module and its tests as they are superseded by model_editor. --- tensorflow/lite/micro/compression/BUILD | 21 -- .../lite/micro/compression/test_models.py | 190 ------------------ .../micro/compression/test_models_test.py | 32 --- 3 files changed, 243 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/test_models.py delete mode 100644 tensorflow/lite/micro/compression/test_models_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index e70c152e469..936c09a9134 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -195,27 +195,6 @@ py_test( ], ) -py_library( - name = "test_models", - srcs = ["test_models.py"], - deps = [ - "//tensorflow/lite/python:schema_py", - requirement("flatbuffers"), - requirement("numpy"), - ], -) - -py_test( - name = "test_models_test", - size = "small", - srcs = ["test_models_test.py"], - deps = [ - ":test_models", - "//tensorflow/lite/python:schema_py", - requirement("tensorflow"), - ], -) - py_library( name = "model_editor", srcs = ["model_editor.py"], diff --git a/tensorflow/lite/micro/compression/test_models.py b/tensorflow/lite/micro/compression/test_models.py deleted file mode 100644 index 80286d17359..00000000000 --- a/tensorflow/lite/micro/compression/test_models.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -"""Tools for constructing flatbuffers for testing. - -This module provides tools for constructing .tflite flatbuffers from a Python -dictionary representation of a model, a prototype of which can be found in -EXAMPLE_MODEL. - -Example usage: - model_definition = {...} # use EXAMPLE_MODEL as prototype - flatbuffer: bytearray = test_models.build(model_definition) -""" - -# This module must remain low-level and independent from any helpers in this -# project which make constructing model and flatbuffers easier, because this -# module is used to define tests for those helpers. - -import flatbuffers -import numpy as np -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - -EXAMPLE_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "metadata": { - 0: { - "name": "metadata0", - "buffer": 0 - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, - "inputs": ( - 0, - 1, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, - "inputs": ( - 3, - 2, - ), - "outputs": (4, ), - }, - }, - "tensors": { - 0: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 2: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 3: { - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - "quantization": { - "quantized_dimension": 0, - }, - }, - }, - }, - }, - "buffers": { - 0: None, - 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Builds a .tflite flatbuffer from a model definition. - - Args: - model_definition: A dictionary representation of the model, a prototype of - which can be found in the EXAMPLE_MODEL attribute of this module. - - Returns: - A tflite flatbuffer. - """ - root = tflite.ModelT() - description = model_definition.get("description") - if description is not None: - root.description = description - - root.operatorCodes = [] - for id, operator_code in model_definition["operator_codes"].items(): - assert id == len(root.operatorCodes) - opcode_t = tflite.OperatorCodeT() - root.operatorCodes.append(opcode_t) - opcode_t.builtinCode = operator_code["builtin_code"] - - root.metadata = [] - if "metadata" in model_definition: - for _, metadata in model_definition["metadata"].items(): - metadata_t = tflite.MetadataT() - metadata_t.name = metadata["name"] - metadata_t.buffer = metadata["buffer"] - root.metadata.append(metadata_t) - - root.subgraphs = [] - for id, subgraph in model_definition["subgraphs"].items(): - assert id == len(root.subgraphs) - subgraph_t = tflite.SubGraphT() - root.subgraphs.append(subgraph_t) - - subgraph_t.operators = [] - for id, operator in subgraph["operators"].items(): - assert id == len(subgraph_t.operators) - operator_t = tflite.OperatorT() - operator_t.opcodeIndex = operator["opcode_index"] - operator_t.inputs = operator["inputs"] - operator_t.outputs = operator["outputs"] - subgraph_t.operators.append(operator_t) - - subgraph_t.tensors = [] - for id, tensor in subgraph["tensors"].items(): - assert id == len(subgraph_t.tensors) - tensor_t = tflite.TensorT() - tensor_t.name = tensor.get("name", None) - tensor_t.shape = tensor["shape"] - tensor_t.type = tensor["type"] - tensor_t.buffer = tensor["buffer"] - - if "quantization" in tensor: - tensor_t.quantization = tflite.QuantizationParametersT() - tensor_t.quantization.quantizedDimension = \ - tensor["quantization"].get("quantized_dimension", None) - tensor_t.quantization.scale = \ - tensor["quantization"].get("scale", None) - tensor_t.quantization.zeroPoint = \ - tensor["quantization"].get("zero_point", None) - - subgraph_t.tensors.append(tensor_t) - - root.buffers = [] - for id, data in model_definition["buffers"].items(): - assert id == len(root.buffers) - buffer_t = tflite.BufferT() - - if data is None: - buffer_t.data = [] - elif isinstance(data, np.ndarray): - array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian - buffer_t.data = list(array.tobytes()) - else: - raise TypeError(f"buffer_id {id} must be None or an np.ndarray") - - root.buffers.append(buffer_t) - - size_hint = 1 * 2**20 - builder = flatbuffers.Builder(size_hint) - builder.Finish(root.Pack(builder)) - flatbuffer = builder.Output() - return flatbuffer diff --git a/tensorflow/lite/micro/compression/test_models_test.py b/tensorflow/lite/micro/compression/test_models_test.py deleted file mode 100644 index 854526f2118..00000000000 --- a/tensorflow/lite/micro/compression/test_models_test.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -from tflite_micro.tensorflow.lite.micro.compression import test_models -from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite - - -class TestBuild(tf.test.TestCase): - - def setUp(self): - self.flatbuffer = test_models.build(test_models.EXAMPLE_MODEL) - - def testNotDegenerate(self): - model = tflite.ModelT.InitFromPackedBuf(self.flatbuffer, 0) - self.assertEqual(model.operatorCodes[0].builtinCode, - tflite.BuiltinOperator.FULLY_CONNECTED) - - -if __name__ == "__main__": - tf.test.main() From f55c78081e755446cad82712aa61f7d361823e15 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 06/21] feat(compression): add DecodeType class for type-safe decode operations Add DecodeType class to replace raw integer decode_type field with named constants and factory methods. Provides predefined constants for built-in types (LUT, HUFFMAN, PRUNING) and a factory method for custom types (128-255). Custom types are automatically named with CUSTOM_{code} prefix for clarity in debugging. The class supports serialization via __int__() and comparison with both DecodeType objects and integers. Update DecodeCommonMetadata to use DecodeType and update tests to use named constants. Add decode module BUILD targets. --- tensorflow/lite/micro/compression/BUILD | 15 ++ tensorflow/lite/micro/compression/decode.py | 240 ++++++++++++++++++ .../lite/micro/compression/decode_test.py | 155 +++++++++++ 3 files changed, 410 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode.py create mode 100644 tensorflow/lite/micro/compression/decode_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 936c09a9134..72bd1e25d6d 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -158,6 +158,21 @@ py_test( ], ) +py_library( + name = "decode", + srcs = ["decode.py"], +) + +py_test( + name = "decode_test", + size = "small", + srcs = ["decode_test.py"], + deps = [ + ":decode", + requirement("tensorflow"), + ], +) + py_library( name = "spec", srcs = ["spec.py"], diff --git a/tensorflow/lite/micro/compression/decode.py b/tensorflow/lite/micro/compression/decode.py new file mode 100644 index 00000000000..ac6943a856b --- /dev/null +++ b/tensorflow/lite/micro/compression/decode.py @@ -0,0 +1,240 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE compression module.""" + +# Implements the DECODE operator compression scheme described in the +# "TFLM DECODE Operator Design" document, revised May 20, 2025. +# +# The DECODE operator transforms an encoded tensor, alongside a paired +# ancillary data tensor, into a tensor ready for use as input to any +# operator. For example, an encoded tensor might contain compressed +# data, while the paired ancillary data tensor holds the information +# necessary for decompression. The DECODE operator's output is a fully +# decompressed tensor. +# +# DECODE operators are inserted into the TfLite model subgraph +# immediately before each operation that uses a decodable tensor as +# input. +# +# Ancillary Data Tensor +# +# The ancillary data tensor contains the information necessary for +# decoding. It begins with a 16-byte DECODE Common Metadata (DCM) +# header, followed by decode-type-specific ancillary data. +# +# DECODE Common Metadata (DCM) +# +# Byte 0: Decode type +# 0-127: TFLM-supported decode operations (see below) +# 128-255: Custom operations requiring application-registered +# handlers +# +# Supported decode types: +# +# 0: LUT decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 1: Huffman decompression using Xtensa format decode tables +# INT8 and INT16 tensor types only, in reference and optimized +# code. +# +# 2: Pruning decompression +# All TFLM tensor types supported in reference and optimized +# code. +# +# 3-127: Reserved +# +# 128-255: Custom decode types +# Requires user-supplied encoding module and decoding ancillary +# data. +# +# Byte 1: DCM version (currently 1) +# +# Bytes 2-3: Reserved +# +# Bytes 4-15: User-defined +# Used by TFLM decode types to avoid requiring additional alignment +# of metadata or ancillary data. +# +# The 16-byte DCM size ensures that subsequent metadata and ancillary +# data are 128-bit aligned, which is required for some optimized +# decoding operations such as Xtensa LUT decompression. +# +# For TFLM decode types, ancillary data starts immediately after the +# DCM. For custom decode types, the location is determined by +# user-defined metadata. + +from dataclasses import dataclass +from typing import Protocol + + +class DecodeType: + """Decode operation type (0-255). + + Use predefined constants for built-in types or DecodeType.custom() + for custom types: + DecodeType.LUT # 0 + DecodeType.HUFFMAN # 1 + DecodeType.PRUNING # 2 + DecodeType.custom(200) # Custom type 128-255 + """ + + # Built-in decode types (class variables set after class definition) + LUT: 'DecodeType' + HUFFMAN: 'DecodeType' + PRUNING: 'DecodeType' + + def __init__(self, code: int, name: str = None): + """Initialize DecodeType. + + Args: + code: Integer code 0-255 + name: Optional name for the type. If not provided: + - Codes 0-127: Named "TYPE_{code}" + - Codes 128-255: Named "CUSTOM_{code}" + """ + if not 0 <= code <= 255: + raise ValueError(f"Decode type must be 0-255, got {code}") + self.code = code + + # Auto-generate name if not provided + if name is None: + self.name = f"CUSTOM_{code}" if code >= 128 else f"TYPE_{code}" + else: + self.name = name + + self._is_custom = code >= 128 + + @property + def is_custom(self) -> bool: + """True if this is a custom decode type (128-255).""" + return self._is_custom + + @classmethod + def custom(cls, code: int) -> 'DecodeType': + """Create custom decode type (128-255). + + Args: + code: Integer code 128-255 + + Returns: + DecodeType with name CUSTOM_{code} + """ + if not 128 <= code <= 255: + raise ValueError(f"Custom decode type must be 128-255, got {code}") + return cls(code) + + def __int__(self): + """Convert to integer for serialization.""" + return self.code + + def __eq__(self, other): + if isinstance(other, DecodeType): + return self.code == other.code + return self.code == other + + def __repr__(self): + return f"DecodeType.{self.name}({self.code})" + + +# Define built-in decode type constants +DecodeType.LUT = DecodeType(0, "LUT") +DecodeType.HUFFMAN = DecodeType(1, "HUFFMAN") +DecodeType.PRUNING = DecodeType(2, "PRUNING") + + +@dataclass +class DecodeCommonMetadata: + """16-byte DECODE Common Metadata (DCM) header. + + Attributes: + decode_type: Decode operation type. Use DecodeType constants or + DecodeType.custom(code) for custom types. + version: DCM version (currently 1). + user_data: 12 bytes of user-defined data (bytes 4-15 of DCM). Used by TFLM + decode types to avoid requiring additional alignment of metadata + or ancillary data. + """ + decode_type: DecodeType + version: int = 1 + user_data: bytes = b'\x00' * 12 + + def to_bytes(self) -> bytes: + """Serialize DCM to 16-byte sequence.""" + decode_code = int(self.decode_type) + if not 0 <= self.version <= 255: + raise ValueError(f"version must be 0-255, got {self.version}") + if len(self.user_data) < 12: + # Pad with zeros if user_data is too short + user_data = self.user_data + b'\x00' * (12 - len(self.user_data)) + else: + user_data = self.user_data[:12] + + result = bytearray(16) + result[0] = decode_code + result[1] = self.version + # bytes 2-3 remain zero (reserved) + result[4:16] = user_data + return bytes(result) + + +class AncillaryDataSerializer(Protocol): + """Protocol for objects that can serialize ancillary data.""" + + def to_bytes(self) -> bytes: + ... + + +@dataclass +class AncillaryDataTensor: + """Complete Ancillary Data Tensor (ADT): DCM + decode-type-specific data. + + The ADT is stored as a buffer in the TFLite model. It begins with a 16-byte + DCM header, followed by decode-type-specific ancillary data. + + Attributes: + dcm: The DECODE Common Metadata header. + ancillary_data: The decode-type-specific ancillary data, either as raw bytes + or as an object implementing the AncillaryDataSerializer + protocol. May be None if only the DCM is needed. + """ + dcm: DecodeCommonMetadata + ancillary_data: AncillaryDataSerializer | bytes | None = None + + def with_ancillary_data( + self, data: AncillaryDataSerializer | bytes) -> 'AncillaryDataTensor': + """Create new ADT with ancillary data added. + + Args: + data: Ancillary data to add, either as raw bytes or as an object + implementing AncillaryDataSerializer. + + Returns: + New AncillaryDataTensor with the specified ancillary data. + """ + return AncillaryDataTensor(self.dcm, data) + + def to_bytes(self) -> bytes: + """Serialize entire ADT to bytes. + + Returns: + Byte sequence containing DCM followed by ancillary data (if present). + """ + dcm_bytes = self.dcm.to_bytes() + if self.ancillary_data is None: + return dcm_bytes + if isinstance(self.ancillary_data, bytes): + return dcm_bytes + self.ancillary_data + return dcm_bytes + self.ancillary_data.to_bytes() diff --git a/tensorflow/lite/micro/compression/decode_test.py b/tensorflow/lite/micro/compression/decode_test.py new file mode 100644 index 00000000000..1882a667003 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_test.py @@ -0,0 +1,155 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from tflite_micro.tensorflow.lite.micro.compression import decode + + +class TestDecodeCommonMetadata(tf.test.TestCase): + + def testBasicSerialization(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + result = dcm.to_bytes() + + # Should be exactly 16 bytes + self.assertEqual(len(result), 16) + + # Byte 0: decode_type + self.assertEqual(result[0], 0) + + # Byte 1: version (default 1) + self.assertEqual(result[1], 1) + + # Bytes 2-3: reserved (should be zero) + self.assertEqual(result[2], 0) + self.assertEqual(result[3], 0) + + # Bytes 4-15: user_data (default all zeros) + self.assertEqual(result[4:16], b'\x00' * 12) + + def testCustomVersion(self): + dcm = decode.DecodeCommonMetadata(decode_type=1, version=2) + result = dcm.to_bytes() + + self.assertEqual(result[0], 1) + self.assertEqual(result[1], 2) + + def testUserData(self): + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data) + + def testUserDataPadding(self): + # User data shorter than 12 bytes should be padded with zeros + user_data = b'\x01\x02\x03' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + expected = b'\x01\x02\x03' + b'\x00' * 9 + self.assertEqual(result[4:16], expected) + + def testUserDataTruncation(self): + # User data longer than 12 bytes should be truncated + user_data = b'\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f' + dcm = decode.DecodeCommonMetadata(decode_type=0, user_data=user_data) + result = dcm.to_bytes() + + self.assertEqual(result[4:16], user_data[:12]) + + def testDecodeTypeRange(self): + # Valid decode types: 0-255 + decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT).to_bytes() + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(127)).to_bytes() + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.custom(255)).to_bytes() + + # Invalid decode types should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=decode.DecodeType(-1)).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata( + decode_type=decode.DecodeType(256)).to_bytes() + + def testVersionRange(self): + # Valid versions: 0-255 + decode.DecodeCommonMetadata(decode_type=0, version=0).to_bytes() + decode.DecodeCommonMetadata(decode_type=0, version=255).to_bytes() + + # Invalid versions should raise ValueError + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=-1).to_bytes() + with self.assertRaises(ValueError): + decode.DecodeCommonMetadata(decode_type=0, version=256).to_bytes() + + +class TestAncillaryDataTensor(tf.test.TestCase): + + def testDcmOnly(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.LUT) + adt = decode.AncillaryDataTensor(dcm) + result = adt.to_bytes() + + # Should be just the 16-byte DCM + self.assertEqual(len(result), 16) + self.assertEqual(result, dcm.to_bytes()) + + def testWithBytesAncillaryData(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.HUFFMAN) + ancillary = b'\xaa\xbb\xcc\xdd' + adt = decode.AncillaryDataTensor(dcm, ancillary) + result = adt.to_bytes() + + # Should be DCM + ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithAncillaryDataMethod(self): + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType.PRUNING) + adt = decode.AncillaryDataTensor(dcm) + + ancillary = b'\x11\x22\x33\x44' + adt_with_data = adt.with_ancillary_data(ancillary) + result = adt_with_data.to_bytes() + + # Original ADT should be unchanged + self.assertEqual(adt.to_bytes(), dcm.to_bytes()) + + # New ADT should have ancillary data + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], ancillary) + + def testWithSerializerProtocol(self): + # Test with an object that implements AncillaryDataSerializer + class MockSerializer: + + def to_bytes(self): + return b'\xff\xee\xdd\xcc' + + dcm = decode.DecodeCommonMetadata(decode_type=decode.DecodeType(3)) + serializer = MockSerializer() + adt = decode.AncillaryDataTensor(dcm, serializer) + result = adt.to_bytes() + + self.assertEqual(len(result), 20) + self.assertEqual(result[:16], dcm.to_bytes()) + self.assertEqual(result[16:], b'\xff\xee\xdd\xcc') + + +if __name__ == '__main__': + tf.test.main() From 62e28055483192f51268f14342a8a2c14676ffa2 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 07/21] docs(compression): add docstring to LookUpTableCompression --- tensorflow/lite/micro/compression/spec.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index 6f782e92d7a..e33f3ba575e 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -58,7 +58,11 @@ class Tensor: @dataclass class LookUpTableCompression(CompressionMethod): + """LUT compression using lookup tables. + Attributes: + index_bitwidth: Number of bits per index (1-7). + """ index_bitwidth: int From e80725dc6636efd561ce91ec4c97d98d6ab1e393 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 08/21] feat(compression): add HuffmanCompression and PruningCompression types Add dataclass placeholders for future compression methods. These will be used by the plugin architecture to dispatch to compression-specific implementations. --- tensorflow/lite/micro/compression/spec.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index e33f3ba575e..c5e5ecc5bed 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -66,6 +66,24 @@ class LookUpTableCompression(CompressionMethod): index_bitwidth: int +@dataclass +class HuffmanCompression(CompressionMethod): + """Huffman compression using Xtensa-format decode tables. + + Supported tensor types: INT8, INT16 only. + """ + pass + + +@dataclass +class PruningCompression(CompressionMethod): + """Pruning (sparsity) compression. + + Supported tensor types: All TFLM tensor types. + """ + pass + + class ParseError(Exception): "Raised when the spec string cannot be parsed." From 5b2072f92e1565c09b07d884f5fd9a0b676ed5ed Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 09/21] refactor(compression): extract _parse_compression_method helper Factor out compression method parsing into a dedicated function that dispatches on the YAML key. This enables parse_yaml to iterate over multiple compression methods per tensor and makes adding new compression types straightforward. --- tensorflow/lite/micro/compression/spec.py | 29 ++++++++++++++++++----- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/micro/compression/spec.py b/tensorflow/lite/micro/compression/spec.py index c5e5ecc5bed..5c0f81885bc 100644 --- a/tensorflow/lite/micro/compression/spec.py +++ b/tensorflow/lite/micro/compression/spec.py @@ -92,6 +92,18 @@ def __init__(self, message="error parsing spec", wrapped_exception=None): self.original_exception = wrapped_exception +def _parse_compression_method(comp: dict) -> CompressionMethod: + """Parse a single compression method from YAML dict.""" + if "lut" in comp: + return LookUpTableCompression(index_bitwidth=comp["lut"]["index_bitwidth"]) + elif "huffman" in comp: + return HuffmanCompression() + elif "pruning" in comp: + return PruningCompression() + else: + raise ParseError(f"Unknown compression method: {list(comp.keys())}") + + def parse_yaml(y: str) -> list[Tensor]: "Parses a compression spec in a YAML string into its Python representation." try: @@ -99,14 +111,19 @@ def parse_yaml(y: str) -> list[Tensor]: tensors = [] for item in config["tensors"]: - bitwidth = item["compression"][0]["lut"]["index_bitwidth"] - tensor = Tensor(subgraph=item["subgraph"], - tensor=item["tensor"], - compression=[ - LookUpTableCompression(index_bitwidth=bitwidth), - ]) + methods = [] + for comp in item["compression"]: + methods.append(_parse_compression_method(comp)) + + tensor = Tensor( + subgraph=item["subgraph"], + tensor=item["tensor"], + compression=methods, + ) tensors.append(tensor) + except ParseError: + raise except Exception as e: raise ParseError() from e From bda4f77bccc6df5b54f141cef1a8254388885b8e Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 10/21] feat(compression): add Compressor protocol and CompressionResult Define the plugin interface for compression methods. Each compressor implements the Compressor protocol with a compress() method that returns encoded data and ancillary data. CompressionError provides a common exception type for compression failures. --- tensorflow/lite/micro/compression/BUILD | 10 +++ .../lite/micro/compression/compressor.py | 80 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compressor.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 72bd1e25d6d..5632bb04932 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -158,6 +158,16 @@ py_test( ], ) +py_library( + name = "compressor", + srcs = ["compressor.py"], + deps = [ + ":decode", + ":model_editor", + ":spec", + ], +) + py_library( name = "decode", srcs = ["decode.py"], diff --git a/tensorflow/lite/micro/compression/compressor.py b/tensorflow/lite/micro/compression/compressor.py new file mode 100644 index 00000000000..5c8f7e91222 --- /dev/null +++ b/tensorflow/lite/micro/compression/compressor.py @@ -0,0 +1,80 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Compression plugin interface.""" + +from dataclasses import dataclass +from typing import Protocol + +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + if wrapped_exception: + super().__init__(f"{message}: {str(wrapped_exception)}") + else: + super().__init__(message) + self.original_exception = wrapped_exception + + +@dataclass +class CompressionResult: + """Result of compressing a tensor. + + Attributes: + encoded_data: The compressed tensor data (e.g., packed indices for LUT). + ancillary_data: The complete ancillary data tensor bytes (DCM + type-specific + data). This is the full buffer contents for the ancillary + tensor. + """ + encoded_data: bytes + ancillary_data: bytes + + +class Compressor(Protocol): + """Protocol that compression plugins must implement. + + Each compression method (LUT, Huffman, Pruning) provides a class implementing + this protocol. The compress() function uses duck typing to call the plugin. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """The DecodeType constant for this compression method.""" + ... + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> CompressionResult: + """Compress a tensor according to the specified method. + + Args: + tensor: The tensor to compress. Must have data (tensor.array is not None) + and quantization parameters for axis inference. + method: The compression method spec (e.g., LookUpTableCompression). + + Returns: + CompressionResult with encoded tensor data and ancillary data bytes. + + Raises: + CompressionError: If compression fails (e.g., too many unique values + for specified bitwidth, missing quantization, etc.). + """ + ... From 69e8e5099ccac95495527a7e6c6935522d52d78f Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 11/21] feat(compression): add LUT compression plugin Extract LUT compression logic from compress.py into a dedicated plugin module. The LutCompressor class implements the Compressor protocol, producing packed indices and ancillary data in the format expected by the C++ DECODE kernel. --- tensorflow/lite/micro/compression/BUILD | 34 ++ tensorflow/lite/micro/compression/lut.py | 317 +++++++++++++++ tensorflow/lite/micro/compression/lut_test.py | 370 ++++++++++++++++++ 3 files changed, 721 insertions(+) create mode 100644 tensorflow/lite/micro/compression/lut.py create mode 100644 tensorflow/lite/micro/compression/lut_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 5632bb04932..af7a857695e 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -168,6 +168,40 @@ py_library( ], ) +py_library( + name = "lut", + srcs = ["lut.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + requirement("bitarray"), + requirement("numpy"), + ], +) + +py_test( + name = "lut_test", + size = "small", + srcs = ["lut_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":lut", + ":model_editor", + ":spec", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + requirement("tensorflow"), + ], +) + py_library( name = "decode", srcs = ["decode.py"], diff --git a/tensorflow/lite/micro/compression/lut.py b/tensorflow/lite/micro/compression/lut.py new file mode 100644 index 00000000000..79bf33f4eae --- /dev/null +++ b/tensorflow/lite/micro/compression/lut.py @@ -0,0 +1,317 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LUT (Look-Up Table) compression plugin.""" + +import sys +from dataclasses import dataclass, field +from typing import Optional + +import bitarray +import bitarray.util +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +@dataclass +class LutCompressedArray: + """Intermediate representation of LUT-compressed data. + + Attributes: + compression_axis: The axis along which compression was performed, or None + for per-tensor compression. + lookup_tables: List of value lookup tables. One table for per-tensor + compression, or one per channel for per-channel compression. + indices: Array of indices into the lookup tables, same shape as original. + """ + compression_axis: Optional[int] = None + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None or self.indices.size == 0: + raise ValueError("No indices to compute bitwidth from") + max_index = int(np.max(self.indices)) + return max_index.bit_length() or 1 + + +@dataclass +class LutAncillaryData: + """LUT-specific ancillary data matching C++ decode_state_lut.cc format. + + The LUT ancillary data uses the DCM user_data bytes (4-15) plus value tables: + - Byte 4: LUT version (currently 1) + - Byte 5: Params (lower 3 bits = bitwidth, 1-7) + - Byte 6: Value table channel stride (elements per channel) + - Bytes 7-15: Reserved (zeros) + - Bytes 16+: Value tables (concatenated, stride elements per channel) + + Attributes: + lut_version: LUT format version (currently 1). + bitwidth: Number of bits per index (1-7). + value_table_stride: Number of elements per channel in value tables. + value_tables: Packed value table data following the DCM. + """ + lut_version: int = 1 + bitwidth: int = 4 + value_table_stride: int = 16 + value_tables: bytes = b'' + + def __post_init__(self): + if not 1 <= self.bitwidth <= 7: + raise ValueError(f"bitwidth must be 1-7, got {self.bitwidth}") + if not 0 <= self.value_table_stride <= 128: + raise ValueError( + f"value_table_stride must be 0-128, got {self.value_table_stride}") + + def to_user_data(self) -> bytes: + """Serialize to 12-byte user_data for DCM bytes 4-15.""" + user_data = bytearray(12) + user_data[0] = self.lut_version + user_data[1] = self.bitwidth & 0x07 + user_data[2] = self.value_table_stride + # Bytes 3-11 (DCM bytes 7-15) remain zero (reserved) + return bytes(user_data) + + def to_bytes(self) -> bytes: + """Serialize for use as AncillaryDataTensor.ancillary_data.""" + # This returns the type-specific data that follows the DCM header. + # For LUT, that's just the value tables since user_data is in DCM. + return self.value_tables + + +def compress_array(tensor: np.ndarray, + axis: Optional[int]) -> LutCompressedArray: + """Compresses the given tensor using lookup tables. + + Args: + tensor: The tensor to be compressed. + axis: The axis along which to compress. If an axis is given, a lookup table + is created for each slice along the axis. If axis is None, a single + lookup table is used for the entire tensor. + + Compressing a tensor with a lookup table per slice along a particular + axis is analogous to quantizing a tensor with different quantization + parameters per slice along a particular axis (dimension). + + Returns: + LutCompressedArray containing lookup tables and indices. + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + if axis is None: + # Compute unique values and indices for the entire tensor + values, indices = np.unique(tensor, return_inverse=True) + compressed.lookup_tables.append(values) + compressed.indices = indices.reshape(tensor.shape) + else: + # Iterate over slices along the compression axis + slice_indices = [] + for slice in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(slice, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(slice.shape) + slice_indices.append(indices) + + # Reconstruct a tensor of indices from the slices + stacked = np.stack(slice_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: + """Determines the axis along which to compress. + + The axis along which to compress is inferred from the tensor's quantization + parameters. + + Args: + tensor: The tensor to analyze. + + Returns: + The axis along which to compress, or None to indicate one value table for + the entire tensor. + + Raises: + CompressionError: If the axis cannot be determined from quantization. + """ + q = tensor.quantization + if q is not None: + # model_editor wraps quantization, access scales/axis from wrapper + scales = q.scales if isinstance(q.scales, list) else [q.scales] + quantization_channels = len(scales) + + if quantization_channels == 1: + # Use one value table for the entire tensor + return None + + if q.axis is not None and q.axis < len(tensor.shape): + if quantization_channels == tensor.shape[q.axis]: + return q.axis + + raise compressor.CompressionError( + "Invalid or no quantization parameters from which to " + "infer the axis along which tensor should be compressed.") + + +def check_bitwidth(compressed: int, specified: int, tensor_spec: spec.Tensor): + """Validates that the specified bitwidth is sufficient. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + + Args: + compressed: The bitwidth required by the compressed data. + specified: The bitwidth specified in the compression spec. + tensor_spec: The tensor spec, for error messages. + + Raises: + CompressionError: If specified bitwidth is too small. + """ + if compressed > specified: + raise compressor.CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {tensor_spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in " + f"{tensor_spec}", + file=sys.stderr) + + +def pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. + + Args: + indices: Array of indices to pack. + bitwidth: Number of bits per index. + + Returns: + Packed bytes with indices in big-endian bit order. + """ + endianness = "big" + bits = bitarray.bitarray(endian=endianness) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) + return bits.tobytes() + + +def pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytes: + """Packs the value tables of a LutCompressedArray. + + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables are concatenated. + + Args: + tables: List of numpy arrays containing lookup table values. + table_len: Length to pad each table to (typically 2**bitwidth). + + Returns: + Packed bytes containing all tables concatenated. + """ + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return bytes(buffer) + + +class LutCompressor: + """LUT compression plugin implementing the Compressor protocol.""" + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.LUT.""" + return decode.DecodeType.LUT + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using LUT compression. + + Args: + tensor: The tensor to compress. + method: Must be a LookUpTableCompression instance. + + Returns: + CompressionResult with packed indices and ancillary data. + + Raises: + CompressionError: If compression fails. + """ + if not isinstance(method, spec.LookUpTableCompression): + raise compressor.CompressionError( + f"LutCompressor requires LookUpTableCompression, got {type(method)}") + + if tensor.array is None: + raise compressor.CompressionError("Tensor has no data to compress") + + spec_bitwidth = method.index_bitwidth + axis = identify_compression_axis(tensor) + compressed = compress_array(tensor.array, axis) + # Note: check_bitwidth requires a spec.Tensor but we don't have it here. + # We'll do a simpler check. + actual_bitwidth = compressed.index_bitwidth + if actual_bitwidth > spec_bitwidth: + raise compressor.CompressionError( + f"index_bitwidth too small: {actual_bitwidth} bits needed, " + f"but only {spec_bitwidth} specified") + elif actual_bitwidth < spec_bitwidth: + print( + f"warning: index_bitwidth larger than necessary: only " + f"{actual_bitwidth} bits needed, but {spec_bitwidth} specified", + file=sys.stderr) + + # Pack indices into bytes + encoded_data = pack_indices(compressed.indices, spec_bitwidth) + + # Pack value tables + table_len = 2**spec_bitwidth + value_tables_bytes = pack_lookup_tables(compressed.lookup_tables, + table_len) + + # Build ancillary data + lut_data = LutAncillaryData( + lut_version=1, + bitwidth=spec_bitwidth, + value_table_stride=table_len, + value_tables=value_tables_bytes, + ) + + # Build complete ancillary data tensor bytes: DCM header + value tables + dcm = decode.DecodeCommonMetadata( + decode_type=self.decode_type, + user_data=lut_data.to_user_data(), + ) + ancillary_data = dcm.to_bytes() + lut_data.to_bytes() + + return compressor.CompressionResult( + encoded_data=encoded_data, + ancillary_data=ancillary_data, + ) diff --git a/tensorflow/lite/micro/compression/lut_test.py b/tensorflow/lite/micro/compression/lut_test.py new file mode 100644 index 00000000000..39e2d3d8944 --- /dev/null +++ b/tensorflow/lite/micro/compression/lut_test.py @@ -0,0 +1,370 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LUT compression plugin.""" + +import numpy as np +import tensorflow as tf + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import lut +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +class TestCompressArray(tf.test.TestCase): + """Tests for the compress_array function.""" + + def test_per_tensor_basic(self): + """Per-tensor compression extracts unique values.""" + array = np.array([1, 2, 1, 2, 3, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertIsNone(compressed.compression_axis) + self.assertEqual(len(compressed.lookup_tables), 1) + self.assertAllEqual(compressed.lookup_tables[0], [1, 2, 3]) + # Indices should map back to original values + reconstructed = compressed.lookup_tables[0][compressed.indices] + self.assertAllEqual(reconstructed, array) + + def test_per_tensor_preserves_shape(self): + """Indices array has same shape as input.""" + # yapf: disable + array = np.array([[1, 2], + [3, 1], + [2, 3]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(compressed.indices.shape, array.shape) + + def test_per_channel_axis0(self): + """Per-channel compression along axis 0.""" + # Each row gets its own value table + # yapf: disable + array = np.array([[1, 1, 1], + [5, 5, 5], + [9, 9, 9]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=0) + + self.assertEqual(compressed.compression_axis, 0) + self.assertEqual(len(compressed.lookup_tables), 3) + self.assertAllEqual(compressed.lookup_tables[0], [1]) + self.assertAllEqual(compressed.lookup_tables[1], [5]) + self.assertAllEqual(compressed.lookup_tables[2], [9]) + + def test_per_channel_axis1(self): + """Per-channel compression along axis 1.""" + # Each column gets its own value table + # yapf: disable + array = np.array([[1, 5], + [1, 5], + [1, 5]], dtype=np.int8) + # yapf: enable + compressed = lut.compress_array(array, axis=1) + + self.assertEqual(compressed.compression_axis, 1) + self.assertEqual(len(compressed.lookup_tables), 2) + self.assertAllEqual(compressed.lookup_tables[0], [1]) + self.assertAllEqual(compressed.lookup_tables[1], [5]) + + def test_single_value(self): + """Array with single unique value.""" + array = np.array([7, 7, 7, 7], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + + self.assertEqual(len(compressed.lookup_tables), 1) + self.assertAllEqual(compressed.lookup_tables[0], [7]) + self.assertAllEqual(compressed.indices, [0, 0, 0, 0]) + + def test_bitwidth_calculation(self): + """Index bitwidth is computed correctly.""" + # 3 unique values -> 2 bits needed + array = np.array([0, 1, 2], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 4 unique values -> 2 bits needed + array = np.array([0, 1, 2, 3], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 2) + + # 5 unique values -> 3 bits needed + array = np.array([0, 1, 2, 3, 4], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 3) + + def test_bitwidth_single_value(self): + """Single unique value requires 1 bit.""" + array = np.array([42, 42, 42], dtype=np.int8) + compressed = lut.compress_array(array, axis=None) + self.assertEqual(compressed.index_bitwidth, 1) + + +class TestPackIndices(tf.test.TestCase): + """Tests for the pack_indices function.""" + + def test_4bit_packing(self): + """Pack indices into 4-bit fields.""" + indices = np.array([1, 2, 3, 0]) + result = lut.pack_indices(indices, bitwidth=4) + # Big-endian: 0001 0010 | 0011 0000 = 0x12 0x30 + self.assertEqual(result, bytes([0x12, 0x30])) + + def test_2bit_packing(self): + """Pack indices into 2-bit fields.""" + indices = np.array([0, 1, 2, 3]) + result = lut.pack_indices(indices, bitwidth=2) + # Big-endian: 00 01 10 11 = 0x1B + self.assertEqual(result, bytes([0x1B])) + + def test_3bit_packing(self): + """Pack indices into 3-bit fields.""" + indices = np.array([0, 1, 2, 3, 4, 5, 6, 7]) + result = lut.pack_indices(indices, bitwidth=3) + # 000 001 010 011 | 100 101 110 111 + # 00000101 | 00111001 | 01110111 = 0x05 0x39 0x77 + self.assertEqual(result, bytes([0x05, 0x39, 0x77])) + + def test_1bit_packing(self): + """Pack indices into 1-bit fields.""" + indices = np.array([0, 1, 0, 1, 1, 0, 1, 0]) + result = lut.pack_indices(indices, bitwidth=1) + # 0 1 0 1 1 0 1 0 = 0x5A + self.assertEqual(result, bytes([0x5A])) + + def test_multidimensional_flattens(self): + """Multidimensional indices are flattened row-major.""" + # yapf: disable + indices = np.array([[0, 1], + [2, 3]]) + # yapf: enable + result = lut.pack_indices(indices, bitwidth=4) + # 0000 0001 | 0010 0011 = 0x01 0x23 + self.assertEqual(result, bytes([0x01, 0x23])) + + +class TestPackLookupTables(tf.test.TestCase): + """Tests for the pack_lookup_tables function.""" + + def test_single_table_int8(self): + """Pack single INT8 lookup table.""" + tables = [np.array([10, 20, 30], dtype=np.int8)] + result = lut.pack_lookup_tables(tables, table_len=4) + # Values: 10, 20, 30, 0 (padding) + self.assertEqual(result, bytes([10, 20, 30, 0])) + + def test_multiple_tables(self): + """Pack multiple lookup tables.""" + tables = [ + np.array([1, 2], dtype=np.int8), + np.array([3, 4], dtype=np.int8), + ] + result = lut.pack_lookup_tables(tables, table_len=4) + # Table 1: 1, 2, 0, 0 | Table 2: 3, 4, 0, 0 + self.assertEqual(result, bytes([1, 2, 0, 0, 3, 4, 0, 0])) + + def test_int16_little_endian(self): + """INT16 values are packed in native byte order.""" + tables = [np.array([0x1234, 0x5678], dtype=' Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 12/21] feat(compression): add Huffman and Pruning plugin stubs Add placeholder implementations that raise CompressionError when invoked. These validate the plugin architecture and will be replaced with working implementations later. --- tensorflow/lite/micro/compression/BUILD | 22 +++++++ tensorflow/lite/micro/compression/huffman.py | 60 ++++++++++++++++++++ tensorflow/lite/micro/compression/pruning.py | 59 +++++++++++++++++++ 3 files changed, 141 insertions(+) create mode 100644 tensorflow/lite/micro/compression/huffman.py create mode 100644 tensorflow/lite/micro/compression/pruning.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index af7a857695e..ddaf6891fcc 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -202,6 +202,28 @@ py_test( ], ) +py_library( + name = "huffman", + srcs = ["huffman.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + +py_library( + name = "pruning", + srcs = ["pruning.py"], + deps = [ + ":compressor", + ":decode", + ":model_editor", + ":spec", + ], +) + py_library( name = "decode", srcs = ["decode.py"], diff --git a/tensorflow/lite/micro/compression/huffman.py b/tensorflow/lite/micro/compression/huffman.py new file mode 100644 index 00000000000..173101335f2 --- /dev/null +++ b/tensorflow/lite/micro/compression/huffman.py @@ -0,0 +1,60 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Huffman compression plugin (stub). + +This module provides a placeholder for Huffman compression using Xtensa-format +decode tables. The actual implementation is not yet available. + +Supported tensor types (when implemented): INT8, INT16 +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class HuffmanCompressor: + """Huffman compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual Huffman + compression algorithm using Xtensa-format decode tables is not yet + implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.HUFFMAN.""" + return decode.DecodeType.HUFFMAN + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using Huffman encoding. + + Args: + tensor: The tensor to compress. + method: Must be a HuffmanCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Huffman compression not yet implemented. " + "This stub exists to validate the plugin architecture.") diff --git a/tensorflow/lite/micro/compression/pruning.py b/tensorflow/lite/micro/compression/pruning.py new file mode 100644 index 00000000000..98d81dc4ade --- /dev/null +++ b/tensorflow/lite/micro/compression/pruning.py @@ -0,0 +1,59 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pruning compression plugin (stub). + +This module provides a placeholder for pruning (sparsity) compression. +The actual implementation is not yet available. + +Supported tensor types (when implemented): All TFLM tensor types +""" + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec + + +class PruningCompressor: + """Pruning compression plugin (stub). + + This stub exists to validate the plugin architecture. The actual pruning + compression algorithm for sparse tensors is not yet implemented. + """ + + @property + def decode_type(self) -> decode.DecodeType: + """Returns DecodeType.PRUNING.""" + return decode.DecodeType.PRUNING + + def compress( + self, + tensor: model_editor.Tensor, + method: spec.CompressionMethod, + ) -> compressor.CompressionResult: + """Compress a tensor using pruning (sparsity) encoding. + + Args: + tensor: The tensor to compress. + method: Must be a PruningCompression instance. + + Returns: + CompressionResult (not implemented). + + Raises: + CompressionError: Always, since this is a stub. + """ + raise compressor.CompressionError( + "Pruning compression not yet implemented. " + "This stub exists to validate the plugin architecture.") From 84f671cc7f6c958dbcef772529e6a8b099a80040 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 13/21] feat(compression): add DECODE operator insertion logic Implement graph modification to insert DECODE operators before consumers of compressed tensors. Each compressed tensor gets a DECODE operator with two inputs (encoded tensor and ancillary data tensor) and one output (decompressed tensor). Consumer operators are rewired to use the DECODE output. --- tensorflow/lite/micro/compression/BUILD | 30 ++ .../lite/micro/compression/decode_insert.py | 217 ++++++++++ .../micro/compression/decode_insert_test.py | 385 ++++++++++++++++++ 3 files changed, 632 insertions(+) create mode 100644 tensorflow/lite/micro/compression/decode_insert.py create mode 100644 tensorflow/lite/micro/compression/decode_insert_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index ddaf6891fcc..04958e28ff6 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -224,6 +224,36 @@ py_library( ], ) +py_library( + name = "decode_insert", + srcs = ["decode_insert.py"], + deps = [ + ":compressor", + ":model_editor", + "//tensorflow/lite/python:schema_py", + ], +) + +py_test( + name = "decode_insert_test", + size = "small", + srcs = ["decode_insert_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + deps = [ + ":compressor", + ":decode", + ":decode_insert", + ":model_editor", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + requirement("tensorflow"), + ], +) + py_library( name = "decode", srcs = ["decode.py"], diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py new file mode 100644 index 00000000000..be976795aa2 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -0,0 +1,217 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DECODE operator insertion into TFLite model graphs. + +This module inserts DECODE operators into a compressed model. DECODE operators +transform encoded tensors (with their paired ancillary data tensors) into +tensors ready for use by downstream operators. + +The DECODE operator is registered as a custom operator named "TFLM_DECODE". +Each DECODE output requires two inputs: the encoded tensor and the ancillary +data tensor (containing the DCM header and decode-type-specific data). +""" + +from collections import defaultdict +from dataclasses import dataclass +from typing import Optional + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + +# Custom operator name for DECODE +DECODE_CUSTOM_OP_NAME = "TFLM_DECODE" + + +@dataclass +class _CompressedTensorInfo: + """Information about a compressed tensor for DECODE insertion.""" + subgraph_idx: int + tensor_idx: int + tensor: model_editor.Tensor + ancillary_data: bytes + consumers: list[model_editor.Operator] + + +def _find_tensor_consumers( + subgraph: model_editor.Subgraph, + tensor: model_editor.Tensor, +) -> list[model_editor.Operator]: + """Find all operators in subgraph that use tensor as an input.""" + consumers = [] + for op in subgraph.operators: + if tensor in op.inputs: + consumers.append(op) + return consumers + + +def _find_earliest_consumer_position( + subgraph: model_editor.Subgraph, + consumers: list[model_editor.Operator], +) -> int: + """Find the position of the earliest consumer in operator list.""" + min_pos = len(subgraph.operators) + for consumer in consumers: + try: + pos = subgraph.operators.index(consumer) + min_pos = min(min_pos, pos) + except ValueError: + pass + return min_pos + + +def _create_ancillary_tensor( + ancillary_data: bytes, + original_tensor: model_editor.Tensor, +) -> model_editor.Tensor: + """Create an ancillary data tensor for a compressed tensor. + + Args: + ancillary_data: The complete ancillary data (DCM + type-specific data). + original_tensor: The original tensor being decoded, for naming. + + Returns: + A new Tensor containing the ancillary data. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_ancillary" + + return model_editor.Tensor( + shape=(len(ancillary_data), ), + dtype=tflite.TensorType.UINT8, + data=ancillary_data, + name=name, + ) + + +def _create_output_tensor( + original_tensor: model_editor.Tensor, ) -> model_editor.Tensor: + """Create the output tensor for a DECODE operator. + + The output tensor has the same shape, dtype, and quantization as the + original tensor would have when decoded. It has no data---the DECODE + operator produces its values at runtime. + + Args: + original_tensor: The original tensor being decoded. + + Returns: + A new Tensor for the DECODE output. + """ + name = None + if original_tensor.name: + name = f"{original_tensor.name}_decoded" + + return model_editor.Tensor( + shape=original_tensor.shape, + dtype=original_tensor.dtype, + quantization=original_tensor.quantization, + name=name, + ) + + +def _rewire_consumers( + consumers: list[model_editor.Operator], + old_tensor: model_editor.Tensor, + new_tensor: model_editor.Tensor, +) -> None: + """Replace old_tensor with new_tensor in all consumer inputs.""" + for consumer in consumers: + consumer.inputs = [ + new_tensor if t is old_tensor else t for t in consumer.inputs + ] + + +def insert_decode_operators( + model: model_editor.Model, + compression_results: dict[tuple[int, int], compressor.CompressionResult], +) -> None: + """Insert DECODE operators for all compressed tensors. + + This function modifies the model in-place, inserting DECODE operators + before any operator that uses a compressed tensor as input. + + For each compressed tensor: + 1. Create an ancillary data tensor containing DCM + type-specific data + 2. Create an output tensor with the same shape/dtype as the decoded tensor + 3. Insert a DECODE operator before the first consumer + 4. Rewire all consumers to use the DECODE output instead of the encoded tensor + + Args: + model: The model to modify in-place. + compression_results: Map from (subgraph_idx, tensor_idx) to the + CompressionResult containing ancillary_data. + """ + # Group compressed tensors by subgraph + by_subgraph: dict[int, list[_CompressedTensorInfo]] = defaultdict(list) + + for (sg_idx, tensor_idx), result in compression_results.items(): + subgraph = model.subgraphs[sg_idx] + tensor = subgraph.tensors[tensor_idx] + consumers = _find_tensor_consumers(subgraph, tensor) + + if not consumers: + # Tensor not used as input anywhere---no DECODE needed + continue + + info = _CompressedTensorInfo( + subgraph_idx=sg_idx, + tensor_idx=tensor_idx, + tensor=tensor, + ancillary_data=result.ancillary_data, + consumers=consumers, + ) + by_subgraph[sg_idx].append(info) + + # Process each subgraph + for sg_idx, tensor_infos in by_subgraph.items(): + subgraph = model.subgraphs[sg_idx] + + # Sort by earliest consumer position (process in reverse order to maintain + # valid positions as we insert) + tensor_infos.sort( + key=lambda info: _find_earliest_consumer_position( + subgraph, info.consumers), + reverse=True, + ) + + for info in tensor_infos: + # Create ancillary data tensor + ancillary_tensor = _create_ancillary_tensor( + info.ancillary_data, + info.tensor, + ) + subgraph.tensors.append(ancillary_tensor) + + # Create output tensor + output_tensor = _create_output_tensor(info.tensor) + subgraph.tensors.append(output_tensor) + + # Create DECODE operator + decode_op = model_editor.Operator( + opcode=tflite.BuiltinOperator.CUSTOM, + custom_code=DECODE_CUSTOM_OP_NAME, + inputs=[info.tensor, ancillary_tensor], + outputs=[output_tensor], + ) + + # Find insertion position (before first consumer) + insert_pos = _find_earliest_consumer_position(subgraph, info.consumers) + + # Insert DECODE operator + subgraph.operators.insert(insert_pos, decode_op) + + # Rewire all consumers to use the decoded output + _rewire_consumers(info.consumers, info.tensor, output_tensor) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py new file mode 100644 index 00000000000..b2f3f5e6438 --- /dev/null +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -0,0 +1,385 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for DECODE operator insertion.""" + +import numpy as np +import tensorflow as tf + +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_simple_fc_model(): + """Build a simple model with one FC operator and compressible weights.""" + # yapf: disable + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 1, 2], + [3, 4, 3, 4], + [1, 2, 1, 2], + [3, 4, 3, 4]], dtype=np.int8), + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + # yapf: enable + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model + + +def _build_shared_weights_model(): + """Build model where one tensor is used by multiple operators.""" + weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="shared_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights], + outputs=[output1], + ), + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights], + outputs=[output2], + ), + ], + ) + ]) + return model + + +def _make_dummy_ancillary_data() -> bytes: + """Create dummy ancillary data for testing.""" + dcm = decode.DecodeCommonMetadata( + decode_type=decode.DecodeType.LUT, + user_data=b'\x01\x04\x10' + b'\x00' * 9, # lut_version, bitwidth, stride + ) + value_tables = bytes([1, 2, 3, 4] + [0] * 12) # 16-byte padded table + return dcm.to_bytes() + value_tables + + +class TestDecodeInsertion(tf.test.TestCase): + """Tests for insert_decode_operators function.""" + + def test_insert_single_decode_operator(self): + """DECODE operator inserted before FC that uses compressed weights.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensors[0] + + # Create compression result + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + # Insert DECODE operators + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 2 operators: DECODE then FC + self.assertEqual(len(sg.operators), 2) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[0].custom_code, + decode_insert.DECODE_CUSTOM_OP_NAME) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + + def test_decode_inputs_structure(self): + """DECODE operator has correct inputs: encoded tensor + ancillary.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensors[0] + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + + # DECODE has 2 inputs + self.assertEqual(len(decode_op.inputs), 2) + # First input is the encoded tensor (original weights) + self.assertIs(decode_op.inputs[0], weights_tensor) + # Second input is ancillary tensor + self.assertEqual(decode_op.inputs[1].dtype, tflite.TensorType.UINT8) + + def test_decode_output_structure(self): + """DECODE operator output has correct shape and dtype.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensors[0] + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + output = decode_op.outputs[0] + + # Output matches original tensor shape and dtype + self.assertEqual(output.shape, weights_tensor.shape) + self.assertEqual(output.dtype, weights_tensor.dtype) + + def test_consumer_rewired_to_decode_output(self): + """FC operator input rewired to use DECODE output.""" + model = _build_simple_fc_model() + weights_tensor = model.subgraphs[0].tensors[0] + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + fc_op = model.subgraphs[0].operators[1] + + # FC's second input (weights) should now be DECODE's output + self.assertIs(fc_op.inputs[1], decode_op.outputs[0]) + # Original weights tensor should NOT be in FC inputs + self.assertNotIn(weights_tensor, fc_op.inputs) + + def test_shared_tensor_single_decode(self): + """Tensor used by multiple ops gets single DECODE, both rewired.""" + model = _build_shared_weights_model() + weights_tensor = model.subgraphs[0].tensors[0] + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + sg = model.subgraphs[0] + + # Should have 3 operators: 1 DECODE + 2 FC + self.assertEqual(len(sg.operators), 3) + self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + + decode_op = sg.operators[0] + fc_op1 = sg.operators[1] + fc_op2 = sg.operators[2] + + # Both FCs should use DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op.outputs[0]) + + def test_ancillary_tensor_contains_dcm(self): + """Ancillary tensor data contains valid DCM header.""" + model = _build_simple_fc_model() + + ancillary_data = _make_dummy_ancillary_data() + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=ancillary_data, + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary_tensor = decode_op.inputs[1] + + # Ancillary tensor data should match what we provided + self.assertEqual(bytes(ancillary_tensor.array), ancillary_data) + + # Verify DCM header + dcm_bytes = ancillary_tensor.array[:16] + self.assertEqual(dcm_bytes[0], 0) # decode_type = LUT + self.assertEqual(dcm_bytes[1], 1) # DCM version + + def test_no_consumers_no_decode(self): + """Tensor with no consumers gets no DECODE operator.""" + # Create model where compressed tensor is not used as input + unused_tensor = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="unused", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + input_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output", + ) + other_weights = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.ones((4, 4), dtype=np.int8), + name="other_weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[unused_tensor, other_weights], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, other_weights], + outputs=[output_t], + ) + ], + ) + ]) + + # Compress the unused tensor + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + # Should still have just 1 operator (no DECODE inserted) + self.assertEqual(len(model.subgraphs[0].operators), 1) + + def test_tensor_naming(self): + """Output and ancillary tensors get appropriate names.""" + model = _build_simple_fc_model() + + compression_results = { + (0, 0): + compressor.CompressionResult( + encoded_data=b'\x00\x00', + ancillary_data=_make_dummy_ancillary_data(), + ) + } + + decode_insert.insert_decode_operators(model, compression_results) + + decode_op = model.subgraphs[0].operators[0] + ancillary = decode_op.inputs[1] + output = decode_op.outputs[0] + + self.assertEqual(ancillary.name, "weights_ancillary") + self.assertEqual(output.name, "weights_decoded") + + +class TestHelperFunctions(tf.test.TestCase): + """Tests for internal helper functions.""" + + def test_find_tensor_consumers(self): + """_find_tensor_consumers finds all ops using a tensor.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + weights = sg.tensors[0] + + consumers = decode_insert._find_tensor_consumers(sg, weights) + + self.assertEqual(len(consumers), 2) + + def test_find_earliest_consumer_position(self): + """_find_earliest_consumer_position returns minimum position.""" + model = _build_shared_weights_model() + sg = model.subgraphs[0] + + # Both operators are consumers, first is at position 0 + pos = decode_insert._find_earliest_consumer_position(sg, sg.operators) + self.assertEqual(pos, 0) + + # Just the second operator + pos = decode_insert._find_earliest_consumer_position(sg, [sg.operators[1]]) + self.assertEqual(pos, 1) + + +if __name__ == "__main__": + tf.test.main() From a98304fd53265dac7f78b2449d18bef2a65cb57f Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 14/21] refactor(compression): use plugin architecture in compress.py Replace monolithic compression logic with a dispatch table that routes compression requests to plugin modules based on the spec's compression method type. After compressing tensors, insert DECODE operators into the model graph. The old metadata flatbuffer approach is removed in favor of the DECODE operator format. --- tensorflow/lite/micro/compression/BUILD | 13 +- tensorflow/lite/micro/compression/compress.py | 296 ++------ .../lite/micro/compression/compress_test.py | 656 +++++++----------- 3 files changed, 330 insertions(+), 635 deletions(-) diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 04958e28ff6..dca1c68d424 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -114,14 +114,15 @@ py_library( "compress.py", ], deps = [ - ":metadata_py", + ":compressor", + ":decode_insert", + ":huffman", + ":lut", ":model_editor", + ":pruning", ":spec", "//tensorflow/lite/micro/tools:tflite_flatbuffer_align", requirement("absl_py"), - "@flatbuffers//:runtime_py", - requirement("bitarray"), - requirement("numpy"), ], ) @@ -148,11 +149,11 @@ py_test( ], deps = [ ":compress", - ":metadata_py", + ":compressor", + ":decode_insert", ":model_editor", ":spec", "//tensorflow/lite/python:schema_py", - requirement("bitarray"), requirement("numpy"), requirement("tensorflow"), ], diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 61d3ce8df17..29b3e97087b 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -16,22 +16,21 @@ See USAGE. """ -import bitarray -import bitarray.util -from dataclasses import dataclass, field import os import sys import tempfile -from typing import ByteString, Iterable, Optional +from typing import ByteString, Iterable, Type import absl.app import absl.flags -import flatbuffers -import numpy as np +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import huffman +from tflite_micro.tensorflow.lite.micro.compression import lut from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import pruning from tflite_micro.tensorflow.lite.micro.compression import spec -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema from tflite_micro.tensorflow.lite.micro.tools import tflite_flatbuffer_align_wrapper USAGE = f"""\ @@ -49,221 +48,48 @@ {spec.EXAMPLE_YAML_SPEC} --- -The only compression method currently implemented is "lut", i.e., -Look-Up-Table. This method requires the tensor in the input model to have a -small number of unique values, fewer than or equal to 2**index_bitwidth. LUT -compression collects these values into a lookup table, and rewrites the tensor -as bitwidth-wide integer indices into that lookup table. Presumably, the input -model has been trained or preprocessed in a way that the tensor values -are binned into a meaningful, limited set. -""" - -# A compressed model augments the usual .tflite flatbuffer with a flatbuffer of -# its own containing compression metadata, stored at the buffer index stored at -# the following key in the .tflite flatbuffer's metadata map. -TFLITE_METADATA_KEY = "COMPRESSION_METADATA" - - -class CompressionError(Exception): - """Raised when compression fails for the reason documented in the message.""" - - def __init__(self, message, wrapped_exception=None): - super().__init__(f"{message}: {str(wrapped_exception)}") - self.original_exception = wrapped_exception - - -class _MetadataBuilder: - """Builder for the compression metadata flatbuffer.""" - - def __init__(self): - self._metadata = schema.MetadataT() - self._metadata.subgraphs = [] - - def compile(self) -> bytearray: - """Packs the metadata into a binary array and returns it. - """ - builder = flatbuffers.Builder(1 * 2**10) - root = self._metadata.Pack(builder) - builder.Finish(root) - return builder.Output() - - def subgraph(self, index: int): - """Return subgraph at index, adding subgraphs if necessary. - """ - while len(self._metadata.subgraphs) <= index: - self._add_subgraph() - return self._metadata.subgraphs[index] - - def add_lut_tensor(self, subgraph_id: int): - """Add LUT tensor to the given subgraph and return it. - """ - tensor = schema.LutTensorT() - self.subgraph(subgraph_id).lutTensors.append(tensor) - return tensor - - def _add_subgraph(self): - subgraph = schema.SubgraphT() - subgraph.lutTensors = [] - self._metadata.subgraphs.append(subgraph) - return subgraph - - -@dataclass -class _LutCompressedArray: - compression_axis: Optional[int] = None - lookup_tables: list[np.ndarray] = field(default_factory=list) - indices: np.ndarray = field(default_factory=lambda: np.array([])) - - @property - def index_bitwidth(self) -> int: - """Returns the number of bits required to encode the indices.""" - if self.indices is None: - raise ValueError - - max_index = int(np.max(self.indices)) - return max_index.bit_length() or 1 - - -def _lut_compress_array(tensor: np.ndarray, - axis: Optional[int]) -> _LutCompressedArray: - """Compresses the given tensor using lookup tables. - - Args: - tensor (np.ndarray): The tensor to be compressed. - - axis (Optional[int]): The axis along which to compress the tensor. If an - axis is given, a lookup table is created for each slice along the - axis. If axis is None, a single lookup table is used for the entire - tensor. - - Compressing a tensor with a lookup table per slice along a - particular axis is analogous to quantizing a tensor with different - quantization parameters per slice along a particular axis (dimension). - - Returns: - _LutCompressedArray: An object containing the compressed tensor data, - including the lookup tables and indices. - """ - compressed = _LutCompressedArray() - compressed.compression_axis = axis - - if axis is None: - # Compute unique values and indices for the entire tensor - values, indices = np.unique(tensor, return_inverse=True) - compressed.lookup_tables.append(values) - compressed.indices = indices.reshape(tensor.shape) - else: - # Iterate over slices along the compression axis - slice_indices = [] - for slice in np.moveaxis(tensor, axis, 0): - values, indices = np.unique(slice, return_inverse=True) - compressed.lookup_tables.append(values) - indices = indices.reshape(slice.shape) - slice_indices.append(indices) - - # Reconstruct a tensor of indices from the slices - stacked = np.stack(slice_indices, axis=0) - compressed.indices = np.moveaxis(stacked, 0, axis) +Supported compression methods: - return compressed + lut: Look-Up-Table compression. Requires the tensor to have a small number of + unique values, fewer than or equal to 2**index_bitwidth. LUT compression + collects these values into a lookup table, and rewrites the tensor as + bitwidth-wide integer indices into that lookup table. + huffman: Huffman compression using Xtensa-format decode tables. (Not yet + implemented.) -def _check_lut_compression(compression) -> spec.LookUpTableCompression: - if len(compression) != 1: - raise CompressionError("Each tensor must have exactly one compression") - if not isinstance(compression[0], spec.LookUpTableCompression): - raise CompressionError('Only "lut" compression may be specified') + pruning: Pruning (sparsity) compression for sparse tensors. (Not yet + implemented.) - return compression[0] - - -def _identify_compression_axis(tensor: model_editor.Tensor) -> Optional[int]: - """Determines the axis along which to compress. - - The axis along which to compress is inferred from the tensor's quantization - parameters. - - Returns: - The axis along which to compress, or None to indicate one value table for - the entire tensor. - - Raises: - CompressionError: If the axis cannot be determined. - """ - q = tensor.quantization - if q is not None: - # model_editor wraps quantization, access scales/axis from wrapper - scales = q.scales if isinstance(q.scales, list) else [q.scales] - quantization_channels = len(scales) - - if quantization_channels == 1: - # Use one value table for the entire tensor - return None - - if q.axis is not None and q.axis < len(tensor.shape): - if quantization_channels == tensor.shape[q.axis]: - return q.axis - - raise CompressionError( - f"Invalid or no quanitzation parameters from which to " - f"infer the axis along which tensor should be compressed.") - - -def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): - """Applies business logic regarding specified bitwidth. - - It is an error if the bitwidth required to compress a tensor exceeds the - specified bitwith, and a warning if the tensor can be compressed in less than - the specified bitwidth. The latter is allowed, and is not an error, to permit - testing with larger bitwidths without re-binning a model. - """ - if compressed > specified: - raise CompressionError( - f"index_bitwidth too small: {compressed} bits needed to " - f"enumerate unique values in tensor specified in {spec}") - elif compressed < specified: - print( - f"warning: index_bitwidth too large: only {compressed} " - f"bits needed to enumerate unique values in tensor specified in {spec}", - file=sys.stderr) - - -def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: - """Packs indices into a bytearray using bitwidth-sized fields. - """ - endianness = "big" - bits = bitarray.bitarray(endian=endianness) - for i in indices.ravel(): - bits.extend( - bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) - return bits.tobytes() +Compressed models use DECODE operators to decompress tensors at runtime. +""" +# Plugin dispatch table: maps CompressionMethod subclasses to compressor instances +_COMPRESSORS: dict[Type[spec.CompressionMethod], compressor.Compressor] = { + spec.LookUpTableCompression: lut.LutCompressor(), + spec.HuffmanCompression: huffman.HuffmanCompressor(), + spec.PruningCompression: pruning.PruningCompressor(), +} -def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: - """Packs the value tables of a LutCompressedArray. - Pack the value tables of a LutCompressedArray into a bytes object in the - format writable to a value_table buffer in the .tflite flatbuffer. The - tables are concatinated. - """ - buffer = bytearray() - for t in tables: - padding_needed = table_len - len(t) - padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) - buffer.extend(padded.tobytes()) - - return buffer +def _get_compressor(method: spec.CompressionMethod) -> compressor.Compressor: + """Get the compressor plugin for a given compression method.""" + compressor_instance = _COMPRESSORS.get(type(method)) + if compressor_instance is None: + raise compressor.CompressionError( + f"No compressor registered for {type(method).__name__}") + return compressor_instance def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: """Applies proper FlatBuffer alignment to a model. - + The Python flatbuffers library doesn't respect `force_align` schema attributes, so we use the C++ wrapper which properly handles alignment requirements. - + Args: model_bytes: The model flatbuffer to align - + Returns: The properly aligned model flatbuffer """ @@ -295,45 +121,47 @@ def _apply_flatbuffer_alignment(model_bytes: bytearray) -> bytearray: def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: """Compresses a model .tflite flatbuffer. + Compresses tensors according to the given specs and inserts DECODE operators + to decompress them at runtime. + Args: model_in: the original, uncompressed .tflite flatbuffer specs: an iterable of compression specs, see module spec.py Returns: - A compressed flatbuffer. + A compressed flatbuffer with DECODE operators inserted. """ model = model_editor.read(model_in) - metadata = _MetadataBuilder() + compression_results: dict[tuple[int, int], compressor.CompressionResult] = {} - for spec in specs: + for tensor_spec in specs: try: - tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] - lut_compression = _check_lut_compression(spec.compression) - spec_bitwidth = lut_compression.index_bitwidth - axis = _identify_compression_axis(tensor) - compressed = _lut_compress_array(tensor.array, axis) - _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - - # overwrite tensor data with indices - tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) - - # write value buffer - value_buffer_data = _pack_lookup_tables(compressed.lookup_tables, - 2**spec_bitwidth) - value_buffer = model_editor.Buffer(data=value_buffer_data) - model.buffers.append(value_buffer) # Auto-sets value_buffer.index - - # add compression metadata for tensor - lut_tensor = metadata.add_lut_tensor(subgraph_id=spec.subgraph) - lut_tensor.tensor = spec.tensor - lut_tensor.valueBuffer = value_buffer.index - lut_tensor.indexBitwidth = spec_bitwidth + tensor = model.subgraphs[tensor_spec.subgraph].tensors[ + tensor_spec.tensor] + + # Currently only one compression method per tensor + if len(tensor_spec.compression) != 1: + raise compressor.CompressionError( + "Each tensor must have exactly one compression method") + + method = tensor_spec.compression[0] + plugin = _get_compressor(method) + result = plugin.compress(tensor, method) + + # Replace tensor data with encoded data + tensor.buffer.data = result.encoded_data + + # Store result for DECODE insertion + compression_results[(tensor_spec.subgraph, tensor_spec.tensor)] = result + except compressor.CompressionError: + raise except Exception as e: - raise CompressionError(f"error compressing {spec}") from e + raise compressor.CompressionError( + f"error compressing {tensor_spec}") from e - # add compression metadata to model - model.metadata[TFLITE_METADATA_KEY] = metadata.compile() + # Insert DECODE operators into the graph + decode_insert.insert_decode_operators(model, compression_results) # Build the model and apply proper alignment unaligned_model = model.build() diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 41a40c8ae47..d8a00690a35 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -11,164 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Integration tests for the compression system.""" -import bitarray -import bitarray.util import numpy as np import tensorflow as tf from tflite_micro.tensorflow.lite.micro.compression import compress -from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema +from tflite_micro.tensorflow.lite.micro.compression import compressor +from tflite_micro.tensorflow.lite.micro.compression import decode_insert from tflite_micro.tensorflow.lite.micro.compression import model_editor from tflite_micro.tensorflow.lite.micro.compression import spec from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class TestPackIndices(tf.test.TestCase): - - def test_basic_case(self): - indices = np.array([1, 2, 3]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0000]) - self.assertEqual(result, expected_bytes) - - def test_single_element(self): - indices = np.array([10]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_1010]) - self.assertEqual(result, expected_bytes) - - def test_different_bitwidth(self): - indices = np.array([1, 2, 3]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) - self.assertEqual(result, expected_bytes) - - def test_large_numbers(self): - indices = np.array([255, 128, 64]) - bitwidth = 8 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) - self.assertEqual(result, expected_bytes) - - def test_multidimensional_array(self): - indices = np.array([[1, 2], [3, 4]]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b0001_0010, 0b0011_0100]) - self.assertEqual(result, expected_bytes) - - def test_zero_bitwidth(self): - indices = np.array([0, 1, 2]) - bitwidth = 0 - with self.assertRaises(ValueError): - compress._pack_indices(indices, bitwidth) - - def test_empty_array(self): - indices = np.array([]) - bitwidth = 4 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = b"" - self.assertEqual(result, expected_bytes) - - def test_bitwidth_1(self): - indices = np.array([1, 0, 1, 1, 0, 1]) - bitwidth = 1 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b101101_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_2(self): - indices = np.array([1, 2, 3, 0]) - bitwidth = 2 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b01_10_11_00]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_3(self): - indices = np.array([1, 3, 5, 7]) - bitwidth = 3 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_5(self): - indices = np.array([1, 2, 16, 31]) - bitwidth = 5 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) - self.assertEqual(result, expected_bytes) - - def test_bitwidth_7(self): - indices = np.array([1, 64, 127, 32]) - bitwidth = 7 - result = compress._pack_indices(indices, bitwidth) - expected_bytes = bytes( - [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) - self.assertEqual(result, expected_bytes) - - -class TestPackLookupTables(tf.test.TestCase): - - def test_int16_positive(self): - tables = [np.array([0x1234, 0x5678], dtype=' tuple[int, bitarray.bitarray, np.ndarray]: - """Helper: extracts the compressed tensor parts for a given spec. - - Returns: - bitwidth - indices - values - """ - subgraph_obj = self.compressed.subgraphs[subgraph] - tensor_obj = subgraph_obj.tensors[tensor] - lut_tensors = self.metadata.subgraphs[subgraph_obj.index].lutTensors - lut_tensor = next(t for t in lut_tensors if t.tensor == tensor_obj.index) - bitwidth = lut_tensor.indexBitwidth - - indices = bitarray.bitarray(buffer=tensor_obj.buffer.data, endian="big") - n_indices = np.prod(tensor_obj.shape) - indices = indices[:n_indices * bitwidth] # trim possible padding - - value_buffer = self.compressed.buffers[lut_tensor.valueBuffer] - values = np.frombuffer(value_buffer.data, dtype=tensor_obj.numpy_dtype) - - return bitwidth, indices, values - - def _make_indices(self, s: str) -> bitarray.bitarray: - """Helper: makes indices from "01" strings for use as expected values.""" - return bitarray.bitarray(s, endian="big") - - def test_compressed_uint8(self): - bitwidth, indices, values = self._get_compressed(subgraph=0, tensor=0) - self.assertEqual(bitwidth, 4) - - # yapf: disable - expected_indices = self._make_indices(""" - 0000 0001 0010 0011 - 0100 0101 0110 0111 - 1000 1001 1010 1011 - 1100 1101 1110 1111 - """) - # yapf: enable - self.assertEqual(indices, expected_indices) - - expected_values = np.array(range(16), dtype=" Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 15/21] feat(model_editor): add subgraph inputs and outputs fields The TFLM interpreter requires subgraph inputs/outputs to be set in the flatbuffer to know which tensors are model inputs and outputs. Without these, models built with model_editor cannot be executed. Add inputs and outputs fields to Subgraph dataclass, populate them in _compile_subgraph when building, and read them back in read(). --- .../lite/micro/compression/model_editor.py | 21 ++++- .../micro/compression/model_editor_test.py | 84 +++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/tensorflow/lite/micro/compression/model_editor.py b/tensorflow/lite/micro/compression/model_editor.py index 541636b7e42..9cedb25af1d 100644 --- a/tensorflow/lite/micro/compression/model_editor.py +++ b/tensorflow/lite/micro/compression/model_editor.py @@ -226,6 +226,8 @@ class Subgraph: """Declarative subgraph specification with imperative methods.""" tensors: List[Tensor] = field(default_factory=list) operators: List[Operator] = field(default_factory=list) + inputs: List[Tensor] = field(default_factory=list) + outputs: List[Tensor] = field(default_factory=list) name: Optional[str] = None _index: Optional[int] = field(default=None, init=False, repr=False) @@ -362,6 +364,12 @@ def read(buffer: bytes) -> Model: opcode_index=fb_op.opcodeIndex) sg.operators.append(op) + # Read subgraph inputs/outputs + if fb_sg.inputs is not None and len(fb_sg.inputs) > 0: + sg.inputs = [sg.tensors[i] for i in fb_sg.inputs] + if fb_sg.outputs is not None and len(fb_sg.outputs) > 0: + sg.outputs = [sg.tensors[i] for i in fb_sg.outputs] + model.subgraphs.append(sg) # Read metadata @@ -469,9 +477,12 @@ def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: t._index = i tensor_to_index[id(t)] = i - # Extract inline tensors from operators - for op in sg.operators: - for tensor in op.inputs + op.outputs: + # Extract inline tensors from operators and subgraph inputs/outputs + inline_sources = [op.inputs + op.outputs for op in sg.operators] + inline_sources.append(sg.inputs) + inline_sources.append(sg.outputs) + for source in inline_sources: + for tensor in source: if id(tensor) not in tensor_to_index: tensor._index = len(all_tensors) tensor_to_index[id(tensor)] = tensor._index @@ -487,6 +498,10 @@ def _compile_subgraph(self, sg: Subgraph) -> tflite.SubGraphT: for op in sg.operators: sg_t.operators.append(self._compile_operator(op, tensor_to_index)) + # Set subgraph inputs/outputs + sg_t.inputs = [tensor_to_index[id(t)] for t in sg.inputs] + sg_t.outputs = [tensor_to_index[id(t)] for t in sg.outputs] + return sg_t def _compile_operator(self, op: Operator, diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py index a6c5de56629..704f78cff57 100644 --- a/tensorflow/lite/micro/compression/model_editor_test.py +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -731,5 +731,89 @@ def test_modify_metadata(self): self.assertEqual(model3.metadata["new_key"], b"new_value") +class TestSubgraphInputsOutputs(tf.test.TestCase): + """Test subgraph inputs and outputs are set correctly.""" + + def test_subgraph_inputs_outputs_set(self): + """Verify subgraph inputs/outputs are set in the flatbuffer.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + fb_model = tflite.ModelT.InitFromPackedBuf(fb, 0) + fb_sg = fb_model.subgraphs[0] + + # Verify inputs/outputs are set (as tensor indices) + self.assertEqual(len(fb_sg.inputs), 1) + self.assertEqual(len(fb_sg.outputs), 1) + + # Verify indices point to correct tensors + input_idx = fb_sg.inputs[0] + output_idx = fb_sg.outputs[0] + self.assertEqual(fb_sg.tensors[input_idx].name, b"input") + self.assertEqual(fb_sg.tensors[output_idx].name, b"output") + + def test_subgraph_inputs_outputs_loopback(self): + """Verify inputs/outputs survive read/build loopback.""" + input_t = Tensor(shape=(1, 4), dtype=tflite.TensorType.INT8, name="input") + output_t = Tensor(shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output") + weights = Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=np.array([[1, 2, 3, 4]] * 4, dtype=np.int8), + name="weights", + ) + + model = Model(subgraphs=[ + Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + + fb = model.build() + loopback = model_editor.read(fb) + sg = loopback.subgraphs[0] + + # Verify high-level inputs/outputs are populated + self.assertEqual(len(sg.inputs), 1) + self.assertEqual(len(sg.outputs), 1) + self.assertEqual(sg.inputs[0].name, "input") + self.assertEqual(sg.outputs[0].name, "output") + + if __name__ == "__main__": tf.test.main() From 01cfab5081c9ef850446879761a258bc7d6541e2 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 24 Nov 2025 22:10:41 -0600 Subject: [PATCH 16/21] test(compression): add integration tests with TFLM interpreter Add tests that compress models with LUT compression, run them through the TFLM Python interpreter, and verify outputs match uncompressed originals. Also verify DECODE operators are inserted and that compressed models are smaller than originals. Tests only run when compression is enabled (--//:with_compression). Placeholder tests for Huffman and Pruning are skipped until implemented. --- tensorflow/lite/micro/compression/BUILD | 26 +++ .../compression_integration_test.py | 220 ++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 tensorflow/lite/micro/compression/compression_integration_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index dca1c68d424..c4f092a7b8b 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -159,6 +159,32 @@ py_test( ], ) +py_test( + name = "compression_integration_test", + size = "small", + srcs = ["compression_integration_test.py"], + tags = [ + "noasan", + "nomsan", + "noubsan", + ], + # Only run when compression IS enabled + target_compatible_with = select({ + "//:with_compression_enabled": [], + "//conditions:default": ["@platforms//:incompatible"], + }), + deps = [ + ":compress_lib", + ":decode_insert", + ":model_editor", + ":spec", + "//python/tflite_micro:runtime", + "//tensorflow/lite/python:schema_py", + requirement("numpy"), + requirement("tensorflow"), + ], +) + py_library( name = "compressor", srcs = ["compressor.py"], diff --git a/tensorflow/lite/micro/compression/compression_integration_test.py b/tensorflow/lite/micro/compression/compression_integration_test.py new file mode 100644 index 00000000000..e8893eb9014 --- /dev/null +++ b/tensorflow/lite/micro/compression/compression_integration_test.py @@ -0,0 +1,220 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Integration tests for compression with TFLM interpreter. + +These tests verify that compressed models produce correct inference results +when run through the TFLM Python interpreter. Tests compress models and +compare outputs against uncompressed originals. + +These tests only run when compression is enabled (--//:with_compression). +""" + +import os +import unittest +import numpy as np +import tensorflow as tf + +from tflite_micro.python.tflite_micro import runtime +from tflite_micro.tensorflow.lite.micro.compression import compress +from tflite_micro.tensorflow.lite.micro.compression import decode_insert +from tflite_micro.tensorflow.lite.micro.compression import model_editor +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite + + +def _build_compressible_model(weight_shape=(4, 4)): + """Build a model with clustered weights for compression testing. + + Args: + weight_shape: Shape of the weight tensor as (rows, cols). + + Returns: + A TFLite flatbuffer (bytes) containing a simple FULLY_CONNECTED model + with weights that have only 4 unique values. + """ + rows, cols = weight_shape + + # Create weights with only 4 unique values (compressible with 2-bit indices) + pattern = np.array([1, 2, 3, 4], dtype=np.int8) + weight_data = np.resize(pattern, (rows, cols)) + + weights = model_editor.Tensor( + shape=weight_shape, + dtype=tflite.TensorType.INT8, + data=weight_data, + name="weights", + quantization=model_editor.Quantization(scales=0.5, zero_points=0), + ) + + input_t = model_editor.Tensor( + shape=(1, cols), + dtype=tflite.TensorType.INT8, + name="input", + ) + output_t = model_editor.Tensor( + shape=(1, rows), + dtype=tflite.TensorType.INT8, + name="output", + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights], + inputs=[input_t], + outputs=[output_t], + operators=[ + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input_t, weights], + outputs=[output_t], + ) + ], + ) + ]) + return model.build() + + +class LutCompressionTest(tf.test.TestCase): + """Integration tests for LUT (lookup table) compression.""" + + def test_lut_compressed_model_matches_uncompressed(self): + """LUT-compressed model produces same outputs as uncompressed.""" + flatbuffer = _build_compressible_model() + + # Create compression spec for weights tensor (index 0 in tensors list) + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + # Compress + compressed_fb = compress.compress(flatbuffer, specs) + + # Run inference on both (convert bytearray to bytes for interpreter) + uncompressed_interp = runtime.Interpreter.from_bytes(bytes(flatbuffer)) + compressed_interp = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Test with multiple random inputs + np.random.seed(42) + for _ in range(10): + test_input = np.random.randint(-128, 127, (1, 4), dtype=np.int8) + + uncompressed_interp.set_input(test_input, 0) + uncompressed_interp.invoke() + expected = uncompressed_interp.get_output(0) + + compressed_interp.set_input(test_input, 0) + compressed_interp.invoke() + actual = compressed_interp.get_output(0) + + self.assertAllEqual(expected, actual) + + def test_lut_decode_operators_present(self): + """DECODE operators are inserted for LUT-compressed tensors.""" + flatbuffer = _build_compressible_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + model = model_editor.read(compressed_fb) + sg = model.subgraphs[0] + + # Find DECODE operators + decode_ops = [ + op for op in sg.operators if op.opcode == tflite.BuiltinOperator.CUSTOM + and op.custom_code == decode_insert.DECODE_CUSTOM_OP_NAME + ] + + self.assertGreater(len(decode_ops), 0, + "DECODE operators should be present") + + def test_lut_compressed_model_is_smaller(self): + """LUT-compressed model is smaller than original. + + Uses a large enough weight tensor (64x64 = 4096 bytes) that compression + savings outweigh the overhead from lookup tables and DECODE operators. + With 2-bit indices, 4096 bytes becomes 1024 bytes of indices. + """ + flatbuffer = _build_compressible_model(weight_shape=(64, 64)) + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ) + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + original_size = len(flatbuffer) + compressed_size = len(compressed_fb) + + self.assertLess( + compressed_size, original_size, + f"Compressed model ({compressed_size} bytes) should be smaller than " + f"original ({original_size} bytes)") + + +class HuffmanCompressionTest(tf.test.TestCase): + """Integration tests for Huffman compression.""" + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_matches_uncompressed(self): + """Huffman-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_decode_operators_present(self): + """DECODE operators are inserted for Huffman-compressed tensors.""" + pass + + @unittest.skip("Huffman compression not yet implemented") + def test_huffman_compressed_model_is_smaller(self): + """Huffman-compressed model is smaller than original.""" + pass + + +class PruningCompressionTest(tf.test.TestCase): + """Integration tests for pruning compression.""" + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_matches_uncompressed(self): + """Pruning-compressed model produces same outputs as uncompressed.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_decode_operators_present(self): + """DECODE operators are inserted for pruning-compressed tensors.""" + pass + + @unittest.skip("Pruning compression not yet implemented") + def test_pruning_compressed_model_is_smaller(self): + """Pruning-compressed model is smaller than original.""" + pass + + +if __name__ == "__main__": + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0" + tf.test.main() From bee756025f00cca77a81dbffd4861904f95eb34a Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 1 Dec 2025 23:22:02 -0600 Subject: [PATCH 17/21] feat(python): add alt decompression memory parameter to interpreter Add alt_decompression_memory_size parameter to the Python interpreter API. When non-zero, allocates a separate memory region for DECODE operator outputs and calls SetDecompressionMemory before AllocateTensors. SetDecompressionMemory stores a pointer to its initializer_list argument, requiring the list to outlive the interpreter. Per C++ standard, an initializer_list's backing array lifetime is only extended to match the list's when initialized in a declaration, not when assigned. This makes the API difficult to use correctly. --- python/tflite_micro/_runtime.cc | 9 +++++---- python/tflite_micro/interpreter_wrapper.cc | 22 ++++++++++++++++++++-- python/tflite_micro/interpreter_wrapper.h | 13 ++++++++++++- python/tflite_micro/runtime.py | 12 ++++++++++++ 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/python/tflite_micro/_runtime.cc b/python/tflite_micro/_runtime.cc index 246545fd016..53825f14f0d 100644 --- a/python/tflite_micro/_runtime.cc +++ b/python/tflite_micro/_runtime.cc @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) { .def(py::init([](const py::bytes& data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - tflite::InterpreterConfig config) { - return std::unique_ptr( - new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size, - num_resource_variables, config)); + tflite::InterpreterConfig config, + size_t alt_decompression_memory_size) { + return std::unique_ptr(new InterpreterWrapper( + data.ptr(), registerers_by_name, arena_size, num_resource_variables, + config, alt_decompression_memory_size)); })) .def("PrintAllocations", &InterpreterWrapper::PrintAllocations) .def("Invoke", &InterpreterWrapper::Invoke) diff --git a/python/tflite_micro/interpreter_wrapper.cc b/python/tflite_micro/interpreter_wrapper.cc index 669589890ad..43c3b3a4365 100644 --- a/python/tflite_micro/interpreter_wrapper.cc +++ b/python/tflite_micro/interpreter_wrapper.cc @@ -238,7 +238,18 @@ InterpreterWrapper::~InterpreterWrapper() { InterpreterWrapper::InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, - size_t arena_size, int num_resource_variables, InterpreterConfig config) { + size_t arena_size, int num_resource_variables, InterpreterConfig config, + size_t alt_decompression_memory_size) + // Members initialized in declaration order. alt_decompression_regions_ + // MUST be initialized here (not assigned in body) so its backing array + // lifetime is extended to match the member's lifetime. + : memory_arena_(new uint8_t[arena_size]), + alt_decompression_memory_(alt_decompression_memory_size > 0 + ? new uint8_t[alt_decompression_memory_size] + : nullptr), + alt_decompression_region_{alt_decompression_memory_.get(), + alt_decompression_memory_size}, + alt_decompression_regions_{alt_decompression_region_} { interpreter_ = nullptr; // `model_data` is used as a raw pointer beyond the scope of this @@ -266,7 +277,6 @@ InterpreterWrapper::InterpreterWrapper( "--//:with_compression=true to enable compression support."); } - memory_arena_ = std::unique_ptr(new uint8_t[arena_size]); for (const std::string& registerer : registerers_by_name) { if (!AddCustomOpRegistererByName(registerer.c_str(), &python_ops_resolver_)) { @@ -296,6 +306,14 @@ InterpreterWrapper::InterpreterWrapper( interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_, resource_variables_); + if (alt_decompression_memory_size > 0) { + TfLiteStatus status = + interpreter_->SetDecompressionMemory(alt_decompression_regions_); + if (status != kTfLiteOk) { + ThrowRuntimeError("TFLM failed to set decompression memory"); + } + } + TfLiteStatus status = interpreter_->AllocateTensors(); if (status != kTfLiteOk) { ThrowRuntimeError("TFLM failed to allocate tensors"); diff --git a/python/tflite_micro/interpreter_wrapper.h b/python/tflite_micro/interpreter_wrapper.h index 9bb31b067fe..aa3da6ace79 100644 --- a/python/tflite_micro/interpreter_wrapper.h +++ b/python/tflite_micro/interpreter_wrapper.h @@ -19,6 +19,7 @@ limitations under the License. #include "python/tflite_micro/python_ops_resolver.h" #include "tensorflow/lite/micro/micro_allocator.h" +#include "tensorflow/lite/micro/micro_context.h" #include "tensorflow/lite/micro/micro_interpreter.h" #include "tensorflow/lite/micro/recording_micro_allocator.h" @@ -40,7 +41,8 @@ class InterpreterWrapper { InterpreterWrapper( PyObject* model_data, const std::vector& registerers_by_name, size_t arena_size, int num_resource_variables, - InterpreterConfig config = InterpreterConfig::kAllocationRecording); + InterpreterConfig config = InterpreterConfig::kAllocationRecording, + size_t alt_decompression_memory_size = 0); ~InterpreterWrapper(); void PrintAllocations(); @@ -57,6 +59,15 @@ class InterpreterWrapper { tflite::RecordingMicroAllocator* recording_allocator_ = nullptr; const PyObject* model_; std::unique_ptr memory_arena_; + std::unique_ptr alt_decompression_memory_; + tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_; + // SetDecompressionMemory stores a pointer to its initializer_list argument, + // requiring the list to outlive the interpreter. Per C++ standard, an + // initializer_list's backing array lifetime is only extended to match the + // list's when initialized in a declaration, not when assigned. This makes + // the API difficult to use correctly; see the constructor init list. + std::initializer_list + alt_decompression_regions_; tflite::PythonOpsResolver python_ops_resolver_; tflite::MicroInterpreter* interpreter_; }; diff --git a/python/tflite_micro/runtime.py b/python/tflite_micro/runtime.py index fbf2f205a50..f3b9029fae1 100644 --- a/python/tflite_micro/runtime.py +++ b/python/tflite_micro/runtime.py @@ -71,6 +71,7 @@ def __init__( custom_op_registerers, arena_size, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): if model_data is None: raise ValueError("Model must not be None") @@ -94,6 +95,7 @@ def __init__( arena_size, num_resource_variables, _ENUM_TRANSLATOR[intrepreter_config], + alt_decompression_memory_size, ) @classmethod @@ -103,6 +105,7 @@ def from_file( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model .tflite filepath. @@ -112,6 +115,9 @@ def from_file( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -127,6 +133,7 @@ def from_file( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) @classmethod @@ -136,6 +143,7 @@ def from_bytes( custom_op_registerers=[], arena_size=None, intrepreter_config=InterpreterConfig.kAllocationRecording, + alt_decompression_memory_size=0, ): """Instantiates a TFLM interpreter from a model in byte array. @@ -145,6 +153,9 @@ def from_bytes( custom OP registerer arena_size: Tensor arena size in bytes. If unused, tensor arena size will default to 10 times the model size. + alt_decompression_memory_size: Size in bytes of alternate decompression + memory. If non-zero, DECODE operators will use this memory instead of + the main arena for decompressed tensor outputs. Returns: An Interpreter instance @@ -155,6 +166,7 @@ def from_bytes( custom_op_registerers, arena_size, intrepreter_config, + alt_decompression_memory_size, ) def print_allocations(self): From b05656bd4e6d30820ebdacef9b09be10c6ec407f Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 1 Dec 2025 23:27:03 -0600 Subject: [PATCH 18/21] test(compression): add alt decompression memory integration test Add test for shared compressed tensors with alternate decompression memory. The test is marked expectedFailure to document the current mismatch between interpreter and DECODE insertion: the interpreter's alt decompression memory resets allocations for each DECODE, but the insertion code shares one DECODE output among all consumers. The workaround is to insert a separate DECODE before each consumer. The expectedFailure decorator should be removed once this is implemented. --- .../compression_integration_test.py | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) diff --git a/tensorflow/lite/micro/compression/compression_integration_test.py b/tensorflow/lite/micro/compression/compression_integration_test.py index e8893eb9014..1264d37179f 100644 --- a/tensorflow/lite/micro/compression/compression_integration_test.py +++ b/tensorflow/lite/micro/compression/compression_integration_test.py @@ -176,6 +176,187 @@ def test_lut_compressed_model_is_smaller(self): f"original ({original_size} bytes)") +def _build_shared_weights_model(): + """Build a model where one compressed tensor is shared between two operators. + + Model structure: + input1 -> [FC1 with weights1] -> output1 + input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2 + + weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which + runs between the two consumers of weights1. + """ + # 4 unique values per tensor for 2-bit LUT compression. Small values avoid + # saturation in chained layers. Different row sums produce varied outputs. + weights1_data = np.array([ + [-1, 0, 0, 1], + [-1, 0, 1, 1], + [-1, 1, 1, 1], + [0, 1, 1, 1], + ], + dtype=np.int8) + weights1 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights1_data, + name="weights1", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + weights2_data = np.array([ + [1, 1, 1, 1], + [1, 1, 2, 2], + [1, 2, 2, 3], + [2, 2, 3, 3], + ], + dtype=np.int8) + weights2 = model_editor.Tensor( + shape=(4, 4), + dtype=tflite.TensorType.INT8, + data=weights2_data, + name="weights2", + quantization=model_editor.Quantization(scales=1.0, zero_points=0), + ) + + # All tensors need matching quantization for FULLY_CONNECTED + quant = model_editor.Quantization(scales=1.0, zero_points=0) + + input1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input1", + quantization=quant, + ) + input2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="input2", + quantization=quant, + ) + output1 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output1", + quantization=quant, + ) + intermediate = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="intermediate", + quantization=quant, + ) + output2 = model_editor.Tensor( + shape=(1, 4), + dtype=tflite.TensorType.INT8, + name="output2", + quantization=quant, + ) + + model = model_editor.Model(subgraphs=[ + model_editor.Subgraph( + tensors=[weights1, weights2], + inputs=[input1, input2], + outputs=[output1, output2], + operators=[ + # FC1: uses weights1 + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input1, weights1], + outputs=[output1], + ), + # FC2: uses weights2 (runs between FC1 and FC3) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[input2, weights2], + outputs=[intermediate], + ), + # FC3: uses weights1 (second consumer, after DECODE(weights2)) + model_editor.Operator( + opcode=tflite.BuiltinOperator.FULLY_CONNECTED, + inputs=[intermediate, weights1], + outputs=[output2], + ), + ], + ) + ]) + return model.build() + + +class AltDecompressionMemoryTest(tf.test.TestCase): + """Tests for alternate decompression memory with shared compressed tensors. + + These tests verify correct behavior when compressed tensors are shared + between multiple operators and alternate decompression memory is enabled. + """ + + @unittest.expectedFailure + def test_shared_compressed_tensor_with_alt_memory(self): + """Verify correct results when a shared compressed tensor is used with alt + decompression memory. + + This test uses a graph where a compressed tensor (weights1) is consumed by + two operators (FC1 and FC3), with an intervening DECODE of a different + compressed tensor (weights2) between them. + + The interpreter's alternate decompression memory has a limitation: each + DECODE's Prepare resets the allocation offset to zero. This means all + DECODE outputs are allocated at the same address, so they overwrite each + other. A DECODE output can only be used until the next DECODE runs. + + To work around this limitation, the DECODE insertion code must insert a + separate DECODE immediately before each consumer of a compressed tensor, + rather than sharing one DECODE output among all consumers. + + This test is expected to fail because the current insertion code does not + yet implement this workaround. + """ + flatbuffer = _build_shared_weights_model() + + specs = [ + spec.Tensor( + subgraph=0, + tensor=0, # weights1 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + spec.Tensor( + subgraph=0, + tensor=1, # weights2 + compression=[spec.LookUpTableCompression(index_bitwidth=2)], + ), + ] + + compressed_fb = compress.compress(flatbuffer, specs) + + # Run without alt decompression memory (baseline) + interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb)) + + # Run with alt decompression memory + interp_with_alt = runtime.Interpreter.from_bytes( + bytes(compressed_fb), + alt_decompression_memory_size=256, + ) + + test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8) + test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8) + + interp_no_alt.set_input(test_input1, 0) + interp_no_alt.set_input(test_input2, 1) + interp_no_alt.invoke() + expected1 = interp_no_alt.get_output(0) + expected2 = interp_no_alt.get_output(1) + + interp_with_alt.set_input(test_input1, 0) + interp_with_alt.set_input(test_input2, 1) + interp_with_alt.invoke() + actual1 = interp_with_alt.get_output(0) + actual2 = interp_with_alt.get_output(1) + + self.assertAllEqual(expected1, actual1, + "Output 1 mismatch with alt decompression memory") + self.assertAllEqual(expected2, actual2, + "Output 2 mismatch with alt decompression memory") + + class HuffmanCompressionTest(tf.test.TestCase): """Integration tests for Huffman compression.""" From 58cc12fe1173484e3e5f7681abae3d307a08f98a Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Mon, 1 Dec 2025 23:36:45 -0600 Subject: [PATCH 19/21] fix(compression): insert DECODE per consumer for alt decompression memory Insert a separate DECODE immediately before each consumer of a compressed tensor, rather than sharing one DECODE output among all consumers. The interpreter's alternate decompression memory resets its allocation offset for each DECODE's Prepare, causing all DECODE outputs to be allocated at the same address. If two consumers share one DECODE and another DECODE runs between them, the intervening DECODE overwrites the shared output, corrupting data for the second consumer. Update test expectations to reflect the new DECODE-per-consumer behavior and change the integration test from expected-failure to expected-pass. --- .../compression_integration_test.py | 6 +-- .../lite/micro/compression/decode_insert.py | 44 ++++++++++++------- .../micro/compression/decode_insert_test.py | 28 +++++++----- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/tensorflow/lite/micro/compression/compression_integration_test.py b/tensorflow/lite/micro/compression/compression_integration_test.py index 1264d37179f..e755ac06ed0 100644 --- a/tensorflow/lite/micro/compression/compression_integration_test.py +++ b/tensorflow/lite/micro/compression/compression_integration_test.py @@ -289,7 +289,6 @@ class AltDecompressionMemoryTest(tf.test.TestCase): between multiple operators and alternate decompression memory is enabled. """ - @unittest.expectedFailure def test_shared_compressed_tensor_with_alt_memory(self): """Verify correct results when a shared compressed tensor is used with alt decompression memory. @@ -303,12 +302,9 @@ def test_shared_compressed_tensor_with_alt_memory(self): DECODE outputs are allocated at the same address, so they overwrite each other. A DECODE output can only be used until the next DECODE runs. - To work around this limitation, the DECODE insertion code must insert a + To work around this limitation, the DECODE insertion code inserts a separate DECODE immediately before each consumer of a compressed tensor, rather than sharing one DECODE output among all consumers. - - This test is expected to fail because the current insertion code does not - yet implement this workaround. """ flatbuffer = _build_shared_weights_model() diff --git a/tensorflow/lite/micro/compression/decode_insert.py b/tensorflow/lite/micro/compression/decode_insert.py index be976795aa2..d471ca63e31 100644 --- a/tensorflow/lite/micro/compression/decode_insert.py +++ b/tensorflow/lite/micro/compression/decode_insert.py @@ -143,11 +143,19 @@ def insert_decode_operators( This function modifies the model in-place, inserting DECODE operators before any operator that uses a compressed tensor as input. - For each compressed tensor: + A separate DECODE is inserted before each consumer, rather than sharing one + DECODE output among all consumers. This is required because the interpreter's + alternate decompression memory resets its allocation offset for each DECODE's + Prepare, causing all DECODE outputs to be allocated at the same address. If + two consumers share one DECODE and another DECODE runs between them, the + intervening DECODE overwrites the shared output, corrupting data for the + second consumer. + + For each consumer of a compressed tensor: 1. Create an ancillary data tensor containing DCM + type-specific data 2. Create an output tensor with the same shape/dtype as the decoded tensor - 3. Insert a DECODE operator before the first consumer - 4. Rewire all consumers to use the DECODE output instead of the encoded tensor + 3. Insert a DECODE operator immediately before the consumer + 4. Rewire the consumer to use the DECODE output Args: model: The model to modify in-place. @@ -179,23 +187,27 @@ def insert_decode_operators( for sg_idx, tensor_infos in by_subgraph.items(): subgraph = model.subgraphs[sg_idx] - # Sort by earliest consumer position (process in reverse order to maintain - # valid positions as we insert) - tensor_infos.sort( - key=lambda info: _find_earliest_consumer_position( - subgraph, info.consumers), + # Collect all (consumer, tensor_info) pairs and sort by consumer position + # in reverse order so insertions don't invalidate positions + consumer_pairs = [] + for info in tensor_infos: + for consumer in info.consumers: + consumer_pairs.append((consumer, info)) + + consumer_pairs.sort( + key=lambda pair: subgraph.operators.index(pair[0]), reverse=True, ) - for info in tensor_infos: - # Create ancillary data tensor + for consumer, info in consumer_pairs: + # Create ancillary data tensor (one per DECODE) ancillary_tensor = _create_ancillary_tensor( info.ancillary_data, info.tensor, ) subgraph.tensors.append(ancillary_tensor) - # Create output tensor + # Create output tensor (one per DECODE) output_tensor = _create_output_tensor(info.tensor) subgraph.tensors.append(output_tensor) @@ -207,11 +219,9 @@ def insert_decode_operators( outputs=[output_tensor], ) - # Find insertion position (before first consumer) - insert_pos = _find_earliest_consumer_position(subgraph, info.consumers) - - # Insert DECODE operator + # Insert DECODE immediately before this consumer + insert_pos = subgraph.operators.index(consumer) subgraph.operators.insert(insert_pos, decode_op) - # Rewire all consumers to use the decoded output - _rewire_consumers(info.consumers, info.tensor, output_tensor) + # Rewire only this consumer to use the decoded output + _rewire_consumers([consumer], info.tensor, output_tensor) diff --git a/tensorflow/lite/micro/compression/decode_insert_test.py b/tensorflow/lite/micro/compression/decode_insert_test.py index b2f3f5e6438..d293cf1f5d7 100644 --- a/tensorflow/lite/micro/compression/decode_insert_test.py +++ b/tensorflow/lite/micro/compression/decode_insert_test.py @@ -222,8 +222,8 @@ def test_consumer_rewired_to_decode_output(self): # Original weights tensor should NOT be in FC inputs self.assertNotIn(weights_tensor, fc_op.inputs) - def test_shared_tensor_single_decode(self): - """Tensor used by multiple ops gets single DECODE, both rewired.""" + def test_shared_tensor_decode_per_consumer(self): + """Tensor used by multiple ops gets separate DECODE for each consumer.""" model = _build_shared_weights_model() weights_tensor = model.subgraphs[0].tensors[0] @@ -239,17 +239,25 @@ def test_shared_tensor_single_decode(self): sg = model.subgraphs[0] - # Should have 3 operators: 1 DECODE + 2 FC - self.assertEqual(len(sg.operators), 3) + # Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC) + self.assertEqual(len(sg.operators), 4) self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[1].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) + self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM) + self.assertEqual(sg.operators[3].opcode, + tflite.BuiltinOperator.FULLY_CONNECTED) - decode_op = sg.operators[0] + decode_op1 = sg.operators[0] fc_op1 = sg.operators[1] - fc_op2 = sg.operators[2] - - # Both FCs should use DECODE's output - self.assertIs(fc_op1.inputs[1], decode_op.outputs[0]) - self.assertIs(fc_op2.inputs[1], decode_op.outputs[0]) + decode_op2 = sg.operators[2] + fc_op2 = sg.operators[3] + + # Each FC should use its own DECODE's output + self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0]) + self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0]) + # The two DECODEs should have different outputs + self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0]) def test_ancillary_tensor_contains_dcm(self): """Ancillary tensor data contains valid DCM header.""" From aed2a7e6e687f8d7f7b34f1e84da3899b60de4d1 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Fri, 5 Dec 2025 15:35:30 -0600 Subject: [PATCH 20/21] test(model_editor): add expected-failure tests for read() edge cases Add tests demonstrating bugs in model_editor.read() when parsing models with None values for tensor shape, operator inputs, or operator outputs. These edge cases can occur in real models from the TFLite converter but cause TypeError crashes in the current implementation. Tests construct models using the low-level TFLite schema to reproduce these conditions. Marked as expectedFailure until the fix is applied. --- .../micro/compression/model_editor_test.py | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py index 704f78cff57..9e804282a0c 100644 --- a/tensorflow/lite/micro/compression/model_editor_test.py +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -14,6 +14,8 @@ """Tests for model_editor module. """ +import unittest + import numpy as np import tensorflow as tf from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -815,5 +817,178 @@ def test_subgraph_inputs_outputs_loopback(self): self.assertEqual(sg.outputs[0].name, "output") +class TestReadEdgeCases(tf.test.TestCase): + """Test model_editor.read() with edge cases from real-world models. + + These tests construct models using the low-level TFLite schema to create + edge cases that may not be producible via model_editor.build(), but can + appear in models from other sources (e.g., TFLite converter). + """ + + def _build_model_with_schema(self, model_t): + """Build a flatbuffer from a ModelT using the low-level schema.""" + import flatbuffers + builder = flatbuffers.Builder(1024) + builder.Finish(model_t.Pack(builder)) + return bytes(builder.Output()) + + @unittest.expectedFailure + def test_read_scalar_tensor(self): + """Verify read() handles tensors with None shape (scalars). + + Some TFLite models have scalar tensors where shape is None rather than + an empty list. This can occur with constant scalars produced by certain + converters. + """ + # Build a minimal model with a scalar tensor (shape=None) + model_t = tflite.ModelT() + model_t.version = 3 + + # Buffer 0 is always empty, buffer 1 holds scalar data + buf0 = tflite.BufferT() + buf0.data = [] + buf1 = tflite.BufferT() + buf1.data = [42] # Single byte scalar value + + model_t.buffers = [buf0, buf1] + + # Create operator code + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.ADD + model_t.operatorCodes = [opcode] + + # Create subgraph with scalar tensor + sg = tflite.SubGraphT() + + # Tensor with shape=None (scalar) + scalar_tensor = tflite.TensorT() + scalar_tensor.name = b"scalar" + scalar_tensor.type = tflite.TensorType.INT8 + scalar_tensor.buffer = 1 + scalar_tensor.shape = None # This is the edge case + + # Normal tensor for comparison + normal_tensor = tflite.TensorT() + normal_tensor.name = b"normal" + normal_tensor.type = tflite.TensorType.INT8 + normal_tensor.buffer = 0 + normal_tensor.shape = [1, 4] + + sg.tensors = [scalar_tensor, normal_tensor] + sg.inputs = [1] + sg.outputs = [1] + sg.operators = [] + + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify scalar tensor was read with empty shape tuple + self.assertEqual(model.subgraphs[0].tensors[0].shape, ()) + self.assertEqual(model.subgraphs[0].tensors[0].name, "scalar") + + # Verify normal tensor shape is preserved + self.assertEqual(model.subgraphs[0].tensors[1].shape, (1, 4)) + + @unittest.expectedFailure + def test_read_operator_with_empty_inputs(self): + """Verify read() handles operators with None inputs/outputs. + + Some operators (e.g., certain control flow or custom ops) may have + empty input or output lists represented as None in the flatbuffer. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + # Custom op that might have unusual input/output patterns + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoInputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + # Single output tensor + output_tensor = tflite.TensorT() + output_tensor.name = b"output" + output_tensor.type = tflite.TensorType.INT8 + output_tensor.buffer = 0 + output_tensor.shape = [1] + + sg.tensors = [output_tensor] + sg.inputs = [] + sg.outputs = [0] + + # Operator with None inputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = None # This is the edge case + op.outputs = [0] + + sg.operators = [op] + model_t.subgraphs = [sg] + + # Build and read + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + # Verify operator was read with empty inputs list + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(model.subgraphs[0].operators[0].inputs, []) + self.assertEqual(len(model.subgraphs[0].operators[0].outputs), 1) + + @unittest.expectedFailure + def test_read_operator_with_empty_outputs(self): + """Verify read() handles operators with None outputs. + + Similar to empty inputs, some operators may have None outputs. + """ + model_t = tflite.ModelT() + model_t.version = 3 + + buf0 = tflite.BufferT() + buf0.data = [] + model_t.buffers = [buf0] + + opcode = tflite.OperatorCodeT() + opcode.builtinCode = tflite.BuiltinOperator.CUSTOM + opcode.customCode = b"NoOutputOp" + model_t.operatorCodes = [opcode] + + sg = tflite.SubGraphT() + + input_tensor = tflite.TensorT() + input_tensor.name = b"input" + input_tensor.type = tflite.TensorType.INT8 + input_tensor.buffer = 0 + input_tensor.shape = [1] + + sg.tensors = [input_tensor] + sg.inputs = [0] + sg.outputs = [] + + # Operator with None outputs (edge case) + op = tflite.OperatorT() + op.opcodeIndex = 0 + op.inputs = [0] + op.outputs = None # This is the edge case + + sg.operators = [op] + model_t.subgraphs = [sg] + + fb = self._build_model_with_schema(model_t) + model = model_editor.read(fb) + + self.assertEqual(len(model.subgraphs[0].operators), 1) + self.assertEqual(len(model.subgraphs[0].operators[0].inputs), 1) + self.assertEqual(model.subgraphs[0].operators[0].outputs, []) + + if __name__ == "__main__": tf.test.main() From ec2bb913c83a82c5bdb8c6c075151e20536fecf8 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Fri, 5 Dec 2025 15:40:12 -0600 Subject: [PATCH 21/21] fix(model_editor): handle None shape and inputs/outputs in read() The TFLite flatbuffer schema allows None values for tensor shape (representing scalars) and operator inputs/outputs (for certain ops). Handle these cases in read() to avoid TypeError when iterating. Remove expectedFailure decorators from edge case tests now that the fix is applied. --- tensorflow/lite/micro/compression/model_editor.py | 9 ++++++--- tensorflow/lite/micro/compression/model_editor_test.py | 5 ----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tensorflow/lite/micro/compression/model_editor.py b/tensorflow/lite/micro/compression/model_editor.py index 9cedb25af1d..b60736c980d 100644 --- a/tensorflow/lite/micro/compression/model_editor.py +++ b/tensorflow/lite/micro/compression/model_editor.py @@ -343,7 +343,8 @@ def read(buffer: bytes) -> Model: quant = Quantization(scales=scales, zero_points=zeros, axis=axis) - tensor = Tensor(shape=tuple(fb_tensor.shape), + shape = tuple(fb_tensor.shape) if fb_tensor.shape is not None else () + tensor = Tensor(shape=shape, dtype=fb_tensor.type, buffer=buf, name=name, @@ -357,9 +358,11 @@ def read(buffer: bytes) -> Model: # Get operator code info opcode_obj = model.operator_codes[fb_op.opcodeIndex] + inputs = [sg.tensors[i] for i in fb_op.inputs] if fb_op.inputs is not None else [] + outputs = [sg.tensors[i] for i in fb_op.outputs] if fb_op.outputs is not None else [] op = Operator(opcode=opcode_obj.builtin_code, - inputs=[sg.tensors[i] for i in fb_op.inputs], - outputs=[sg.tensors[i] for i in fb_op.outputs], + inputs=inputs, + outputs=outputs, custom_code=opcode_obj.custom_code, opcode_index=fb_op.opcodeIndex) sg.operators.append(op) diff --git a/tensorflow/lite/micro/compression/model_editor_test.py b/tensorflow/lite/micro/compression/model_editor_test.py index 9e804282a0c..0ce1fc6e371 100644 --- a/tensorflow/lite/micro/compression/model_editor_test.py +++ b/tensorflow/lite/micro/compression/model_editor_test.py @@ -14,8 +14,6 @@ """Tests for model_editor module. """ -import unittest - import numpy as np import tensorflow as tf from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite @@ -832,7 +830,6 @@ def _build_model_with_schema(self, model_t): builder.Finish(model_t.Pack(builder)) return bytes(builder.Output()) - @unittest.expectedFailure def test_read_scalar_tensor(self): """Verify read() handles tensors with None shape (scalars). @@ -892,7 +889,6 @@ def test_read_scalar_tensor(self): # Verify normal tensor shape is preserved self.assertEqual(model.subgraphs[0].tensors[1].shape, (1, 4)) - @unittest.expectedFailure def test_read_operator_with_empty_inputs(self): """Verify read() handles operators with None inputs/outputs. @@ -943,7 +939,6 @@ def test_read_operator_with_empty_inputs(self): self.assertEqual(model.subgraphs[0].operators[0].inputs, []) self.assertEqual(len(model.subgraphs[0].operators[0].outputs), 1) - @unittest.expectedFailure def test_read_operator_with_empty_outputs(self): """Verify read() handles operators with None outputs.