Skip to content

Commit a6ca1ec

Browse files
Tristan Konoligeyangulei
authored andcommitted
[TVMSCRIPT] Misc error message improvements (apache#9543)
* [TVMSCRIPT] Misc error message improvements * only prevent indexing into handles with multiple indexes * lint
1 parent 4b4a3ee commit a6ca1ec

File tree

4 files changed

+170
-37
lines changed

4 files changed

+170
-37
lines changed

python/tvm/script/parser.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,18 @@ def transform_Assign(self, node):
566566
self.context.remove_symbol(var.name)
567567
return tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span))
568568

569-
self.report_error("Unsupported Assign stmt", node.span)
569+
self.report_error(
570+
"""Assignments should be either
571+
1. A "special statement" with return value
572+
1.1 Buffer = T.match_buffer()/T.buffer_decl()
573+
1.2 Var = T.var()
574+
1.3 Var = T.env_thread()
575+
2. A store into a buffer: Buffer[PrimExpr, PrimExpr, ..., PrimExpr] = PrimExpr
576+
3. A store into a variable: Var[PrimExpr] = PrimExpr
577+
4. A with scope handler with concise scoping and var def
578+
4.1 var = T.allocate()""",
579+
node.span,
580+
)
570581

571582
def transform_SubscriptAssign(self, node):
572583
"""Visitor for statements of the form :code:`x[1] = 2`."""
@@ -583,6 +594,12 @@ def transform_SubscriptAssign(self, node):
583594
span=tvm_span_from_synr(node.span),
584595
)
585596
else:
597+
if symbol.dtype == "handle" and len(indexes) != 1:
598+
self.report_error(
599+
"Handles only support one-dimensional indexing. Use `T.match_buffer` to "
600+
"construct a multidimensional buffer from a handle.",
601+
node.params[0].span,
602+
)
586603
if len(indexes) != 1:
587604
self.report_error(
588605
f"Store is only allowed with one index, but {len(indexes)} were provided.",
@@ -736,9 +753,35 @@ def transform_Call(self, node):
736753
return self.transform_Subscript(node)
737754
if node.func_name.name in self._binop_maker:
738755
lhs = self.transform(node.params[0])
756+
# There is no supertype for everything that can appear in
757+
# an expression, so we manually add what we might get here.
758+
if not isinstance(lhs, (tvm.tir.PrimExpr, BufferSlice)):
759+
# We would really like to report a more specific
760+
# error here, but this parser contains no distinction
761+
# between parsing statements and parsing expressions. All
762+
# rules just call `transform`.
763+
self.report_error(
764+
f"Left hand side of binary op must be a PrimExpr, "
765+
"but it is a {type(lhs).__name__}",
766+
node.params[0].span,
767+
)
739768
rhs = self.transform(node.params[1])
740-
return self._binop_maker[node.func_name.name](
741-
lhs, rhs, span=tvm_span_from_synr(node.span)
769+
if not isinstance(rhs, (tvm.tir.PrimExpr, BufferSlice)):
770+
self.report_error(
771+
f"Right hand side of binary op must be a PrimExpr, "
772+
"but it is a {type(rhs).__name__}",
773+
node.params[1].span,
774+
)
775+
return call_with_error_reporting(
776+
self.report_error,
777+
node.span,
778+
lambda node, lhs, rhs, span: self._binop_maker[node.func_name.name](
779+
lhs, rhs, span=span
780+
),
781+
node,
782+
lhs,
783+
rhs,
784+
tvm_span_from_synr(node.span),
742785
)
743786
if node.func_name.name in self._unaryop_maker:
744787
rhs = self.transform(node.params[0])
@@ -764,6 +807,8 @@ def transform_Call(self, node):
764807
self.transform(k): self.transform(v) for k, v in node.keyword_params.items()
765808
}
766809
if isinstance(func, tvm.tir.op.Op):
810+
if not "dtype" in kw_args.keys():
811+
self.report_error(f"{func} requires a dtype keyword argument.", node.span)
767812
# pattern 2
768813
return tvm.tir.Call(
769814
kw_args["dtype"], func, args, span=tvm_span_from_synr(node.span)
@@ -862,15 +907,33 @@ def transform_Subscript(self, node):
862907

863908
indexes = [self.transform(x) for x in node.params[1].values]
864909
if isinstance(symbol, tvm.tir.expr.Var):
865-
for index in indexes:
866-
if not isinstance(index, (tvm.tir.PrimExpr, int)):
867-
self.report_error(
868-
"Buffer load indexes should be int or PrimExpr, but they are "
869-
+ type(index),
870-
node.span,
871-
)
872-
return tvm.tir.Load(
873-
"float32", symbol, indexes, True, span=tvm_span_from_synr(node.span)
910+
if symbol.dtype == "handle":
911+
self.report_error(
912+
"Cannot read directly from a handle, use `T.match_buffer` "
913+
"to create a buffer to read from.",
914+
node.params[0].span,
915+
)
916+
if len(indexes) > 1:
917+
self.report_error(
918+
"Only a single index can be provided when indexing into a `var`.",
919+
node.params[1].span,
920+
)
921+
index = indexes[0]
922+
if not isinstance(index, (tvm.tir.PrimExpr, int)):
923+
self.report_error(
924+
"Var load index should be an int or PrimExpr, but it is a" + type(index),
925+
node.span,
926+
)
927+
928+
return call_with_error_reporting(
929+
self.report_error,
930+
node.span,
931+
tvm.tir.Load,
932+
"float32",
933+
symbol,
934+
index,
935+
True,
936+
span=tvm_span_from_synr(node.span),
874937
)
875938
elif isinstance(symbol, tvm.tir.Buffer):
876939
return BufferSlice(

src/tir/ir/expr.cc

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,34 @@
3131
namespace tvm {
3232
namespace tir {
3333

34-
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
35-
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
36-
using T = Name::ContainerType; \
37-
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
38-
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
39-
ICHECK(a.dtype() == b.dtype()) \
40-
<< "TypeError: mismatched types. " << a.dtype() << " vs. " << b.dtype() << "\n"; \
41-
ObjectPtr<T> node = make_object<T>(); \
42-
node->dtype = a.dtype(); \
43-
node->a = std::move(a); \
44-
node->b = std::move(b); \
45-
node->span = std::move(span); \
46-
data_ = std::move(node); \
34+
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
35+
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
36+
using T = Name::ContainerType; \
37+
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
38+
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
39+
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
40+
<< b.dtype() << "\n"; \
41+
ObjectPtr<T> node = make_object<T>(); \
42+
node->dtype = a.dtype(); \
43+
node->a = std::move(a); \
44+
node->b = std::move(b); \
45+
node->span = std::move(span); \
46+
data_ = std::move(node); \
4747
}
4848

49-
#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
50-
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
51-
using T = Name::ContainerType; \
52-
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
53-
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
54-
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
55-
ObjectPtr<T> node = make_object<T>(); \
56-
node->dtype = DataType::Bool(a.dtype().lanes()); \
57-
node->a = std::move(a); \
58-
node->b = std::move(b); \
59-
node->span = std::move(span); \
60-
data_ = std::move(node); \
49+
#define TVM_DEFINE_CMPOP_CONSTRUCTOR(Name) \
50+
Name::Name(PrimExpr a, PrimExpr b, Span span) { \
51+
using T = Name::ContainerType; \
52+
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
53+
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
54+
CHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types. " << a.dtype() << " vs. " \
55+
<< b.dtype() << "\n"; \
56+
ObjectPtr<T> node = make_object<T>(); \
57+
node->dtype = DataType::Bool(a.dtype().lanes()); \
58+
node->a = std::move(a); \
59+
node->b = std::move(b); \
60+
node->span = std::move(span); \
61+
data_ = std::move(node); \
6162
}
6263

6364
// Var

tests/python/unittest/test_tvmscript_complete.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,12 @@ def test_complete_alloc_buffer():
314314
tvm.ir.assert_structural_equal(alloc_buffer_func, expect_alloc_buffer_func)
315315

316316

317+
@T.prim_func
318+
def load_var() -> None:
319+
d = T.var("float32")
320+
d[1] = d[1]
321+
322+
317323
if __name__ == "__main__":
318324
test_complete_matmul()
319325
test_complete_matmul_original()

tests/python/unittest/test_tvmscript_error_report.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,5 +614,68 @@ def test_fuse_fail_nested_loop_outer():
614614
assert expected_sub_error_message in str(execinfo.value)
615615

616616

617+
def load_var_multiple() -> None:
618+
d = T.var("float32")
619+
d[2] = d[2, 1] # error cannot provide two indices to load
620+
621+
622+
def test_load_var():
623+
check_error(load_var_multiple, 3)
624+
625+
626+
def store_var_multiple() -> None:
627+
d = T.var("float32")
628+
d[2, 1] = d[1] # error cannot provide two indices to store
629+
630+
631+
def test_store_var():
632+
check_error(store_var_multiple, 3)
633+
634+
635+
def load_handle(h: T.handle) -> None:
636+
h_ = T.match_buffer(h, [1])
637+
h_[0] = h[0] # error cannot load from handle
638+
639+
640+
def test_load_handle():
641+
check_error(load_var_multiple, 3)
642+
643+
644+
def store_handle(h: T.handle) -> None:
645+
h_ = T.match_buffer(h, [1])
646+
h[0] = h_[0] # error cannot store to handle
647+
648+
649+
def test_store_handle():
650+
check_error(store_var_multiple, 3)
651+
652+
653+
def binop_bad_ast_type(h: T.handle):
654+
h_ = T.match_buffer(h, [1])
655+
h_[0] = h + [2] # error rhs should be a primexpr
656+
657+
658+
def test_binop_bad_ast_type():
659+
check_error(binop_bad_ast_type, 3)
660+
661+
662+
def binop_bad_type(h: T.handle):
663+
h_ = T.match_buffer(h, [1])
664+
h_[0] = h + 2 # error lhs and rhs should be the same type
665+
666+
667+
def test_binop_bad_type():
668+
check_error(binop_bad_type, 3)
669+
670+
671+
def floor_dtype(h: T.handle):
672+
h_ = T.match_buffer(h, [1])
673+
h_[0] = T.floor(2) # error floor requires a dtype
674+
675+
676+
def test_floor_dtype():
677+
check_error(floor_dtype, 3)
678+
679+
617680
if __name__ == "__main__":
618681
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)