Skip to content

Commit b702e13

Browse files
committed
Add unit test for round-trip of opaque function
1 parent e328963 commit b702e13

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

src/script/printer/relax/utils.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,16 @@ inline Optional<ExprDoc> StructInfoAsAnn(const relax::Var& v, const ObjectPath&
8484
if (!v->struct_info_.defined()) {
8585
return NullOpt;
8686
}
87+
bool attempt_to_hide_struct_info = !d->cfg->show_all_struct_info;
88+
8789
if (const auto* call = rhs.as<relax::CallNode>()) {
8890
static const Op& call_tir_op = Op::Get("relax.call_tir");
8991
static const Op& call_dps_packed_op = Op::Get("relax.call_dps_packed");
9092
if (call->op.same_as(call_tir_op) || call->op.same_as(call_dps_packed_op)) {
91-
return NullOpt;
93+
attempt_to_hide_struct_info = true;
9294
}
9395
}
94-
if (!d->cfg->show_all_struct_info) {
96+
if (attempt_to_hide_struct_info) {
9597
Optional<relax::StructInfo> inferred_sinfo = NullOpt;
9698
if (auto opt = rhs.as<relax::Call>()) {
9799
auto call = opt.value();

tests/python/tvmscript/test_tvmscript_roundtrip.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tvm
2222
import tvm.testing
2323
from tvm import tir
24-
from tvm.script import tir as T, ir as I
24+
from tvm.script import tir as T, ir as I, relax as R
2525

2626
import numpy as np
2727

@@ -3996,6 +3996,24 @@ def func():
39963996
yield make_ir_generator(op, arg)
39973997

39983998

3999+
def relax_extern_func():
4000+
@R.function
4001+
def func(A: R.Tensor([10, 20], "float32")):
4002+
func = R.ExternFunc("dummy_func")
4003+
4004+
B: R.Tensor([10, 20], "float32") = R.call_dps_packed(
4005+
func, [A], out_sinfo=R.Tensor([10, 20], "float32")
4006+
)
4007+
4008+
C: R.Tensor(ndim=2, dtype="float32") = R.call_dps_packed(
4009+
func, [B], out_sinfo=R.Tensor([10, 20], "float32")
4010+
)
4011+
4012+
return C
4013+
4014+
return func
4015+
4016+
39994017
ir_generator = tvm.testing.parameter(
40004018
launch_env_thread,
40014019
opt_gemm_normalize,
@@ -4081,13 +4099,35 @@ def func():
40814099
*op_of_literal(),
40824100
)
40834101

4102+
relax_ir_generator = tvm.testing.parameter(
4103+
relax_extern_func,
4104+
)
4105+
4106+
show_all_relax_struct_info = tvm.testing.parameter(
4107+
by_dict={
4108+
"show_all_struct_info": True,
4109+
"hide_inferable_struct_info": False,
4110+
}
4111+
)
4112+
40844113

40854114
def test_roundtrip(ir_generator):
40864115
original = ir_generator()
40874116
after_roundtrip = tvm.script.from_source(original.script(show_meta=True))
40884117
tvm.ir.assert_structural_equal(original, after_roundtrip, True)
40894118

40904119

4120+
def test_relax_roundtrip(relax_ir_generator, show_all_relax_struct_info):
4121+
original = relax_ir_generator()
4122+
after_roundtrip = tvm.script.from_source(
4123+
original.script(
4124+
show_meta=True,
4125+
show_all_struct_info=show_all_relax_struct_info,
4126+
)
4127+
)
4128+
tvm.ir.assert_structural_equal(original, after_roundtrip, True)
4129+
4130+
40914131
def test_return_none_no_trailing_type():
40924132
func = return_none()
40934133
script = func.script()

0 commit comments

Comments
 (0)