Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from da7f19 to 0794c1
11 changes: 9 additions & 2 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,11 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
new_args.push_back(src_value);
new_args.push_back(GetMemoryOrder());

// erase use_tma from annotations
auto annotations = this->annotations;
annotations.erase("use_tma");
Call atomicadd_call =
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args);
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args, annotations);

Stmt body = tvm::tir::Evaluate(atomicadd_call);

Expand Down Expand Up @@ -390,10 +393,14 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {

int need_reduce = 1;
int eviction_policy = 0;
// erase use_tma from annotations
auto annotations = this->annotations;
annotations.erase("use_tma");
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
{address_of_src, address_of_dst,
ceildiv(src_size * src->dtype.bits(), 8),
need_reduce, eviction_policy}));
need_reduce, eviction_policy},
annotations));
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
}
auto simt_loop = MakeSIMTLoop(analyzer);
Expand Down
6 changes: 4 additions & 2 deletions src/target/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ struct CUDAWarpIntrinsic {

static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr &e) {
const CallNode *call = e.as<CallNode>();
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args,
call->annotations);
}

template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) {
Expand All @@ -127,7 +128,8 @@ template <typename T> static PrimExpr DispatchCUDAShuffle(const PrimExpr &e) {
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{
{call->args[0], call->args[1], call->args[2], call->args[3]}};
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args,
call->annotations);
}

TVM_REGISTER_OP("tir.rsqrt")
Expand Down
14 changes: 7 additions & 7 deletions src/transform/inject_tma_barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
new_args.Set(is_1d_tma_load ? 2 : 1,
Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32), 0)}));
return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expand Down Expand Up @@ -382,7 +382,7 @@ class BarrierCreationRewriter : public StmtExprMutator {
}
}

return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
} else {
return StmtExprMutator::VisitExpr_(op);
}
Expand Down Expand Up @@ -521,7 +521,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32),
static_cast<int>(imm->value))}));
return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
Expand All @@ -537,7 +537,7 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
} else {
new_args.Set(1, barrier_id);
}
return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
} else if (op->op.same_as(mbarrier_expect_tx())) {
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
Expand All @@ -552,17 +552,17 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
clear_arrive_ = clear_expect_list_[cur_expect_idx_++];
if (clear_arrive_) {
return Call(op->dtype, builtin::ptx_arrive_barrier_expect_tx(),
new_args);
new_args, op->annotations);
}
return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
if (clear_arrive_) {
clear_arrive_ = false;
return 0;
}
// by default, all threads must wait.
auto new_args = op->args;
return Call(op->dtype, op->op, new_args);
return Call(op->dtype, op->op, new_args, op->annotations);
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
Expand Down
Loading