From 37dc108452c172f4023ca082408a935a2ee580c9 Mon Sep 17 00:00:00 2001 From: Ryan Zambrotta <33207723+rtzam@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:37:14 -0700 Subject: [PATCH] Improves the error message of invalid index expression use (#696) This commit adds a test to verify that when a user attempts to allocate a temporary variable with an index type, a descriptive error is provided --------- Co-authored-by: Yuka Ikarashi --- src/exo/pyparser.py | 71 +++++++++++++++++++++-------------------- tests/test_typecheck.py | 12 +++++++ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/src/exo/pyparser.py b/src/exo/pyparser.py index 25403fc41..f997543c6 100644 --- a/src/exo/pyparser.py +++ b/src/exo/pyparser.py @@ -128,6 +128,24 @@ def getsrcinfo(node): # Parser Pass object +# detect which sort of type we have here +_is_size = lambda x: isinstance(x, pyast.Name) and x.id == "size" +_is_index = lambda x: isinstance(x, pyast.Name) and x.id == "index" +_is_bool = lambda x: isinstance(x, pyast.Name) and x.id == "bool" +_is_stride = lambda x: isinstance(x, pyast.Name) and x.id == "stride" + +_prim_types = { + "R": UAST.Num(), + "f16": UAST.F16(), + "f32": UAST.F32(), + "f64": UAST.F64(), + "i8": UAST.INT8(), + "ui8": UAST.UINT8(), + "ui16": UAST.UINT16(), + "i32": UAST.INT32(), +} + + class Parser: def __init__( self, @@ -312,8 +330,8 @@ def parse_config_field(self, stmt): "index": UAST.Index(), "stride": UAST.Stride(), } - for k in Parser._prim_types: - typ_list[k] = Parser._prim_types[k] + for k in _prim_types: + typ_list[k] = _prim_types[k] del typ_list["R"] if ( @@ -348,35 +366,29 @@ def is_at(x): typ_node = node mem_node = None - # detect which sort of type we have here - is_size = lambda x: isinstance(x, pyast.Name) and x.id == "size" - is_index = lambda x: isinstance(x, pyast.Name) and x.id == "index" - is_bool = lambda x: isinstance(x, pyast.Name) and x.id == "bool" - is_stride = lambda x: isinstance(x, pyast.Name) and x.id == "stride" - # parse each kind of type here - if is_size(typ_node): + if _is_size(typ_node): if mem_node is not None: self.err( node, "size types should not be annotated with " "memory locations" ) return UAST.Size(), None - elif is_index(typ_node): + elif _is_index(typ_node): if mem_node is not None: self.err( node, "size types should not be annotated with " "memory locations" ) return UAST.Index(), None - elif is_bool(typ_node): + elif _is_bool(typ_node): if mem_node is not None: self.err( node, "size types should not be annotated with " "memory locations" ) return UAST.Bool(), None - elif is_stride(typ_node): + elif _is_stride(typ_node): if mem_node is not None: self.err( node, @@ -403,17 +415,6 @@ def parse_alloc_typmem(self, node): typ = self.parse_num_type(node) return typ, mem - _prim_types = { - "R": UAST.Num(), - "f16": UAST.F16(), - "f32": UAST.F32(), - "f64": UAST.F64(), - "i8": UAST.INT8(), - "ui8": UAST.UINT8(), - "ui16": UAST.UINT16(), - "i32": UAST.INT32(), - } - def parse_num_type(self, node, is_arg=False): if isinstance(node, pyast.Subscript): if isinstance(node.value, pyast.List): @@ -431,23 +432,17 @@ def parse_num_type(self, node, is_arg=False): ) base = node.value.elts[0] - if ( - not isinstance(base, pyast.Name) - or base.id not in Parser._prim_types - ): + if not isinstance(base, pyast.Name) or base.id not in _prim_types: self.err( node, "expected window type to be of " "the form '[R][...]', '[f32][...]', etc.", ) - typ = Parser._prim_types[base.id] + typ = _prim_types[base.id] is_window = True - elif ( - isinstance(node.value, pyast.Name) - and node.value.id in Parser._prim_types - ): - typ = Parser._prim_types[node.value.id] + elif isinstance(node.value, pyast.Name) and node.value.id in _prim_types: + typ = _prim_types[node.value.id] is_window = False else: self.err( @@ -481,8 +476,14 @@ def parse_num_type(self, node, is_arg=False): return typ - elif isinstance(node, pyast.Name) and node.id in Parser._prim_types: - return Parser._prim_types[node.id] + elif isinstance(node, pyast.Name) and node.id in _prim_types: + return _prim_types[node.id] + elif isinstance(node, pyast.Name) and ( + _is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node) + ): + raise ParseError( + node, f"Cannot allocate an intermediate value of type {node.id}" + ) else: self.err(node, "unrecognized type: " + pyast.dump(node)) diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index 0b15a14a6..393bc4d74 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -4,6 +4,7 @@ from exo import proc, config from exo.libs.memories import GEMM_SCRATCH +from exo.pyparser import ParseError # --- Typechecking tests --- @@ -18,6 +19,17 @@ class ConfigLoad: return ConfigLoad +def test_size0(): + with pytest.raises( + ParseError, match="Cannot allocate an intermediate value of type" + ): + + @proc + def foo(x: size): + size: size + pass + + def test_stride1(): ConfigLoad = new_config_ld()