diff --git a/magma/bit.py b/magma/bit.py index e31bc9eca..6731d7ff3 100644 --- a/magma/bit.py +++ b/magma/bit.py @@ -3,21 +3,18 @@ * Subtype of the Digital type * Implementation of hwtypes.AbstractBit """ -import keyword import typing as tp import functools import hwtypes as ht from hwtypes.bit_vector_abc import AbstractBit, TypeFamily -from .t import Direction +from .t import Direction, Type from .digital import Digital, DigitalMeta from .digital import VCC, GND # TODO(rsetaluri): only here for b.c. from magma.compatibility import IntegerTypes from magma.debug import debug_wire from magma.family import get_family -from magma.interface import IO -from magma.language_utils import primitive_to_python -from magma.protocol_type import magma_type, MagmaProtocol +from magma.protocol_type import magma_value from magma.operator_utils import output_only @@ -113,11 +110,20 @@ def ite(self, t_branch, f_branch): @debug_wire def wire(self, o, debug_info): + o = magma_value(o) # Cast to Bit here so we don't get a Digital instead if isinstance(o, (IntegerTypes, bool, ht.Bit)): o = Bit(o) + if type(o).is_bits_1(): + o = o[0] return super().wire(o, debug_info) + @classmethod + def is_wireable(cls, rhs): + if issubclass(rhs, Type) and rhs.is_bits_1(): + return True + return DigitalMeta.is_wireable(cls, rhs) + BitIn = Bit[Direction.In] BitOut = Bit[Direction.Out] diff --git a/magma/bits.py b/magma/bits.py index 4fbd6a52b..fdd22ba6e 100644 --- a/magma/bits.py +++ b/magma/bits.py @@ -200,10 +200,15 @@ def is_wireable(cls, rhs): return True if issubclass(cls, UInt) and issubclass(rhs, SInt): return False - elif issubclass(cls, SInt) and issubclass(rhs, UInt): + if issubclass(cls, SInt) and issubclass(rhs, UInt): return False + if len(cls) == 1 and issubclass(rhs, Bit): + return True return super().is_wireable(rhs) + def is_bits_1(cls): + return len(cls) == 1 + class Bits(Array, AbstractBitVector, metaclass=BitsMeta): __hash__ = Array.__hash__ @@ -237,6 +242,7 @@ def __int__(self): @debug_wire def wire(self, other, debug_info): + from .conversions import bits if isinstance(other, (IntegerTypes, BitVector)): N = (other.bit_length() if isinstance(other, IntegerTypes) @@ -245,8 +251,9 @@ def wire(self, other, debug_info): raise ValueError( f"Cannot convert integer {other} " f"(bit_length={other.bit_length()}) to Bits ({len(self)})") - from .conversions import bits other = bits(other, len(self)) + if isinstance(other, Bit) and len(self) == 1: + other = bits(other, 1) super().wire(other, debug_info) @classmethod diff --git a/magma/t.py b/magma/t.py index bf439512c..5d045813d 100644 --- a/magma/t.py +++ b/magma/t.py @@ -177,6 +177,9 @@ def undirected_t(cls): def is_directed(cls): return cls is not cls.qualify(Direction.Undirected) + def is_bits_1(self): + return False + @lru_cache() def In(T): diff --git a/tests/test_circuit/test_new_style_syntax.py b/tests/test_circuit/test_new_style_syntax.py index 33acaccff..12723607a 100644 --- a/tests/test_circuit/test_new_style_syntax.py +++ b/tests/test_circuit/test_new_style_syntax.py @@ -99,7 +99,7 @@ def definition(io): def test_defn_wiring_error(caplog): class _Foo(m.Circuit): - io = m.IO(I=m.In(m.Bit), O=m.In(m.Bit), O1=m.Out(m.Bits[1])) + io = m.IO(I=m.In(m.Bit), O=m.In(m.Bit), O1=m.Out(m.Bits[2])) m.wire(io.I, io.O) m.wire(io.I, io.O1) @@ -108,13 +108,13 @@ class _Foo(m.Circuit): assert has_error(caplog, "Cannot wire _Foo.I (Out(Bit)) to _Foo.O (Out(Bit))") assert has_error(caplog, - "Cannot wire _Foo.I (Out(Bit)) to _Foo.O1 (In(Bits[1]))") + "Cannot wire _Foo.I (Out(Bit)) to _Foo.O1 (In(Bits[2]))") @wrap_with_context_manager(logging_level("DEBUG")) def test_inst_wiring_error(caplog): class _Bar(m.Circuit): - io = m.IO(I=m.In(m.Bits[1]), O=m.Out(m.Bits[1])) + io = m.IO(I=m.In(m.Bits[2]), O=m.Out(m.Bits[2])) class _Foo(m.Circuit): io = m.IO(I=m.In(m.Bit), O=m.Out(m.Bit)) @@ -125,10 +125,10 @@ class _Foo(m.Circuit): assert has_error( caplog, - "Cannot wire _Foo.I (Out(Bit)) to _Foo._Bar_inst0.I (In(Bits[1]))") + "Cannot wire _Foo.I (Out(Bit)) to _Foo._Bar_inst0.I (In(Bits[2]))") assert has_error( caplog, - "Cannot wire _Foo._Bar_inst0.O (Out(Bits[1])) to _Foo.O (In(Bit))") + "Cannot wire _Foo._Bar_inst0.O (Out(Bits[2])) to _Foo.O (In(Bit))") assert has_error(caplog, "_Foo.O not driven") assert has_debug(caplog, "_Foo.O: Unconnected") assert has_error(caplog, "_Foo._Bar_inst0.I not driven") diff --git a/tests/test_wire/test_wireable.py b/tests/test_wire/test_wireable.py index 4735b3311..67f16086e 100644 --- a/tests/test_wire/test_wireable.py +++ b/tests/test_wire/test_wireable.py @@ -1,3 +1,4 @@ +import pytest import magma as m @@ -20,3 +21,17 @@ class Main2(m.Circuit): Cannot wire Main2.a (Out(SInt[16])) to Main2.b (In(UInt[16]))\ """ assert caplog.messages[1][-len(expected):] == expected + + +@pytest.mark.parametrize('Ts', [ + (m.Bit, m.Bits[1]), + (m.Bits[1], m.Bit), +]) +def test_bit_bits1(Ts): + class Main(m.Circuit): + io = m.IO(a=m.In(Ts[0]), b=m.Out(Ts[1])) + io.b @= io.a + + # NOTE: We call compile here to ensure a wiring error was not reported + # (otherwise it would raise an exception) + m.compile('build/Main', Main)