diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 9f86998640be..a2101263082d 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -547,7 +547,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor(var); } - Expr VisitExpr_(const VarNode* var) final { return VisitVar_(var); } + Expr VisitExpr_(const VarNode* var_ptr) final { + auto var = VisitVar_(var_ptr); + if (HasVoidStructInfo(var)) { + return VisitExpr(Tuple(Array{})); + } else { + return var; + } + } Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_(var); } diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index 395e027bce57..7fd7e21a6739 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -20,6 +20,7 @@ #define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ #include +#include #include #include @@ -109,12 +110,12 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String GetStructInfo(last_binding->var)); tvm::relax::Expr body; - if (const auto* var_binding = last_binding.as(); - var_binding && var_binding->value->IsInstance()) { + const auto* var_binding = last_binding.as(); + + if (var_binding && tvm::relax::IsLeafOrTuple(var_binding->value)) { body = var_binding->value; - } else if (const auto* var_binding = last_binding.as()) { - last_block_bindings.push_back(last_binding = - tvm::relax::VarBinding(new_var, var_binding->value)); + } else if (var_binding) { + last_block_bindings.push_back(tvm::relax::VarBinding(new_var, var_binding->value)); body = new_var; } else if (const auto* match_cast = last_binding.as()) { last_block_bindings.push_back( diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 5aa99878f951..44a2cd338c5e 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -69,24 +69,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Doc ret = d->AsDoc(n->value, n_p->Attr("value")); d->cfg->binding_names.pop_back(); return ret; - - // Uncommenting this section hides the variable binding - // when the StructInfo is void. For example, printing - // `R.assert_op(expr)` instead of `_ = R.assert_op(expr)`. - // However, Relax represents void values as an empty - // tuple, and a void-type variable may still be used later - // in the function. Hiding bindings of these void-type - // variables would result in use of an undefined variable. - // - // TODO(Lunderberg): Inline void-type variable to use - // `R.tuple()` during normalization. This will avoid the - // cases that trigger the undefined variables, and allow - // this syntax sugar to be enabled. - // - // } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) && - // relax::HasVoidStructInfo(n->var)) { - // ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); - // return ExprStmtDoc(rhs); + } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) && + relax::HasVoidStructInfo(n->var)) { + ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); + return ExprStmtDoc(rhs); } else { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); diff --git a/tests/python/relax/test_transform_lift_transform_params.py b/tests/python/relax/test_transform_lift_transform_params.py index d75aeedf822c..80de52ca6621 100644 --- a/tests/python/relax/test_transform_lift_transform_params.py +++ b/tests/python/relax/test_transform_lift_transform_params.py @@ -548,8 +548,10 @@ def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) with R.dataflow(): gv: R.Tuple = R.tuple() - R.output(gv) - return gv + R.output() + # All instance of the empty tuple are normalized to be + # in-line. + return R.tuple() @R.function def main(shape: R.Shape(["n"])) -> R.Shape(["n"]): @@ -612,8 +614,8 @@ def main_transform_params(params: R.Tuple) -> R.Tuple: R.func_attr({"num_input": 0}) with R.dataflow(): gv: R.Tuple = R.tuple() - R.output(gv) - return gv + R.output() + return R.tuple() @R.function def main(shape: R.Shape(["n"])) -> R.Shape(["n"]): diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index a6feb0b8abca..f37df4d07969 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -552,5 +552,34 @@ def test_nesting_non_dataflow_in_dataflow_error(): # should fail due to a normal binding block being inside a dataflowblock +def test_remove_usage_of_void_type_variables(): + """All empty tuples should be constructed in-line + + For readability, TVMScript hides the variable binding if the + variable has a void type. For example, `R.assert_op(condition)` + instead of `void_var: R.Tuple([]) = R.assert_op(condition)`. + However, Relax follows standard convention of functional + languages, and uses an empty tuple to represent void. Since an + empty tuple may be legally used later in the function, the + `void_var` may require a binding. + + This is avoided by normalizing all usage of a void-type + variable with an in-line `R.tuple()`. + """ + x = relax.Var("x", R.Tuple([])) + bindings = [ + relax.VarBinding(x, R.assert_op(R.const(True, "bool"))), + ] + seq = relax.SeqExpr([relax.BindingBlock(bindings)], x) + before = relax.Function([], seq, ret_struct_info=R.Tuple([])) + + after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"] + + @R.function(private=True) + def expected(): + x = R.assert_op(R.const(True, "bool")) + return R.tuple() + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 667fb0a132b6..7b64eb1dee39 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -16,8 +16,6 @@ # under the License. # pylint: disable=missing-docstring -import pytest - import tvm import tvm.testing from tvm import IRModule, relax, tir @@ -636,7 +634,6 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32 ) -@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_assert_op(): @I.ir_module class AssertOpMod: @@ -661,7 +658,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) -@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_print(): @I.ir_module class PrintMod: @@ -710,7 +706,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) -@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_directly_construct_private_funcs(): # public @R.function