Skip to content

Commit 340bfc5

Browse files
authored
[Bugfix] Fix atomicadd auto vectorize identify var error (#883)
* update * update * update * update
1 parent 4a229dd commit 340bfc5

File tree

3 files changed

+355
-331
lines changed

3 files changed

+355
-331
lines changed

src/op/atomic_add.cc

Lines changed: 173 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "../target/utils.h"
1414
#include "../transform/atomicadd_vectorize.h"
1515
#include "../transform/common/loop_fusion_utils.h"
16+
#include "../transform/common/loop_parallel_transform_utils.h"
1617
#include "../transform/loop_partition.h"
1718
#include "builtin.h"
1819

@@ -21,31 +22,6 @@ namespace tl {
2122

2223
using namespace tir;
2324

24-
/**
25-
* @brief Extracts a numeric architecture identifier from a Target's "arch"
26-
* attribute.
27-
*
28-
* Reads the Target's "arch" string (must be defined) and, if it has the form
29-
* "sm_<N>", parses and returns N as an integer. For any other arch string,
30-
* returns 0.
31-
*
32-
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
33-
* the attribute is defined).
34-
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
35-
*/
36-
static int GetArchInt(Target target) {
37-
int arch_int = 0;
38-
auto s = target->GetAttr<String>("arch");
39-
ICHECK(s.defined());
40-
std::string arch = s.value();
41-
if (arch.rfind("sm_", 0) == 0) {
42-
arch_int = std::stoi(arch.substr(3));
43-
} else {
44-
arch_int = 0;
45-
}
46-
return arch_int;
47-
}
48-
4925
/**
5026
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
5127
*
@@ -328,6 +304,47 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
328304
return Downcast<For>(body);
329305
}
330306

307+
/**
308+
* @brief Infer and return the layout map for the atomic add operator.
309+
*
310+
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
311+
* present, validates that local.fragment layouts for src and dst match when
312+
* both are provided, and then delegates layout inference to the underlying
313+
* ParallelOp.
314+
*
315+
* @param T Layout inference inputs, including an optional mapping of buffers to
316+
* layouts.
317+
* @param level Inference strictness level.
318+
* @return LayoutMap The inferred layout mapping for buffers used by this
319+
* operator.
320+
*
321+
* @note This method mutates the AtomicAddNode by creating and storing a
322+
* ParallelOp on first invocation.
323+
* @throws If both src and dst have layouts in `local.fragment` and their
324+
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
325+
*/
326+
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
327+
InferLevel level) const {
328+
if (!par_op_.defined()) {
329+
arith::Analyzer analyzer;
330+
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
331+
}
332+
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
333+
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
334+
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
335+
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
336+
if (src_layout && dst_layout) {
337+
ICHECK(src_layout->IsEqual(dst_layout, true))
338+
<< "Get different layout for " << src << " and " << dst
339+
<< "\nLHS = " << src_layout->DebugOutput()
340+
<< "\nRHS = " << dst_layout->DebugOutput()
341+
<< "\nYou may need to use a shared memory to transform the layout";
342+
}
343+
}
344+
}
345+
return par_op_->InferLayout(T, level);
346+
}
347+
331348
/**
332349
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
333350
* TIR loop.
@@ -389,70 +406,142 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
389406
}
390407
auto simt_loop = MakeSIMTLoop(analyzer);
391408
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));
392-
auto par_op = ParallelOp(fused_loop);
393-
394-
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
395-
InferLevel::kFree};
396-
for (auto level : levels) {
397-
(par_op)->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer,
398-
false, T.buffer_remap},
399-
level);
400-
}
401-
auto loop_layout = par_op->GetLoopLayout();
402-
Var thread_var = T.thread_var;
403-
Range thread_bounds = T.thread_bounds;
404-
auto thread_loop =
405-
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
406-
auto vectorized_thread_loop = VectorizeAtomicAdd(
407-
thread_loop, thread_var, thread_bounds, GetArchInt(target));
409+
auto transformed_loop =
410+
Downcast<For>(ParallelLoopTransformer::Substitute(fused_loop));
411+
412+
auto GetArchInt = [&](const Target &tgt) -> int {
413+
int arch_int = 0;
414+
if (auto s = tgt->GetAttr<String>("arch")) {
415+
std::string arch = s.value();
416+
if (arch.rfind("sm_", 0) == 0)
417+
arch_int = std::stoi(arch.substr(3));
418+
}
419+
return arch_int;
420+
};
408421

409-
if (par_op->GetPredicate(T.thread_var).defined()) {
410-
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
411-
vectorized_thread_loop);
412-
}
422+
struct AtomicLoopNestCollector : tir::StmtExprVisitor {
423+
Array<IterVar> loop_vars;
424+
Map<Buffer, Array<PrimExpr>> indice_map;
425+
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
426+
arith::Analyzer analyzer;
413427

414-
return vectorized_thread_loop;
415-
}
428+
void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); }
416429

417-
/**
418-
* @brief Infer and return the layout map for the atomic add operator.
419-
*
420-
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
421-
* present, validates that local.fragment layouts for src and dst match when
422-
* both are provided, and then delegates layout inference to the underlying
423-
* ParallelOp.
424-
*
425-
* @param T Layout inference inputs, including an optional mapping of buffers to
426-
* layouts.
427-
* @param level Inference strictness level.
428-
* @return LayoutMap The inferred layout mapping for buffers used by this
429-
* operator.
430-
*
431-
* @note This method mutates the AtomicAddNode by creating and storing a
432-
* ParallelOp on first invocation.
433-
* @throws If both src and dst have layouts in `local.fragment` and their
434-
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
435-
*/
436-
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
437-
InferLevel level) const {
438-
if (!par_op_.defined()) {
439-
arith::Analyzer analyzer;
440-
par_op_ = ParallelOp(MakeSIMTLoop(&analyzer));
441-
}
442-
if (T.layout_map.count(src) && T.layout_map.count(dst)) {
443-
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") {
444-
const FragmentNode *src_layout = T.layout_map[src].as<FragmentNode>();
445-
const FragmentNode *dst_layout = T.layout_map[dst].as<FragmentNode>();
446-
if (src_layout && dst_layout) {
447-
ICHECK(src_layout->IsEqual(dst_layout, true))
448-
<< "Get different layout for " << src << " and " << dst
449-
<< "\nLHS = " << src_layout->DebugOutput()
450-
<< "\nRHS = " << dst_layout->DebugOutput()
451-
<< "\nYou may need to use a shared memory to transform the layout";
430+
void VisitStmt_(const ForNode *op) final {
431+
if (op->kind == ForKind::kParallel) {
432+
loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
433+
IterVarType::kDataPar));
452434
}
435+
analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
436+
StmtExprVisitor::VisitStmt_(op);
453437
}
454-
}
455-
return par_op_->InferLayout(T, level);
438+
void VisitStmt_(const BufferStoreNode *op) final {
439+
if (op->buffer.scope() == "local.fragment") {
440+
indice_map.Set(op->buffer, op->indices);
441+
writes.insert(op->buffer);
442+
}
443+
StmtExprVisitor::VisitStmt_(op);
444+
}
445+
void VisitExpr_(const BufferLoadNode *op) final {
446+
if (op->buffer.scope() == "local.fragment") {
447+
indice_map.Set(op->buffer, op->indices);
448+
}
449+
StmtExprVisitor::VisitExpr_(op);
450+
}
451+
};
452+
453+
auto ComputeLoopLayoutFromBuffer =
454+
[&](const Buffer &buf, const Array<PrimExpr> &indices,
455+
const LayoutMap &layout_map, const Range &thread_bounds,
456+
const Array<IterVar> &loop_vars) -> Fragment {
457+
Fragment src = layout_map[buf].as<Fragment>().value();
458+
Var rep;
459+
auto rep_iter =
460+
IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar);
461+
PrimExpr fth = src->ForwardThread(indices, rep);
462+
fth = analyzer->Simplify(fth);
463+
Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter)
464+
->BindThreadRange(thread_bounds);
465+
return out;
466+
};
467+
468+
struct AtomicInferResult {
469+
Fragment loop_layout;
470+
Optional<PrimExpr> predicate;
471+
};
472+
473+
auto AtomicAddInferLayout =
474+
[&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
475+
AtomicLoopNestCollector C;
476+
C.Run(loop);
477+
Optional<Buffer> read_src;
478+
int best_rank = -1;
479+
for (auto kv : C.indice_map) {
480+
const Buffer &buf = kv.first;
481+
if (buf.scope() != "local.fragment")
482+
continue;
483+
if (!args.layout_map.count(buf))
484+
continue;
485+
int rank = static_cast<int>(kv.second.size());
486+
if (rank > best_rank) {
487+
best_rank = rank;
488+
read_src = buf;
489+
}
490+
}
491+
AtomicAddVectorizePlanner planner;
492+
int sm = GetArchInt(target);
493+
auto plan = planner.Plan(loop, sm);
494+
int vec = std::max(plan.vector_size, 1);
495+
if (auto cw = loop->annotations.Get("coalesced_width")) {
496+
if (const auto *imm = cw->as<IntImmNode>()) {
497+
int expected = imm->value;
498+
ICHECK_GT(expected, 0);
499+
ICHECK(vec % expected == 0)
500+
<< "vector_size " << vec << " not divisible by coalesced_width "
501+
<< expected;
502+
vec = expected;
503+
} else {
504+
LOG(FATAL) << "coalesced_width should be IntImmNode.";
505+
}
506+
}
507+
PrimExpr total = 1;
508+
for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body)
509+
total = total * s.as<For>().value()->extent;
510+
PrimExpr denom = args.thread_bounds->extent * vec;
511+
while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) {
512+
vec >>= 1;
513+
denom = args.thread_bounds->extent * vec;
514+
}
515+
if (vec < 1)
516+
vec = 1;
517+
Fragment loop_layout;
518+
if (read_src) {
519+
loop_layout = ComputeLoopLayoutFromBuffer(
520+
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
521+
args.thread_bounds, C.loop_vars);
522+
} else {
523+
const For &remapped = loop;
524+
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
525+
}
526+
527+
Optional<PrimExpr> pred;
528+
if (plan.dynamic && plan.condition.defined()) {
529+
pred = plan.condition;
530+
}
531+
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
532+
<< " loop_layout=" << loop_layout->DebugOutput();
533+
return {loop_layout, pred};
534+
};
535+
536+
auto ret = AtomicAddInferLayout(transformed_loop,
537+
{T.target, T.thread_bounds, T.layout_map,
538+
analyzer, false, T.buffer_remap});
539+
Fragment loop_layout = ret.loop_layout;
540+
auto thread_loop =
541+
PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout);
542+
auto vectorized_thread_loop =
543+
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
544+
return vectorized_thread_loop;
456545
}
457546

458547
TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)

0 commit comments

Comments
 (0)