Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,14 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
return GetRef<Var>(var);
}

Expr VisitExpr_(const VarNode* var) final { return VisitVar_<Var>(var); }
Expr VisitExpr_(const VarNode* var_ptr) final {
auto var = VisitVar_<Var>(var_ptr);
if (HasVoidStructInfo(var)) {
return VisitExpr(Tuple(Array<Expr>{}));
} else {
return var;
}
}

Expr VisitExpr_(const DataflowVarNode* var) final { return VisitVar_<DataflowVar>(var); }

Expand Down
11 changes: 6 additions & 5 deletions src/script/ir_builder/relax/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_

#include <tvm/relax/struct_info_functor.h>
#include <tvm/relax/utils.h>
#include <tvm/script/ir_builder/relax/frame.h>
#include <tvm/script/ir_builder/relax/ir.h>

Expand Down Expand Up @@ -109,12 +110,12 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String
GetStructInfo(last_binding->var));
tvm::relax::Expr body;

if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();
var_binding && var_binding->value->IsInstance<tvm::relax::VarNode>()) {
const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();

if (var_binding && tvm::relax::IsLeafOrTuple(var_binding->value)) {
body = var_binding->value;
} else if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
last_block_bindings.push_back(last_binding =
tvm::relax::VarBinding(new_var, var_binding->value));
} else if (var_binding) {
last_block_bindings.push_back(tvm::relax::VarBinding(new_var, var_binding->value));
body = new_var;
} else if (const auto* match_cast = last_binding.as<tvm::relax::MatchCastNode>()) {
last_block_bindings.push_back(
Expand Down
22 changes: 4 additions & 18 deletions src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,24 +69,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
d->cfg->binding_names.pop_back();
return ret;

// Uncommenting this section hides the variable binding
// when the StructInfo is void. For example, printing
// `R.assert_op(expr)` instead of `_ = R.assert_op(expr)`.
// However, Relax represents void values as an empty
// tuple, and a void-type variable may still be used later
// in the function. Hiding bindings of these void-type
// variables would result in use of an undefined variable.
//
// TODO(Lunderberg): Inline void-type variable to use
// `R.tuple()` during normalization. This will avoid the
// cases that trigger the undefined variables, and allow
// this syntax sugar to be enabled.
//
// } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) &&
// relax::HasVoidStructInfo(n->var)) {
// ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
// return ExprStmtDoc(rhs);
} else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) &&
relax::HasVoidStructInfo(n->var)) {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
return ExprStmtDoc(rhs);
} else {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_transform_lift_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,10 @@ def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
return gv
R.output()
# All instance of the empty tuple are normalized to be
# in-line.
return R.tuple()

@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
Expand Down Expand Up @@ -612,8 +614,8 @@ def main_transform_params(params: R.Tuple) -> R.Tuple:
R.func_attr({"num_input": 0})
with R.dataflow():
gv: R.Tuple = R.tuple()
R.output(gv)
return gv
R.output()
return R.tuple()

@R.function
def main(shape: R.Shape(["n"])) -> R.Shape(["n"]):
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relax/test_transform_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,5 +552,34 @@ def test_nesting_non_dataflow_in_dataflow_error():
# should fail due to a normal binding block being inside a dataflowblock


def test_remove_usage_of_void_type_variables():
"""All empty tuples should be constructed in-line

For readability, TVMScript hides the variable binding if the
variable has a void type. For example, `R.assert_op(condition)`
instead of `void_var: R.Tuple([]) = R.assert_op(condition)`.
However, Relax follows standard convention of functional
languages, and uses an empty tuple to represent void. Since an
empty tuple may be legally used later in the function, the
`void_var` may require a binding.

This is avoided by normalizing all usage of a void-type
variable with an in-line `R.tuple()`.
"""
x = relax.Var("x", R.Tuple([]))
bindings = [
relax.VarBinding(x, R.assert_op(R.const(True, "bool"))),
]
seq = relax.SeqExpr([relax.BindingBlock(bindings)], x)
before = relax.Function([], seq, ret_struct_info=R.Tuple([]))

after = relax.transform.Normalize()(tvm.IRModule({"main": before}))["main"]

@R.function(private=True)
def expected():
x = R.assert_op(R.const(True, "bool"))
return R.tuple()


if __name__ == "__main__":
tvm.testing.main()
5 changes: 0 additions & 5 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
# pylint: disable=missing-docstring

import pytest

import tvm
import tvm.testing
from tvm import IRModule, relax, tir
Expand Down Expand Up @@ -636,7 +634,6 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_assert_op():
@I.ir_module
class AssertOpMod:
Expand All @@ -661,7 +658,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_print():
@I.ir_module
class PrintMod:
Expand Down Expand Up @@ -710,7 +706,6 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_directly_construct_private_funcs():
# public
@R.function
Expand Down