-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
Actual behavior
Traceback (most recent call last):
File "simple_test.py", line 62, in <module>
ex = relax.build(mod, target='llvm') #crash here!
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 341, in build
return _vmlink(
^^^^^^^^
File "/software/tvm/python/tvm/relax/vm_build.py", line 247, in _vmlink
lib = tvm.build(
^^^^^^^^^^
File "/software/tvm/python/tvm/driver/build_module.py", line 239, in build
input_mod = lower(inputs)
^^^^^^^^^^^^^
File "/software/tvm/python/tvm/driver/build_module.py", line 130, in lower
return ffi.lower_module(inp, simple_mode)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
raise_last_ffi_error()
File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
raise py_err
tvm.error.InternalError: Traceback (most recent call last):
22: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, bool)>::AssignTypedLambda<tvm::{lambda(tvm::IRModule, bool)#3}>(tvm::{lambda(tvm::IRModule, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
21: tvm::LowerModule(tvm::IRModule, bool)
20: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
19: tvm::transform::Pass::operator()(tvm::IRModule) const
18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
17: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
15: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform13FlattenBufferEvEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
13: tvm::tir::FlattenBuffer(tvm::tir::PrimFunc)
12: tvm::tir::BufferFlattener::Flatten(tvm::tir::PrimFunc)
11: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
10: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
9: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
8: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
6: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
4: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
3: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
2: tvm::tir::BufferFlattener::VisitStmt_(tvm::tir::BufferStoreNode const*)
1: tvm::tir::BufferFlattener::GetFlattenedBuffer(tvm::tir::Buffer)
0: tvm::tir::Buffer::GetFlattenedBuffer() const
File "/software/tvm/src/tir/ir/buffer.cc", line 352
InternalError: Check failed: last_sep < self->shape.size() (1 vs. 1) : Last output axis must contain at least one input axis.
Environment
- TVM: 0.17.dev0
Steps to reproduce
import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
@I.ir_module
class Module:
@T.prim_func(private=True)
def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")):
T.func_attr({"operator_name": "relax.some_op"})
# with T.block("root"):
for ax0, ax1 in T.grid(4, 4):
with T.block("T_add"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
@T.prim_func(private=True)
def te_layout_transform_axis_separator(x: T.Buffer((T.int64(16),), "float32"), var_te_layout_transform_axis_separator: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
te_layout_transform_axis_separator = T.match_buffer(var_te_layout_transform_axis_separator, (T.int64(4), T.int64(4)), axis_separators=[1])
# with T.block("root"):
for self in range(T.int64(16)):
with T.block("te_layout_transform_axis_separator"):
v_self = T.axis.spatial(T.int64(16), self)
T.reads(x[v_self])
T.writes(te_layout_transform_axis_separator[v_self // T.int64(4), v_self % T.int64(4)])
te_layout_transform_axis_separator[v_self // T.int64(4), v_self % T.int64(4)] = x[v_self]
@T.prim_func(private=True)
def te_layout_transform_axis_separator1(lv3: T.Buffer((T.int64(4), T.int64(4)), "float32"), var_te_layout_transform_axis_separator: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
te_layout_transform_axis_separator = T.match_buffer(var_te_layout_transform_axis_separator, (T.int64(16),), axis_separators=[1])
# with T.block("root"):
for self, i0 in T.grid(T.int64(4), T.int64(4)):
with T.block("te_layout_transform_axis_separator"):
v_self, v_i0 = T.axis.remap("SS", [self, i0])
T.reads(lv3[v_self, v_i0])
T.writes(te_layout_transform_axis_separator[v_self * T.int64(4) + v_i0])
te_layout_transform_axis_separator[v_self * T.int64(4) + v_i0] = lv3[v_self, v_i0]
@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
cls = Module
with R.dataflow():
lv = R.call_tir(cls.te_layout_transform_axis_separator, (x,), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv1 = R.call_tir(cls.te_layout_transform_axis_separator, (y,), out_sinfo=R.Tensor((4, 4), dtype="float32"))
lv2 = R.call_tir(cls.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")])
lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
lv4 = R.call_tir(cls.te_layout_transform_axis_separator1, (lv3,), out_sinfo=R.Tensor((16,), dtype="float32"))
R.output(lv4)
return lv4
mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm') #crash here!
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug