Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
5eccce5
feat(compression): implement model_editor for TFLite model manipulation
rkuester Nov 25, 2025
953b734
refactor(compression): migrate compress.py from model_facade to model…
rkuester Nov 25, 2025
fdf9c60
chore(compression): remove model_facade.py and model_facade_test.py
rkuester Nov 25, 2025
dcb9577
refactor(compression): replace test_models with model_editor in compr…
rkuester Nov 25, 2025
c252adf
chore(compression): remove test_models.py and test_models_test.py
rkuester Nov 25, 2025
f55c780
feat(compression): add DecodeType class for type-safe decode operations
rkuester Nov 25, 2025
62e2805
docs(compression): add docstring to LookUpTableCompression
rkuester Nov 25, 2025
e80725d
feat(compression): add HuffmanCompression and PruningCompression types
rkuester Nov 25, 2025
5b2072f
refactor(compression): extract _parse_compression_method helper
rkuester Nov 25, 2025
bda4f77
feat(compression): add Compressor protocol and CompressionResult
rkuester Nov 25, 2025
69e8e50
feat(compression): add LUT compression plugin
rkuester Nov 25, 2025
86136ed
feat(compression): add Huffman and Pruning plugin stubs
rkuester Nov 25, 2025
84f671c
feat(compression): add DECODE operator insertion logic
rkuester Nov 25, 2025
a98304f
refactor(compression): use plugin architecture in compress.py
rkuester Nov 25, 2025
93b12cd
feat(model_editor): add subgraph inputs and outputs fields
rkuester Nov 25, 2025
01cfab5
test(compression): add integration tests with TFLM interpreter
rkuester Nov 25, 2025
bee7560
feat(python): add alt decompression memory parameter to interpreter
rkuester Dec 2, 2025
b05656b
test(compression): add alt decompression memory integration test
rkuester Dec 2, 2025
58cc12f
fix(compression): insert DECODE per consumer for alt decompression me…
rkuester Dec 2, 2025
aed2a7e
test(model_editor): add expected-failure tests for read() edge cases
rkuester Dec 5, 2025
ec2bb91
fix(model_editor): handle None shape and inputs/outputs in read()
rkuester Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/tflite_micro/_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ PYBIND11_MODULE(_runtime, m) {
.def(py::init([](const py::bytes& data,
const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
tflite::InterpreterConfig config) {
return std::unique_ptr<InterpreterWrapper>(
new InterpreterWrapper(data.ptr(), registerers_by_name, arena_size,
num_resource_variables, config));
tflite::InterpreterConfig config,
size_t alt_decompression_memory_size) {
return std::unique_ptr<InterpreterWrapper>(new InterpreterWrapper(
data.ptr(), registerers_by_name, arena_size, num_resource_variables,
config, alt_decompression_memory_size));
}))
.def("PrintAllocations", &InterpreterWrapper::PrintAllocations)
.def("Invoke", &InterpreterWrapper::Invoke)
Expand Down
22 changes: 20 additions & 2 deletions python/tflite_micro/interpreter_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,18 @@ InterpreterWrapper::~InterpreterWrapper() {

InterpreterWrapper::InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables, InterpreterConfig config) {
size_t arena_size, int num_resource_variables, InterpreterConfig config,
size_t alt_decompression_memory_size)
// Members initialized in declaration order. alt_decompression_regions_
// MUST be initialized here (not assigned in body) so its backing array
// lifetime is extended to match the member's lifetime.
: memory_arena_(new uint8_t[arena_size]),
alt_decompression_memory_(alt_decompression_memory_size > 0
? new uint8_t[alt_decompression_memory_size]
: nullptr),
alt_decompression_region_{alt_decompression_memory_.get(),
alt_decompression_memory_size},
alt_decompression_regions_{alt_decompression_region_} {
interpreter_ = nullptr;

// `model_data` is used as a raw pointer beyond the scope of this
Expand Down Expand Up @@ -266,7 +277,6 @@ InterpreterWrapper::InterpreterWrapper(
"--//:with_compression=true to enable compression support.");
}

memory_arena_ = std::unique_ptr<uint8_t[]>(new uint8_t[arena_size]);
for (const std::string& registerer : registerers_by_name) {
if (!AddCustomOpRegistererByName(registerer.c_str(),
&python_ops_resolver_)) {
Expand Down Expand Up @@ -296,6 +306,14 @@ InterpreterWrapper::InterpreterWrapper(
interpreter_ = new MicroInterpreter(model, python_ops_resolver_, allocator_,
resource_variables_);

if (alt_decompression_memory_size > 0) {
TfLiteStatus status =
interpreter_->SetDecompressionMemory(alt_decompression_regions_);
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to set decompression memory");
}
}

TfLiteStatus status = interpreter_->AllocateTensors();
if (status != kTfLiteOk) {
ThrowRuntimeError("TFLM failed to allocate tensors");
Expand Down
13 changes: 12 additions & 1 deletion python/tflite_micro/interpreter_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include "python/tflite_micro/python_ops_resolver.h"
#include "tensorflow/lite/micro/micro_allocator.h"
#include "tensorflow/lite/micro/micro_context.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/recording_micro_allocator.h"

Expand All @@ -40,7 +41,8 @@ class InterpreterWrapper {
InterpreterWrapper(
PyObject* model_data, const std::vector<std::string>& registerers_by_name,
size_t arena_size, int num_resource_variables,
InterpreterConfig config = InterpreterConfig::kAllocationRecording);
InterpreterConfig config = InterpreterConfig::kAllocationRecording,
size_t alt_decompression_memory_size = 0);
~InterpreterWrapper();

void PrintAllocations();
Expand All @@ -57,6 +59,15 @@ class InterpreterWrapper {
tflite::RecordingMicroAllocator* recording_allocator_ = nullptr;
const PyObject* model_;
std::unique_ptr<uint8_t[]> memory_arena_;
std::unique_ptr<uint8_t[]> alt_decompression_memory_;
tflite::MicroContext::AlternateMemoryRegion alt_decompression_region_;
// SetDecompressionMemory stores a pointer to its initializer_list argument,
// requiring the list to outlive the interpreter. Per C++ standard, an
// initializer_list's backing array lifetime is only extended to match the
// list's when initialized in a declaration, not when assigned. This makes
// the API difficult to use correctly; see the constructor init list.
std::initializer_list<tflite::MicroContext::AlternateMemoryRegion>
alt_decompression_regions_;
tflite::PythonOpsResolver python_ops_resolver_;
tflite::MicroInterpreter* interpreter_;
};
Expand Down
12 changes: 12 additions & 0 deletions python/tflite_micro/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(
custom_op_registerers,
arena_size,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
if model_data is None:
raise ValueError("Model must not be None")
Expand All @@ -94,6 +95,7 @@ def __init__(
arena_size,
num_resource_variables,
_ENUM_TRANSLATOR[intrepreter_config],
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -103,6 +105,7 @@ def from_file(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model .tflite filepath.

Expand All @@ -112,6 +115,9 @@ def from_file(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -127,6 +133,7 @@ def from_file(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

@classmethod
Expand All @@ -136,6 +143,7 @@ def from_bytes(
custom_op_registerers=[],
arena_size=None,
intrepreter_config=InterpreterConfig.kAllocationRecording,
alt_decompression_memory_size=0,
):
"""Instantiates a TFLM interpreter from a model in byte array.

Expand All @@ -145,6 +153,9 @@ def from_bytes(
custom OP registerer
arena_size: Tensor arena size in bytes. If unused, tensor arena size will
default to 10 times the model size.
alt_decompression_memory_size: Size in bytes of alternate decompression
memory. If non-zero, DECODE operators will use this memory instead of
the main arena for decompressed tensor outputs.

Returns:
An Interpreter instance
Expand All @@ -155,6 +166,7 @@ def from_bytes(
custom_op_registerers,
arena_size,
intrepreter_config,
alt_decompression_memory_size,
)

def print_allocations(self):
Expand Down
158 changes: 138 additions & 20 deletions tensorflow/lite/micro/compression/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ py_library(
"compress.py",
],
deps = [
":metadata_py",
":model_facade",
":compressor",
":decode_insert",
":huffman",
":lut",
":model_editor",
":pruning",
":spec",
"//tensorflow/lite/micro/tools:tflite_flatbuffer_align",
requirement("absl_py"),
"@flatbuffers//:runtime_py",
requirement("bitarray"),
requirement("numpy"),
],
)

Expand All @@ -148,33 +149,149 @@ py_test(
],
deps = [
":compress",
":metadata_py",
":model_facade",
":compressor",
":decode_insert",
":model_editor",
":spec",
"//tensorflow/lite/python:schema_py",
requirement("numpy"),
requirement("tensorflow"),
],
)

py_test(
name = "compression_integration_test",
size = "small",
srcs = ["compression_integration_test.py"],
tags = [
"noasan",
"nomsan",
"noubsan",
],
# Only run when compression IS enabled
target_compatible_with = select({
"//:with_compression_enabled": [],
"//conditions:default": ["@platforms//:incompatible"],
}),
deps = [
":compress_lib",
":decode_insert",
":model_editor",
":spec",
":test_models",
"//python/tflite_micro:runtime",
"//tensorflow/lite/python:schema_py",
requirement("numpy"),
requirement("tensorflow"),
],
)

py_library(
name = "compressor",
srcs = ["compressor.py"],
deps = [
":decode",
":model_editor",
":spec",
],
)

py_library(
name = "lut",
srcs = ["lut.py"],
deps = [
":compressor",
":decode",
":model_editor",
":spec",
requirement("bitarray"),
requirement("numpy"),
],
)

py_test(
name = "lut_test",
size = "small",
srcs = ["lut_test.py"],
tags = [
"noasan",
"nomsan",
"noubsan",
],
deps = [
":compressor",
":decode",
":lut",
":model_editor",
":spec",
"//tensorflow/lite/python:schema_py",
requirement("numpy"),
requirement("tensorflow"),
],
)

py_library(
name = "model_facade",
srcs = ["model_facade.py"],
name = "huffman",
srcs = ["huffman.py"],
deps = [
":compressor",
":decode",
":model_editor",
":spec",
],
)

py_library(
name = "pruning",
srcs = ["pruning.py"],
deps = [
":compressor",
":decode",
":model_editor",
":spec",
],
)

py_library(
name = "decode_insert",
srcs = ["decode_insert.py"],
deps = [
":compressor",
":model_editor",
"//tensorflow/lite/python:schema_py",
requirement("flatbuffers"),
],
)

py_test(
name = "model_facade_test",
name = "decode_insert_test",
size = "small",
srcs = ["model_facade_test.py"],
srcs = ["decode_insert_test.py"],
tags = [
"noasan",
"nomsan",
"noubsan",
],
deps = [
":model_facade",
":test_models",
":compressor",
":decode",
":decode_insert",
":model_editor",
"//tensorflow/lite/python:schema_py",
requirement("numpy"),
requirement("tensorflow"),
],
)

py_library(
name = "decode",
srcs = ["decode.py"],
)

py_test(
name = "decode_test",
size = "small",
srcs = ["decode_test.py"],
deps = [
":decode",
requirement("tensorflow"),
],
)
Expand Down Expand Up @@ -217,8 +334,8 @@ py_test(
)

py_library(
name = "test_models",
srcs = ["test_models.py"],
name = "model_editor",
srcs = ["model_editor.py"],
deps = [
"//tensorflow/lite/python:schema_py",
requirement("flatbuffers"),
Expand All @@ -227,12 +344,13 @@ py_library(
)

py_test(
name = "test_models_test",
name = "model_editor_test",
size = "small",
srcs = ["test_models_test.py"],
srcs = ["model_editor_test.py"],
deps = [
":test_models",
":model_editor",
"//tensorflow/lite/python:schema_py",
requirement("numpy"),
requirement("tensorflow"),
],
)
Expand Down
Loading
Loading