From 62dcce758feab303c25120fd8cab0d5b13d6094c Mon Sep 17 00:00:00 2001 From: Lenny Truong Date: Thu, 24 Sep 2020 14:27:54 -0700 Subject: [PATCH] Refactor implicit coercion logic Provides reuseable decorator and function to handle magma's implicit coercion of Python values. For now, I moved the logic from wiring and mux to a shared place to ensure consistency. There's probably other places to do this, but we can do this on demand as we enounter incosistencies. One issue is that `int` coercion can't really happen up front like this since it is determined by the place it's used (i.e. when it's wired to a Bits we check it can fit in the bit width). @rsetaluri perhaps we could use SmartBits for this? (convert it to a smartbits that's automatically extended to wire to a regular bits, but I think right now our rules don't allow implicit truncation, but we could define a special SmartBits subclass that only extends. My one concern would be whether this would make the error a bit harder to traceback) Fixes https://github.com/phanrahan/magma/issues/828 --- magma/bits.py | 6 ++--- magma/coerce.py | 39 ++++++++++++++++++++++++++++ magma/conversions.py | 7 ++--- magma/digital.py | 2 +- magma/primitives/mux.py | 20 +++++--------- magma/wire.py | 13 ++-------- tests/test_errors/test_mux_errors.py | 2 +- tests/test_primitives/test_mux.py | 2 +- 8 files changed, 57 insertions(+), 34 deletions(-) create mode 100644 magma/coerce.py diff --git a/magma/bits.py b/magma/bits.py index 2e5dbabcd..c816f6a3e 100644 --- a/magma/bits.py +++ b/magma/bits.py @@ -149,10 +149,8 @@ def __int__(self): @debug_wire def wire(self, other, debug_info): - if isinstance(other, (IntegerTypes, BitVector)): - N = (other.bit_length() - if isinstance(other, IntegerTypes) - else len(other)) + if isinstance(other, IntegerTypes): + N = other.bit_length() if N > len(self): raise ValueError( f"Cannot convert integer {other} " diff --git a/magma/coerce.py b/magma/coerce.py new file mode 100644 index 000000000..142fa0537 --- /dev/null +++ b/magma/coerce.py @@ -0,0 +1,39 @@ +from functools import wraps + +import hwtypes as ht + +from magma.protocol_type import MagmaProtocol +from magma.debug import debug_info + + +def python_to_magma_coerce(value): + if isinstance(value, debug_info): + # Short circuit tuple converion + return value + + # Circular import + from magma.conversions import tuple_, sint, uint, bits, bit + if isinstance(value, tuple): + return tuple_(value) + if isinstance(value, ht.SIntVector): + return sint(value, len(value)) + if isinstance(value, ht.UIntVector): + return uint(value, len(value)) + if isinstance(value, ht.BitVector): + return bits(value, len(value)) + if isinstance(value, (bool, ht.Bit)): + return bit(value) + + if isinstance(value, MagmaProtocol): + return value._get_magma_value_() + + return value + + +def python_to_magma_coerce_wrapper(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + args = (python_to_magma_coerce(a) for a in args) + kwargs = {k: python_to_magma_coerce(v) for k, v in kwargs.items()} + return fn(*args, **kwargs) + return wrapper diff --git a/magma/conversions.py b/magma/conversions.py index 790332515..65ac97aa9 100644 --- a/magma/conversions.py +++ b/magma/conversions.py @@ -31,7 +31,8 @@ __all__ += ['as_bits', 'from_bits'] def can_convert_to_bit(value): - return isinstance(magma_value(value), (Digital, Array, Tuple, IntegerTypes)) + return isinstance(magma_value(value), (Digital, Array, Tuple, IntegerTypes, + ht.Bit)) def can_convert_to_bit_type(value): @@ -59,9 +60,9 @@ def convertbit(value, totype): "bit can only be used on arrays and tuples of bits" f"; not {type(value)}") - assert isinstance(value, (IntegerTypes, Digital)) + assert isinstance(value, (IntegerTypes, Digital, ht.Bit)) - if isinstance(value, IntegerTypes): + if isinstance(value, (IntegerTypes, ht.Bit)): value = totype(1) if value else totype(0) if value.is_input(): diff --git a/magma/digital.py b/magma/digital.py index da9f5ab42..148382796 100644 --- a/magma/digital.py +++ b/magma/digital.py @@ -165,7 +165,7 @@ def wire(self, o, debug_info): i = self o = magma_value(o) # promote integer types to LOW/HIGH - if isinstance(o, (IntegerTypes, bool, ht.Bit)): + if isinstance(o, IntegerTypes): o = HIGH if o else LOW if not isinstance(o, Digital): diff --git a/magma/primitives/mux.py b/magma/primitives/mux.py index 9db247bd8..c67eea5c4 100644 --- a/magma/primitives/mux.py +++ b/magma/primitives/mux.py @@ -5,6 +5,7 @@ from magma.bits import Bits, UInt, SInt from magma.bitutils import clog2, seq2int from magma.circuit import coreir_port_mapping +from magma.coerce import python_to_magma_coerce from magma.generator import Generator2 from magma.interface import IO from magma.protocol_type import MagmaProtocol, magma_type @@ -87,22 +88,14 @@ def _infer_mux_type(args): """ T = None for arg in args: - if isinstance(arg, (Type, MagmaProtocol)): - next_T = type(arg).qualify(Direction.Undirected) - elif isinstance(arg, UIntVector): - next_T = UInt[len(arg)] - elif isinstance(arg, SIntVector): - next_T = SInt[len(arg)] - elif isinstance(arg, BitVector): - next_T = Bits[len(arg)] - elif isinstance(arg, (ht.Bit, bool)): - next_T = Bit - elif isinstance(arg, tuple): - next_T = type(tuple_(arg)) - elif isinstance(arg, int): + if isinstance(arg, int): # Cannot infer type without width, use wiring implicit coercion to # handle (or raise type error there) continue + if not isinstance(arg, (Type, MagmaProtocol)): + raise TypeError(f"Found unsupport argument {arg} of type" + f" {type(arg)}") + next_T = type(arg).qualify(Direction.Undirected) if T is not None: if issubclass(T, next_T): @@ -144,6 +137,7 @@ def mux(I: list, S, **kwargs): S = seq2int(S.bits()) if isinstance(S, int): return I[S] + I = tuple(python_to_magma_coerce(i) for i in I) T, I = _infer_mux_type(I) inst = Mux(len(I), T, **kwargs)() if len(I) == 2 and isinstance(S, Bits[1]): diff --git a/magma/wire.py b/magma/wire.py index cddbfd7af..5b18ffc8e 100644 --- a/magma/wire.py +++ b/magma/wire.py @@ -4,8 +4,8 @@ from .wire_container import Wire # TODO(rsetaluri): only here for b.c. from .debug import debug_wire from .logging import root_logger -from .protocol_type import magma_value +from magma.coerce import python_to_magma_coerce_wrapper from magma.wire_container import WiringLog @@ -15,18 +15,9 @@ _CONSTANTS = (IntegerTypes, BitVector, Bit) +@python_to_magma_coerce_wrapper @debug_wire def wire(o, i, debug_info=None): - o = magma_value(o) - i = magma_value(i) - - # Circular import - from .conversions import tuple_ - if isinstance(o, tuple): - o = tuple_(o) - if isinstance(i, tuple): - i = tuple_(i) - # Wire(o, Circuit). if hasattr(i, 'interface'): i.wire(o, debug_info) diff --git a/tests/test_errors/test_mux_errors.py b/tests/test_errors/test_mux_errors.py index 7d38b7554..1c0b6fb0d 100644 --- a/tests/test_errors/test_mux_errors.py +++ b/tests/test_errors/test_mux_errors.py @@ -10,7 +10,7 @@ class Foo(m.Circuit): with pytest.raises(TypeError) as e: m.mux([1, 2], io.S) assert str(e.value) == f"""\ -Could not infer mux type from [1, 2] +Could not infer mux type from (1, 2) Need at least one magma value, BitVector, bool or tuple\ """ diff --git a/tests/test_primitives/test_mux.py b/tests/test_primitives/test_mux.py index 03f3284b0..04c57519f 100644 --- a/tests/test_primitives/test_mux.py +++ b/tests/test_primitives/test_mux.py @@ -301,7 +301,7 @@ class test_mux_array_select_bits_1(m.Circuit): def test_mux_intv(ht_T, m_T): class Main(m.Circuit): O = m.mux([ht_T[4](1), m_T[4](2)], m.Bit()) - assert isinstance(O, m_T) + assert isinstance(O, m_T), type(O) @pytest.mark.parametrize("ht_T", [ht.UIntVector, ht.SIntVector])