Skip to content

Commit cf78500

Browse files
committed
fix(compression): insert DECODE per consumer for alt decompression memory
Insert a separate DECODE immediately before each consumer of a compressed tensor, rather than sharing one DECODE output among all consumers. The interpreter's alternate decompression memory resets its allocation offset for each DECODE's Prepare, causing all DECODE outputs to be allocated at the same address. If two consumers share one DECODE and another DECODE runs between them, the intervening DECODE overwrites the shared output, corrupting data for the second consumer. Update test expectations to reflect the new DECODE-per-consumer behavior and change the integration test from expected-failure to expected-pass.
1 parent ed35ee0 commit cf78500

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

tensorflow/lite/micro/compression/compression_integration_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ class AltDecompressionMemoryTest(tf.test.TestCase):
287287
between multiple operators and alternate decompression memory is enabled.
288288
"""
289289

290-
@unittest.expectedFailure
291290
def test_shared_compressed_tensor_with_alt_memory(self):
292291
"""Verify correct results when a shared compressed tensor is used with alt
293292
decompression memory.
@@ -301,12 +300,9 @@ def test_shared_compressed_tensor_with_alt_memory(self):
301300
DECODE outputs are allocated at the same address, so they overwrite each
302301
other. A DECODE output can only be used until the next DECODE runs.
303302
304-
To work around this limitation, the DECODE insertion code must insert a
303+
To work around this limitation, the DECODE insertion code inserts a
305304
separate DECODE immediately before each consumer of a compressed tensor,
306305
rather than sharing one DECODE output among all consumers.
307-
308-
This test is expected to fail because the current insertion code does not
309-
yet implement this workaround.
310306
"""
311307
flatbuffer = _build_shared_weights_model()
312308

tensorflow/lite/micro/compression/decode_insert.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,19 @@ def insert_decode_operators(
144144
This function modifies the model in-place, inserting DECODE operators
145145
before any operator that uses a compressed tensor as input.
146146
147-
For each compressed tensor:
147+
A separate DECODE is inserted before each consumer, rather than sharing one
148+
DECODE output among all consumers. This is required because the interpreter's
149+
alternate decompression memory resets its allocation offset for each DECODE's
150+
Prepare, causing all DECODE outputs to be allocated at the same address. If
151+
two consumers share one DECODE and another DECODE runs between them, the
152+
intervening DECODE overwrites the shared output, corrupting data for the
153+
second consumer.
154+
155+
For each consumer of a compressed tensor:
148156
1. Create an ancillary data tensor containing DCM + type-specific data
149157
2. Create an output tensor with the same shape/dtype as the decoded tensor
150-
3. Insert a DECODE operator before the first consumer
151-
4. Rewire all consumers to use the DECODE output instead of the encoded tensor
158+
3. Insert a DECODE operator immediately before the consumer
159+
4. Rewire the consumer to use the DECODE output
152160
153161
Args:
154162
model: The model to modify in-place.
@@ -180,23 +188,27 @@ def insert_decode_operators(
180188
for sg_idx, tensor_infos in by_subgraph.items():
181189
subgraph = model.subgraphs[sg_idx]
182190

183-
# Sort by earliest consumer position (process in reverse order to maintain
184-
# valid positions as we insert)
185-
tensor_infos.sort(
186-
key=lambda info: _find_earliest_consumer_position(
187-
subgraph, info.consumers),
191+
# Collect all (consumer, tensor_info) pairs and sort by consumer position
192+
# in reverse order so insertions don't invalidate positions
193+
consumer_pairs = []
194+
for info in tensor_infos:
195+
for consumer in info.consumers:
196+
consumer_pairs.append((consumer, info))
197+
198+
consumer_pairs.sort(
199+
key=lambda pair: subgraph.operators.index(pair[0]),
188200
reverse=True,
189201
)
190202

191-
for info in tensor_infos:
192-
# Create ancillary data tensor
203+
for consumer, info in consumer_pairs:
204+
# Create ancillary data tensor (one per DECODE)
193205
ancillary_tensor = _create_ancillary_tensor(
194206
info.ancillary_data,
195207
info.tensor,
196208
)
197209
subgraph.tensors.append(ancillary_tensor)
198210

199-
# Create output tensor
211+
# Create output tensor (one per DECODE)
200212
output_tensor = _create_output_tensor(info.tensor)
201213
subgraph.tensors.append(output_tensor)
202214

@@ -208,11 +220,9 @@ def insert_decode_operators(
208220
outputs=[output_tensor],
209221
)
210222

211-
# Find insertion position (before first consumer)
212-
insert_pos = _find_earliest_consumer_position(subgraph, info.consumers)
213-
214-
# Insert DECODE operator
223+
# Insert DECODE immediately before this consumer
224+
insert_pos = subgraph.operators.index(consumer)
215225
subgraph.operators.insert(insert_pos, decode_op)
216226

217-
# Rewire all consumers to use the decoded output
218-
_rewire_consumers(info.consumers, info.tensor, output_tensor)
227+
# Rewire only this consumer to use the decoded output
228+
_rewire_consumers([consumer], info.tensor, output_tensor)

tensorflow/lite/micro/compression/decode_insert_test.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ def test_consumer_rewired_to_decode_output(self):
224224
# Original weights tensor should NOT be in FC inputs
225225
self.assertNotIn(weights_tensor, fc_op.inputs)
226226

227-
def test_shared_tensor_single_decode(self):
228-
"""Tensor used by multiple ops gets single DECODE, both rewired."""
227+
def test_shared_tensor_decode_per_consumer(self):
228+
"""Tensor used by multiple ops gets separate DECODE for each consumer."""
229229
model = _build_shared_weights_model()
230230
weights_tensor = model.subgraphs[0].tensors[0]
231231

@@ -241,17 +241,25 @@ def test_shared_tensor_single_decode(self):
241241

242242
sg = model.subgraphs[0]
243243

244-
# Should have 3 operators: 1 DECODE + 2 FC
245-
self.assertEqual(len(sg.operators), 3)
244+
# Should have 4 operators: 2 DECODEs + 2 FCs (DECODE before each FC)
245+
self.assertEqual(len(sg.operators), 4)
246246
self.assertEqual(sg.operators[0].opcode, tflite.BuiltinOperator.CUSTOM)
247+
self.assertEqual(sg.operators[1].opcode,
248+
tflite.BuiltinOperator.FULLY_CONNECTED)
249+
self.assertEqual(sg.operators[2].opcode, tflite.BuiltinOperator.CUSTOM)
250+
self.assertEqual(sg.operators[3].opcode,
251+
tflite.BuiltinOperator.FULLY_CONNECTED)
247252

248-
decode_op = sg.operators[0]
253+
decode_op1 = sg.operators[0]
249254
fc_op1 = sg.operators[1]
250-
fc_op2 = sg.operators[2]
251-
252-
# Both FCs should use DECODE's output
253-
self.assertIs(fc_op1.inputs[1], decode_op.outputs[0])
254-
self.assertIs(fc_op2.inputs[1], decode_op.outputs[0])
255+
decode_op2 = sg.operators[2]
256+
fc_op2 = sg.operators[3]
257+
258+
# Each FC should use its own DECODE's output
259+
self.assertIs(fc_op1.inputs[1], decode_op1.outputs[0])
260+
self.assertIs(fc_op2.inputs[1], decode_op2.outputs[0])
261+
# The two DECODEs should have different outputs
262+
self.assertIsNot(decode_op1.outputs[0], decode_op2.outputs[0])
255263

256264
def test_ancillary_tensor_contains_dcm(self):
257265
"""Ancillary tensor data contains valid DCM header."""

0 commit comments

Comments
 (0)