Skip to content

Commit 1af82ad

Browse files
authored
[Unity] Validate struct info in relax::Call constructor (#16311)
* [Unity] Validate struct info in relax::Call constructor All operations called by a `relax::Call` node must have a `FuncStructInfo`. Prior to this commit, an invalid struct info would be caught by the `BlockBuilder` during normalization. This delay between the invalid `relax::Call` being constructed and the invalid `relax::Call` being detected makes debugging difficult. This commit adds an additional check during the `relax::Call` constructor, to provide earlier error detection. * Updated unit test to avoid using Tensor as callable function
1 parent 4a7e4fe commit 1af82ad

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

src/relax/ir/expr.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ Id::Id(String name_hint) {
3636
}
3737

3838
Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<StructInfo> sinfo_args, Span span) {
39+
CHECK(!op->struct_info_.defined() || op->struct_info_->IsInstance<FuncStructInfoNode>())
40+
<< "ValueError: "
41+
<< "Call expects its operator to have FuncStructInfo, "
42+
<< "but operator " << op << ", which was called with arguments " << args
43+
<< ", has struct info " << op->struct_info_;
44+
3945
ObjectPtr<CallNode> n = make_object<CallNode>();
4046
n->op = std::move(op);
4147
n->args = std::move(args);

tests/python/relax/test_expr.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,25 @@ def test_datatype_imm():
271271
_check_json_roundtrip(d0)
272272

273273

274+
def test_call():
275+
dtype = rx.PrimStructInfo("int32")
276+
func = rx.Var("func", rx.FuncStructInfo([dtype], dtype))
277+
arg = rx.Var("arg", dtype)
278+
call = rx.Call(func, [arg])
279+
assert call.op.same_as(func)
280+
assert len(call.args) == 1
281+
assert call.args[0].same_as(arg)
282+
283+
284+
def test_call_raises_error_for_invalid_function():
285+
"""relax::Call requires the function to have FuncStructInfo"""
286+
dtype = rx.PrimStructInfo("int32")
287+
func = rx.Var("func", dtype)
288+
arg = rx.Var("arg", dtype)
289+
290+
with pytest.raises(ValueError):
291+
rx.Call(func, [arg])
292+
293+
274294
if __name__ == "__main__":
275295
tvm.testing.main()

tests/python/relax/test_op_misc.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,16 @@ def test_implicit_op():
6666
m, n = tvm.tir.Var("m", "int64"), tvm.tir.Var("n", "int64")
6767
x = rx.Var("x", R.Tensor([m, n], "float32"))
6868
y = rx.Var("y", R.Tensor([m, n], "float32"))
69+
func = rx.Var(
70+
"func",
71+
R.Callable(
72+
[R.Tensor([m, n], "float32")],
73+
R.Callable(
74+
[R.Tensor([m, n], "float32")],
75+
R.Tuple,
76+
),
77+
),
78+
)
6979

7080
def _check_call(expr, op_name: str):
7181
assert isinstance(expr, rx.Call)
@@ -94,9 +104,9 @@ def _check_call(expr, op_name: str):
94104
_check_call(x.astype("float32"), "astype")
95105

96106
# Call
97-
call_expr = x(y)(y)
107+
call_expr = func(y)(y)
98108
assert isinstance(call_expr.op, rx.Call)
99-
assert call_expr.op.op == x
109+
assert call_expr.op.op == func
100110

101111
# GetTupleItem
102112
## Eager get item for tuple

0 commit comments

Comments
 (0)