Skip to content

[Bug] [Relax] Variable was used before its definition #17222

@Cookiee235

Description

@Cookiee235

The Relax IR in the below test case passed the well-formed checking, but failed when using the DCE unexpectedly!

Actual behavior

Traceback (most recent call last):
  File "test_sim.py", line 54, in <module>
    mod = tvm.relax.transform.DeadCodeElimination()(mod)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
    raise_last_ffi_error()
  File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform19DeadCodeEliminationENS0_5ArrayINS0_6StringEvEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SI_SM_
  13: tvm::relax::DeadCodeElimination(tvm::IRModule const&, tvm::runtime::Array<tvm::runtime::String, void>)
  12: tvm::relax::RemoveAllUnused(tvm::RelayExpr)
  11: tvm::relax::CollectVarUsage(tvm::RelayExpr const&)
  10: tvm::relax::UDChain::Collect(tvm::RelayExpr const&)
  9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  8: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
  7: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
  6: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
  5: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
  4: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  3: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
  2: tvm::relax::UDChain::VisitBinding_(tvm::relax::VarBindingNode const*)
  1: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
  0: tvm::relax::UDChain::VisitVarDef(tvm::relax::Var const&)
  File "/software/tvm-lunder/src/relax/analysis/udchain.cc", line 75
TVMError: Check failed: (!usage_map.count(var)) is false: Variable while_loop was used before its definition

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def add(i: T.Buffer((), "int32"), c: T.Buffer((), "int32"), T_add: T.Buffer((), "int32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        with T.block("T_add"):
            vi = T.axis.spatial(1, T.int64(0))
            T.reads(i[()], c[()])
            T.writes(T_add[()])
            T_add[()] = i[()] + c[()]

    @T.prim_func(private=True)
    def add1(s: T.Buffer((T.int64(2), T.int64(3)), "float32"), x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(s[v_ax0, v_ax1], x[v_ax0, v_ax1])
                T.writes(T_add[v_ax0, v_ax1])
                T_add[v_ax0, v_ax1] = s[v_ax0, v_ax1] + x[v_ax0, v_ax1]

    @R.function
    def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
        cls = Module

        @R.function
        def while_loop(i: R.Tensor((), dtype="int32"), s: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
            cond: R.Tensor((), dtype="bool") = R.call_pure_packed("test.vm.less", i, R.const(10, "int32"), sinfo_args=(R.Tensor((), dtype="bool"),))
            c: R.Tensor((), dtype="int32") = R.const(1, "int32")
            if cond:
                new_i = R.call_tir(cls.add, (i, c), out_sinfo=R.Tensor((), dtype="int32"))
                new_s = R.call_tir(cls.add1, (s, x), out_sinfo=R.Tensor((2, 3), dtype="float32"))
                r_then: R.Tensor((2, 3), dtype="float32") = while_loop(new_i, new_s)
                r: R.Tensor((2, 3), dtype="float32") = r_then
            else:
                r: R.Tensor((2, 3), dtype="float32") = s
            return r

        gv: R.Tensor((2, 3), dtype="float32") = while_loop(R.const(0, "int32"), x)
        return gv


mod = Module
mod.show()
assert relax.analysis.well_formed(mod)
mod = tvm.relax.transform.DeadCodeElimination()(mod)
mod.show()

cc @Lunderberg @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions