Skip to content

[Bug] [Relax] InternalError: Check failed: last_sep < self->shape.size() Last output axis must contain at least one input axis #17215

@Cookiee235

Description

@Cookiee235

Actual behavior

Traceback (most recent call last):
  File "simple_test.py", line 62, in <module>
    ex = relax.build(mod, target='llvm')  #crash here!
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 341, in build
    return _vmlink(
           ^^^^^^^^
  File "/software/tvm/python/tvm/relax/vm_build.py", line 247, in _vmlink
    lib = tvm.build(
          ^^^^^^^^^^
  File "/software/tvm/python/tvm/driver/build_module.py", line 239, in build
    input_mod = lower(inputs)
                ^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/driver/build_module.py", line 130, in lower
    return ffi.lower_module(inp, simple_mode)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/software/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/software/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  22: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, bool)>::AssignTypedLambda<tvm::{lambda(tvm::IRModule, bool)#3}>(tvm::{lambda(tvm::IRModule, bool)#3}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  21: tvm::LowerModule(tvm::IRModule, bool)
  20: tvm::LowerWithPassList(tvm::IRModule, tvm::runtime::Array<tvm::transform::Pass, void>)
  19: tvm::transform::Pass::operator()(tvm::IRModule) const
  18: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  17: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  15: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_3tir8PrimFuncES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_9transform13FlattenBufferEvEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SG_SK_
  13: tvm::tir::FlattenBuffer(tvm::tir::PrimFunc)
  12: tvm::tir::BufferFlattener::Flatten(tvm::tir::PrimFunc)
  11: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
  10: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  9: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  8: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  7: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
  6: tvm::arith::IRMutatorWithAnalyzer::VisitStmt_(tvm::tir::ForNode const*)
  5: tvm::tir::StmtMutator::VisitStmt(tvm::tir::Stmt const&)
  4: tvm::tir::StmtFunctor<tvm::tir::Stmt (tvm::tir::Stmt const&)>::VisitStmt(tvm::tir::Stmt const&)
  3: _ZZN3tvm3tir11StmtFunctorIFNS0_4StmtERKS2_EE10InitVTableEvENUlRKNS_7runtime9ObjectRefEPS6_E
  2: tvm::tir::BufferFlattener::VisitStmt_(tvm::tir::BufferStoreNode const*)
  1: tvm::tir::BufferFlattener::GetFlattenedBuffer(tvm::tir::Buffer)
  0: tvm::tir::Buffer::GetFlattenedBuffer() const
  File "/software/tvm/src/tir/ir/buffer.cc", line 352
InternalError: Check failed: last_sep < self->shape.size() (1 vs. 1) : Last output axis must contain at least one input axis.

Environment

  • TVM: 0.17.dev0

Steps to reproduce

import tvm
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), "float32")):
        T.func_attr({"operator_name": "relax.some_op"})
        # with T.block("root"):
        for ax0, ax1 in T.grid(4, 4):
            with T.block("T_add"):
                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
                T.reads(arg0[v_ax0, v_ax1], arg1[v_ax0, v_ax1])
                T.writes(output0[v_ax0, v_ax1], output1[v_ax0, v_ax1])
                output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
                output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]

    @T.prim_func(private=True)
    def te_layout_transform_axis_separator(x: T.Buffer((T.int64(16),), "float32"), var_te_layout_transform_axis_separator: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        te_layout_transform_axis_separator = T.match_buffer(var_te_layout_transform_axis_separator, (T.int64(4), T.int64(4)), axis_separators=[1])
        # with T.block("root"):
        for self in range(T.int64(16)):
            with T.block("te_layout_transform_axis_separator"):
                v_self = T.axis.spatial(T.int64(16), self)
                T.reads(x[v_self])
                T.writes(te_layout_transform_axis_separator[v_self // T.int64(4), v_self % T.int64(4)])
                te_layout_transform_axis_separator[v_self // T.int64(4), v_self % T.int64(4)] = x[v_self]

    @T.prim_func(private=True)
    def te_layout_transform_axis_separator1(lv3: T.Buffer((T.int64(4), T.int64(4)), "float32"), var_te_layout_transform_axis_separator: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        te_layout_transform_axis_separator = T.match_buffer(var_te_layout_transform_axis_separator, (T.int64(16),), axis_separators=[1])
        # with T.block("root"):
        for self, i0 in T.grid(T.int64(4), T.int64(4)):
            with T.block("te_layout_transform_axis_separator"):
                v_self, v_i0 = T.axis.remap("SS", [self, i0])
                T.reads(lv3[v_self, v_i0])
                T.writes(te_layout_transform_axis_separator[v_self * T.int64(4) + v_i0])
                te_layout_transform_axis_separator[v_self * T.int64(4) + v_i0] = lv3[v_self, v_i0]

    @R.function
    def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        cls = Module
        with R.dataflow():
            lv = R.call_tir(cls.te_layout_transform_axis_separator, (x,), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            lv1 = R.call_tir(cls.te_layout_transform_axis_separator, (y,), out_sinfo=R.Tensor((4, 4), dtype="float32"))
            lv2 = R.call_tir(cls.relax_some_op_replacement, (lv, lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), dtype="float32")])
            lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
            lv4 = R.call_tir(cls.te_layout_transform_axis_separator1, (lv3,), out_sinfo=R.Tensor((16,), dtype="float32"))
            R.output(lv4)
        return lv4

mod = Module
mod.show()
mod = tvm.relax.transform.LegalizeOps()(mod)
mod = relax.transform.FuseTIR()(mod)
mod = relax.transform.LambdaLift()(mod)
ex = relax.build(mod, target='llvm')  #crash here!

cc @Lunderberg @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions