Skip to content

Commit

Permalink
Improves the error message of invalid index expression use (#696)
Browse files Browse the repository at this point in the history
This commit adds a test to verify that when a
user attempts to allocate a temporary
variable with an index type, a descriptive
error is provided

---------

Co-authored-by: Yuka Ikarashi <[email protected]>
  • Loading branch information
rtzam and yamaguchi1024 authored Aug 28, 2024
1 parent b12ee6f commit 37dc108
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
71 changes: 36 additions & 35 deletions src/exo/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,24 @@ def getsrcinfo(node):
# Parser Pass object


# detect which sort of type we have here
_is_size = lambda x: isinstance(x, pyast.Name) and x.id == "size"
_is_index = lambda x: isinstance(x, pyast.Name) and x.id == "index"
_is_bool = lambda x: isinstance(x, pyast.Name) and x.id == "bool"
_is_stride = lambda x: isinstance(x, pyast.Name) and x.id == "stride"

_prim_types = {
"R": UAST.Num(),
"f16": UAST.F16(),
"f32": UAST.F32(),
"f64": UAST.F64(),
"i8": UAST.INT8(),
"ui8": UAST.UINT8(),
"ui16": UAST.UINT16(),
"i32": UAST.INT32(),
}


class Parser:
def __init__(
self,
Expand Down Expand Up @@ -312,8 +330,8 @@ def parse_config_field(self, stmt):
"index": UAST.Index(),
"stride": UAST.Stride(),
}
for k in Parser._prim_types:
typ_list[k] = Parser._prim_types[k]
for k in _prim_types:
typ_list[k] = _prim_types[k]
del typ_list["R"]

if (
Expand Down Expand Up @@ -348,35 +366,29 @@ def is_at(x):
typ_node = node
mem_node = None

# detect which sort of type we have here
is_size = lambda x: isinstance(x, pyast.Name) and x.id == "size"
is_index = lambda x: isinstance(x, pyast.Name) and x.id == "index"
is_bool = lambda x: isinstance(x, pyast.Name) and x.id == "bool"
is_stride = lambda x: isinstance(x, pyast.Name) and x.id == "stride"

# parse each kind of type here
if is_size(typ_node):
if _is_size(typ_node):
if mem_node is not None:
self.err(
node, "size types should not be annotated with " "memory locations"
)
return UAST.Size(), None

elif is_index(typ_node):
elif _is_index(typ_node):
if mem_node is not None:
self.err(
node, "size types should not be annotated with " "memory locations"
)
return UAST.Index(), None

elif is_bool(typ_node):
elif _is_bool(typ_node):
if mem_node is not None:
self.err(
node, "size types should not be annotated with " "memory locations"
)
return UAST.Bool(), None

elif is_stride(typ_node):
elif _is_stride(typ_node):
if mem_node is not None:
self.err(
node,
Expand All @@ -403,17 +415,6 @@ def parse_alloc_typmem(self, node):
typ = self.parse_num_type(node)
return typ, mem

_prim_types = {
"R": UAST.Num(),
"f16": UAST.F16(),
"f32": UAST.F32(),
"f64": UAST.F64(),
"i8": UAST.INT8(),
"ui8": UAST.UINT8(),
"ui16": UAST.UINT16(),
"i32": UAST.INT32(),
}

def parse_num_type(self, node, is_arg=False):
if isinstance(node, pyast.Subscript):
if isinstance(node.value, pyast.List):
Expand All @@ -431,23 +432,17 @@ def parse_num_type(self, node, is_arg=False):
)

base = node.value.elts[0]
if (
not isinstance(base, pyast.Name)
or base.id not in Parser._prim_types
):
if not isinstance(base, pyast.Name) or base.id not in _prim_types:
self.err(
node,
"expected window type to be of "
"the form '[R][...]', '[f32][...]', etc.",
)

typ = Parser._prim_types[base.id]
typ = _prim_types[base.id]
is_window = True
elif (
isinstance(node.value, pyast.Name)
and node.value.id in Parser._prim_types
):
typ = Parser._prim_types[node.value.id]
elif isinstance(node.value, pyast.Name) and node.value.id in _prim_types:
typ = _prim_types[node.value.id]
is_window = False
else:
self.err(
Expand Down Expand Up @@ -481,8 +476,14 @@ def parse_num_type(self, node, is_arg=False):

return typ

elif isinstance(node, pyast.Name) and node.id in Parser._prim_types:
return Parser._prim_types[node.id]
elif isinstance(node, pyast.Name) and node.id in _prim_types:
return _prim_types[node.id]
elif isinstance(node, pyast.Name) and (
_is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node)
):
raise ParseError(
node, f"Cannot allocate an intermediate value of type {node.id}"
)
else:
self.err(node, "unrecognized type: " + pyast.dump(node))

Expand Down
12 changes: 12 additions & 0 deletions tests/test_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from exo import proc, config
from exo.libs.memories import GEMM_SCRATCH
from exo.pyparser import ParseError


# --- Typechecking tests ---
Expand All @@ -18,6 +19,17 @@ class ConfigLoad:
return ConfigLoad


def test_size0():
with pytest.raises(
ParseError, match="Cannot allocate an intermediate value of type"
):

@proc
def foo(x: size):
size: size
pass


def test_stride1():
ConfigLoad = new_config_ld()

Expand Down

0 comments on commit 37dc108

Please sign in to comment.