Skip to content

[Bug] Crash with tensorize + prefetch after PR#11589 #11716

@kparzysz-quic

Description

@kparzysz-quic

Testcase:

import tvm


def compute_something(A, C):
    ib = tvm.tir.ir_builder.create()
    S = C.vstore(
        0,
        tvm.tir.call_intrin(
            "uint8x2", tvm.ir.Op.get("tir.reinterpret"), A.vload(0) + A.vload(1)
        ),
    )
    ib.emit(S)
    return ib.get()


def intrin_compute_something(S):
    A = tvm.te.placeholder((128,), dtype="int16", name="A")
    C = tvm.te.compute((128,), lambda i: (A[i] * S).astype("uint8"), name="C")

    Ab = tvm.tir.decl_buffer(
        A.shape, A.dtype, name="Ab", elem_offset=tvm.te.var("b_offset", "int32")
    )
    Cb = tvm.tir.decl_buffer(
        C.shape, C.dtype, name="Cb", elem_offset=tvm.te.var("c_offset", "int32")
    )

    def intrin_func(ins, outs):
        M = compute_something(ins[0], outs[0])
        return M, None, None

    return tvm.te.decl_tensor_intrin(
        C.op,
        intrin_func,
        binds={A: Ab, C: Cb},
        default_buffer_params={"offset_factor": 128},
    )


def some_op(target):
    D, H, W = tvm.te.var("D"), tvm.te.var("H"), tvm.te.var("W")
    S = tvm.te.var("S", dtype="uint16")
    A = tvm.te.placeholder((H, W, D * 128), name="A", dtype="int16")

    C = tvm.te.compute(
        A.shape, lambda yy, xx, cc: (A[yy, xx, cc] * S).astype("uint8"), name="C"
    )

    # Create schedule without prefetch
    s = tvm.te.create_schedule(C.op)

    cy, cx, cc = s[C].op.axis
    co, ci = s[C].split(cc, factor=128)
    s[C].tensorize(ci, intrin_compute_something(S))
    yo, yi = s[C].split(cy, factor=32)
    s[C].prefetch(A, yo, 1)

    module = tvm.build(s, [A, C, D, S], target)
    return module


def test_some_op():
    module = some_op("llvm")


test_some_op()

Run: python3 testcase.py

[...]
  7: tvm::arith::BufferTouchedDomain::VisitStmt_(tvm::tir::BufferStoreNode const*)
  6: void tvm::arith::BufferTouchedDomain::Touch<tvm::runtime::Array<tvm::PrimExpr, void> >(std::__1::vector<std::__1::vector<tvm::arith::IntSet, std::__1::allocator<tvm::arith::IntSet> >, std::__1::allocator<std::__1::vector<tvm::arith::IntSet, std::__1::allocator<tvm::arith::IntSet> > > >*, tvm::runtime::Array<tvm::PrimExpr, void> const&) const
  5: tvm::arith::EvalSet(tvm::PrimExpr, std::__1::unordered_map<tvm::tir::VarNode const*, tvm::arith::IntSet, std::__1::hash<tvm::tir::VarNode const*>, std::__1::equal_to<tvm::tir::VarNode const*>, std::__1::allocator<std::__1::pair<tvm::tir::VarNode const* const, tvm::arith::IntSet> > > const&)
  4: tvm::arith::EvalSet(tvm::PrimExpr, tvm::runtime::Map<tvm::tir::Var, tvm::arith::IntSet, void, void> const&)
  3: tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>::VisitExpr(tvm::PrimExpr const&)
  2: tvm::NodeFunctor<tvm::arith::IntervalSet (tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::ExprFunctor<tvm::arith::IntervalSet (tvm::PrimExpr const&)>*) const
  1: _ZZN3tvm3tir11ExprFunctorIF
  0: tvm::arith::IntervalSetEvaluator::VisitExpr_(tvm::tir::RampNode const*)
  File "/w/src/aitools/tvm-upstream/src/arith/int_set.cc", line 453
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (eval_vec_) is false:

cc: @csullivan

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions