Skip to content

Commit ed35ee0

Browse files
committed
test(compression): add alt decompression memory integration test
Add test for shared compressed tensors with alternate decompression memory. The test is marked expectedFailure to document the current mismatch between interpreter and DECODE insertion: the interpreter's alt decompression memory resets allocations for each DECODE, but the insertion code shares one DECODE output among all consumers. The workaround is to insert a separate DECODE before each consumer. The expectedFailure decorator should be removed once this is implemented.
1 parent 0e59927 commit ed35ee0

File tree

1 file changed

+179
-0
lines changed

1 file changed

+179
-0
lines changed

tensorflow/lite/micro/compression/compression_integration_test.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,185 @@ def test_lut_compressed_model_is_smaller(self):
176176
f"original ({original_size} bytes)")
177177

178178

179+
def _build_shared_weights_model():
180+
"""Build a model where one compressed tensor is shared between two operators.
181+
182+
Model structure:
183+
input1 -> [FC1 with weights1] -> output1
184+
input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2
185+
186+
weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which
187+
runs between the two consumers of weights1.
188+
"""
189+
# 4 unique values per tensor for 2-bit LUT compression. Small values avoid
190+
# saturation in chained layers. Different row sums produce varied outputs.
191+
weights1_data = np.array([
192+
[-1, 0, 0, 1],
193+
[-1, 0, 1, 1],
194+
[-1, 1, 1, 1],
195+
[0, 1, 1, 1],
196+
], dtype=np.int8)
197+
weights1 = model_editor.Tensor(
198+
shape=(4, 4),
199+
dtype=tflite.TensorType.INT8,
200+
data=weights1_data,
201+
name="weights1",
202+
quantization=model_editor.Quantization(scales=1.0, zero_points=0),
203+
)
204+
205+
weights2_data = np.array([
206+
[1, 1, 1, 1],
207+
[1, 1, 2, 2],
208+
[1, 2, 2, 3],
209+
[2, 2, 3, 3],
210+
], dtype=np.int8)
211+
weights2 = model_editor.Tensor(
212+
shape=(4, 4),
213+
dtype=tflite.TensorType.INT8,
214+
data=weights2_data,
215+
name="weights2",
216+
quantization=model_editor.Quantization(scales=1.0, zero_points=0),
217+
)
218+
219+
# All tensors need matching quantization for FULLY_CONNECTED
220+
quant = model_editor.Quantization(scales=1.0, zero_points=0)
221+
222+
input1 = model_editor.Tensor(
223+
shape=(1, 4),
224+
dtype=tflite.TensorType.INT8,
225+
name="input1",
226+
quantization=quant,
227+
)
228+
input2 = model_editor.Tensor(
229+
shape=(1, 4),
230+
dtype=tflite.TensorType.INT8,
231+
name="input2",
232+
quantization=quant,
233+
)
234+
output1 = model_editor.Tensor(
235+
shape=(1, 4),
236+
dtype=tflite.TensorType.INT8,
237+
name="output1",
238+
quantization=quant,
239+
)
240+
intermediate = model_editor.Tensor(
241+
shape=(1, 4),
242+
dtype=tflite.TensorType.INT8,
243+
name="intermediate",
244+
quantization=quant,
245+
)
246+
output2 = model_editor.Tensor(
247+
shape=(1, 4),
248+
dtype=tflite.TensorType.INT8,
249+
name="output2",
250+
quantization=quant,
251+
)
252+
253+
model = model_editor.Model(subgraphs=[
254+
model_editor.Subgraph(
255+
tensors=[weights1, weights2],
256+
inputs=[input1, input2],
257+
outputs=[output1, output2],
258+
operators=[
259+
# FC1: uses weights1
260+
model_editor.Operator(
261+
opcode=tflite.BuiltinOperator.FULLY_CONNECTED,
262+
inputs=[input1, weights1],
263+
outputs=[output1],
264+
),
265+
# FC2: uses weights2 (runs between FC1 and FC3)
266+
model_editor.Operator(
267+
opcode=tflite.BuiltinOperator.FULLY_CONNECTED,
268+
inputs=[input2, weights2],
269+
outputs=[intermediate],
270+
),
271+
# FC3: uses weights1 (second consumer, after DECODE(weights2))
272+
model_editor.Operator(
273+
opcode=tflite.BuiltinOperator.FULLY_CONNECTED,
274+
inputs=[intermediate, weights1],
275+
outputs=[output2],
276+
),
277+
],
278+
)
279+
])
280+
return model.build()
281+
282+
283+
class AltDecompressionMemoryTest(tf.test.TestCase):
284+
"""Tests for alternate decompression memory with shared compressed tensors.
285+
286+
These tests verify correct behavior when compressed tensors are shared
287+
between multiple operators and alternate decompression memory is enabled.
288+
"""
289+
290+
@unittest.expectedFailure
291+
def test_shared_compressed_tensor_with_alt_memory(self):
292+
"""Verify correct results when a shared compressed tensor is used with alt
293+
decompression memory.
294+
295+
This test uses a graph where a compressed tensor (weights1) is consumed by
296+
two operators (FC1 and FC3), with an intervening DECODE of a different
297+
compressed tensor (weights2) between them.
298+
299+
The interpreter's alternate decompression memory has a limitation: each
300+
DECODE's Prepare resets the allocation offset to zero. This means all
301+
DECODE outputs are allocated at the same address, so they overwrite each
302+
other. A DECODE output can only be used until the next DECODE runs.
303+
304+
To work around this limitation, the DECODE insertion code must insert a
305+
separate DECODE immediately before each consumer of a compressed tensor,
306+
rather than sharing one DECODE output among all consumers.
307+
308+
This test is expected to fail because the current insertion code does not
309+
yet implement this workaround.
310+
"""
311+
flatbuffer = _build_shared_weights_model()
312+
313+
specs = [
314+
spec.Tensor(
315+
subgraph=0,
316+
tensor=0, # weights1
317+
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
318+
),
319+
spec.Tensor(
320+
subgraph=0,
321+
tensor=1, # weights2
322+
compression=[spec.LookUpTableCompression(index_bitwidth=2)],
323+
),
324+
]
325+
326+
compressed_fb = compress.compress(flatbuffer, specs)
327+
328+
# Run without alt decompression memory (baseline)
329+
interp_no_alt = runtime.Interpreter.from_bytes(bytes(compressed_fb))
330+
331+
# Run with alt decompression memory
332+
interp_with_alt = runtime.Interpreter.from_bytes(
333+
bytes(compressed_fb),
334+
alt_decompression_memory_size=256,
335+
)
336+
337+
test_input1 = np.array([[1, 1, 1, 1]], dtype=np.int8)
338+
test_input2 = np.array([[1, 1, 1, 1]], dtype=np.int8)
339+
340+
interp_no_alt.set_input(test_input1, 0)
341+
interp_no_alt.set_input(test_input2, 1)
342+
interp_no_alt.invoke()
343+
expected1 = interp_no_alt.get_output(0)
344+
expected2 = interp_no_alt.get_output(1)
345+
346+
interp_with_alt.set_input(test_input1, 0)
347+
interp_with_alt.set_input(test_input2, 1)
348+
interp_with_alt.invoke()
349+
actual1 = interp_with_alt.get_output(0)
350+
actual2 = interp_with_alt.get_output(1)
351+
352+
self.assertAllEqual(expected1, actual1,
353+
"Output 1 mismatch with alt decompression memory")
354+
self.assertAllEqual(expected2, actual2,
355+
"Output 2 mismatch with alt decompression memory")
356+
357+
179358
class HuffmanCompressionTest(tf.test.TestCase):
180359
"""Integration tests for Huffman compression."""
181360

0 commit comments

Comments
 (0)