Skip to content

Commit

Permalink
Fix a wire_matrix bug with single-element wire_matrices of Inputs/Reg…
Browse files Browse the repository at this point in the history
…isters. Add tests.

The logic that decides whether to slice or concatenate did not work in this
case. See the new comment in wire_matrix's constructor for details.
  • Loading branch information
fdxmw committed May 31, 2024
1 parent bc48bc9 commit b724196
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 3 deletions.
23 changes: 22 additions & 1 deletion pyrtl/helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,14 +1630,35 @@ def __init__(self, name: str = '', block: Block = None,
name=component_name, bitwidth=self._component_bitwidth,
type=component_schema))

# By default, slice the concatenated value into components iff
# exactly one value is provided.
slicing = len(values) == 1

# Handle Input and Register special cases.
if concatenated_type is Input or concatenated_type is Register:
# Slice the concatenated value. Override the default 'slicing'
# because 'values' is empty when slicing a concatenated Input
# or Register.
#
# Note that we can't just check len(values) == 1 after we set
# values to [None] because that doesn't work when there is only
# one element in the wire_matrix. We must distinguish between:
#
# 1. Slicing values to produce values[0] (this case).
# 2. Concatenating values[0] to produce values (next case).
#
# But len(values) == 1 in both cases. The slice in (1) and
# concatenate in (2) are both no-ops, but we have to get the
# direction right. In the first case, values[0] is driven by
# values, and in the second case, values is driven by
# values[0].
slicing = True
values = [None]
elif component_type is Input or component_type is Register:
values = [None for _ in range(self._size)]

self._components = [None for i in range(len(schema))]
if len(values) == 1:
if slicing:
# Concatenated value was provided. Slice it into components.
_slice(block=block, schema=schema, bitwidth=self._bitwidth,
component_type=component_type, name=name,
Expand Down
48 changes: 46 additions & 2 deletions tests/test_helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,11 +1549,15 @@ def test_anonymous_pixel_concatenate(self):
BitPair = pyrtl.wire_matrix(component_schema=1, size=2)


# Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be a
# component of a wire_matrix (Word).
# Word is an array of two Bytes. This checks that a @wire_struct (Byte) can be
# a component of a wire_matrix (Word).
Word = pyrtl.wire_matrix(component_schema=Byte, size=2)


# ByteMatrix tests the corner case of a single-element wire_matrix.
ByteMatrix = pyrtl.wire_matrix(component_schema=Byte, size=1)


# DWord is an array of two Words, or effectively a 2x2 array of Bytes. This
# check that a wire_matrix (Word) can be a component of a wire_matrix (DWord).
DWord = pyrtl.wire_matrix(component_schema=Word, size=2)
Expand Down Expand Up @@ -1727,6 +1731,46 @@ def test_wire_matrix_composition_cached_data(self):
self.assertEqual(sim.inspect('cached_data.data[1].high'), 0xC)
self.assertEqual(sim.inspect('cached_data.data[1].low'), 0xD)

def test_byte_matrix_const(self):
byte_matrix = ByteMatrix(name='byte_matrix', values=[0xAB])
self.assertEqual(len(byte_matrix), 1)
self.assertEqual(byte_matrix.bitwidth, 8)
self.assertEqual(len(pyrtl.as_wires(byte_matrix)), 8)

# Constants are sliced immediately.
self.assertEqual(byte_matrix.val, 0xAB)
self.assertEqual(byte_matrix[0].val, 0xAB)
self.assertEqual(byte_matrix[0].high.val, 0xA)
self.assertEqual(byte_matrix[0].low.val, 0xB)

def test_byte_matrix_input_slice(self):
byte_matrix = ByteMatrix(name='byte_matrix',
component_type=pyrtl.Input)

self.assertTrue(isinstance(pyrtl.as_wires(byte_matrix[0]),
pyrtl.Input))

sim = pyrtl.Simulation()
sim.step(provided_inputs={'byte_matrix[0]': 0xAB})
self.assertEqual(sim.inspect('byte_matrix'), 0xAB)
self.assertEqual(sim.inspect('byte_matrix[0]'), 0xAB)
self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA)
self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB)

def test_byte_matrix_input_concatenate(self):
byte_matrix = ByteMatrix(name='byte_matrix',
concatenated_type=pyrtl.Input)

self.assertTrue(isinstance(pyrtl.as_wires(byte_matrix),
pyrtl.Input))

sim = pyrtl.Simulation()
sim.step(provided_inputs={'byte_matrix': 0xAB})
self.assertEqual(sim.inspect('byte_matrix'), 0xAB)
self.assertEqual(sim.inspect('byte_matrix[0]'), 0xAB)
self.assertEqual(sim.inspect('byte_matrix[0].high'), 0xA)
self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB)


if __name__ == "__main__":
unittest.main()

0 comments on commit b724196

Please sign in to comment.