diff --git a/3rdparty/tvm b/3rdparty/tvm index da7f19b69..0794c13a0 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit da7f19b6908045a1f9bf94cb7e044beaa32421b6 +Subproject commit 0794c13a0900532f3b878fccab9a50c975d8a03c diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index e918b4ed7..c65641374 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -280,8 +280,11 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { new_args.push_back(src_value); new_args.push_back(GetMemoryOrder()); + // erase use_tma from annotations + auto annotations = this->annotations; + annotations.erase("use_tma"); Call atomicadd_call = - tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args); + tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args, annotations); Stmt body = tvm::tir::Evaluate(atomicadd_call); @@ -390,10 +393,14 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int need_reduce = 1; int eviction_policy = 0; + // erase use_tma from annotations + auto annotations = this->annotations; + annotations.erase("use_tma"); auto body = Evaluate(Call(DataType::Handle(), tma_store(), {address_of_src, address_of_dst, ceildiv(src_size * src->dtype.bits(), 8), - need_reduce, eviction_policy})); + need_reduce, eviction_policy}, + annotations)); return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body); } auto simt_loop = MakeSIMTLoop(analyzer); diff --git a/src/target/intrin_rule_cuda.cc b/src/target/intrin_rule_cuda.cc index 1aacd7204..e3186c713 100644 --- a/src/target/intrin_rule_cuda.cc +++ b/src/target/intrin_rule_cuda.cc @@ -118,7 +118,8 @@ struct CUDAWarpIntrinsic { static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) { const CallNode *call = e.as(); - return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args); + return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args, + call->annotations); } template static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) { @@ -127,7 +128,8 @@ template static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) { ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size Array cuda_args{ {call->args[0], call->args[1], call->args[2], call->args[3]}}; - return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args, + call->annotations); } TVM_REGISTER_OP("tir.rsqrt") diff --git a/src/transform/inject_tma_barrier.cc b/src/transform/inject_tma_barrier.cc index 93beb15d4..77ebf649f 100644 --- a/src/transform/inject_tma_barrier.cc +++ b/src/transform/inject_tma_barrier.cc @@ -173,7 +173,7 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer { new_args.Set(is_1d_tma_load ? 2 : 1, Call(DataType::Handle(), get_mbarrier(), {IntImm(DataType::Int(32), 0)})); - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } return IRMutatorWithAnalyzer::VisitExpr_(op); } @@ -382,7 +382,7 @@ class BarrierCreationRewriter : public StmtExprMutator { } } - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } else { return StmtExprMutator::VisitExpr_(op); } @@ -521,7 +521,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { new_args.Set(2, Call(DataType::Handle(), get_mbarrier(), {IntImm(DataType::Int(32), static_cast(imm->value))})); - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } } return IRMutatorWithAnalyzer::VisitExpr_(op); @@ -537,7 +537,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { } else { new_args.Set(1, barrier_id); } - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } else if (op->op.same_as(mbarrier_expect_tx())) { auto call_ref = tvm::ffi::GetRef(op); if (!tma_op_to_barrier_id_.count(call_ref)) { @@ -552,9 +552,9 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { clear_arrive_ = clear_expect_list_[cur_expect_idx_++]; if (clear_arrive_) { return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(), - new_args); + new_args, op->annotations); } - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } else if (op->op.same_as(builtin::ptx_arrive_barrier())) { if (clear_arrive_) { clear_arrive_ = false; @@ -562,7 +562,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer { } // by default, all threads must wait. auto new_args = op->args; - return Call(op->dtype, op->op, new_args); + return Call(op->dtype, op->op, new_args, op->annotations); } return IRMutatorWithAnalyzer::VisitExpr_(op); }