Skip to content

Commit c317f26

Browse files
committed
update
1 parent 8fe3540 commit c317f26

File tree

3 files changed

+304
-275
lines changed

3 files changed

+304
-275
lines changed

src/op/atomic_add.cc

Lines changed: 147 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "../target/utils.h"
1414
#include "../transform/atomicadd_vectorize.h"
1515
#include "../transform/common/loop_fusion_utils.h"
16+
#include "../transform/common/loop_parallel_transform_utils.h"
1617
#include "../transform/loop_partition.h"
1718
#include "builtin.h"
1819

@@ -225,34 +226,6 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
225226
}
226227
}
227228

228-
/**
229-
* @brief Build a SIMT-style loop nest that performs element-wise atomic
230-
* additions from src to dst.
231-
*
232-
* Constructs a nested loop (parallelized per iter var) that loads a value from
233-
* the source buffer, optionally casts it to the destination dtype, and performs
234-
* an extern atomic add into the destination buffer address. For scalar
235-
* (zero-dimensional) operations a trivial serial For with a single BufferStore
236-
* is returned.
237-
*
238-
* The method:
239-
* - Creates iter vars for all non-singleton extents and binds them into the
240-
* provided analyzer.
241-
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
242-
* - Computes indexed accesses and emits optional bound predicates;
243-
* out-of-bounds accesses are masked to zero when predicates are uncertain.
244-
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
245-
* src_value)` call wrapped in an Evaluate statement.
246-
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
247-
* is defined it is attached as the "coalesced_width" annotation on each loop.
248-
*
249-
* Note: This function mutates the analyzer binding state by binding loop
250-
* variables and may fail via ICHECK if internal assumptions about shapes are
251-
* violated.
252-
*
253-
* @return A nested For loop (parallel loops) implementing the atomic-add
254-
* kernel. For scalar cases a serial For of extent 1 is returned.
255-
*/
256229
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
257230
Array<IterVar> loop_vars = MakeIterVars();
258231
bool is_scalar = loop_vars.empty();
@@ -418,6 +391,152 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
418391
return par_op_->InferLayout(T, level);
419392
}
420393

394+
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
395+
Target target = T.target;
396+
auto simt_loop = MakeSIMTLoop(analyzer);
397+
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
398+
auto transformed_loop =
399+
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
400+
LOG(INFO) << transformed_loop;
401+
402+
auto GetArchInt = [&](const Target &tgt) -> int {
403+
int arch_int = 0;
404+
if (auto s = tgt->GetAttr<String>("arch")) {
405+
std::string arch = s.value();
406+
if (arch.rfind("sm_", 0) == 0)
407+
arch_int = std::stoi(arch.substr(3));
408+
}
409+
return arch_int;
410+
};
411+
412+
struct AtomicLoopNestCollector : tir::StmtExprVisitor {
413+
Array<IterVar> loop_vars;
414+
Map<Buffer, Array<PrimExpr>> indice_map;
415+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
416+
arith::Analyzer analyzer;
417+
418+
void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); }
419+
420+
void VisitStmt_(const ForNode *op) final {
421+
if (op->kind == ForKind::kParallel) {
422+
loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
423+
IterVarType::kDataPar));
424+
}
425+
analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
426+
StmtExprVisitor::VisitStmt_(op);
427+
}
428+
void VisitStmt_(const BufferStoreNode *op) final {
429+
if (op->buffer.scope() == "local.fragment") {
430+
indice_map.Set(op->buffer, op->indices);
431+
writes.insert(op->buffer);
432+
}
433+
StmtExprVisitor::VisitStmt_(op);
434+
}
435+
void VisitExpr_(const BufferLoadNode *op) final {
436+
if (op->buffer.scope() == "local.fragment") {
437+
indice_map.Set(op->buffer, op->indices);
438+
}
439+
StmtExprVisitor::VisitExpr_(op);
440+
}
441+
};
442+
443+
auto ComputeLoopLayoutFromBuffer =
444+
[&](const Buffer &buf, const Array<PrimExpr> &indices,
445+
const LayoutMap &layout_map, const Range &thread_bounds,
446+
const Array<IterVar> &loop_vars) -> Fragment {
447+
Fragment src = layout_map[buf].as<Fragment>().value();
448+
Var rep;
449+
auto rep_iter =
450+
IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar);
451+
PrimExpr fth = src->ForwardThread(indices, rep);
452+
fth = analyzer->Simplify(fth);
453+
Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter)
454+
->BindThreadRange(thread_bounds);
455+
return out;
456+
};
457+
458+
struct AtomicInferResult {
459+
Fragment loop_layout;
460+
Optional<PrimExpr> predicate;
461+
};
462+
463+
auto AtomicAddInferLayout =
464+
[&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
465+
AtomicLoopNestCollector C;
466+
C.Run(loop);
467+
Optional<Buffer> read_src;
468+
int best_rank = -1;
469+
for (auto kv : C.indice_map) {
470+
const Buffer &buf = kv.first;
471+
if (buf.scope() != "local.fragment")
472+
continue;
473+
if (!args.layout_map.count(buf))
474+
continue;
475+
int rank = static_cast<int>(kv.second.size());
476+
if (rank > best_rank) {
477+
best_rank = rank;
478+
read_src = buf;
479+
}
480+
}
481+
AtomicAddVectorizePlanner planner;
482+
int sm = GetArchInt(target);
483+
auto plan = planner.Plan(loop, sm);
484+
int vec = std::max(plan.vector_size, 1);
485+
if (auto cw = loop->annotations.Get("coalesced_width")) {
486+
if (const auto *imm = cw->as<IntImmNode>()) {
487+
int expected = imm->value;
488+
ICHECK_GT(expected, 0);
489+
ICHECK(vec % expected == 0)
490+
<< "vector_size " << vec << " not divisible by coalesced_width "
491+
<< expected;
492+
vec = expected;
493+
} else {
494+
LOG(FATAL) << "coalesced_width should be IntImmNode.";
495+
}
496+
}
497+
PrimExpr total = 1;
498+
for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body)
499+
total = total * s.as<For>().value()->extent;
500+
PrimExpr denom = args.thread_bounds->extent * vec;
501+
while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) {
502+
vec >>= 1;
503+
denom = args.thread_bounds->extent * vec;
504+
}
505+
if (vec < 1)
506+
vec = 1;
507+
Fragment loop_layout;
508+
if (read_src) {
509+
loop_layout = ComputeLoopLayoutFromBuffer(
510+
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
511+
args.thread_bounds, C.loop_vars);
512+
} else {
513+
For remapped = loop; // 简化处理
514+
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
515+
}
516+
517+
Optional<PrimExpr> pred;
518+
if (plan.dynamic && plan.condition.defined()) {
519+
pred = plan.condition;
520+
}
521+
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
522+
<< " loop_layout=" << loop_layout->DebugOutput();
523+
return {loop_layout, pred};
524+
};
525+
526+
auto ret = AtomicAddInferLayout(transformed_loop,
527+
{T.target, T.thread_bounds, T.layout_map,
528+
analyzer, false, T.buffer_remap});
529+
Fragment loop_layout = ret.loop_layout;
530+
LOG(INFO) << loop_layout->DebugOutput();
531+
auto thread_loop =
532+
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
533+
LOG(INFO) << thread_loop;
534+
auto vectorized_thread_loop =
535+
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
536+
LOG(INFO) << vectorized_thread_loop;
537+
return vectorized_thread_loop;
538+
}
539+
421540
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
422541
.set_num_inputs(2)
423542
.set_attr<TCallEffectKind>("TCallEffectKind",

0 commit comments

Comments
 (0)