-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
The Relax IR in the below test case passed the well-formed checking, but failed when using the DCE unexpectedly!
Actual behavior
Traceback (most recent call last):
File "test_sim.py", line 54, in <module>
mod = tvm.relax.transform.DeadCodeElimination()(mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/ir/transform.py", line 238, in __call__
return _ffi_transform_api.RunPass(self, mod)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm-lunder/python/tvm/_ffi/_ctypes/packed_func.py", line 240, in __call__
raise_last_ffi_error()
File "/software/tvm-lunder/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
17: tvm::transform::Pass::operator()(tvm::IRModule) const
16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relax9transform19DeadCodeEliminationENS0_5ArrayINS0_6StringEvEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SI_SM_
13: tvm::relax::DeadCodeElimination(tvm::IRModule const&, tvm::runtime::Array<tvm::runtime::String, void>)
12: tvm::relax::RemoveAllUnused(tvm::RelayExpr)
11: tvm::relax::CollectVarUsage(tvm::RelayExpr const&)
10: tvm::relax::UDChain::Collect(tvm::RelayExpr const&)
9: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
8: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::FunctionNode const*)
7: tvm::relax::ExprVisitor::VisitExpr(tvm::RelayExpr const&)
6: tvm::relax::ExprVisitor::VisitExpr_(tvm::relax::SeqExprNode const*)
5: tvm::relax::ExprVisitor::VisitBindingBlock(tvm::relax::BindingBlock const&)
4: tvm::relax::ExprVisitor::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
3: tvm::relax::ExprVisitor::VisitBinding(tvm::relax::Binding const&)
2: tvm::relax::UDChain::VisitBinding_(tvm::relax::VarBindingNode const*)
1: tvm::relax::ExprVisitor::VisitBinding_(tvm::relax::VarBindingNode const*)
0: tvm::relax::UDChain::VisitVarDef(tvm::relax::Var const&)
File "/software/tvm-lunder/src/relax/analysis/udchain.cc", line 75
TVMError: Check failed: (!usage_map.count(var)) is false: Variable while_loop was used before its definition
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def add(i: T.Buffer((), "int32"), c: T.Buffer((), "int32"), T_add: T.Buffer((), "int32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
with T.block("T_add"):
vi = T.axis.spatial(1, T.int64(0))
T.reads(i[()], c[()])
T.writes(T_add[()])
T_add[()] = i[()] + c[()]
@T.prim_func(private=True)
def add1(s: T.Buffer((T.int64(2), T.int64(3)), "float32"), x: T.Buffer((T.int64(2), T.int64(3)), "float32"), T_add: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(s[v_ax0, v_ax1], x[v_ax0, v_ax1])
T.writes(T_add[v_ax0, v_ax1])
T_add[v_ax0, v_ax1] = s[v_ax0, v_ax1] + x[v_ax0, v_ax1]
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
cls = Module
@R.function
def while_loop(i: R.Tensor((), dtype="int32"), s: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
cond: R.Tensor((), dtype="bool") = R.call_pure_packed("test.vm.less", i, R.const(10, "int32"), sinfo_args=(R.Tensor((), dtype="bool"),))
c: R.Tensor((), dtype="int32") = R.const(1, "int32")
if cond:
new_i = R.call_tir(cls.add, (i, c), out_sinfo=R.Tensor((), dtype="int32"))
new_s = R.call_tir(cls.add1, (s, x), out_sinfo=R.Tensor((2, 3), dtype="float32"))
r_then: R.Tensor((2, 3), dtype="float32") = while_loop(new_i, new_s)
r: R.Tensor((2, 3), dtype="float32") = r_then
else:
r: R.Tensor((2, 3), dtype="float32") = s
return r
gv: R.Tensor((2, 3), dtype="float32") = while_loop(R.const(0, "int32"), x)
return gv
mod = Module
mod.show()
assert relax.analysis.well_formed(mod)
mod = tvm.relax.transform.DeadCodeElimination()(mod)
mod.show()
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug