@@ -148,6 +148,18 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
148148 return indices;
149149}
150150
151+ std::pair<Array<PrimExpr>, PrimExpr>
152+ AtomicAddNode::ReturnIndicesAndSize (int src_dst) const {
153+ Array<PrimExpr> indices;
154+ Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
155+ PrimExpr size = 1 ;
156+ for (size_t i = 0 ; i < ranges.size (); i++) {
157+ indices.push_back (ranges[i]->min );
158+ size *= ranges[i]->extent ;
159+ }
160+ return {indices, size};
161+ }
162+
151163/* *
152164 * @brief Build a combined bound-check predicate for indexed access.
153165 *
@@ -370,6 +382,28 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
370382 */
371383Stmt AtomicAddNode::Lower (const LowerArgs &T, arith::Analyzer *analyzer) const {
372384 Target target = T.target ;
385+ if (use_tma->value != 0 ) {
386+ Array<PrimExpr> src_indices, dst_indices;
387+ PrimExpr src_size, dst_size;
388+ std::tie (src_indices, src_size) = ReturnIndicesAndSize (0 );
389+ std::tie (dst_indices, dst_size) = ReturnIndicesAndSize (1 );
390+ ICHECK (analyzer->CanProveEqual (src_size, dst_size))
391+ << " src_size = " << src_size << " , dst_size = " << dst_size;
392+ BufferLoad src_node = BufferLoad (src, src_indices);
393+ BufferLoad dst_node = BufferLoad (dst, dst_indices);
394+ Call address_of_src =
395+ Call (DataType::Handle (), builtin::address_of (), {src_node});
396+ Call address_of_dst =
397+ Call (DataType::Handle (), builtin::address_of (), {dst_node});
398+
399+ int need_reduce = 1 ;
400+ int eviction_policy = 0 ;
401+ auto body = Evaluate (Call (DataType::Handle (), tma_store (),
402+ {address_of_src, address_of_dst,
403+ ceildiv (src_size * src->dtype .bits (), 8 ),
404+ need_reduce, eviction_policy}));
405+ return IfThenElse (EQ (T.thread_var , T.thread_bounds ->min ), body);
406+ }
373407 auto simt_loop = MakeSIMTLoop (analyzer);
374408 auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse (simt_loop));
375409 auto transformed_loop =
@@ -486,7 +520,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
486520 read_src.value (), C.indice_map [read_src.value ()], args.layout_map ,
487521 args.thread_bounds , C.loop_vars );
488522 } else {
489- const For& remapped = loop;
523+ const For & remapped = loop;
490524 loop_layout = PlanLoopPartition (remapped, vec, args.thread_bounds );
491525 }
492526
0 commit comments