Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 44 additions & 252 deletions src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ using namespace tir;
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, Map<String, ObjectRef> annotations) {
ICHECK(args.size() >= 2)
<< "AtomicAdd expects at least 2 arguments (src, dst), got "
<< args.size();
ObjectPtr<AtomicAddNode> node = tvm::ffi::make_object<AtomicAddNode>();
Array<Range> rgs[2];
Buffer bf[2];
Expand Down Expand Up @@ -74,71 +77,29 @@ TileOperator AtomicAddNode::Clone() const {
return AtomicAdd(op);
}

/**
* @brief Create data-parallel iteration variables for non-singleton dimensions
* of the source.
*
* Constructs an Array of IterVar corresponding to each dimension in `src_range`
* whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
* Var named sequentially ("i", "j", "k", ...) with the same dtype as the
* extent, and type IterVarType::kDataPar. The ordering of returned itervars
* matches the order of dimensions in `src_range`.
*
* @return Array<IterVar> Iteration variables for all non-singleton extents in
* `src_range`.
*/
Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
for (size_t i = 0; i < src_range.size(); i++) {
if (is_one(src_range[i]->extent))
continue;
Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype);
idx++;
loop_vars.push_back(
{Range(0, src_range[i]->extent), var, IterVarType::kDataPar});
}
return loop_vars;
}
const Op &AtomicAddNode::GetElemOp() const { return atomic_add_elem_op(); }

// ivs: itervars returned by MakeIterVars()
/**
* @brief Build index expressions for either source or destination from loop
* iter vars.
* @brief Get vectorization length based on dst dtype and target SM version.
*
* Given a list of iteration variables that correspond to the non-singleton
* extents of the selected region (source when src_dst == 0, destination when
* src_dst == 1), return an array of index expressions matching the full rank of
* that region. For dimensions with extent == 1, the corresponding index is the
* range's minimum; otherwise the index is `min + ivar`.
* Returns:
* - 2 for float16/bfloat16
* - 4 for float32 on SM >= 90
* - 1 for all other cases
*
* @param ivs Iteration variables in order for all non-singleton dimensions of
* the chosen region.
* @param src_dst Selects which region to index: 0 for source (src_range), 1 for
* destination (dst_range).
* @return Array<PrimExpr> Index expressions for every dimension of the selected
* region, in original dimension order.
*
* @note The function checks that the number of provided iter vars equals the
* number of non-singleton extents; it will abort (ICHECK) if they differ.
* @param target The target architecture to check SM version.
* @return int The vectorization length.
*/
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
indices.push_back(ranges[i]->min);
else {
indices.push_back(ranges[i]->min + ivs[idx]->var);
idx++;
}
int AtomicAddNode::GetVectorizeLength(Target target) const {
DataType dtype = dst->dtype;
if (dtype.is_float16() || dtype.is_bfloat16()) {
return 2;
}
if (dtype.is_float() && dtype.bits() == 32 &&
TargetHasSMVersionGE(target, 90)) {
return 4;
}
ICHECK(idx == ivs.size())
<< "idx = " << idx << ", ivs.size() = " << ivs.size()
<< "src name = " << src->name << ", dst name = " << dst->name;
return indices;
return 1;
}

std::pair<Array<PrimExpr>, PrimExpr>
Expand All @@ -153,62 +114,6 @@ AtomicAddNode::ReturnIndicesAndSize(int src_dst) const {
return {indices, size};
}

/**
* @brief Build a combined bound-check predicate for indexed access.
*
* Constructs an AND'd predicate ensuring each non-singleton index (derived from
* `ivs`) stays within [0, extent) for the selected operand (source when
* `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
* range list this produces two conditions:
* - range.min + iv >= 0
* - range.min + iv < extent
*
* Conditions that the analyzer can prove (with symbolic bounds) are omitted.
* If no uncertain conditions remain, an empty PrimExpr is returned.
*
* Note: the function ICHECKs that `extents.size()` equals the number of ranges
* for the selected operand.
*
* @param ivs Iteration variables corresponding to non-singleton extents (order
* matches the non-unit ranges of the chosen operand).
* @param extents Per-dimension upper bounds to check against; must have the
* same size as the selected range list.
* @param src_dst Selects which ranges to validate: 0 => `src_range`, else
* `dst_range`.
* @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
* an empty PrimExpr when no checks are required.
*/
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents,
int src_dst) const {
Array<Range> ranges = src_dst == 0 ? src_range : dst_range;
Array<PrimExpr> cond_list;
ICHECK(extents.size() == ranges.size()) << extents << " " << ranges;
size_t idx = 0;
for (size_t i = 0; i < ranges.size(); i++) {
if (is_one(ranges[i]->extent))
continue;
PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i];
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
cond = ranges[i]->min + ivs[idx]->var >= 0;
if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) {
cond_list.push_back(cond);
}
idx++;
}
if (cond_list.empty())
return {};
else {
PrimExpr cond = cond_list[0];
for (size_t i = 1; i < cond_list.size(); i++)
cond = And(cond, cond_list[i]);
return cond;
}
}

/**
* @brief Build a SIMT-style loop nest that performs element-wise atomic
* additions from src to dst.
Expand All @@ -225,8 +130,9 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
* - Computes indexed accesses and emits optional bound predicates;
* out-of-bounds accesses are masked to zero when predicates are uncertain.
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
* src_value)` call wrapped in an Evaluate statement.
* - Emits an extern `call_intrin(op.Op.get("tl.atomic_add_elem_op"),
* address_of(dst_value), src_value), annotations)` call wrapped in an Evaluate
* statement.
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
* is defined it is attached as the "coalesced_width" annotation on each loop.
*
Expand Down Expand Up @@ -284,7 +190,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
auto annotations = this->annotations;
annotations.erase("use_tma");
Call atomicadd_call =
tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args, annotations);
tvm::tir::Call(dst->dtype, atomic_add_elem_op(), new_args, annotations);

Stmt body = tvm::tir::Evaluate(atomicadd_call);

Expand Down Expand Up @@ -657,140 +563,26 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
}
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

auto GetArchInt = [&](const Target &tgt) -> int {
int arch_int = 0;
if (auto s = tgt->GetAttr<String>("arch")) {
std::string arch = s.value();
if (arch.rfind("sm_", 0) == 0)
arch_int = std::stoi(arch.substr(3));
}
return arch_int;
};

struct AtomicLoopNestCollector : tir::StmtExprVisitor {
Array<IterVar> loop_vars;
Map<Buffer, Array<PrimExpr>> indice_map;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> writes;
arith::Analyzer analyzer;

void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); }

void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kDataPar));
}
analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (IsFragmentBuffer(op->buffer)) {
indice_map.Set(op->buffer, op->indices);
writes.insert(op->buffer);
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitExpr_(const BufferLoadNode *op) final {
if (IsFragmentBuffer(op->buffer)) {
indice_map.Set(op->buffer, op->indices);
}
StmtExprVisitor::VisitExpr_(op);
}
};

auto ComputeLoopLayoutFromBuffer =
[&](const Buffer &buf, const Array<PrimExpr> &indices,
const LayoutMap &layout_map, const Range &thread_bounds,
const Array<IterVar> &loop_vars) -> Fragment {
Fragment src = layout_map[buf].as<Fragment>().value();
Var rep;
auto rep_iter =
IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar);
PrimExpr fth = src->ForwardThread(indices, rep);
fth = analyzer->Simplify(fth);
Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter)
->BindThreadRange(thread_bounds);
return out;
};

struct AtomicInferResult {
Fragment loop_layout;
Optional<PrimExpr> predicate;
};

auto AtomicAddInferLayout =
[&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult {
AtomicLoopNestCollector C;
C.Run(loop);
Optional<Buffer> read_src;
int best_rank = -1;
for (auto kv : C.indice_map) {
const Buffer &buf = kv.first;
if (!IsFragmentBuffer(buf))
continue;
if (!args.layout_map.count(buf))
continue;
int rank = static_cast<int>(kv.second.size());
if (rank > best_rank) {
best_rank = rank;
read_src = buf;
}
}
AtomicAddVectorizePlanner planner;
int sm = GetArchInt(target);
auto plan = planner.Plan(loop, sm);
int vec = std::max(plan.vector_size, 1);
if (auto cw = loop->annotations.Get(attr::kCoalescedWidth)) {
if (const auto *imm = cw->as<IntImmNode>()) {
int expected = imm->value;
ICHECK_GT(expected, 0);
ICHECK(vec % expected == 0)
<< "vector_size " << vec << " not divisible by coalesced_width "
<< expected;
vec = expected;
} else {
LOG(FATAL) << "coalesced_width should be IntImmNode.";
}
}
PrimExpr total = 1;
for (Stmt s = loop; s.as<For>().has_value(); s = s.as<For>().value()->body)
total = total * s.as<For>().value()->extent;
PrimExpr denom = args.thread_bounds->extent * vec;
while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) {
vec >>= 1;
denom = args.thread_bounds->extent * vec;
}
if (vec < 1)
vec = 1;
Fragment loop_layout;
if (read_src) {
loop_layout = ComputeLoopLayoutFromBuffer(
read_src.value(), C.indice_map[read_src.value()], args.layout_map,
args.thread_bounds, C.loop_vars);
} else {
const For &remapped = loop;
loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds);
}

Optional<PrimExpr> pred;
if (plan.dynamic && plan.condition.defined()) {
pred = plan.condition;
}
DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec
<< " loop_layout=" << loop_layout->DebugOutput();
return {loop_layout, pred};
};

auto ret =
AtomicAddInferLayout(fused_loop, {T.target, T.thread_bounds, T.layout_map,
analyzer, false, T.buffer_remap});
Fragment loop_layout = ret.loop_layout;
auto thread_loop =
PartitionLoop(fused_loop, T.thread_var, analyzer, loop_layout);
auto vectorized_thread_loop =
VectorizeAtomicAdd(thread_loop, GetArchInt(target));
return vectorized_thread_loop;
auto par_op = ParallelOp(fused_loop);
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
// 1.give par_op a recommended vectorize size. (only works for free layout
// inference).
for (auto level : levels) {
par_op->InferLayout({T.target,
T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
level);
}
auto loop_layout = par_op->GetLoopLayout();
auto lowered_loop =
LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer,
T.layout_map, par_op->GetPredicate(T.thread_var));
return lowered_loop;
}

TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
Expand Down
Loading
Loading