Skip to content

Commit c31dfcd

Browse files
committed
update
1 parent 0d56bc2 commit c31dfcd

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

src/op/atomic_add.cc

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
148148
return indices;
149149
}
150150

151+
std::pair<Array<PrimExpr>, PrimExpr>
152+
AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
153+
Array<PrimExpr> indices;
154+
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
155+
PrimExpr size = 1;
156+
for (size_t i = 0; i < ranges.size(); i++) {
157+
indices.push_back(ranges[i]->min);
158+
size *= ranges[i]->extent;
159+
}
160+
return {indices, size};
161+
}
162+
151163
/**
152164
* @brief Build a combined bound-check predicate for indexed access.
153165
*
@@ -370,6 +382,28 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
370382
*/
371383
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
372384
Target target = T.target;
385+
if (use_tma->value != 0) {
386+
Array<PrimExpr> src_indices, dst_indices;
387+
PrimExpr src_size, dst_size;
388+
std::tie(src_indices, src_size) = ReturnIndicesAndSize(0);
389+
std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1);
390+
ICHECK(analyzer->CanProveEqual(src_size, dst_size))
391+
<< "src_size = " << src_size << ", dst_size = " << dst_size;
392+
BufferLoad src_node = BufferLoad(src, src_indices);
393+
BufferLoad dst_node = BufferLoad(dst, dst_indices);
394+
Call address_of_src =
395+
Call(DataType::Handle(), builtin::address_of(), {src_node});
396+
Call address_of_dst =
397+
Call(DataType::Handle(), builtin::address_of(), {dst_node});
398+
399+
int need_reduce = 1;
400+
int eviction_policy = 0;
401+
auto body = Evaluate(Call(DataType::Handle(), tma_store(),
402+
{address_of_src, address_of_dst,
403+
ceildiv(src_size * src->dtype.bits(), 8),
404+
need_reduce, eviction_policy}));
405+
return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body);
406+
}
373407
auto simt_loop = MakeSIMTLoop(analyzer);
374408
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
375409
auto transformed_loop =
@@ -486,7 +520,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
486520
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
487521
args.thread_bounds, C.loop_vars);
488522
} else {
489-
const For& remapped = loop;
523+
const For &remapped = loop;
490524
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
491525
}
492526

0 commit comments

Comments
 (0)