From b724196adc64558bec8e0df2eb2972992053ec0c Mon Sep 17 00:00:00 2001 From: Jeremy Lau <30300826+fdxmw@users.noreply.github.com> Date: Fri, 31 May 2024 15:45:41 -0700 Subject: [PATCH] Fix a wire_matrix bug with single-element wire_matrices of Inputs/Registers. Add tests. The logic that decides whether to slice or concatenate did not work in this case. See the new comment in wire_matrix's constructor for details. --- pyrtl/helperfuncs.py | 23 ++++++++++++++++++- tests/test_helperfuncs.py | 48 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 301667c2..728dc2b3 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -1630,14 +1630,35 @@ def __init__(self, name: str = '', block: Block = None, name=component_name, bitwidth=self._component_bitwidth, type=component_schema)) + # By default, slice the concatenated value into components iff + # exactly one value is provided. + slicing = len(values) == 1 + # Handle Input and Register special cases. if concatenated_type is Input or concatenated_type is Register: + # Slice the concatenated value. Override the default 'slicing' + # because 'values' is empty when slicing a concatenated Input + # or Register. + # + # Note that we can't just check len(values) == 1 after we set + # values to [None] because that doesn't work when there is only + # one element in the wire_matrix. We must distinguish between: + # + # 1. Slicing values to produce values[0] (this case). + # 2. Concatenating values[0] to produce values (next case). + # + # But len(values) == 1 in both cases. The slice in (1) and + # concatenate in (2) are both no-ops, but we have to get the + # direction right. In the first case, values[0] is driven by + # values, and in the second case, values is driven by + # values[0]. + slicing = True values = [None] elif component_type is Input or component_type is Register: values = [None for _ in range(self._size)] self._components = [None for i in range(len(schema))] - if len(values) == 1: + if slicing: # Concatenated value was provided. Slice it into components. _slice(block=block, schema=schema, bitwidth=self._bitwidth, component_type=component_type, name=name, diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 3f228a3f..4d9e75ef 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1549,11 +1549,15 @@ def test_anonymous_pixel_concatenate(self): BitPair = pyrtl.wire_matrix(component_schema=1, size=2) -# Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be a -# component of a wire_matrix (Word). +# Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be +# a component of a wire_matrix (Word). Word = pyrtl.wire_matrix(component_schema=Byte, size=2) +# ByteMatrix tests the corner case of a single-element wire_matrix. +ByteMatrix = pyrtl.wire_matrix(component_schema=Byte, size=1) + + # DWord is an array of two Words, or effectively a 2x2 array of Bytes. This # check that a wire_matrix (Word) can be a component of a wire_matrix (DWord). DWord = pyrtl.wire_matrix(component_schema=Word, size=2) @@ -1727,6 +1731,46 @@ def test_wire_matrix_composition_cached_data(self): self.assertEqual(sim.inspect('cached_data.data[1].high'), 0xC) self.assertEqual(sim.inspect('cached_data.data[1].low'), 0xD) + def test_byte_matrix_const(self): + byte_matrix = ByteMatrix(name='byte_matrix', values=[0xAB]) + self.assertEqual(len(byte_matrix), 1) + self.assertEqual(byte_matrix.bitwidth, 8) + self.assertEqual(len(pyrtl.as_wires(byte_matrix)), 8) + + # Constants are sliced immediately. + self.assertEqual(byte_matrix.val, 0xAB) + self.assertEqual(byte_matrix[0].val, 0xAB) + self.assertEqual(byte_matrix[0].high.val, 0xA) + self.assertEqual(byte_matrix[0].low.val, 0xB) + + def test_byte_matrix_input_slice(self): + byte_matrix = ByteMatrix(name='byte_matrix', + component_type=pyrtl.Input) + + self.assertTrue(isinstance(pyrtl.as_wires(byte_matrix[0]), + pyrtl.Input)) + + sim = pyrtl.Simulation() + sim.step(provided_inputs={'byte_matrix[0]': 0xAB}) + self.assertEqual(sim.inspect('byte_matrix'), 0xAB) + self.assertEqual(sim.inspect('byte_matrix[0]'), 0xAB) + self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA) + self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB) + + def test_byte_matrix_input_concatenate(self): + byte_matrix = ByteMatrix(name='byte_matrix', + concatenated_type=pyrtl.Input) + + self.assertTrue(isinstance(pyrtl.as_wires(byte_matrix), + pyrtl.Input)) + + sim = pyrtl.Simulation() + sim.step(provided_inputs={'byte_matrix': 0xAB}) + self.assertEqual(sim.inspect('byte_matrix'), 0xAB) + self.assertEqual(sim.inspect('byte_matrix[0]'), 0xAB) + self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA) + self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB) + if __name__ == "__main__": unittest.main()