1313#include " ../target/utils.h"
1414#include " ../transform/atomicadd_vectorize.h"
1515#include " ../transform/common/loop_fusion_utils.h"
16+ #include " ../transform/common/loop_parallel_transform_utils.h"
1617#include " ../transform/loop_partition.h"
1718#include " builtin.h"
1819
@@ -21,31 +22,6 @@ namespace tl {
2122
2223using namespace tir ;
2324
24- /* *
25- * @brief Extracts a numeric architecture identifier from a Target's "arch"
26- * attribute.
27- *
28- * Reads the Target's "arch" string (must be defined) and, if it has the form
29- * "sm_<N>", parses and returns N as an integer. For any other arch string,
30- * returns 0.
31- *
32- * @param target Target whose "arch" attribute will be inspected (ICHECKs that
33- * the attribute is defined).
34- * @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
35- */
36- static int GetArchInt (Target target) {
37- int arch_int = 0 ;
38- auto s = target->GetAttr <String>(" arch" );
39- ICHECK (s.defined ());
40- std::string arch = s.value ();
41- if (arch.rfind (" sm_" , 0 ) == 0 ) {
42- arch_int = std::stoi (arch.substr (3 ));
43- } else {
44- arch_int = 0 ;
45- }
46- return arch_int;
47- }
48-
4925/* *
5026 * @brief Construct an AtomicAdd operator from call arguments and a buffer map.
5127 *
@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
328304 return Downcast<For>(body);
329305}
330306
307+ /* *
308+ * @brief Infer and return the layout map for the atomic add operator.
309+ *
310+ * Constructs a cached ParallelOp (by building the SIMT loop) if not already
311+ * present, validates that local.fragment layouts for src and dst match when
312+ * both are provided, and then delegates layout inference to the underlying
313+ * ParallelOp.
314+ *
315+ * @param T Layout inference inputs, including an optional mapping of buffers to
316+ * layouts.
317+ * @param level Inference strictness level.
318+ * @return LayoutMap The inferred layout mapping for buffers used by this
319+ * operator.
320+ *
321+ * @note This method mutates the AtomicAddNode by creating and storing a
322+ * ParallelOp on first invocation.
323+ * @throws If both src and dst have layouts in `local.fragment` and their
324+ * fragment layouts differ, an ICHECK failure is raised with diagnostic output.
325+ */
326+ LayoutMap AtomicAddNode::InferLayout (const LayoutInferArgs &T,
327+ InferLevel level) const {
328+ if (!par_op_.defined ()) {
329+ arith::Analyzer analyzer;
330+ par_op_ = ParallelOp (MakeSIMTLoop (&analyzer));
331+ }
332+ if (T.layout_map .count (src) && T.layout_map .count (dst)) {
333+ if (src.scope () == " local.fragment" && dst.scope () == " local.fragment" ) {
334+ const FragmentNode *src_layout = T.layout_map [src].as <FragmentNode>();
335+ const FragmentNode *dst_layout = T.layout_map [dst].as <FragmentNode>();
336+ if (src_layout && dst_layout) {
337+ ICHECK (src_layout->IsEqual (dst_layout, true ))
338+ << " Get different layout for " << src << " and " << dst
339+ << " \n LHS = " << src_layout->DebugOutput ()
340+ << " \n RHS = " << dst_layout->DebugOutput ()
341+ << " \n You may need to use a shared memory to transform the layout" ;
342+ }
343+ }
344+ }
345+ return par_op_->InferLayout (T, level);
346+ }
347+
331348/* *
332349 * @brief Lower the atomic-add top-level operator into a parallel, vectorized
333350 * TIR loop.
@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
389406 }
390407 auto simt_loop = MakeSIMTLoop (analyzer);
391408 auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse (simt_loop));
392- auto par_op = ParallelOp (fused_loop);
393-
394- std::vector<InferLevel> levels = {InferLevel::kCommon , InferLevel::kStrict ,
395- InferLevel::kFree };
396- for (auto level : levels) {
397- (par_op)->InferLayout ({T.target , T.thread_bounds , T.layout_map , analyzer,
398- false , T.buffer_remap },
399- level);
400- }
401- auto loop_layout = par_op->GetLoopLayout ();
402- Var thread_var = T.thread_var ;
403- Range thread_bounds = T.thread_bounds ;
404- auto thread_loop =
405- PartitionLoop (par_op->GetRoot (), T.thread_var , analyzer, loop_layout);
406- auto vectorized_thread_loop = VectorizeAtomicAdd (
407- thread_loop, thread_var, thread_bounds, GetArchInt (target));
409+ auto transformed_loop =
410+ Downcast<For>(ParallelLoopTransformer::Substitute (fused_loop));
411+
412+ auto GetArchInt = [&](const Target &tgt) -> int {
413+ int arch_int = 0 ;
414+ if (auto s = tgt->GetAttr <String>(" arch" )) {
415+ std::string arch = s.value ();
416+ if (arch.rfind (" sm_" , 0 ) == 0 )
417+ arch_int = std::stoi (arch.substr (3 ));
418+ }
419+ return arch_int;
420+ };
408421
409- if (par_op->GetPredicate (T.thread_var ).defined ()) {
410- return IfThenElse (par_op->GetPredicate (T.thread_var ).value (),
411- vectorized_thread_loop);
412- }
422+ struct AtomicLoopNestCollector : tir::StmtExprVisitor {
423+ Array<IterVar> loop_vars;
424+ Map<Buffer, Array<PrimExpr>> indice_map;
425+ std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
426+ arith::Analyzer analyzer;
413427
414- return vectorized_thread_loop;
415- }
428+ void Run (const Stmt &s) { StmtExprVisitor::VisitStmt (s); }
416429
417- /* *
418- * @brief Infer and return the layout map for the atomic add operator.
419- *
420- * Constructs a cached ParallelOp (by building the SIMT loop) if not already
421- * present, validates that local.fragment layouts for src and dst match when
422- * both are provided, and then delegates layout inference to the underlying
423- * ParallelOp.
424- *
425- * @param T Layout inference inputs, including an optional mapping of buffers to
426- * layouts.
427- * @param level Inference strictness level.
428- * @return LayoutMap The inferred layout mapping for buffers used by this
429- * operator.
430- *
431- * @note This method mutates the AtomicAddNode by creating and storing a
432- * ParallelOp on first invocation.
433- * @throws If both src and dst have layouts in `local.fragment` and their
434- * fragment layouts differ, an ICHECK failure is raised with diagnostic output.
435- */
436- LayoutMap AtomicAddNode::InferLayout (const LayoutInferArgs &T,
437- InferLevel level) const {
438- if (!par_op_.defined ()) {
439- arith::Analyzer analyzer;
440- par_op_ = ParallelOp (MakeSIMTLoop (&analyzer));
441- }
442- if (T.layout_map .count (src) && T.layout_map .count (dst)) {
443- if (src.scope () == " local.fragment" && dst.scope () == " local.fragment" ) {
444- const FragmentNode *src_layout = T.layout_map [src].as <FragmentNode>();
445- const FragmentNode *dst_layout = T.layout_map [dst].as <FragmentNode>();
446- if (src_layout && dst_layout) {
447- ICHECK (src_layout->IsEqual (dst_layout, true ))
448- << " Get different layout for " << src << " and " << dst
449- << " \n LHS = " << src_layout->DebugOutput ()
450- << " \n RHS = " << dst_layout->DebugOutput ()
451- << " \n You may need to use a shared memory to transform the layout" ;
430+ void VisitStmt_ (const ForNode *op) final {
431+ if (op->kind == ForKind::kParallel ) {
432+ loop_vars.push_back (IterVar (Range (op->min , op->extent ), op->loop_var ,
433+ IterVarType::kDataPar ));
452434 }
435+ analyzer.Bind (op->loop_var , Range::FromMinExtent (op->min , op->extent ));
436+ StmtExprVisitor::VisitStmt_ (op);
453437 }
454- }
455- return par_op_->InferLayout (T, level);
438+ void VisitStmt_ (const BufferStoreNode *op) final {
439+ if (op->buffer .scope () == " local.fragment" ) {
440+ indice_map.Set (op->buffer , op->indices );
441+ writes.insert (op->buffer );
442+ }
443+ StmtExprVisitor::VisitStmt_ (op);
444+ }
445+ void VisitExpr_ (const BufferLoadNode *op) final {
446+ if (op->buffer .scope () == " local.fragment" ) {
447+ indice_map.Set (op->buffer , op->indices );
448+ }
449+ StmtExprVisitor::VisitExpr_ (op);
450+ }
451+ };
452+
453+ auto ComputeLoopLayoutFromBuffer =
454+ [&](const Buffer &buf, const Array<PrimExpr> &indices,
455+ const LayoutMap &layout_map, const Range &thread_bounds,
456+ const Array<IterVar> &loop_vars) -> Fragment {
457+ Fragment src = layout_map[buf].as <Fragment>().value ();
458+ Var rep;
459+ auto rep_iter =
460+ IterVar (Range (0 , src->ReplicateExtent ()), rep, IterVarType::kDataPar );
461+ PrimExpr fth = src->ForwardThread (indices, rep);
462+ fth = analyzer->Simplify (fth);
463+ Fragment out = Fragment (loop_vars, /* forward_index=*/ {}, fth, rep_iter)
464+ ->BindThreadRange (thread_bounds);
465+ return out;
466+ };
467+
468+ struct AtomicInferResult {
469+ Fragment loop_layout;
470+ Optional<PrimExpr> predicate;
471+ };
472+
473+ auto AtomicAddInferLayout =
474+ [&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
475+ AtomicLoopNestCollector C;
476+ C.Run (loop);
477+ Optional<Buffer> read_src;
478+ int best_rank = -1 ;
479+ for (auto kv : C.indice_map ) {
480+ const Buffer &buf = kv.first ;
481+ if (buf.scope () != " local.fragment" )
482+ continue ;
483+ if (!args.layout_map .count (buf))
484+ continue ;
485+ int rank = static_cast <int >(kv.second .size ());
486+ if (rank > best_rank) {
487+ best_rank = rank;
488+ read_src = buf;
489+ }
490+ }
491+ AtomicAddVectorizePlanner planner;
492+ int sm = GetArchInt (target);
493+ auto plan = planner.Plan (loop, sm);
494+ int vec = std::max (plan.vector_size , 1 );
495+ if (auto cw = loop->annotations .Get (" coalesced_width" )) {
496+ if (const auto *imm = cw->as <IntImmNode>()) {
497+ int expected = imm->value ;
498+ ICHECK_GT (expected, 0 );
499+ ICHECK (vec % expected == 0 )
500+ << " vector_size " << vec << " not divisible by coalesced_width "
501+ << expected;
502+ vec = expected;
503+ } else {
504+ LOG (FATAL) << " coalesced_width should be IntImmNode." ;
505+ }
506+ }
507+ PrimExpr total = 1 ;
508+ for (Stmt s = loop; s.as <For>().has_value (); s = s.as <For>().value ()->body )
509+ total = total * s.as <For>().value ()->extent ;
510+ PrimExpr denom = args.thread_bounds ->extent * vec;
511+ while (!analyzer->CanProve (floormod (total, denom) == 0 ) && vec > 1 ) {
512+ vec >>= 1 ;
513+ denom = args.thread_bounds ->extent * vec;
514+ }
515+ if (vec < 1 )
516+ vec = 1 ;
517+ Fragment loop_layout;
518+ if (read_src) {
519+ loop_layout = ComputeLoopLayoutFromBuffer (
520+ read_src.value (), C.indice_map [read_src.value ()], args.layout_map ,
521+ args.thread_bounds , C.loop_vars );
522+ } else {
523+ const For &remapped = loop;
524+ loop_layout = PlanLoopPartition (remapped, vec, args.thread_bounds );
525+ }
526+
527+ Optional<PrimExpr> pred;
528+ if (plan.dynamic && plan.condition .defined ()) {
529+ pred = plan.condition ;
530+ }
531+ DLOG (INFO) << " [AtomicAddInferLayout] vec=" << vec
532+ << " loop_layout=" << loop_layout->DebugOutput ();
533+ return {loop_layout, pred};
534+ };
535+
536+ auto ret = AtomicAddInferLayout (transformed_loop,
537+ {T.target , T.thread_bounds , T.layout_map ,
538+ analyzer, false , T.buffer_remap });
539+ Fragment loop_layout = ret.loop_layout ;
540+ auto thread_loop =
541+ PartitionLoop (transformed_loop, T.thread_var , analyzer, loop_layout);
542+ auto vectorized_thread_loop =
543+ VectorizeAtomicAdd (thread_loop, GetArchInt (target));
544+ return vectorized_thread_loop;
456545}
457546
458547TIR_REGISTER_TL_OP (AtomicAdd, atomicadd)
0 commit comments