From 661ad5ea5be195e0ee236200b173ba641f366158 Mon Sep 17 00:00:00 2001 From: Ryan Kuester Date: Thu, 14 Nov 2024 11:31:23 -0600 Subject: [PATCH] feat(compression): update model compression tools --- tensorflow/lite/micro/compression/BUILD | 41 +- tensorflow/lite/micro/compression/compress.py | 396 ++++++++++-------- .../lite/micro/compression/compress_test.py | 160 ++++++- .../lite/micro/compression/discretize.py | 172 -------- tensorflow/lite/micro/compression/lib.py | 71 ---- tensorflow/lite/micro/compression/lib_test.py | 178 -------- .../lite/micro/compression/model_facade.py | 284 +++++++------ .../micro/compression/model_facade_test.py | 109 ++--- tensorflow/lite/micro/compression/spec.py | 64 +++ .../lite/micro/compression/spec_test.py | 37 ++ .../lite/micro/compression/test_models.py | 33 +- .../micro/compression/test_models_test.py | 13 +- tensorflow/lite/micro/compression/view.py | 1 - .../lite/micro/compression/view_test.py | 10 +- third_party/python_requirements.in | 1 + third_party/python_requirements.txt | 55 +++ 16 files changed, 737 insertions(+), 888 deletions(-) delete mode 100644 tensorflow/lite/micro/compression/discretize.py delete mode 100644 tensorflow/lite/micro/compression/lib.py delete mode 100644 tensorflow/lite/micro/compression/lib_test.py create mode 100644 tensorflow/lite/micro/compression/spec.py create mode 100644 tensorflow/lite/micro/compression/spec_test.py diff --git a/tensorflow/lite/micro/compression/BUILD b/tensorflow/lite/micro/compression/BUILD index 05713f207ab..1a3da4c2024 100644 --- a/tensorflow/lite/micro/compression/BUILD +++ b/tensorflow/lite/micro/compression/BUILD @@ -73,8 +73,9 @@ py_binary( "compress.py", ], deps = [ - ":lib", ":metadata_py", + ":model_facade", + ":spec", "//tensorflow/lite/python:schema_py", "@absl_py//absl:app", "@absl_py//absl/flags", @@ -93,6 +94,7 @@ py_test( ], deps = [ ":compress", + ":model_facade", requirement("tensorflow"), ], ) @@ -150,24 +152,6 @@ sh_test( ], ) -py_library( - name = "lib", - srcs = ["lib.py"], - deps = [ - "model_facade", - ], -) - -py_test( - name = "lib_test", - size = "small", - srcs = ["lib_test.py"], - deps = [ - "lib", - requirement("tensorflow"), - ], -) - py_library( name = "model_facade", srcs = ["model_facade.py"], @@ -183,6 +167,7 @@ py_library( deps = [ "//tensorflow/lite/python:schema_py", requirement("flatbuffers"), + requirement("numpy"), ], ) @@ -206,3 +191,21 @@ py_test( requirement("tensorflow"), ], ) + +py_library( + name = "spec", + srcs = ["spec.py"], + deps = [ + requirement("pyyaml"), + ], +) + +py_test( + name = "spec_test", + size = "small", + srcs = ["spec_test.py"], + deps = [ + ":spec", + requirement("tensorflow"), + ], +) diff --git a/tensorflow/lite/micro/compression/compress.py b/tensorflow/lite/micro/compression/compress.py index 01c3415bef7..1f6132a16fa 100644 --- a/tensorflow/lite/micro/compression/compress.py +++ b/tensorflow/lite/micro/compression/compress.py @@ -11,80 +11,86 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Usage: - bazel run tensorflow/lite/micro/tools:compress -- \\ - $(realpath ) [] - -Transform applicable tensors into compressed, look-up-table tensors. This is -the last stage of model compression. A prior stage must rewrite the elements of -those tensors to a small number of discrete values. This stage reduces such -tensors elements into indices into a value table. - -Identify tensors to compress according to the criteria: - 1. command line argument: --tensors [0:]3,[0:]4 - 2. metadata["COMPRESSION_INSTRUCTIONS"] json - 3. all inputs to operators known to understand compression (default) -""" -from dataclasses import dataclass -from functools import reduce -from typing import Sequence -import math -import os +import bitarray +import bitarray.util +from collections.abc import ByteString +from dataclasses import dataclass, field import sys -import textwrap - -from tflite_micro.tensorflow.lite.micro.compression import ( - lib, - model_facade, - metadata_py_generated as schema, -) +from typing import Iterable import absl.app import absl.flags -import bitarray -import bitarray.util import flatbuffers +import numpy as np + +from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import spec +from tflite_micro.tensorflow.lite.micro.compression import metadata_py_generated as schema + +USAGE = f"""\ +Usage: compress.py --input --spec [--output ] + +Produce a compressed model from the input model by compressing tensors +according to the instructions in the spec file. The spec file lists the tensors +to compress, the compression methods to use on each tensor, and any parameters +for each compression method. + +The spec file is a YAML-format file with a dictionary at the root, containing a +key "tensors" with a list of tensors to compress as its value. E.g.: + +--- +{spec.EXAMPLE_YAML_SPEC} +--- + +The only compression method currently implemented is "lut", i.e., +Look-Up-Table. This method requires the tensor in the input model to have a +small number of unique values, fewer than or equal to 2**index_bitwidth. LUT +compression collects these values into a lookup table, and rewrites the tensor +as bitwidth-wide integer indices into that lookup table. Presumably, the input +model has been trained or preprocessed in a way that the tensor values +are binned into a meaningful, limited set. +""" + +TFLITE_METADATA_KEY = "COMPRESSION_METADATA" + + +class CompressionError(Exception): + """Raised when compression fails for the reason documented in the message.""" + + def __init__(self, message, wrapped_exception=None): + super().__init__(f"{message}: {str(wrapped_exception)}") + self.original_exception = wrapped_exception -class MetadataBuilder: +class _MetadataBuilder: def __init__(self): self._metadata = schema.MetadataT() self._metadata.subgraphs = [] - def pack(self) -> bytearray: + def compile(self) -> bytearray: + """Packs the metadata into a binary array and returns it. + """ builder = flatbuffers.Builder(1 * 2**10) root = self._metadata.Pack(builder) builder.Finish(root) return builder.Output() - def subgraph(self, index): - """Return subgraph at index, adding subgraphs if necessary.""" - try: - subgraph = self._metadata.subgraphs[index] - except IndexError: - need = index + 1 - len(self._metadata.subgraphs) - for _ in range(0, need): - subgraph = self._add_subgraph() - return subgraph + def subgraph(self, index: int): + """Return subgraph at index, adding subgraphs if necessary. + """ + while len(self._metadata.subgraphs) <= index: + self._add_subgraph() + return self._metadata.subgraphs[index] - def add_lut_tensor(self, subgraph_id): - """Add LUT tensor to the given subgraph and return it.""" + def add_lut_tensor(self, subgraph_id: int): + """Add LUT tensor to the given subgraph and return it. + """ tensor = schema.LutTensorT() self.subgraph(subgraph_id).lutTensors.append(tensor) return tensor - def get_lut_by_tensor(self, tensor: model_facade.Tensor): - for subgraph_index in range(len(self._metadata.subgraphs)): - for item in self.subgraph(subgraph_index).lutTensors: - buffer_index = tensor.subgraph.model.subgraphs[subgraph_index].tensors[ - item.tensor].buffer_index - if tensor.buffer_index == buffer_index: - return item - return None - def _add_subgraph(self): subgraph = schema.SubgraphT() subgraph.lutTensors = [] @@ -92,166 +98,190 @@ def _add_subgraph(self): return subgraph -def pack(indices: Sequence[int], bitwidth: int) -> bytes: - """Pack an iterable of indices into a bytearray using bitwidth-sized fields. +@dataclass +class LutCompressedArray: + compression_axis: int = 0 + lookup_tables: list[np.ndarray] = field(default_factory=list) + indices: np.ndarray = field(default_factory=lambda: np.array([])) + + @property + def index_bitwidth(self) -> int: + """Returns the number of bits required to encode the indices.""" + if self.indices is None: + raise ValueError + + max_index = np.max(self.indices) + return int(np.ceil(np.log2(max_index) or 1)) + + +def _lut_compress_array(tensor: np.ndarray, axis: int) -> LutCompressedArray: + """Compresses using a lookup table per subarray along the given axis. + + Compressing a tensor with a lookup table per subarray along a particular axis + is analogous to quantizing a tensor with different quantization parameters + per subarray along a particular axis (dimension). + """ + compressed = LutCompressedArray() + compressed.compression_axis = axis + + # Iterate over subarrays along the compression axis + subarray_indices = [] + for subarray in np.moveaxis(tensor, axis, 0): + values, indices = np.unique(subarray, return_inverse=True) + compressed.lookup_tables.append(values) + indices = indices.reshape(subarray.shape) + subarray_indices.append(indices) + + # Reconstruct a tensor of indices from the subarrays + stacked = np.stack(subarray_indices, axis=0) + compressed.indices = np.moveaxis(stacked, 0, axis) + + return compressed + + +def _assert_lut_only(compression): + if len(compression) != 1: + raise CompressionError("Each tensor must have exactly one compression") + if not isinstance(compression[0], spec.LookUpTableCompression): + raise CompressionError('Only "lut" compression may be specified') + + +def _identify_compression_axis(tensor: model_facade._Tensor) -> int: + """Finds the axis along which to compress. + + Use the quantization axis, else the NWHC channel dimension. If necessary, + an user-specified override could be added to the compression spec schema. + """ + if tensor.quantization is not None: + axis = tensor.quantization.quantizedDimension + else: + axis = tensor.array.ndim - 1 + + return axis + + +def _check_bitwidth(compressed: int, specified: int, spec: spec.Tensor): + """Applies business logic regarding specified bitwidth. + + It is an error if the bitwidth required to compress a tensor exceeds the + specified bitwith, and a warning if the tensor can be compressed in less than + the specified bitwidth. The latter is allowed, and is not an error, to permit + testing with larger bitwidths without re-binning a model. + """ + if compressed > specified: + raise CompressionError( + f"index_bitwidth too small: {compressed} bits needed to " + f"enumerate unique values in tensor specified in {spec}") + elif compressed < specified: + print( + f"warning: index_bitwidth too large: only {compressed} " + f"bits needed to enumerate unique values in tensor specified in {spec}", + file=sys.stderr) + + +def _pack_indices(indices: np.ndarray, bitwidth: int) -> bytes: + """Packs indices into a bytearray using bitwidth-sized fields. """ endianness = "big" bits = bitarray.bitarray(endian=endianness) - for i in indices: - bits.extend(bitarray.util.int2ba(i, length=bitwidth, endian=endianness)) + for i in indices.ravel(): + bits.extend( + bitarray.util.int2ba(int(i), length=bitwidth, endian=endianness)) return bits.tobytes() -def add_lut_tensor(metadata: MetadataBuilder, *, subgraph_index: int, - tensor_index: int, buffer_index: int, bitwidth: int): - lut_tensor = metadata.add_lut_tensor(subgraph_id=subgraph_index) - lut_tensor.tensor = tensor_index - lut_tensor.valueBuffer = buffer_index - lut_tensor.indexBitwidth = bitwidth - +def _pack_lookup_tables(tables: list[np.ndarray], table_len: int) -> bytearray: + """Packs the value tables of a LutCompressedArray. -def lut_compress(tensor: model_facade.Tensor, metadata: MetadataBuilder, *, - alt_axis: bool): - """ Transform the given tensor into a compressed LUT tensor. + Pack the value tables of a LutCompressedArray into a bytes object in the + format writable to a value_table buffer in the .tflite flatbuffer. The + tables, one per subarray, are concatinated. """ - assert len(tensor.values) == reduce(lambda x, y: x * y, tensor.shape) - - # Identify levels per channel - nr_channels = tensor.channel_count - levels = [] - stride = len(tensor.values) // nr_channels - for channel in range(0, nr_channels): - if alt_axis: - channel_values = tensor.values[channel::nr_channels] - else: - start = channel * stride - end = start + stride - channel_values = tensor.values[start:end] - channel_levels = sorted(set(channel_values)) - levels.append(channel_levels) - - nr_levels = max((len(ch) for ch in levels)) - index_bitwidth = math.ceil(math.log2(nr_levels)) if nr_levels > 1 else 1 - - # create and write value buffer with levels - value_buffer = tensor.subgraph.model.add_buffer() - for channel in range(0, nr_channels): - values = levels[channel] - values.extend([0] * (nr_levels - len(values))) - value_buffer.extend_values(values, tensor.type) - - # rewrite original buffer with indices - indices = [] - for i, value in enumerate(tensor.values): - if alt_axis: - channel = i % nr_channels - else: - channel = i // stride - indices.append(levels[channel].index(value)) - tensor.buffer.data = pack(indices, index_bitwidth) - - # add metadata - add_lut_tensor(metadata, - subgraph_index=tensor.subgraph.index, - tensor_index=tensor.index, - buffer_index=value_buffer.index, - bitwidth=index_bitwidth) + buffer = bytearray() + for t in tables: + padding_needed = table_len - len(t) + padded = np.pad(t, (0, padding_needed), mode='constant', constant_values=0) + buffer.extend(padded.tobytes()) + return buffer -@dataclass -class TensorSpec: - subgraph_id: int - tensor_id: int +def compress(model_in: ByteString, specs: Iterable[spec.Tensor]) -> bytearray: + model = model_facade.read(model_in) + metadata = _MetadataBuilder() -def strategy_lut_listed_tensors(tensors: Sequence[TensorSpec], - alt_axis_tensors: Sequence[TensorSpec]): - """Return a strategy that lut-compresses each tensor listed in args. - """ + for spec in specs: + try: + tensor = model.subgraphs[spec.subgraph].tensors[spec.tensor] + _assert_lut_only(spec.compression) + axis = _identify_compression_axis(tensor) + compressed = _lut_compress_array(tensor.array, axis) + spec_bitwidth = spec.compression[0].index_bitwidth + _check_bitwidth(compressed.index_bitwidth, spec_bitwidth, spec) - def _strategy(model: model_facade.Model, metadata: MetadataBuilder): - for spec in tensors: - tensor = model.subgraphs[spec.subgraph_id].tensors[spec.tensor_id] - lut_data = metadata.get_lut_by_tensor(tensor) - if lut_data is not None: - add_lut_tensor(metadata, - subgraph_index=spec.subgraph_id, - tensor_index=spec.tensor_id, - buffer_index=lut_data.valueBuffer, - bitwidth=lut_data.indexBitwidth) - else: - lut_compress(tensor, metadata, alt_axis=False) - for spec in alt_axis_tensors: - tensor = model.subgraphs[spec.subgraph_id].tensors[spec.tensor_id] - lut_compress(tensor, metadata, alt_axis=True) - - return _strategy - - -def compress_model(buffer, strategy): - model = model_facade.read(buffer) - metadata = MetadataBuilder() - strategy(model, metadata) - model.add_metadata(lib.METADATA_KEY, metadata.pack()) - return model.pack() - - -def compress_file(input_path, output_path, strategy): - with open(input_path, "rb") as file: - buffer = bytes(file.read()) - compressed = compress_model(buffer, strategy) - with open(output_path, "wb") as file: - file.write(compressed) + # overwrite tensor data with indices + tensor.buffer.data = _pack_indices(compressed.indices, spec_bitwidth) + # write value buffer + value_buffer = model.add_buffer() + value_buffer.data = _pack_lookup_tables(compressed.lookup_tables, + 2**spec_bitwidth) + # add compression metadata for tensor + lut_tensor = metadata.add_lut_tensor(subgraph_id=tensor.subgraph.index) + lut_tensor.tensor = tensor.index + lut_tensor.valueBuffer = value_buffer.index + lut_tensor.indexBitwidth = spec_bitwidth -FLAGS = absl.flags.FLAGS -absl.flags.DEFINE_string("tensors", None, - "List of [subgraph]tensor,... indices to compress") -absl.flags.DEFINE_string("alt_axis_tensors", None, - "List of [subgraph]tensor,... indices to compress") + except Exception as e: + raise CompressionError(f"error compressing {spec}") from e + # add compression metadata to model + model.add_metadata(TFLITE_METADATA_KEY, metadata.compile()) -def parse_tensors_flag(arg): - if arg is None: - return [] + return model.compile() - specs = [] - for element in arg.split(","): - parts = [int(part) for part in element.split(":")] - if len(parts) == 1: - specs.append(TensorSpec(subgraph_id=0, tensor_id=parts[0])) - elif len(parts) == 2: - specs.append(TensorSpec(subgraph_id=parts[0], tensor_id=parts[1])) - return specs +FLAGS = absl.flags.FLAGS +absl.flags.DEFINE_string("input", None, None) +absl.flags.DEFINE_string("spec", None, None) +absl.flags.DEFINE_string("output", None, None) + + +def _fail_w_usage() -> int: + absl.app.usage() + return 1 def main(argv): - try: - input_path = argv[1] - except IndexError: - absl.app.usage() - return 1 + if len(argv) > 1: + return _fail_w_usage() + + in_path = FLAGS.input + if in_path is None: + return _fail_w_usage() + else: + with open(in_path, "rb") as in_file: + in_model = in_file.read() + + spec_path = FLAGS.spec + if spec_path is None: + return _fail_w_usage() + else: + with open(spec_path, "rb") as spec_file: + specs = spec.parse_yaml(spec_file.read()) - try: - output_path = argv[2] - except IndexError: - output_path = input_path.split(".tflite")[0] + ".compressed.tflite" + out_path = FLAGS.output + if out_path is None: + out_path = in_path.split(".tflite")[0] + ".compressed.tflite" - specs = parse_tensors_flag(FLAGS.tensors) - alt_axis_specs = parse_tensors_flag(FLAGS.alt_axis_tensors) - strategy = strategy_lut_listed_tensors(specs, alt_axis_specs) + compressed = compress(in_model, specs) - print(f"compressing {input_path} to {output_path}") - compress_file(input_path, output_path, strategy) + with open(out_path, "wb") as out_file: + out_file.write(compressed) return 0 if __name__ == "__main__": - name = os.path.basename(sys.argv[0]) - usage = textwrap.dedent(f"""\ - Usage: {name} [--tensors=] [--alt_axis_tensors=] - Compress a .tflite model.""") - sys.modules['__main__'].__doc__ = usage + sys.modules['__main__'].__doc__ = USAGE absl.app.run(main) diff --git a/tensorflow/lite/micro/compression/compress_test.py b/tensorflow/lite/micro/compression/compress_test.py index 228b5eea43b..59865946697 100644 --- a/tensorflow/lite/micro/compression/compress_test.py +++ b/tensorflow/lite/micro/compression/compress_test.py @@ -12,31 +12,155 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +import tensorflow as tf + from tflite_micro.tensorflow.lite.micro.compression import compress -import tensorflow as tf +class TestPackIndices(tf.test.TestCase): -class TestParseTensorsOption(tf.test.TestCase): + def test_basic_case(self): + indices = np.array([1, 2, 3]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0001_0010, 0b0011_0000]) + self.assertEqual(result, expected_bytes) - def testParseTensorsList(self): - arg = "3,4,0:5,1:6" - specs = compress.parse_tensors_flag(arg) - expect = [ - compress.TensorSpec(subgraph_id=0, tensor_id=3), - compress.TensorSpec(subgraph_id=0, tensor_id=4), - compress.TensorSpec(subgraph_id=0, tensor_id=5), - compress.TensorSpec(subgraph_id=1, tensor_id=6), - ] - self.assertEqual(specs, expect) + def test_single_element(self): + indices = np.array([10]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0000_1010]) + self.assertEqual(result, expected_bytes) + + def test_different_bitwidth(self): + indices = np.array([1, 2, 3]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0000_0001, 0b0000_0010, 0b0000_0011]) + self.assertEqual(result, expected_bytes) + + def test_large_numbers(self): + indices = np.array([255, 128, 64]) + bitwidth = 8 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b1111_1111, 0b1000_0000, 0b0100_0000]) + self.assertEqual(result, expected_bytes) + + def test_multidimensional_array(self): + indices = np.array([[1, 2], [3, 4]]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b0001_0010, 0b0011_0100]) + self.assertEqual(result, expected_bytes) + + def test_zero_bitwidth(self): + indices = np.array([0, 1, 2]) + bitwidth = 0 + with self.assertRaises(ValueError): + compress._pack_indices(indices, bitwidth) + + def test_empty_array(self): + indices = np.array([]) + bitwidth = 4 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = b"" + self.assertEqual(result, expected_bytes) + + def test_bitwidth_1(self): + indices = np.array([1, 0, 1, 1, 0, 1]) + bitwidth = 1 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b101101_00]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_2(self): + indices = np.array([1, 2, 3, 0]) + bitwidth = 2 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b01_10_11_00]) + self.assertEqual(result, expected_bytes) - def testParseTensorsSingleton(self): - arg = "1:3" - specs = compress.parse_tensors_flag(arg) - expect = [ - compress.TensorSpec(subgraph_id=1, tensor_id=3), + def test_bitwidth_3(self): + indices = np.array([1, 3, 5, 7]) + bitwidth = 3 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b001_011_10, 0b1_111_0000]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_5(self): + indices = np.array([1, 2, 16, 31]) + bitwidth = 5 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes([0b00001_000, 0b10_10000_1, 0b1111_0000]) + self.assertEqual(result, expected_bytes) + + def test_bitwidth_7(self): + indices = np.array([1, 64, 127, 32]) + bitwidth = 7 + result = compress._pack_indices(indices, bitwidth) + expected_bytes = bytes( + [0b0000001_1, 0b000000_11, 0b11111_010, 0b0000_0000]) + self.assertEqual(result, expected_bytes) + + +class TestPackLookupTables(tf.test.TestCase): + + def test_int16_positive(self): + tables = [np.array([0x1234, 0x5678], dtype=') [] - -Given a model, rewrite the elements of certain tensors to a small number of -discrete values. This is the first stage of model compression using -look-up-table tensors. A future stage, in a different program, transforms the -rewritten tensors into look-up-table tensors, wherein the elements are reduced -to indices into a value table containing the discrete values. - -This program is meant as a test and reference implementation for other -first-stage programs which similarly rewrite elements, using more sophesticated -methods for determining the discrete values. -""" - -from tensorflow.lite.python import schema_py_generated as tflite_schema - -import absl.app -import flatbuffers -import numpy as np -import sklearn.cluster -import struct -import sys - -TENSOR_TYPE_TO_STRUCT_FORMAT = { - tflite_schema.TensorType.INT8: "b", - tflite_schema.TensorType.INT16: "h", - tflite_schema.TensorType.INT32: "i", - tflite_schema.TensorType.FLOAT32: "f", -} - - -def unpack_buffer_data(data, struct_format): - little_endian = "<" - unpacker = struct.Struct(little_endian + struct_format) - values = [v[0] for v in unpacker.iter_unpack(bytes(data))] - return values - - -def bin_and_quant(sequence, num_values): - """Quantize a sequence of integers, minimizing the total error using k-means - clustering. - - Parameters: - sequence :list - a sequence of integers to be quanized - num_values :int - the number of quantization levels - - Returns: - The input sequence, with all values quantized to one of the discovered - quantization levels. - """ - sequence = np.array(sequence).reshape(-1, 1) - kmeans = sklearn.cluster.KMeans(n_clusters=num_values, - random_state=0).fit(sequence) - indices = kmeans.predict(sequence).tolist() - values = kmeans.cluster_centers_.flatten() - values = np.round(values).astype(int).tolist() - quantized = [values[i] for i in indices] - return quantized - - -def replace_buffer_data(buffer, values, format): - new = bytearray() - little_endian = "<" - packer = struct.Struct(little_endian + format) - for v in values: - new.extend(packer.pack(v)) - - assert (len(buffer.data) == len(new)) - buffer.data = new - - -def discretize_tensor(tensor, buffer): - format = TENSOR_TYPE_TO_STRUCT_FORMAT[tensor.type] - values = unpack_buffer_data(buffer.data, format) - levels = 4 - if len(values) > levels: - discretized = bin_and_quant(values, 4) - replace_buffer_data(buffer, discretized, format) - - -def map_actionable_opcodes(model): - """Sparsely map operator code indices to indices of input tensors to - discretize.""" - - actionable_operators = { - tflite_schema.BuiltinOperator.FULLY_CONNECTED: (1, 2) - } - - opcodes = {} - for index, operator_code in enumerate(model.operatorCodes): - inputs = actionable_operators.get(operator_code.builtinCode, None) - if inputs is not None: - opcodes[index] = inputs - - return opcodes - - -def discretize(model): - # Discretize the input tensors of which operator_codes? - actionable_opcodes = map_actionable_opcodes(model) - - # Walk graph nodes (operators) and build list of tensors to discretize - tensors = set() - for subgraph_id, subgraph in enumerate(model.subgraphs): - for operator_id, operator in enumerate(subgraph.operators): - inputs = actionable_opcodes.get(operator.opcodeIndex, None) - if inputs is not None: - for input in (operator.inputs[i] for i in inputs): - tensors.add(subgraph.tensors[input]) - - # Discretize tensors - for t in tensors: - discretize_tensor(t, model.buffers[t.buffer]) - - return model - - -def read_model(path): - with open(path, 'rb') as file: - buffer = bytearray(file.read()) - return tflite_schema.ModelT.InitFromPackedBuf(buffer, 0) - - -def write_model(model, path): - builder = flatbuffers.Builder(32) - root = model.Pack(builder) - builder.Finish(root) - buffer: bytearray = builder.Output() - - with open(path, 'wb') as file: - file.write(buffer) - - -def main(argv) -> None: - try: - input_path = argv[1] - except IndexError: - absl.app.usage() - return 1 - - try: - output_path = argv[2] - except IndexError: - output_path = input_path.split(".tflite")[0] + ".discretized.tflite" - - print(f"discretizing {input_path} to {output_path}") - - model = read_model(input_path) - model = discretize(model) - write_model(model, output_path) - - return 0 - - -if __name__ == "__main__": - rc = absl.app.run(main) - sys.exit(rc) diff --git a/tensorflow/lite/micro/compression/lib.py b/tensorflow/lite/micro/compression/lib.py deleted file mode 100644 index 04ae3d90682..00000000000 --- a/tensorflow/lite/micro/compression/lib.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import struct -import typing - -from tensorflow.lite.python import schema_py_generated as tflite - -import model_facade - -METADATA_KEY = "COMPRESSION_METADATA" - -# Operator input indices for which LUT-compression is implemented, by opcode -LUT_COMPRESSABLE_INPUTS = { - tflite.BuiltinOperator.FULLY_CONNECTED: (1, 2), -} - - -def lut_compressable_inputs( - model: model_facade.Model, ) -> typing.Sequence[model_facade.Tensor]: - """LUT-compressable input tensors in the model.""" - - tensors = set() - for subgraph in model.subgraphs: - for op in subgraph.operators: - indices = LUT_COMPRESSABLE_INPUTS.get(op.opcode.builtinCode, ()) - for i in indices: - tensors.add(op.inputs[i]) - return tensors - - -_struct_formats = { - tflite.TensorType.FLOAT32: "f", - tflite.TensorType.FLOAT16: "e", - tflite.TensorType.FLOAT64: "d", - tflite.TensorType.INT8: "b", - tflite.TensorType.INT16: "h", - tflite.TensorType.INT32: "i", - tflite.TensorType.INT64: "q", - tflite.TensorType.UINT8: "B", - tflite.TensorType.UINT16: "H", - tflite.TensorType.UINT32: "I", - tflite.TensorType.UINT64: "Q", - tflite.TensorType.BOOL: "?", -} - - -def buffer_values(buffer: model_facade.Buffer, type_: tflite.TensorType): - """Return properly-typed values unpacked from the given buffer. - """ - little_endian = "<" # always, per tflite schema - format = little_endian + _struct_formats[type_] - # iter_unpack yields tuples of length 1, unpack the tuples - return [t[0] for t in struct.iter_unpack(format, buffer.data)] - - -def tensor_values(tensor: model_facade.Tensor): - """Return properly-typed values for the given tensor. - """ - return buffer_values(tensor.data, tensor.type) diff --git a/tensorflow/lite/micro/compression/lib_test.py b/tensorflow/lite/micro/compression/lib_test.py deleted file mode 100644 index ed01f1e4191..00000000000 --- a/tensorflow/lite/micro/compression/lib_test.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2024 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import tensorflow as tf - -from tensorflow.lite.python import schema_py_generated as tflite - -import lib -import model_facade -import test_models - -TEST_MODEL = { - "operator_codes": { - 0: { - "builtin_code": tflite.BuiltinOperator.FULLY_CONNECTED, - }, - 1: { - "builtin_code": tflite.BuiltinOperator.ADD, - }, - }, - "subgraphs": { - 0: { - "operators": { - 0: { - "opcode_index": 1, # ADD - "inputs": ( - 1, - 2, - ), - "outputs": (3, ), - }, - 1: { - "opcode_index": 0, # FULLY_CONNECTED - "inputs": ( - 3, - 4, - 5, - ), - "outputs": (6, ), - }, - }, - "tensors": { - 0: { - "name": "tensor0", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 1, - }, - 1: { - "name": "tensor1", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 2, - }, - 2: { - "name": "tensor2", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 3, - }, - 3: { - "name": "tensor3", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 4, - }, - 4: { - "name": "tensor4", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 5, - }, - 5: { - "name": "tensor5", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 6, - }, - 6: { - "name": "tensor6", - "shape": (16, 1), - "type": tflite.TensorType.INT8, - "buffer": 6, - }, - }, - }, - }, - "buffers": { - 0: - bytes(), - 1: - bytes((206, 185, 109, 109, 212, 205, 25, 47, 42, 209, 94, 138, 182, 3, - 76, 2)), - 2: - bytes((148, 182, 190, 244, 159, 22, 165, 201, 178, 97, 85, 161, 126, - 39, 36, 107)), - 3: - bytes((67, 84, 53, 155, 137, 191, 63, 251, 102, 53, 123, 189, 34, 212, - 164, 199)), - 4: - bytes((243, 242, 195, 117, 196, 198, 158, 26, 76, 47, 246, 162, 222, - 94, 6, 255)), - 5: - bytes((137, 54, 208, 227, 58, 118, 231, 43, 81, 217, 169, 205, 202, - 138, 4, 145)), - 6: - bytes((234, 181, 174, 210, 0, 49, 101, 145, 0, 13, 167, 230, 86, 78, - 87, 106)), - } -} - - -class TestLUT(tf.test.TestCase): - - def setUp(self): - flatbuffer = test_models.build(TEST_MODEL) - self.facade = model_facade.read(flatbuffer) - - def testCompressableInputs(self): - tensors = lib.lut_compressable_inputs(self.facade) - self.assertEqual(len(tensors), 2) - names = set(t.name for t in tensors) - self.assertEqual(names, set(("tensor4", "tensor5"))) - - -class TestBufferValues(tf.test.TestCase): - - def setUp(self): - flatbuffer = test_models.build(TEST_MODEL) - self.facade = model_facade.read(flatbuffer) - - def testInt8(self): - get = lib.buffer_values(self.facade.buffers[1], tflite.TensorType.INT8) - expect = [] - data = TEST_MODEL["buffers"][1] - for value in data: - expect.append((value - 0x100) if (value & 0x80) else value) - self.assertAllEqual(get, expect) - - def testInt16(self): - get = lib.buffer_values(self.facade.buffers[1], tflite.TensorType.INT16) - expect = [] - data = TEST_MODEL["buffers"][1] - for i in range(0, len(data), 2): - value = (data[i + 1] << 8) + data[i] - expect.append((value - 0x1_0000) if (value & 0x8000) else value) - self.assertAllEqual(get, expect) - - -class TestTensorValues(tf.test.TestCase): - - def setUp(self): - flatbuffer = test_models.build(TEST_MODEL) - self.facade = model_facade.read(flatbuffer) - - def testInt8(self): - get = lib.tensor_values(self.facade.subgraphs[0].tensors[1]) - expect = [] - data = TEST_MODEL["buffers"][2] - for value in data: - expect.append((value - 0x100) if (value & 0x80) else value) - self.assertAllEqual(get, expect) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/lite/micro/compression/model_facade.py b/tensorflow/lite/micro/compression/model_facade.py index c934ec18ca9..61276a8c968 100644 --- a/tensorflow/lite/micro/compression/model_facade.py +++ b/tensorflow/lite/micro/compression/model_facade.py @@ -12,121 +12,82 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""A facade for manipulating tflite.Model. +"""A facade for working with tflite.Model. -Usage: - model = model_facade.read_file(path) - # manipulate model - model.write_file(path) +Provide convenient navigation and data types for working with tflite.Model, +which can be tedious and verbose to working with directly. -A tflite.Model can be tedious and verbose to navigate. +Usage: + model = model_facade.read(flatbuffer) + # manipulate + new_flatbuffer = model.compile() """ -# TODO: make a better distinction between object representation objects -# and facade objects. - -from typing import Sequence - -from tensorflow.lite.python import schema_py_generated as tflite - import flatbuffers -import struct - - -def read(buffer: bytes): - """Read a tflite.Model from a buffer and return a model facade.""" - schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) - return Model(schema_model) +import numpy as np +from numpy.typing import NDArray +from typing import ByteString, Generic, TypeVar +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite -class Model: - """A facade for manipulating tflite.Model.""" +_IteratorTo = TypeVar("_IteratorTo") - def __init__(self, representation: tflite.ModelT): - self.root = representation - def pack(self) -> bytearray: - """Pack and return the tflite.Model as a flatbuffer.""" - size_hint = 4 * 2**10 - builder = flatbuffers.Builder(size_hint) - builder.Finish(self.root.Pack(builder)) - return builder.Output() - - def add_buffer(self): - """Add a buffer to the model and return a Buffer facade.""" - buffer = tflite.BufferT() - buffer.data = [] - self.root.buffers.append(buffer) - index = len(self.root.buffers) - 1 - return Buffer(buffer, index, self.root) - - def add_metadata(self, key, value): - """Add a key-value pair, writing value to newly created tflite.Buffer.""" - metadata = tflite.MetadataT() - metadata.name = key - buffer = self.add_buffer() - buffer.data = value - metadata.buffer = buffer.index - self.root.metadata.append(metadata) - - @property - def operatorCodes(self): - return self.root.operatorCodes - - @property - def subgraphs(self): - return Iterator(self.root.subgraphs, Subgraph, parent=self) - - @property - def buffers(self): - return Iterator(self.root.buffers, Buffer, parent=self) - - -class Iterator: +class _Iterator(Generic[_IteratorTo]): def __init__(self, sequence, cls, parent): self._sequence = sequence self._cls = cls + self._index = 0 self._parent = parent - def __getitem__(self, key): + def __getitem__(self, key) -> _IteratorTo: return self._cls(self._sequence[key], key, self._parent) def __len__(self): return len(self._sequence) + def __iter__(self): + self._index = 0 + return self -class IndirectIterator: + def __next__(self): + try: + result = self[self._index] + self._index += 1 + return result + except IndexError: + raise StopIteration + + +class _IndirectIterator(Generic[_IteratorTo]): def __init__(self, indices, sequence): self._indices = indices + self._index = 0 self._sequence = sequence - def __getitem__(self, key): + def __getitem__(self, key) -> _IteratorTo: index = self._indices[key] return self._sequence[index] def __len__(self): return len(self._indices) + def __iter__(self): + self._index = 0 + return self -class Subgraph: - - def __init__(self, subgraph, index, model): - self.subgraph = subgraph - self.index = index - self.model = model + def __next__(self): + try: + result = self[self._index] + self._index += 1 + return result + except IndexError: + raise StopIteration - @property - def operators(self): - return Iterator(self.subgraph.operators, Operator, parent=self) - - @property - def tensors(self): - return Iterator(self.subgraph.tensors, Tensor, parent=self) - -class Operator: +class _Operator: def __init__(self, operator, index, subgraph): self.operator = operator @@ -139,82 +100,79 @@ def opcode(self) -> tflite.OperatorCodeT: @property def inputs(self): - return IndirectIterator(self.operator.inputs, self.subgraph.tensors) + return _IndirectIterator(self.operator.inputs, self.subgraph.tensors) + + +_NP_DTYPES = { + tflite.TensorType.FLOAT16: np.dtype(" np.dtype: + return _NP_DTYPES[self._tensor.type] @property - def values(self): - reader = struct.iter_unpack(_struct_formats[self.type], self.data) - # iter_unpack yields tuples of length 1, unpack the tuples - return [value[0] for value in reader] - - @property - def channel_count(self): - if (self.tensor.quantization is None - or self.tensor.quantization.scale is None): - return 1 - return len(self.tensor.quantization.scale) + def array(self) -> np.ndarray: + """Returns an array created from the Tensor's data, type, and shape. + Note the bytes in the data buffer and the Tensor's type and shape may be + inconsistent, and thus the returned array invalid, if the data buffer has + been altered according to the compression schema, in which the data buffer + is an array of fixed-width, integer fields. + """ + return np.frombuffer(self.data, + dtype=self.dtype).reshape(self._tensor.shape) -_struct_formats = { - tflite.TensorType.FLOAT32: " _Iterator[_Operator]: + return _Iterator(self.subgraph.operators, _Operator, parent=self) + + @property + def tensors(self) -> _Iterator[_Tensor]: + return _Iterator(self.subgraph.tensors, _Tensor, parent=self) + + +class _Model: + """A facade for manipulating tflite.Model. + """ + + def __init__(self, representation: tflite.ModelT): + self.root = representation + + def compile(self) -> bytearray: + """Returns a tflite.Model flatbuffer. + """ + size_hint = 4 * 2**10 + builder = flatbuffers.Builder(size_hint) + builder.Finish(self.root.Pack(builder)) + return builder.Output() + + def add_buffer(self) -> _Buffer: + """Adds a buffer to the model. + """ + buffer = tflite.BufferT() + buffer.data = [] + self.root.buffers.append(buffer) + index = len(self.root.buffers) - 1 + return _Buffer(buffer, index, self.root) + + def add_metadata(self, key, value): + """Adds a key-value pair, writing value to a newly created buffer. + """ + metadata = tflite.MetadataT() + metadata.name = key + buffer = self.add_buffer() + buffer.data = value + metadata.buffer = buffer.index + self.root.metadata.append(metadata) + + @property + def operatorCodes(self): + return self.root.operatorCodes + + @property + def subgraphs(self) -> _Iterator[_Subgraph]: + return _Iterator(self.root.subgraphs, _Subgraph, parent=self) + + @property + def buffers(self) -> _Iterator[_Buffer]: + return _Iterator(self.root.buffers, _Buffer, parent=self) + + +def read(buffer: ByteString): + """Reads a tflite.Model and returns a model facade. + """ + schema_model = tflite.ModelT.InitFromPackedBuf(buffer, 0) + return _Model(schema_model) diff --git a/tensorflow/lite/micro/compression/model_facade_test.py b/tensorflow/lite/micro/compression/model_facade_test.py index 739ae0f3cc0..0a75aa9f89d 100644 --- a/tensorflow/lite/micro/compression/model_facade_test.py +++ b/tensorflow/lite/micro/compression/model_facade_test.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np import tensorflow as tf -from tensorflow.lite.python import schema_py_generated as tflite - -import model_facade -import test_models +from tflite_micro.tensorflow.lite.python import schema_py_generated as tflite +from tflite_micro.tensorflow.lite.micro.compression import model_facade +from tflite_micro.tensorflow.lite.micro.compression import test_models TEST_MODEL = { "operator_codes": { @@ -60,29 +60,29 @@ "name": "tensor1", "shape": (8, 1), "type": tflite.TensorType.INT16, - "buffer": 1, + "buffer": 2, }, 2: { "name": "tensor2", "shape": (4, 1), "type": tflite.TensorType.INT32, - "buffer": 1, + "buffer": 3, }, 3: { "name": "tensor3", "shape": (2, 1), "type": tflite.TensorType.INT64, - "buffer": 1, + "buffer": 4, }, }, }, }, "buffers": { - 0: - bytes(), - 1: - bytes((206, 185, 109, 109, 212, 205, 25, 47, 42, 209, 94, 138, 182, 3, - 76, 2)), + 0: None, + 1: np.array(range(16), dtype=np.dtype(" bytearray: - """Build a tflite flatbuffer from a model spec. + """Builds a tflite flatbuffer from a model spec. Args: spec: A dictionary representation of the model, a prototype of which @@ -70,7 +70,15 @@ def build(spec: dict) -> bytearray: for id, data in spec["buffers"].items(): assert id == len(root.buffers) buffer_t = tflite.BufferT() - buffer_t.data = data + + if data is None: + buffer_t.data = [] + elif isinstance(data, np.ndarray): + array = data.astype(data.dtype.newbyteorder("<")) # ensure little-endian + buffer_t.data = list(array.tobytes()) + else: + raise TypeError(f"buffer_id {id} has invalid data {data}") + root.buffers.append(buffer_t) size_hint = 1 * 2**20 @@ -80,15 +88,6 @@ def build(spec: dict) -> bytearray: return flatbuffer -def get_buffer(spec: dict, subgraph: int, tensor: int) -> bytearray: - """Return the buffer for a given tensor in a model spec. - """ - tensor_spec = spec["subgraphs"][subgraph]["tensors"][tensor] - buffer_id = tensor_spec["buffer"] - buffer = spec["buffers"][buffer_id] - return buffer - - EXAMPLE_MODEL = { "operator_codes": { 0: { @@ -143,10 +142,10 @@ def get_buffer(spec: dict, subgraph: int, tensor: int) -> bytearray: }, }, "buffers": { - 0: bytes(), - 1: bytes((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), - 2: bytes((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), - 3: bytes((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), - 4: bytes((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)), + 0: None, + 1: np.array(range(16), dtype=np.dtype(" 50) -class TestSpecManipulation(tf.test.TestCase): - - def testGetBuffer(self): - buffer = test_models.get_buffer(test_models.EXAMPLE_MODEL, - subgraph=0, - tensor=0) - expect = test_models.EXAMPLE_MODEL["buffers"][1] - self.assertTrue(buffer is expect) - - if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/lite/micro/compression/view.py b/tensorflow/lite/micro/compression/view.py index dd7d520cdd8..1a72baf7e94 100644 --- a/tensorflow/lite/micro/compression/view.py +++ b/tensorflow/lite/micro/compression/view.py @@ -19,7 +19,6 @@ import os import sys -import lib from tensorflow.lite.micro.compression import metadata_py_generated as compression_schema from tensorflow.lite.python import schema_py_generated as tflite_schema diff --git a/tensorflow/lite/micro/compression/view_test.py b/tensorflow/lite/micro/compression/view_test.py index 47c02cfe5ea..c8882e8f3c0 100644 --- a/tensorflow/lite/micro/compression/view_test.py +++ b/tensorflow/lite/micro/compression/view_test.py @@ -13,9 +13,9 @@ # limitations under the License. from absl.testing import absltest - -import test_models -import view +import numpy as np +from tflite_micro.tensorflow.lite.micro.compression import test_models +from tflite_micro.tensorflow.lite.micro.compression import view _MODEL = { "description": "Test model", @@ -62,8 +62,8 @@ }, }, "buffers": { - 0: bytes(), - 1: bytes(i for i in range(1, 16)), + 0: None, + 1: np.array(range(16), dtype="