Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 180 additions & 1 deletion src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ namespace tl {

using namespace tir;

/**
* @brief Extracts a numeric architecture identifier from a Target's "arch"
* attribute.
*
* Reads the Target's "arch" string (must be defined) and, if it has the form
* "sm_<N>", parses and returns N as an integer. For any other arch string,
* returns 0.
*
* @param target Target whose "arch" attribute will be inspected (ICHECKs that
* the attribute is defined).
* @return int Parsed integer suffix when the arch is "sm_<N>", otherwise 0.
*/
static int GetArchInt(Target target) {
int arch_int = 0;
auto s = target->GetAttr<String>("arch");
Expand All @@ -34,6 +46,25 @@ static int GetArchInt(Target target) {
return arch_int;
}

/**
* @brief Construct an AtomicAdd operator from call arguments and a buffer map.
*
* Builds the internal AtomicAddNode, extracts the source and destination
* regions and their backing Buffers from the first two call-style expressions
* in `args` (via RegionOp), and stores them along with their ranges. If a third
* argument is provided, it is interpreted as an integer immediate and stored as
* the node's coalesced width.
*
* @param args Call-style PrimExprs where:
* - args[0] is the source region call,
* - args[1] is the destination region call,
* - args[2] (optional) is an IntImm specifying coalesced width.
* @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects.
*
* Notes:
* - The constructor checks that args[0] and args[1] are CallNodes.
* - The constructed node is stored in this->data_.
*/
AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
ObjectPtr<AtomicAddNode> node = make_object<AtomicAddNode>();
Array<Range> rgs[2];
Expand All @@ -54,6 +85,15 @@ AtomicAdd::AtomicAdd(Array<PrimExpr> args, BufferMap vmap) {
data_ = std::move(node);
}

/**
* @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator.
*
* Produces a new AtomicAddNode object copied from this node. If this node has
* an associated ParallelOp (par_op_), the parallel op is cloned and attached to
* the new node so the cloned operator preserves parallelization state.
*
* @return TileOperator A TileOperator owning the cloned AtomicAddNode.
*/
TileOperator AtomicAddNode::Clone() const {
auto op = make_object<AtomicAddNode>(*this);
if (par_op_.defined()) {
Expand All @@ -62,6 +102,19 @@ TileOperator AtomicAddNode::Clone() const {
return AtomicAdd(op);
}

/**
* @brief Create data-parallel iteration variables for non-singleton dimensions
* of the source.
*
* Constructs an Array of IterVar corresponding to each dimension in `src_range`
* whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a
* Var named sequentially ("i", "j", "k", ...) with the same dtype as the
* extent, and type IterVarType::kDataPar. The ordering of returned itervars
* matches the order of dimensions in `src_range`.
*
* @return Array<IterVar> Iteration variables for all non-singleton extents in
* `src_range`.
*/
Array<IterVar> AtomicAddNode::MakeIterVars() const {
Array<IterVar> loop_vars;
size_t idx = 0;
Expand All @@ -77,7 +130,26 @@ Array<IterVar> AtomicAddNode::MakeIterVars() const {
}

// ivs: itervars returned by MakeIterVars()
// src_dst: 0 for src_indices, 1 for dst_indices
/**
* @brief Build index expressions for either source or destination from loop
* iter vars.
*
* Given a list of iteration variables that correspond to the non-singleton
* extents of the selected region (source when src_dst == 0, destination when
* src_dst == 1), return an array of index expressions matching the full rank of
* that region. For dimensions with extent == 1, the corresponding index is the
* range's minimum; otherwise the index is `min + ivar`.
*
* @param ivs Iteration variables in order for all non-singleton dimensions of
* the chosen region.
* @param src_dst Selects which region to index: 0 for source (src_range), 1 for
* destination (dst_range).
* @return Array<PrimExpr> Index expressions for every dimension of the selected
* region, in original dimension order.
*
* @note The function checks that the number of provided iter vars equals the
* number of non-singleton extents; it will abort (ICHECK) if they differ.
*/
Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
int src_dst) const {
Array<PrimExpr> indices;
Expand All @@ -97,6 +169,31 @@ Array<PrimExpr> AtomicAddNode::MakeIndices(const Array<IterVar> &ivs,
return indices;
}

/**
* @brief Build a combined bound-check predicate for indexed access.
*
* Constructs an AND'd predicate ensuring each non-singleton index (derived from
* `ivs`) stays within [0, extent) for the selected operand (source when
* `src_dst==0`, destination otherwise). For each non-unit Range in the chosen
* range list this produces two conditions:
* - range.min + iv >= 0
* - range.min + iv < extent
*
* Conditions that the analyzer can prove (with symbolic bounds) are omitted.
* If no uncertain conditions remain, an empty PrimExpr is returned.
*
* Note: the function ICHECKs that `extents.size()` equals the number of ranges
* for the selected operand.
*
* @param ivs Iteration variables corresponding to non-singleton extents (order
* matches the non-unit ranges of the chosen operand).
* @param extents Per-dimension upper bounds to check against; must have the
* same size as the selected range list.
* @param src_dst Selects which ranges to validate: 0 => `src_range`, else
* `dst_range`.
* @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or
* an empty PrimExpr when no checks are required.
*/
PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
const Array<IterVar> &ivs,
Array<PrimExpr> extents,
Expand Down Expand Up @@ -128,6 +225,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer,
}
}

/**
* @brief Build a SIMT-style loop nest that performs element-wise atomic
* additions from src to dst.
*
* Constructs a nested loop (parallelized per iter var) that loads a value from
* the source buffer, optionally casts it to the destination dtype, and performs
* an extern atomic add into the destination buffer address. For scalar
* (zero-dimensional) operations a trivial serial For with a single BufferStore
* is returned.
*
* The method:
* - Creates iter vars for all non-singleton extents and binds them into the
* provided analyzer.
* - Validates loop variable counts against src/dst ranges (ICHECK on mismatch).
* - Computes indexed accesses and emits optional bound predicates;
* out-of-bounds accesses are masked to zero when predicates are uncertain.
* - Emits an extern `call_extern("AtomicAdd", address_of(dst_value),
* src_value)` call wrapped in an Evaluate statement.
* - Wraps the body with a parallel For at each loop level. If `coalesced_width`
* is defined it is attached as the "coalesced_width" annotation on each loop.
*
* Note: This function mutates the analyzer binding state by binding loop
* variables and may fail via ICHECK if internal assumptions about shapes are
* violated.
*
* @return A nested For loop (parallel loops) implementing the atomic-add
* kernel. For scalar cases a serial For of extent 1 is returned.
*/
For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Array<IterVar> loop_vars = MakeIterVars();
bool is_scalar = loop_vars.size() == 0;
Expand Down Expand Up @@ -191,6 +316,41 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
return Downcast<For>(body);
}

/**
* @brief Lower the atomic-add top-level operator into a parallel, vectorized
* TIR loop.
*
* Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs
* layout inference at multiple levels, partitions the root loop by the provided
* thread variable, vectorizes the thread loop, and returns the final
* (optionally predicate-guarded) statement.
*
* The lowering pipeline:
* - Build the SIMT loop via MakeSIMTLoop.
* - Fuse parallel loops into a single For and wrap as a ParallelOp.
* - Run layout inference at kCommon, kStrict, and kFree levels using fields
* from `T`.
* - Obtain the loop layout, partition the root loop with PartitionLoop by
* `T.thread_var`.
* - Vectorize the partitioned thread loop via VectorizeLoop.
* - If the ParallelOp produced a predicate for `T.thread_var`, return an
* IfThenElse that guards the vectorized loop with that predicate; otherwise
* return the vectorized loop.
*
* @param T Lowering context whose fields are used:
* - T.target: target architecture for layout inference and lowering
* decisions.
* - T.thread_var: the Var used to partition the outer loop for thread-level
* parallelism.
* - T.thread_bounds: bounds associated with the thread dimension (used during
* partitioning).
* - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used
* during InferLayout.
* @param analyzer Analyzer used for symbolic reasoning during partitioning and
* folding (omitted from detailed param docs as a common analysis utility).
* @return Stmt A lowered TIR statement representing the parallelized and
* vectorized atomic-add.
*/
Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
auto simt_loop = MakeSIMTLoop(analyzer);
Expand Down Expand Up @@ -221,6 +381,25 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}

/**
* @brief Infer and return the layout map for the atomic add operator.
*
* Constructs a cached ParallelOp (by building the SIMT loop) if not already
* present, validates that local.fragment layouts for src and dst match when
* both are provided, and then delegates layout inference to the underlying
* ParallelOp.
*
* @param T Layout inference inputs, including an optional mapping of buffers to
* layouts.
* @param level Inference strictness level.
* @return LayoutMap The inferred layout mapping for buffers used by this
* operator.
*
* @note This method mutates the AtomicAddNode by creating and storing a
* ParallelOp on first invocation.
* @throws If both src and dst have layouts in `local.fragment` and their
* fragment layouts differ, an ICHECK failure is raised with diagnostic output.
*/
LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (!par_op_.defined()) {
Expand Down
73 changes: 73 additions & 0 deletions src/op/atomic_add.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,79 @@
#include "operator.h"
#include "parallel.h"

/**
* Lower this tile operator into a TIR statement for the given lowering context.
*
* @param T Lowering context containing mapped buffers and iteration
* information.
* @param analyzer Arithmetic analyzer used to simplify and reason about
* expressions.
* @return A TIR Stmt that implements the atomic-add tile operation for the
* provided context.
*/
/**
* Infer memory/layout mapping for tensors and buffers used by this operator.
*
* @param T Layout inference context providing buffer and shape information.
* @param level Inference aggressiveness level; higher levels may perform more
* speculative decisions.
* @return A LayoutMap describing inferred layouts for the operator's inputs and
* outputs.
*/
/**
* Get the Op registration that identifies this tile operator.
*
* @return A reference to the registered Op representing this operator.
*/
/**
* Create a deep copy of this tile operator node wrapped as a TileOperator.
*
* @return A TileOperator handle owning a cloned AtomicAddNode.
*/
/**
* Construct a SIMT-style For loop nest (thread/block mapping) appropriate for
* the operator.
*
* @param analyzer Arithmetic analyzer used to simplify loop bounds and
* predicates.
* @return A For loop node representing the SIMT-parallel loop structure.
*/
/**
* Create iteration variables used by this operator's loop nest.
*
* @return An array of IterVar objects describing the loop iteration axes.
*/
/**
* Produce index expressions for either source or destination buffer access
* based on iteration vars.
*
* @param ivs IterVars created by MakeIterVars().
* @param src_dst Selects which indices to produce: 0 for source indices, 1 for
* destination indices.
* @return An array of PrimExpr index expressions suitable for indexing the
* selected buffer.
*/
/**
* Build a predicate expression that guards out-of-bounds or conditional
* accesses for src or dst.
*
* @param analyzer Arithmetic analyzer used to simplify the predicate.
* @param ivs IterVars created by MakeIterVars().
* @param extents The loop extents corresponding to the itervars.
* @param src_dst Selects which side the predicate is for: 0 for source, 1 for
* destination.
* @return A PrimExpr boolean predicate that evaluates to true for valid
* iterations.
*/
/**
* Construct an AtomicAdd tile operator from operation arguments and a buffer
* mapping.
*
* @param args Operation arguments (e.g., values or indices) specific to the
* atomic-add semantics.
* @param vmap Mapping from buffer names to Buffer objects used by this
* operator.
*/
namespace tvm {
namespace tl {

Expand Down
Loading