@@ -176,6 +176,185 @@ def test_lut_compressed_model_is_smaller(self):
176176 f"original ({ original_size } bytes)" )
177177
178178
179+ def _build_shared_weights_model ():
180+ """Build a model where one compressed tensor is shared between two operators.
181+
182+ Model structure:
183+ input1 -> [FC1 with weights1] -> output1
184+ input2 -> [FC2 with weights2] -> intermediate -> [FC3 with weights1] -> output2
185+
186+ weights1 is shared between FC1 and FC3. weights2 is used only by FC2, which
187+ runs between the two consumers of weights1.
188+ """
189+ # 4 unique values per tensor for 2-bit LUT compression. Small values avoid
190+ # saturation in chained layers. Different row sums produce varied outputs.
191+ weights1_data = np .array ([
192+ [- 1 , 0 , 0 , 1 ],
193+ [- 1 , 0 , 1 , 1 ],
194+ [- 1 , 1 , 1 , 1 ],
195+ [0 , 1 , 1 , 1 ],
196+ ], dtype = np .int8 )
197+ weights1 = model_editor .Tensor (
198+ shape = (4 , 4 ),
199+ dtype = tflite .TensorType .INT8 ,
200+ data = weights1_data ,
201+ name = "weights1" ,
202+ quantization = model_editor .Quantization (scales = 1.0 , zero_points = 0 ),
203+ )
204+
205+ weights2_data = np .array ([
206+ [1 , 1 , 1 , 1 ],
207+ [1 , 1 , 2 , 2 ],
208+ [1 , 2 , 2 , 3 ],
209+ [2 , 2 , 3 , 3 ],
210+ ], dtype = np .int8 )
211+ weights2 = model_editor .Tensor (
212+ shape = (4 , 4 ),
213+ dtype = tflite .TensorType .INT8 ,
214+ data = weights2_data ,
215+ name = "weights2" ,
216+ quantization = model_editor .Quantization (scales = 1.0 , zero_points = 0 ),
217+ )
218+
219+ # All tensors need matching quantization for FULLY_CONNECTED
220+ quant = model_editor .Quantization (scales = 1.0 , zero_points = 0 )
221+
222+ input1 = model_editor .Tensor (
223+ shape = (1 , 4 ),
224+ dtype = tflite .TensorType .INT8 ,
225+ name = "input1" ,
226+ quantization = quant ,
227+ )
228+ input2 = model_editor .Tensor (
229+ shape = (1 , 4 ),
230+ dtype = tflite .TensorType .INT8 ,
231+ name = "input2" ,
232+ quantization = quant ,
233+ )
234+ output1 = model_editor .Tensor (
235+ shape = (1 , 4 ),
236+ dtype = tflite .TensorType .INT8 ,
237+ name = "output1" ,
238+ quantization = quant ,
239+ )
240+ intermediate = model_editor .Tensor (
241+ shape = (1 , 4 ),
242+ dtype = tflite .TensorType .INT8 ,
243+ name = "intermediate" ,
244+ quantization = quant ,
245+ )
246+ output2 = model_editor .Tensor (
247+ shape = (1 , 4 ),
248+ dtype = tflite .TensorType .INT8 ,
249+ name = "output2" ,
250+ quantization = quant ,
251+ )
252+
253+ model = model_editor .Model (subgraphs = [
254+ model_editor .Subgraph (
255+ tensors = [weights1 , weights2 ],
256+ inputs = [input1 , input2 ],
257+ outputs = [output1 , output2 ],
258+ operators = [
259+ # FC1: uses weights1
260+ model_editor .Operator (
261+ opcode = tflite .BuiltinOperator .FULLY_CONNECTED ,
262+ inputs = [input1 , weights1 ],
263+ outputs = [output1 ],
264+ ),
265+ # FC2: uses weights2 (runs between FC1 and FC3)
266+ model_editor .Operator (
267+ opcode = tflite .BuiltinOperator .FULLY_CONNECTED ,
268+ inputs = [input2 , weights2 ],
269+ outputs = [intermediate ],
270+ ),
271+ # FC3: uses weights1 (second consumer, after DECODE(weights2))
272+ model_editor .Operator (
273+ opcode = tflite .BuiltinOperator .FULLY_CONNECTED ,
274+ inputs = [intermediate , weights1 ],
275+ outputs = [output2 ],
276+ ),
277+ ],
278+ )
279+ ])
280+ return model .build ()
281+
282+
283+ class AltDecompressionMemoryTest (tf .test .TestCase ):
284+ """Tests for alternate decompression memory with shared compressed tensors.
285+
286+ These tests verify correct behavior when compressed tensors are shared
287+ between multiple operators and alternate decompression memory is enabled.
288+ """
289+
290+ @unittest .expectedFailure
291+ def test_shared_compressed_tensor_with_alt_memory (self ):
292+ """Verify correct results when a shared compressed tensor is used with alt
293+ decompression memory.
294+
295+ This test uses a graph where a compressed tensor (weights1) is consumed by
296+ two operators (FC1 and FC3), with an intervening DECODE of a different
297+ compressed tensor (weights2) between them.
298+
299+ The interpreter's alternate decompression memory has a limitation: each
300+ DECODE's Prepare resets the allocation offset to zero. This means all
301+ DECODE outputs are allocated at the same address, so they overwrite each
302+ other. A DECODE output can only be used until the next DECODE runs.
303+
304+ To work around this limitation, the DECODE insertion code must insert a
305+ separate DECODE immediately before each consumer of a compressed tensor,
306+ rather than sharing one DECODE output among all consumers.
307+
308+ This test is expected to fail because the current insertion code does not
309+ yet implement this workaround.
310+ """
311+ flatbuffer = _build_shared_weights_model ()
312+
313+ specs = [
314+ spec .Tensor (
315+ subgraph = 0 ,
316+ tensor = 0 , # weights1
317+ compression = [spec .LookUpTableCompression (index_bitwidth = 2 )],
318+ ),
319+ spec .Tensor (
320+ subgraph = 0 ,
321+ tensor = 1 , # weights2
322+ compression = [spec .LookUpTableCompression (index_bitwidth = 2 )],
323+ ),
324+ ]
325+
326+ compressed_fb = compress .compress (flatbuffer , specs )
327+
328+ # Run without alt decompression memory (baseline)
329+ interp_no_alt = runtime .Interpreter .from_bytes (bytes (compressed_fb ))
330+
331+ # Run with alt decompression memory
332+ interp_with_alt = runtime .Interpreter .from_bytes (
333+ bytes (compressed_fb ),
334+ alt_decompression_memory_size = 256 ,
335+ )
336+
337+ test_input1 = np .array ([[1 , 1 , 1 , 1 ]], dtype = np .int8 )
338+ test_input2 = np .array ([[1 , 1 , 1 , 1 ]], dtype = np .int8 )
339+
340+ interp_no_alt .set_input (test_input1 , 0 )
341+ interp_no_alt .set_input (test_input2 , 1 )
342+ interp_no_alt .invoke ()
343+ expected1 = interp_no_alt .get_output (0 )
344+ expected2 = interp_no_alt .get_output (1 )
345+
346+ interp_with_alt .set_input (test_input1 , 0 )
347+ interp_with_alt .set_input (test_input2 , 1 )
348+ interp_with_alt .invoke ()
349+ actual1 = interp_with_alt .get_output (0 )
350+ actual2 = interp_with_alt .get_output (1 )
351+
352+ self .assertAllEqual (expected1 , actual1 ,
353+ "Output 1 mismatch with alt decompression memory" )
354+ self .assertAllEqual (expected2 , actual2 ,
355+ "Output 2 mismatch with alt decompression memory" )
356+
357+
179358class HuffmanCompressionTest (tf .test .TestCase ):
180359 """Integration tests for Huffman compression."""
181360
0 commit comments