diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index acc54e9e0..166e6813d 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -21,6 +21,18 @@ namespace tl { using namespace tir; +/** + * @brief Extracts a numeric architecture identifier from a Target's "arch" + * attribute. + * + * Reads the Target's "arch" string (must be defined) and, if it has the form + * "sm_", parses and returns N as an integer. For any other arch string, + * returns 0. + * + * @param target Target whose "arch" attribute will be inspected (ICHECKs that + * the attribute is defined). + * @return int Parsed integer suffix when the arch is "sm_", otherwise 0. + */ static int GetArchInt(Target target) { int arch_int = 0; auto s = target->GetAttr("arch"); @@ -34,6 +46,25 @@ static int GetArchInt(Target target) { return arch_int; } +/** + * @brief Construct an AtomicAdd operator from call arguments and a buffer map. + * + * Builds the internal AtomicAddNode, extracts the source and destination + * regions and their backing Buffers from the first two call-style expressions + * in `args` (via RegionOp), and stores them along with their ranges. If a third + * argument is provided, it is interpreted as an integer immediate and stored as + * the node's coalesced width. + * + * @param args Call-style PrimExprs where: + * - args[0] is the source region call, + * - args[1] is the destination region call, + * - args[2] (optional) is an IntImm specifying coalesced width. + * @param vmap Mapping from buffers used by RegionOp to concrete Buffer objects. + * + * Notes: + * - The constructor checks that args[0] and args[1] are CallNodes. + * - The constructed node is stored in this->data_. + */ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { ObjectPtr node = make_object(); Array rgs[2]; @@ -54,6 +85,15 @@ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a deep copy of this AtomicAdd node wrapped as a TileOperator. + * + * Produces a new AtomicAddNode object copied from this node. If this node has + * an associated ParallelOp (par_op_), the parallel op is cloned and attached to + * the new node so the cloned operator preserves parallelization state. + * + * @return TileOperator A TileOperator owning the cloned AtomicAddNode. + */ TileOperator AtomicAddNode::Clone() const { auto op = make_object(*this); if (par_op_.defined()) { @@ -62,6 +102,19 @@ TileOperator AtomicAddNode::Clone() const { return AtomicAdd(op); } +/** + * @brief Create data-parallel iteration variables for non-singleton dimensions + * of the source. + * + * Constructs an Array of IterVar corresponding to each dimension in `src_range` + * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a + * Var named sequentially ("i", "j", "k", ...) with the same dtype as the + * extent, and type IterVarType::kDataPar. The ordering of returned itervars + * matches the order of dimensions in `src_range`. + * + * @return Array Iteration variables for all non-singleton extents in + * `src_range`. + */ Array AtomicAddNode::MakeIterVars() const { Array loop_vars; size_t idx = 0; @@ -77,7 +130,26 @@ Array AtomicAddNode::MakeIterVars() const { } // ivs: itervars returned by MakeIterVars() -// src_dst: 0 for src_indices, 1 for dst_indices +/** + * @brief Build index expressions for either source or destination from loop + * iter vars. + * + * Given a list of iteration variables that correspond to the non-singleton + * extents of the selected region (source when src_dst == 0, destination when + * src_dst == 1), return an array of index expressions matching the full rank of + * that region. For dimensions with extent == 1, the corresponding index is the + * range's minimum; otherwise the index is `min + ivar`. + * + * @param ivs Iteration variables in order for all non-singleton dimensions of + * the chosen region. + * @param src_dst Selects which region to index: 0 for source (src_range), 1 for + * destination (dst_range). + * @return Array Index expressions for every dimension of the selected + * region, in original dimension order. + * + * @note The function checks that the number of provided iter vars equals the + * number of non-singleton extents; it will abort (ICHECK) if they differ. + */ Array AtomicAddNode::MakeIndices(const Array &ivs, int src_dst) const { Array indices; @@ -97,6 +169,31 @@ Array AtomicAddNode::MakeIndices(const Array &ivs, return indices; } +/** + * @brief Build a combined bound-check predicate for indexed access. + * + * Constructs an AND'd predicate ensuring each non-singleton index (derived from + * `ivs`) stays within [0, extent) for the selected operand (source when + * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen + * range list this produces two conditions: + * - range.min + iv >= 0 + * - range.min + iv < extent + * + * Conditions that the analyzer can prove (with symbolic bounds) are omitted. + * If no uncertain conditions remain, an empty PrimExpr is returned. + * + * Note: the function ICHECKs that `extents.size()` equals the number of ranges + * for the selected operand. + * + * @param ivs Iteration variables corresponding to non-singleton extents (order + * matches the non-unit ranges of the chosen operand). + * @param extents Per-dimension upper bounds to check against; must have the + * same size as the selected range list. + * @param src_dst Selects which ranges to validate: 0 => `src_range`, else + * `dst_range`. + * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or + * an empty PrimExpr when no checks are required. + */ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, @@ -128,6 +225,34 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, } } +/** + * @brief Build a SIMT-style loop nest that performs element-wise atomic + * additions from src to dst. + * + * Constructs a nested loop (parallelized per iter var) that loads a value from + * the source buffer, optionally casts it to the destination dtype, and performs + * an extern atomic add into the destination buffer address. For scalar + * (zero-dimensional) operations a trivial serial For with a single BufferStore + * is returned. + * + * The method: + * - Creates iter vars for all non-singleton extents and binds them into the + * provided analyzer. + * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch). + * - Computes indexed accesses and emits optional bound predicates; + * out-of-bounds accesses are masked to zero when predicates are uncertain. + * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), + * src_value)` call wrapped in an Evaluate statement. + * - Wraps the body with a parallel For at each loop level. If `coalesced_width` + * is defined it is attached as the "coalesced_width" annotation on each loop. + * + * Note: This function mutates the analyzer binding state by binding loop + * variables and may fail via ICHECK if internal assumptions about shapes are + * violated. + * + * @return A nested For loop (parallel loops) implementing the atomic-add + * kernel. For scalar cases a serial For of extent 1 is returned. + */ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.size() == 0; @@ -191,6 +316,41 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } +/** + * @brief Lower the atomic-add top-level operator into a parallel, vectorized + * TIR loop. + * + * Constructs a SIMT-style loop for the atomic-add, fuses parallel loops, runs + * layout inference at multiple levels, partitions the root loop by the provided + * thread variable, vectorizes the thread loop, and returns the final + * (optionally predicate-guarded) statement. + * + * The lowering pipeline: + * - Build the SIMT loop via MakeSIMTLoop. + * - Fuse parallel loops into a single For and wrap as a ParallelOp. + * - Run layout inference at kCommon, kStrict, and kFree levels using fields + * from `T`. + * - Obtain the loop layout, partition the root loop with PartitionLoop by + * `T.thread_var`. + * - Vectorize the partitioned thread loop via VectorizeLoop. + * - If the ParallelOp produced a predicate for `T.thread_var`, return an + * IfThenElse that guards the vectorized loop with that predicate; otherwise + * return the vectorized loop. + * + * @param T Lowering context whose fields are used: + * - T.target: target architecture for layout inference and lowering + * decisions. + * - T.thread_var: the Var used to partition the outer loop for thread-level + * parallelism. + * - T.thread_bounds: bounds associated with the thread dimension (used during + * partitioning). + * - T.layout_map, T.buffer_remap: layout and buffer remapping inputs used + * during InferLayout. + * @param analyzer Analyzer used for symbolic reasoning during partitioning and + * folding (omitted from detailed param docs as a common analysis utility). + * @return Stmt A lowered TIR statement representing the parallelized and + * vectorized atomic-add. + */ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; auto simt_loop = MakeSIMTLoop(analyzer); @@ -221,6 +381,25 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } +/** + * @brief Infer and return the layout map for the atomic add operator. + * + * Constructs a cached ParallelOp (by building the SIMT loop) if not already + * present, validates that local.fragment layouts for src and dst match when + * both are provided, and then delegates layout inference to the underlying + * ParallelOp. + * + * @param T Layout inference inputs, including an optional mapping of buffers to + * layouts. + * @param level Inference strictness level. + * @return LayoutMap The inferred layout mapping for buffers used by this + * operator. + * + * @note This method mutates the AtomicAddNode by creating and storing a + * ParallelOp on first invocation. + * @throws If both src and dst have layouts in `local.fragment` and their + * fragment layouts differ, an ICHECK failure is raised with diagnostic output. + */ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (!par_op_.defined()) { diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 678d62e55..d35422ee2 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -10,6 +10,79 @@ #include "operator.h" #include "parallel.h" +/** + * Lower this tile operator into a TIR statement for the given lowering context. + * + * @param T Lowering context containing mapped buffers and iteration + * information. + * @param analyzer Arithmetic analyzer used to simplify and reason about + * expressions. + * @return A TIR Stmt that implements the atomic-add tile operation for the + * provided context. + */ +/** + * Infer memory/layout mapping for tensors and buffers used by this operator. + * + * @param T Layout inference context providing buffer and shape information. + * @param level Inference aggressiveness level; higher levels may perform more + * speculative decisions. + * @return A LayoutMap describing inferred layouts for the operator's inputs and + * outputs. + */ +/** + * Get the Op registration that identifies this tile operator. + * + * @return A reference to the registered Op representing this operator. + */ +/** + * Create a deep copy of this tile operator node wrapped as a TileOperator. + * + * @return A TileOperator handle owning a cloned AtomicAddNode. + */ +/** + * Construct a SIMT-style For loop nest (thread/block mapping) appropriate for + * the operator. + * + * @param analyzer Arithmetic analyzer used to simplify loop bounds and + * predicates. + * @return A For loop node representing the SIMT-parallel loop structure. + */ +/** + * Create iteration variables used by this operator's loop nest. + * + * @return An array of IterVar objects describing the loop iteration axes. + */ +/** + * Produce index expressions for either source or destination buffer access + * based on iteration vars. + * + * @param ivs IterVars created by MakeIterVars(). + * @param src_dst Selects which indices to produce: 0 for source indices, 1 for + * destination indices. + * @return An array of PrimExpr index expressions suitable for indexing the + * selected buffer. + */ +/** + * Build a predicate expression that guards out-of-bounds or conditional + * accesses for src or dst. + * + * @param analyzer Arithmetic analyzer used to simplify the predicate. + * @param ivs IterVars created by MakeIterVars(). + * @param extents The loop extents corresponding to the itervars. + * @param src_dst Selects which side the predicate is for: 0 for source, 1 for + * destination. + * @return A PrimExpr boolean predicate that evaluates to true for valid + * iterations. + */ +/** + * Construct an AtomicAdd tile operator from operation arguments and a buffer + * mapping. + * + * @param args Operation arguments (e.g., values or indices) specific to the + * atomic-add semantics. + * @param vmap Mapping from buffer names to Buffer objects used by this + * operator. + */ namespace tvm { namespace tl { diff --git a/src/op/copy.cc b/src/op/copy.cc index 49261176a..3c1a15a38 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -107,10 +107,26 @@ template static Array ReverseArray(Array array) { } /*! - * \brief Constructor for Copy operator. - * \param args Array of PrimExpr representing the arguments of the copy - * operation. \param vmap BufferMap mapping original buffer names to new buffer - * names. + * \brief Construct a Copy operator node from call arguments and a buffer map. + * + * This constructor parses the first two entries of `args` as Call nodes + * describing source and destination Regions (via RegionOp), extracts their + * Buffers and Ranges, and stores them on the newly created CopyNode. It also + * reads optional arguments: + * - args[2] (IntImm): coalesced width (stored only if > 0), + * - args[3] (Bool): disable TMA lowering flag, + * - args[4] (IntImm): eviction policy. + * + * Preconditions: + * - `args` must contain at least two Call-compatible PrimExpr entries + * describing regions; an ICHECK will fail if they are not CallNodes. + * + * @param args Array of PrimExpr where: + * - args[0] is the source Region call, + * - args[1] is the destination Region call, + * - optional args[2..4] are coalesced width, disable_tma, and eviction + * policy. + * @param vmap BufferMap used to resolve RegionOp buffers and ranges. */ Copy::Copy(Array args, BufferMap vmap) { ObjectPtr node = make_object(); @@ -141,6 +157,16 @@ Copy::Copy(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a shallow clone of this CopyNode as a TileOperator. + * + * Produces a new CopyNode object copy-constructed from this node. If a parallel + * sub-operation (par_op_) is present, the sub-operation is cloned as well and + * attached to the new node. The returned value is a TileOperator wrapper + * around the newly created node. + * + * @return TileOperator A TileOperator owning the cloned CopyNode. + */ TileOperator CopyNode::Clone() const { auto op = make_object(*this); if (par_op_.defined()) { @@ -197,14 +223,27 @@ Array CopyNode::MakeIndices(const Array &ivs, return indices; } -/*! - * \brief Create predicate for the copy operation. - * This function generates boundary checks to ensure memory access safety. - * It creates conditions like (min + iv) < extent and (min + iv) >= 0 for each - * dimension. \param analyzer Arithmetic analyzer for simplification. \param ivs - * Array of IterVar. \param extents Array of PrimExpr representing the extents - * of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices. - * \return PrimExpr representing the predicate for the copy operation. +/** + * @brief Build a boundary predicate that guards memory accesses for the copy. + * + * Constructs a conjunction of per-dimension bounds checks (e.g. `min + iv < + * extent` and `min + iv >= 0`) for every dynamic dimension involved in the + * copy. Uses the provided arithmetic analyzer to elide checks that can be + * proven statically. + * + * The function ICHECKs that the supplied `extents` align with the operator's + * recorded ranges for the selected side (source when `src_dst == 0`, + * destination when `src_dst == 1`). + * + * @param ivs IterVars corresponding to the varying dimensions of the copy. Each + * IterVar maps to a non-unit extent dimension in the stored ranges. + * @param extents Extents of the tensor being accessed (must match the number of + * ranges); used as the upper bounds for generated checks. + * @param src_dst Selects which side's ranges to use: `0` for source, `1` for + * destination. + * @return PrimExpr A conjunction of necessary bounds checks, or an empty + * `PrimExpr` (null) if all checks are provably true and no predicate is + * required. */ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, const Array &ivs, @@ -236,13 +275,25 @@ PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, } } -/*! - * \brief Create SIMT loop for the copy operation. - * This function generates a single-threaded loop structure for the copy - * operation. It handles scalar copies (single element) and multi-dimensional - * copies with nested loops. \param analyzer Arithmetic analyzer for - * simplification. \return For representing the SIMT loop for the copy - * operation. +/** + * @brief Construct a SIMT-style nested loop that implements the copy. + * + * Builds a loop nest that performs element-wise loads from the source buffer + * and stores into the destination buffer. For a scalar copy (no varying + * iteration dimensions) this returns a single serial loop executing one + * store. For multi-dimensional copies it: + * - creates data-parallel loops (Parallel For) for each varying dimension, + * - binds the resulting iteration variables to the provided arithmetic + * analyzer for simplification, + * - computes source and destination index expressions, + * - applies per-buffer boundary predicates (if needed) to mask out-of-range + * accesses, + * - inserts a cast when src and dst dtypes differ, + * - applies an optional `coalesced_width` annotation to generated parallel + * loops when present. + * + * @param analyzer Analyzer used to simplify and bind loop variable domains. + * @return For A nested For statement representing the generated SIMT loop nest. */ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); @@ -291,14 +342,19 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } -/*! - * \brief Compute linear layout for TMA copy. - * This function creates a linear layout transformation for shared memory in TMA - * operations. It transforms multi-dimensional indices into a linear address - * using a 256-element block pattern. The transformation follows: [i, j] -> - * [i//256, j//256, i%256, j%256] \param shared_tensor Buffer representing the - * shared tensor. \return Layout representing the linear layout for the TMA - * copy. +/** + * @brief Compute a linearized shared-memory layout used for TMA transfers. + * + * Creates a Layout that maps an N-D shared tensor into a 1-D-like ordering + * suitable for TMA by blocking each dimension into 256-element tiles and + * splitting each original index into a quotient and remainder. Effectively + * transforms each index i_k into two coordinates: floor(i_k / 256) and + * i_k % 256, producing an ordering equivalent to concatenating all quotients + * followed by all remainders. + * + * @param shared_tensor The shared-memory buffer whose shape defines the input + * dimensions for the layout inference. + * @return Layout A Layout describing the linearized ordering for the TMA copy. */ Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const { Array input_size = shared_tensor->shape; @@ -317,15 +373,27 @@ Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const { return Layout(input_size, forward_index); } -/*! - * \brief Infer layout for the copy operation. - * This function determines the optimal memory layout for the copy operation - * based on the target architecture. For bulk load/store operations, it may - * apply swizzling layouts for better performance. For LDSM/STSM operations, it - * uses register layout inference from the underlying parallel op. \param T - * LayoutInferArgs containing target and layout map. \param level InferLevel - * indicating the level of layout inference. \return LayoutMap containing the - * inferred layout. +/** + * @brief Infer memory layouts for this Copy operation. + * + * Determines an appropriate LayoutMap for the copy based on the target and + * enabled lowering paths. For TMA-capable targets when the chosen copy + * instruction is BulkLoad or BulkStore, this may produce a linearized shared + * memory layout suitable for TMA transfers (only when inference is invoked at + * InferLevel::kFree and no layout for the shared buffer is already annotated). + * For other cases (including LDSM/STSM and the normal copy path), layout + * inference is delegated to the SIMT parallel operation produced by + * MakeSIMTLoop(). + * + * This method may read PassContext configuration (kDisableTMALower) and may + * lazily construct and cache the parallel operation in par_op_ as a side + * effect. + * + * @param T LayoutInferArgs containing target and the current layout map. + * @param level The inference level controlling how aggressive/layouts may be + * proposed. + * @return LayoutMap mapping buffers to inferred layouts (may be empty if no + * additional layouts are suggested). */ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { @@ -361,13 +429,24 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return par_op_->InferLayout(T, level); } -/*! - * \brief Check if the copy operation is a bulk load. - * This function verifies if the copy operation can be implemented using CUDA's - * Bulk Load instruction. Requirements include: target supports bulk copy, - * source is global memory, destination is shared.dyn, and both buffers have the - * same data type. \param target Target device. \return True if the copy - * operation is a bulk load, false otherwise. +/** + * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) + * instruction. + * + * The function returns true when all of the following hold: + * - the target architecture advertises bulk-copy/TMA support; + * - the source buffer resides in global memory; + * - the destination buffer resides in shared memory (either "shared" or + * "shared.dyn"); + * - the source and destination have the same element data type. + * + * If the source and destination dtypes differ, a warning is logged and the + * function returns false (the caller is expected to fall back to a normal + * copy). + * + * @param target The compilation target to query for bulk-copy support. + * @return true if the copy can be implemented as a Bulk Load (TMA); false + * otherwise. */ bool CopyNode::CheckBulkLoad(Target target) const { // 1. arch must have bulk copy support @@ -389,13 +468,17 @@ bool CopyNode::CheckBulkLoad(Target target) const { return true; } -/*! - * \brief Check if the copy operation is a bulk store. - * This function verifies if the copy operation can be implemented using CUDA's - * Bulk Store instruction. Requirements include: target supports bulk copy, - * source is shared.dyn, destination is global memory, and both buffers have the - * same data type. \param target Target device. \return True if the copy - * operation is a bulk store, false otherwise. +/** + * @brief Determine if this CopyNode can be lowered to a CUDA BulkStore (TMA + * store). + * + * Checks whether the target supports bulk copy, the source buffer is in shared + * memory (shared or shared.dyn), the destination buffer is in global memory, + * and both buffers have the same element data type. If the data types differ, + * a warning is logged and false is returned. + * + * @param target Target device/architecture to check for bulk-copy support. + * @return true if all conditions for a BulkStore are met; false otherwise. */ bool CopyNode::CheckBulkStore(Target target) const { // 1. arch must have bulk copy support @@ -431,12 +514,15 @@ bool CopyNode::CheckLDSMCopy(Target target) const { dst.scope() == "local.fragment"; } -/*! - * \brief Check if the copy operation is a STSM copy. - * This function verifies if the copy operation can be implemented using CUDA's - * Store Matrix (STSM) instruction. Requirements include: target supports - * STMATRIX, source is local.fragment, destination is shared.dyn. \param target - * Target device. \return True if the copy operation is a STSM copy, false +/** + * @brief Determine whether this copy can use the STMATRIX store (STSM) path. + * + * Returns true when the target supports STMATRIX and the source buffer is in + * the `local.fragment` scope while the destination buffer is in shared memory + * (`shared` or `shared.dyn`). + * + * @param target The compilation target to query for STMATRIX support. + * @return true if the copy may be lowered to an STSM instruction; false * otherwise. */ bool CopyNode::CheckSTSMCopy(Target target) const { @@ -444,13 +530,20 @@ bool CopyNode::CheckSTSMCopy(Target target) const { (dst.scope() == "shared.dyn" || dst.scope() == "shared"); } -/*! - * \brief Get the copy instruction type. - * This function determines the most appropriate copy instruction based on the - * target architecture and buffer memory scopes. It checks for specialized - * instructions (TMA, LDSM, STSM) in order of preference, falling back to normal - * copy if no specialized instruction is applicable. \param target Target - * device. \return CopyInst representing the copy instruction type. +/** + * @brief Selects the most specific copy instruction supported for the given + * target and buffers. + * + * Determines which specialized copy lowering to use (TMA bulk load/store, LDSM, + * STSM) based on target capabilities and the memory scopes of the + * source/destination buffers. If TMA lowering is disabled via the flag, + * BulkLoad/BulkStore are not selected. The selection priority is: BulkLoad, + * BulkStore, LDSM, STSM, then Normal (fallback). + * + * @param target The compilation target used to query hardware capabilities. + * @param disable_tma_lower If true, prevents selecting TMA-based bulk + * load/store instructions. + * @return CopyInst The chosen copy instruction enum value. */ CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower) const { // disable_tma_lower is from pass_configs @@ -503,14 +596,23 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } -/*! - * \brief Lower the copy operation to a normal copy. - * This function generates standard load/store operations for targets that don't - * support specialized copy instructions. It applies loop fusion, - * parallelization, and vectorization transformations to optimize performance on - * both CPU and GPU targets. \param T LowerArgs containing target and layout - * map. \param analyzer Arithmetic analyzer for simplification. \return Stmt - * representing the normal copy code. +/** + * @brief Lower the copy operator using the generic (non-specialized) path. + * + * Generates standard load/store code paths for targets that cannot or should + * not use specialized copy instructions (TMA, LDSM/STSM). Builds a SIMT loop, + * fuses and transforms parallel loops, infers and applies loop layouts on GPU + * targets, partitions by thread, and applies vectorization appropriate to the + * device (CPU or GPU). If a thread-level predicate is required, the resulting + * body is guarded with an IfThenElse. + * + * @param T Lowering context including the target, thread bounds, thread var, + * layout map, and buffer remapping used during layout inference and + * loop partitioning. + * @param analyzer Arithmetic analyzer used to simplify and reason about bounds + * during loop partitioning and predicate construction. + * @return Stmt Lowered statement representing the transformed, vectorized + * normal-copy loop (possibly wrapped in a predicate). */ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, arith::Analyzer *analyzer) const { @@ -547,16 +649,29 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, return vectorized_thread_loop; } -/*! - * \brief Lower the copy operation to LDSM/STSM copy. - * This function generates PTX code for matrix load/store operations - * (LDSM/STSM). It handles 8x8 fragment layout validation, shared memory stride - * checking, and generates optimized matrix transfer instructions for tensor - * cores. Falls back to normal copy if layout constraints are not satisfied. - * \param T LowerArgs containing target and layout map. - * \param analyzer Arithmetic analyzer for simplification. - * \param copy_inst CopyInst representing the copy instruction type. - * \return Stmt representing the LDSM/STSM copy code. +/** + * @brief Lower a Copy operator to LDSM/STSM (warp-level 8x8 matrix) + * instructions. + * + * Lowers a CopyNode into PTX matrix load/store (LDSM/STSM) sequences when the + * access/layouts meet the hardware constraints required by warp-level 8x8 + * fragment transfers (thread-mapped 8x8 fragment layout, 16-byte contiguous + * shared memory accesses, full-range local tiles, matching dtypes for loads, + * and no access predicates). If these conditions are not met the function + * falls back to lowering via LowerNormalCopy(). + * + * The routine validates layout/thread-mapping compatibility (including support + * for transposed fragment layouts), determines vectorization factor (4/2/1) + * based on extent alignment, computes shared/local addresses, emits the + * appropriate ptx_ldmatrix/ptx_stmatrix call(s), and wraps them in a small + * loop that may be unrolled and adjusted for thread-bounds offsets. + * + * @param T Lowering context (target, layout/ buffer remaps, thread/ bounds). + * @param analyzer Arithmetic analyzer used to simplify and prove bounds. + * @param copy_inst Must be either CopyInst::kLDSM or CopyInst::kSTSM to select + * matrix-load vs matrix-store lowering. + * @return Stmt A statement implementing the LDSM/STSM lowering, or the result + * of LowerNormalCopy(...) when constraints require fallback. */ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) const { @@ -740,16 +855,31 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, return for_node; } -/*! - * \brief Lower the copy operation to bulk copy using TMA. - * This function generates PTX code for Tensor Memory Accelerator (TMA) bulk - * copy operations. It creates TMA descriptors, handles shared memory layout - * detection (including swizzling), and generates optimized bulk load/store - * instructions for Hopper architecture. Falls back to normal copy if layout or - * shape constraints are not satisfied. \param T LowerArgs containing target and - * layout map. \param analyzer Arithmetic analyzer for simplification. \param - * copy_inst CopyInst representing the copy instruction type. \return Stmt - * representing the bulk copy code. +/** + * @brief Lower a Copy operator to a bulk TMA (Tensor Memory Accelerator) + * transfer. + * + * Lowers the copy to an optimized TMA load or store when the target and buffer + * layouts permit. Constructs a TMADesc, detects shared-memory + * swizzle/interleave patterns, encodes global shape/stride/SMEM parameters, and + * emits either a 1D TMA transfer (when global/shared are contiguous and element + * counts match, currently only for loads) or a full multi-dimensional TMA call. + * The emitted statement is guarded so only the thread with min thread id + * executes the TMA. + * + * If preconditions are not satisfied (unsupported swizzle, stride/size limits, + * mismatched element counts, OOB risks, or other hardware constraints), this + * function falls back to LowerNormalCopy. + * + * @param T LowerArgs containing target information, thread/bounds variables, + * and layout/ buffer remap information used for descriptor + * construction. + * @param analyzer Analyzer used to prove shapes/contiguity/equality + * constraints. + * @param copy_inst Indicates whether to emit a BulkLoad (TMA load) or BulkStore + * (TMA store). Must be CopyInst::kBulkLoad or kBulkStore. + * @return Stmt A TIR statement performing the bulk TMA copy (or the result of + * LowerNormalCopy when falling back). */ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, CopyInst copy_inst) const { @@ -1154,13 +1284,28 @@ Array TMADesc::EncodeCallArgs() const { return args; } -/*! - * \brief Constructor for Conv2DIm2ColOp. - * This operation performs im2col transformation for 2D convolution on GPU using - * TMA. It extracts patches from the input tensor and rearranges them for - * efficient matrix multiplication. \param args Array of PrimExpr representing - * the arguments of the Conv2DIm2ColOp. \param vmap BufferMap mapping original - * buffer names to new buffer names. +/** + * @brief Construct a Conv2DIm2ColOp node. + * + * Initializes a Conv2DIm2ColOpNode from raw TL-call arguments and a buffer map. + * The constructor extracts source and destination Buffers from vmap and reads + * convolution parameters encoded in args: + * - args[0]: source tensor access pointer + * - args[1]: destination tensor access pointer + * - args[2]: nhw_step (PrimExpr) + * - args[3]: c_step (PrimExpr) + * - args[4]: kernel (IntImm) + * - args[5]: stride (IntImm) + * - args[6]: dilation (IntImm) + * - args[7]: padding (IntImm) + * - args[8]: eviction_policy (IntImm) + * + * The created node stores these values (src, dst, nhw_step, c_step, kernel, + * stride, dilation, padding, eviction_policy) for later lowering to TMA-based + * GPU intrinsics. + * + * @param args Array of PrimExpr TL-call arguments (see list above). + * @param vmap Mapping from original buffer variables to actual Buffer objects. */ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { ObjectPtr node = make_object(); @@ -1176,20 +1321,49 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a shallow copy of this Conv2DIm2ColOpNode wrapped as a + * TileOperator. + * + * Produces a new Conv2DIm2ColOp that owns a freshly allocated + * Conv2DIm2ColOpNode initialized from this node (member-wise copy). This is + * used to duplicate the operator node for compiler passes that require + * independent operator instances. + * + * @return TileOperator A TileOperator containing the cloned Conv2DIm2ColOpNode. + */ TileOperator Conv2DIm2ColOpNode::Clone() const { auto op = make_object(*this); return Conv2DIm2ColOp(op); } -/*! - * \brief Lower the Conv2DIm2ColOp to PTX code. - * This function generates optimized im2col transformation using TMA - * instructions. It creates a TMA descriptor for the im2col operation, handling - * convolution parameters like kernel size, stride, padding, and dilation. The - * operation is optimized for Hopper architecture with support for different - * shared memory layouts. \param T LowerArgs containing target and layout map. - * \param analyzer Arithmetic analyzer for simplification. - * \return Stmt representing the PTX code for the Conv2DIm2ColOp. +/** + * @brief Lower Conv2D im2col into a TMA-backed PTX sequence for Hopper. + * + * Constructs a TMA im2col descriptor from the Conv2DIm2ColOp parameters + * (kernel, stride, dilation, padding, channel/image tiling, dtype and shapes), + * emits a call to create the im2col descriptor, and returns a statement that + * invokes the corresponding tma_load_im2col builtin guarded to a single + * thread. The lowering assumes the destination resides in shared memory and the + * source in global memory and uses the provided layout information (when + * available) to select the appropriate shared-memory swizzle. + * + * Preconditions (checked with ICHECK): + * - Target is Hopper. + * - src.scope() == "global" and dst.scope() is "shared.dyn" or "shared". + * - src->shape has rank 4 and dst->shape has rank 2. + * - src and dst have the same dtype. + * - When a shared layout is supplied it must match a recognized TMA swizzle + * pattern (32B/64B/128B) or an ICHECK will fail. + * + * @param T Lowering context (target, layout map, thread_var, thread_bounds, + * buffer remapping, etc.). Used to fetch target/layout and to emit a + * thread-guarded TMA call. + * @param analyzer Arithmetic analyzer used to prove divisibility and simplify + * expressions required by descriptor construction. + * @return Stmt A TIR statement that performs a tma_load_im2col call wrapped in + * a thread-min guard (IfThenElse). The returned statement is ready + * to be inserted into the lowered TIR. */ Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { @@ -1360,6 +1534,16 @@ TIR_REGISTER_TL_OP(Copy, copy) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +/** + * @brief Layout inference hook for Conv2DIm2ColOpNode. + * + * This operator does not provide any layout inference; the function + * intentionally returns an empty LayoutMap to indicate no layout suggestions. + * + * @param T Context for layout inference (ignored). + * @param level Inference level (ignored). + * @return LayoutMap An empty map. + */ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { return {}; diff --git a/src/op/copy.h b/src/op/copy.h index 2b9f2d855..9ba48bc0b 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -90,9 +90,220 @@ struct TMAIm2ColDesc { /*! * \brief Copy operator for transferring data between buffers. * - * This class implements a generic copy operator in TensorIR Lowering for - * block-wise or element-wise data transfer, possibly optimized with - * parallelization or TMA hardware acceleration. + * Performs element- or block-wise copies between `src` and `dst` buffers for + * TensorIR lowering. The operator supports thread-level parallelization, + * shared-memory layouts, and hardware-accelerated paths (TMA/LDSM/STMATRIX) + * when available. Public fields describe the copy ranges and tuning knobs + * (coalesced width, eviction policy, disable_tma). + */ + +/*! + * \brief Lower the copy operator to a TIR statement. + * + * Produces a TIR statement implementing the configured copy (normal, LDSM, + * STSM, or bulk TMA-based) for the given lowering context. + * + * \param T Lowering arguments that provide buffer bindings and context. + * \param analyzer Analyzer used for expression simplification and bounds + * checks. \return A TIR `Stmt` implementing the copy. + */ + +/*! + * \brief Infer buffer layouts after applying this operator. + * + * Computes resulting layouts (shape/stride mappings) for buffers affected by + * this copy operation. + * + * \param T Arguments for layout inference (buffer maps, shapes). + * \param level Granularity of inference to perform. + * \return A LayoutMap describing inferred layouts. + */ + +/*! + * \brief Check if bulk global->shared copy is supported on the target. + * + * Returns true if the target supports bulk (TMA) loads from global memory. + * + * \param target Target to query. + */ + +/*! + * \brief Check if bulk shared->global store is supported on the target. + * + * Returns true if the target supports bulk (TMA) stores to global memory. + * + * \param target Target to query. + */ + +/*! + * \brief Check if LDSM (LDMATRIX) memory-copy is supported on the target. + * + * \param target Target to query. + */ + +/*! + * \brief Check if STSM (STMATRIX) memory-copy is supported on the target. + * + * \param target Target to query. + */ + +/*! + * \brief Select the copy instruction type to use. + * + * Chooses between kNormal, kLDSM, kSTSM, kBulkLoad, and kBulkStore based on + * the target capabilities and whether TMA lowering is disabled. + * + * \param target Target to query. + * \param disable_tma_lower When true, force non-TMA copy paths. + * \return The selected CopyInst value. + */ + +/*! + * \brief Clone this copy operator. + * + * Returns a TileOperator reference that is a shallow clone of this operator + * object suitable for further modifications in pass pipelines. + */ + +/*! + * \brief Generate lowering for bulk (global-to-shared or shared-to-global) + * copy. + * + * Implements TMA-based bulk load/store lowering when `copy_inst` indicates a + * bulk path. The function encodes TMA descriptors and produces calls or + * loops required by the selected bulk mechanism. + * + * \param T Lowering context. + * \param analyzer Analyzer for simplification. + * \param copy_inst Copy instruction type indicating bulk load/store. + * \return A TIR `Stmt` implementing the bulk copy. + */ + +/*! + * \brief Generate lowering for LDS matrix-copy paths (LDMATRIX/STMATRIX). + * + * Emits the lowering for LDS-based matrix-copy instructions when the chosen + * `copy_inst` is an LDSM or STSM variant. + * + * \param T Lowering context. + * \param analyzer Analyzer for simplification. + * \param copy_inst Copy instruction type indicating an LDS matrix path. + * \return A TIR `Stmt` implementing the matrix-copy. + */ + +/*! + * \brief Generate lowering for the normal (non-bulk, scalar/vec) copy path. + * + * Emits element-wise or vectorized loads/stores using the computed iteration + * space and predicates to ensure in-bounds accesses. + * + * \param T Lowering context. + * \param analyzer Analyzer for simplification. + * \return A TIR `Stmt` implementing the normal copy. + */ + +/*! + * \brief Generate a SIMT-style thread-level loop for the copy. + * + * Produces a `For` loop that distributes copy work across SIMD/warp lanes or + * CUDA threads according to the operator's iteration strategy. + * + * \param analyzer Analyzer for simplification. + * \return A `For` loop representing the thread-level iteration. + */ + +/*! + * \brief Compute a linear shared-memory layout suitable for TMA copies. + * + * Returns a `Layout` that maps the shared-memory `shared_tensor` into a + * linearized representation required by bulk/TMA transfers. + * + * \param shared_tensor Buffer representing the shared-memory tensor. + * \return A `Layout` describing the linearized shared layout. + */ + +/*! + * \brief Create iterator variables for multi-dimensional copy loops. + * + * The returned `IterVar` array enumerates the loop indices used to traverse + * the copy extents in each tensor dimension. + * + * \return Array of iterator variables. + */ + +/*! + * \brief Calculate source or destination indices from iteration variables. + * + * Converts the iterator variables (from MakeIterVars) into concrete index + * expressions for either the source image or the destination tensor. + * + * \param ivs Iterator variables returned by MakeIterVars(). + * \param src_dst 0 to produce source indices, 1 to produce destination indices. + * \return Array of `PrimExpr` index expressions. + */ + +/*! + * \brief Construct the boundary predicate ensuring in-bounds accesses. + * + * Builds a boolean expression that guards loads/stores so they only occur + * when indices lie within the provided `extents`. + * + * \param analyzer Arithmetic analyzer used to simplify predicates. + * \param ivs Iterator variables. + * \param extents Extent expressions for the target buffer. + * \param src_dst 0 = predicate for source indices, 1 = predicate for + * destination. \return A `PrimExpr` boolean predicate. + */ + +/*! + * \brief Constructor. + * + * \param args Expression arguments for the copy (indices, sizes, etc.). + * \param vmap Buffer variable mapping for source and destination. + */ + +/*! + * \brief Get the TVM Op handle corresponding to this Copy op. + */ + +/*! + * \brief Special operator for Conv2D im2col transformation. + * + * Converts an input feature map into an im2col matrix layout used for GEMM- + * based convolution lowering. Public fields configure kernel geometry, + * stride/padding/dilation, and cache eviction behavior. + */ + +/*! + * \brief Lower to TIR statement. + * + * Emits TIR that performs the im2col extraction from `src` into `dst` + * according to kernel, stride, padding, and dilation parameters. + * + * \param T Lowering context with buffer bindings. + * \param analyzer Analyzer for expression simplification and bounds reasoning. + * \return A TIR `Stmt` performing the im2col transform. + */ + +/*! + * \brief Infer layout for this operator. + * + * Produces the layout mapping for the destination im2col matrix given the + * source layout and convolution parameters. + * + * \param T Layout inference arguments. + * \param level Inference granularity level. + * \return A LayoutMap with inferred layouts for affected buffers. + */ + +/*! + * \brief Get TVM Op handle for Conv2DIm2Col. + */ + +/*! + * \brief Clone this Conv2DIm2Col operator. + * + * Returns a TileOperator reference that is a shallow clone of this operator. */ class CopyNode : public TileOperatorNode { public: @@ -208,6 +419,24 @@ class CopyNode : public TileOperatorNode { PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const; + /** + * \brief Create a deep copy of this operator. + * + * Returns a TileOperator that is a copy of the current node, preserving all + * configuration (buffers, parameters, and layout-related fields). + * @return A TileOperator owning the cloned operator node. + */ + + /** + * \brief Constructor. + * \param args Expression arguments for the Conv2D im2col operator. + * \param vmap Buffer variable mapping. + */ + + /** + * \brief Get the TVM Op handle corresponding to this Conv2DIm2Col operator. + * @return Reference to the singleton TVM Op representing this operator. + */ TileOperator Clone() const; }; diff --git a/src/op/elem.cc b/src/op/elem.cc index a3b5b469e..a46935879 100644 --- a/src/op/elem.cc +++ b/src/op/elem.cc @@ -22,6 +22,42 @@ namespace tl { using namespace tir; +/** + * @brief Construct a Fill operator node from call arguments and a buffer map. + * + * This constructor builds a FillNode describing an element-wise fill of a + * destination buffer region with a scalar/vector value and stores it in + * `data_`. + * + * Detailed behavior: + * - If `args[0]` is a `BufferLoad`, the loaded buffer becomes the destination + * and the load indices are converted to per-dimension ranges: + * - `Ramp(base, lanes, stride)` is converted to `Range(base, lanes)`. Only + * stride == 1 and constant `lanes` are supported. + * - Non-ramp indices become `Range(index, 1)`. + * - Otherwise `args[0]` is treated as an access pointer; the destination buffer + * is resolved via `vmap[GetVarFromAccessPtr(args[0])]` and the region is the + * full buffer shape for each dimension. + * - `args[1]` is used as the fill value; it is cast to the destination buffer's + * dtype if necessary. + * - Performs validation: + * - Region dimensionality must match destination rank. + * - For statically-known region mins and extents, checks that mins >= 0 and + * extents do not exceed the corresponding destination shape extents. + * + * Parameters: + * @param args Call arguments: expected layout is [dst_access_or_bufferload, + * value]. + * - args[0]: destination access (BufferLoad or pointer expression). + * - args[1]: value to fill (scalar or vector). + * @param vmap Mapping from buffer variables to Buffer objects; used to resolve + * the destination when args[0] is not a BufferLoad. + * + * Notes: + * - The constructor enforces constraints (e.g., stride == 1 ramps, constant + * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out + * of bounds. + */ Fill::Fill(Array args, BufferMap vmap) { ObjectPtr node = make_object(); @@ -71,11 +107,31 @@ Fill::Fill(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a copy of this FillNode and return it as a TileOperator. + * + * Constructs a new FillNode by copying the current node and wraps the copy in a + * Fill TileOperator. + * + * @return TileOperator A TileOperator that owns the copied FillNode. + */ TileOperator FillNode::Clone() const { auto op = make_object(*this); return Fill(op); } +/** + * @brief Build a SIMT-style nested parallel loop that fills the destination + * buffer. + * + * Constructs per-dimension data-parallel loop iterators matching this node's + * region extents, emits a BufferStore that writes the node's `value` into `dst` + * at the loop indices, and nests the loops (innermost to outermost) as parallel + * `For` nodes. Returns the outermost `For` loop representing the complete + * multi-dimensional fill kernel. + * + * @return For Outermost parallel `For` loop of the generated nested SIMT loop. + */ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { int ndim = dst->shape.size(); Array loop_vars; @@ -93,6 +149,24 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } +/** + * @brief Lower this Fill operator to a TIR statement for the target. + * + * Lowers the FillNode into a Stmt according to the destination buffer scope: + * - "local.fragment" and shared ("shared", "shared.dyn"): create a parallel + * operation from a SIMT loop, infer its layout, partition the root loop by + * the thread variable, vectorize the resulting thread loop, and, if a + * per-thread predicate exists, guard the vectorized loop with that + * predicate. + * - "local": build a SIMT loop and return its vectorized form. + * - other scopes: fatal error. + * + * The lowering may query layout and thread information from @p T and uses the + * provided analyzer for any required arithmetic/layout analysis. + * + * @param T Lowering arguments (target, thread bounds, thread var, layout map). + * @return Stmt The lowered TIR statement implementing the fill. + */ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (dst.scope() == "local.fragment") { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); @@ -129,6 +203,17 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } +/** + * @brief Infer memory/layout mapping for the Fill operator. + * + * Returns the layout mapping produced by layout inference for this FillNode. + * Currently no layout inference is performed for Fill and the function returns + * an empty LayoutMap. + * + * @param T Context required for layout inference (unused). + * @param level The inference level requested (unused). + * @return LayoutMap Empty map indicating no inferred layouts for this operator. + */ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { return {}; diff --git a/src/op/elem.h b/src/op/elem.h index a3efb3f92..902fc4506 100644 --- a/src/op/elem.h +++ b/src/op/elem.h @@ -10,6 +10,63 @@ #include "operator.h" #include "parallel.h" +/** + * Lower the Fill operator into TIR statements. + * + * Produces a TIR Stmt that implements element-wise filling of `dst` over + * `region` with `value`, using information from `T`. + * + * @param T Lowering inputs (buffers, shapes, and iteration info) used to + * generate the IR. + */ + +/** + * Infer the memory layout mapping for the Fill operator. + * + * Returns a LayoutMap that describes how logical iteration axes map to memory + * dimensions for the destination buffer. `level` controls the aggressiveness + * of inference (e.g., relaxed vs. strict constraints). + * + * @param T Layout inference inputs (buffers, shapes, and related metadata). + * @param level Inference level controlling precision of the returned mapping. + */ + +/** + * Return the global operator descriptor for tl.Fill. + * + * The returned Op can be used to look up operator-level metadata and to + * register or query the operator within the TVM operator registry. + */ + +/** + * Create a copy of this operator node as a TileOperator reference. + * + * The returned TileOperator is an independent handle representing a clone of + * the underlying FillNode. + */ + +/** + * Build a SIMT-style For loop that implements the fill. + * + * Constructs and returns a TIR `For` loop that iterates over the target region + * in a SIMT-friendly ordering appropriate for `dst` and `region`. + */ + +/** + * Construct a Fill operator from argument expressions and a buffer mapping. + * + * @param args Positional PrimExpr arguments passed to the operator (e.g., + * indices or shape expressions required by the operator's specification). + * @param vmap Mapping from named buffer parameters to concrete tir::Buffer + * instances used by this operator instance. + */ + +/** + * Return the global operator descriptor for the public Fill wrapper. + * + * Mirrors FillNode::Get() and provides the operator descriptor for users of the + * public TileOperator API. + */ namespace tvm { namespace tl { diff --git a/src/op/gemm.cc b/src/op/gemm.cc index c308dc5a1..1142a39b5 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -19,6 +19,16 @@ namespace tl { using namespace tir; +/** + * @brief Compute the prime factorization of an integer. + * + * Returns the prime factors of x in non-decreasing order by repeatedly dividing + * out the smallest possible factor. + * + * @param x Integer to factorize. If x <= 1, an empty vector is returned. + * @return std::vector Prime factors of x (with multiplicity), in + * non-decreasing order. + */ static std::vector toPrimeFactors(int x) { int i = 2; std::vector result; @@ -33,6 +43,34 @@ static std::vector toPrimeFactors(int x) { return result; } +/** + * @brief Construct a Gemm operator from serialized TL arguments and a buffer + * map. + * + * This constructor deserializes operator parameters from `args` and resolves + * buffer references via `vmap`, populating an internal GemmNode with: + * - device pointers for A, B, C and their corresponding Buffer objects, + * - transpose flags for A and B, + * - matrix dimensions M, N, K, + * - warp allocation policy and clear_accum flag, + * - strides and memory offsets for A and B, + * - optional kPack (must be 1 or 2) and optional wg_wait. + * + * The populated GemmNode is stored into the wrapper's internal `data_`. + * + * @param args Positional serialized arguments produced by the TL frontend: + * expected layout is: + * [Aptr, Bptr, Cptr, trans_A (Bool), trans_B (Bool), + * M (Int), N (Int), K (Int), policy (Int), clear_accum (Bool), + * stride_A (Int), stride_B (Int), offset_A (Int), offset_B (Int), + * (optional) kPack (Int), (optional) wg_wait (Int)] + * @param vmap Mapping from access pointer vars to Buffer objects used to + * resolve the Buffer corresponding to each pointer argument. + * + * @note If `kPack` is provided it must be 1 or 2; otherwise the constructor + * fails with an ICHECK (runtime assertion). No other validation is + * performed here. + */ Gemm::Gemm(Array args, BufferMap vmap) { ObjectPtr node = make_object(); @@ -66,11 +104,39 @@ Gemm::Gemm(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a copy of this GemmNode as a TileOperator. + * + * Constructs a new GemmNode by copying the current node state and returns it + * wrapped in a Gemm TileOperator. + * + * @return TileOperator A Gemm operator that owns a copy of this node. + */ TileOperator GemmNode::Clone() const { auto op = make_object(*this); return Gemm(op); } +/** + * @brief Selects the GEMM implementation variant for a given block size and + * target. + * + * Determines which low-level GEMM instruction to use: + * - Returns kWGMMA when running on Hopper-class targets and the operator meets + * WGMMA constraints (M >= 64, number of warps is a multiple of 4, and + * CheckWGMMA() returns true). + * - Returns kMFMA for CDNA targets. + * - Returns kMMA for CUDA targets. + * + * @param block_size Number of threads in the CUDA/ROCm thread block used for + * the GEMM. + * @param target Target backend describing the hardware (used to detect + * architecture). + * @return GemmInst The chosen GEMM implementation enum value. + * + * @throws fatal error (ICHECK) If the target is not recognized/supported, this + * function triggers a runtime check failure. + */ GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; @@ -375,6 +441,20 @@ bool GemmNode::CheckWGMMA() const { } } +/** + * @brief Parse and return the numeric GPU architecture from a Target's "arch" + * attribute. + * + * Examines the target's "arch" string and, if it matches the pattern + * "sm_", returns as an int. If the attribute is present but does not + * match that pattern, returns 0. + * + * Preconditions: the target must have an "arch" attribute (this is checked via + * ICHECK). + * + * @return int The parsed architecture number (e.g., 80 for "sm_80"), or 0 if + * the arch string does not match "sm_". + */ static int GetArchInt(Target target) { int arch_int = 0; auto s = target->GetAttr("arch"); @@ -388,6 +468,19 @@ static int GetArchInt(Target target) { return arch_int; } +/** + * @brief Lower the GEMM operator to a TL TIR call expression. + * + * Constructs a tl::gemm call string parameterized by M, N, K, warp partition, + * transpose flags, accumulation clearing, target-specific stride/offset/kPack + * and optional workgroup wait value, then returns an Evaluate(call) node + * invoking tl::tl_gemm with the composed string and the A/B/C buffer handles. + * + * @param T Contains lowering context including thread bounds and target. + * @param analyzer Optional arithmetic analyzer used by lowering (may be + * nullptr). + * @return Stmt A TIR statement representing the evaluated TL GEMM call. + */ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); @@ -426,28 +519,23 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } /** - * @brief Infer memory/layout mappings for A, B, and C buffers for this GEMM op. + * @brief Infer and bind target-specific memory/layout mappings for A, B, and C. * - * Generates and returns a LayoutMap that binds buffer A, B, and C to - * target- and architecture-specific fragment or shared-memory layouts based - * on the current target, thread bounds, warp partitioning, data types, and - * transpose flags. This performs target dispatch (Volta, Ampere/Turing/SM120, - * Hopper, CDNA), selects the appropriate fragment or shared layout creators, - * and binds fragment layouts to the thread range when buffers are local - * fragments. + * Infers per-buffer layouts (fragment or shared-memory layouts) for this GEMM + * operator according to the target architecture, thread bounds, warp + * partitioning, data types, and transpose flags, then binds fragment layouts + * to the thread range when required. * * Preconditions: - * - C.scope() must be "local.fragment". + * - C.scope() == "local.fragment" * - * Postconditions / side effects: - * - Marks the operator's layout inference as completed (sets completed_ = - * true). + * Side effects: + * - Marks layout inference as completed (sets completed_ = true). * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or * incompatible shape constraints. * - * @param T Layout inference inputs (thread bounds and target). - * @param level Inference level (unused for side effects but retained for API). - * @return LayoutMap mapping each of A, B, and C to their inferred layouts. + * @param T Input layout-inference context (provides thread bounds and target). + * @return LayoutMap mapping A, B, and C to their inferred layouts. */ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { diff --git a/src/op/gemm.h b/src/op/gemm.h index 15199b2f3..53bde7b12 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -10,6 +10,74 @@ #include "operator.h" namespace tvm { +/** + * Check whether the target and configuration allow using WGMMA (wavefront-group + * MMA) for this GEMM. + * + * @returns true if WGMMA can be used for the current node configuration and + * target; false otherwise. + */ +/** + * Lower this GEMM operator to a TVM Stmt for the given lowering context. + * + * @param T Lowering arguments and context (tile mappings, target, etc.). + * @param analyzer Arithmetic analyzer used for symbolic simplification and + * bounds reasoning. + * @returns A lowered Stmt implementing the GEMM. + */ +/** + * Infer memory/layout mapping for GEMM inputs/outputs at the given inference + * level. + * + * @param T Layout inference inputs (buffers, shapes, constraints). + * @param level Inference level that controls how aggressive/specific the + * inferred layouts should be. + * @returns A LayoutMap describing how logical tensor axes map to storage/layout + * axes. + */ +/** + * Create a deep/shallow copy of this TileOperator node as a TileOperator + * reference. + * + * @returns A TileOperator reference that represents a clone of this GemmNode. + */ +/** + * Determine the specific GEMM instruction variant to use for the given block + * size and target. + * + * @param block_size The tile/block size (in elements or threads) used to select + * instruction variant. + * @param target The compilation target describing architecture and instruction + * set. + * @returns The GemmInst enum value representing the chosen GEMM instruction + * family. + */ +/** + * Compute how to partition work across warps for the given number of warps and + * GEMM instruction. + * + * The returned pair is (warp_rows, warp_cols), describing the per-warp tiling + * in row and column dimensions respectively. + * + * @param num_warps Total number of warps available for the block. + * @param gemm_inst The GEMM instruction variant selected for the target. + * @param target The compilation target which may constrain or influence + * partitioning. + * @returns A pair = (warp_rows, warp_cols) describing the warp + * partition. + */ +/** + * Construct a Gemm operator handle from call arguments and a buffer mapping. + * + * @param args Array of call-time PrimExpr arguments passed to the operator. + * @param vmap Mapping from buffer names/indices to tir::Buffer objects used by + * this GEMM. + */ +/** + * Obtain the registered Op descriptor for the GEMM operator. + * + * @returns A const reference to the Op representing "tl.Gemm". + */ namespace tl { using namespace tir; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 2b4b1c064..4bc08b846 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -17,6 +17,17 @@ namespace tvm { namespace tl { +/** + * @brief Decomposes a positive integer into its prime factors. + * + * Returns the prime factorization of `x` as a vector of prime factors in + * non-decreasing order. If `x <= 1` the returned vector is empty. + * + * @param x Integer to factorize (expected non-negative; behavior: returns empty + * for values <= 1). + * @return std::vector Prime factors of `x` (with repetition), e.g. 12 -> + * {2, 2, 3}. + */ static std::vector toPrimeFactors(int x) { int i = 2; std::vector result; @@ -31,6 +42,27 @@ static std::vector toPrimeFactors(int x) { return result; } +/** + * @brief Construct a GemmSP operator node from TL call arguments and a buffer + * map. + * + * Parses the expected call argument tuple and fills an internal GemmSPNode: + * - Buffers: A (args[0]), E (args[1]), B (args[2]), C (args[3]) are looked up + * in vmap. + * - Booleans: trans_A (args[4]), trans_B (args[5]). + * - Dimensions: M (args[6]), N (args[7]), K (args[8]) as integers. + * - Warp policy: policy (args[9]) mapped to GemmWarpPolicy. + * - clear_accum: boolean flag (args[10]). + * - Optional kPack (args[11]): must be 1 or 2 (checked via ICHECK). + * - Optional wg_wait (args[12]): integer workgroup wait parameter. + * + * The populated GemmSPNode is stored in the instance's internal data_ pointer. + * + * @param args Positional TL call arguments in the above order. + * @param vmap BufferMap mapping access pointers (from args) to Buffer objects. + * + * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. + */ GemmSP::GemmSP(Array args, BufferMap vmap) { ObjectPtr node = make_object(); node->A = vmap[GetVarFromAccessPtr(args[0])]; @@ -57,11 +89,41 @@ GemmSP::GemmSP(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a deep copy of this GemmSPNode wrapped as a TileOperator. + * + * Returns a new TileOperator that owns a copy of this node. The cloned node + * duplicates all fields of the original; subsequent modifications to the + * clone do not affect the original node. + * + * @return TileOperator A TileOperator holding a cloned GemmSPNode. + */ TileOperator GemmSPNode::Clone() const { auto op = make_object(*this); return GemmSP(op); } +/** + * @brief Compute a partition of warps across the M and N GEMM dimensions. + * + * Computes (m_warp, n_warp) such that m_warp * n_warp == num_warps and the + * warp counts respect element-per-warp granularity and the configured + * GemmWarpPolicy. On Hopper targets, when `maybe_hopper_wgmma` is true and + * the problem size permits, a warp-group (WGMMA)-aware partitioning is used + * (groups of 4 warps). + * + * @param num_warps Total number of warps available for the block. + * @param target Hardware target used to decide target-specific strategies + * (e.g., Hopper WGMMA grouping). + * @param maybe_hopper_wgmma If true, allows using Hopper WGMMA-specific + * partitioning when the target and problem size + * permit. + * @return std::pair A pair (m_warp, n_warp) giving the number of warp + * partitions along M and N, respectively. + * + * @note The function uses ICHECK to enforce invariants (e.g., unknown policy or + * invalid m_warp * n_warp), which will terminate on failure. + */ std::pair GemmSPNode::ComputeWarpPartition(int num_warps, Target target, bool maybe_hopper_wgmma) const { @@ -220,6 +282,24 @@ GemmSPNode::ComputeWarpPartition(int num_warps, Target target, return {m_warp, n_warp}; } +/** + * @brief Lower this GemmSP node to a TL (tensile-like) intrinsic call. + * + * Constructs and returns an Evaluate statement containing a call to the + * TL gemm_sp intrinsic that encodes this GEMM's template parameters + * (M, N, K, warp partition, transposition flags, clear_accum, and optional + * Hopper/WGMMA and wg_wait modifiers) and the remapped buffer access pointers. + * + * The function validates that A, B, and E reside in shared (or shared.dyn) + * memory (ICHECK failures otherwise), computes the warp partition based on + * the launch configuration and target, and emits a single tl::tl_gemm_sp call + * with a string template describing the configuration. + * + * @param T Lowering context containing thread bounds, target, and optional + * buffer remapping used to obtain the final buffer AccessPtr + * arguments for the TL call. + * @return Stmt An Evaluate wrapping the constructed tl::tl_gemm_sp call. + */ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int warp_size = 32; @@ -264,6 +344,34 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(new_call); } +/** + * @brief Infers and returns the memory/layout mapping for the GemmSP operator. + * + * Infers thread-local fragment layout for C and shared-memory layouts for A and + * B based on the target (Hopper-only path), block/thread bounds in T, + * transposition flags, and matrix dimensions stored in the node. The function + * caches its work: if layout inference has already completed (completed_ == + * true) it returns an empty LayoutMap. + * + * Precondition: + * - C.scope() must be "local.fragment". + * + * Behavior notes: + * - Only the Hopper target is supported; non-Hopper targets trigger a fatal + * check. + * - For Hopper, the function computes a warp partition from block size and may + * enable WGMMA-specific fragment creation when conditions on M and block size + * are met. + * - A and B must reside in "shared" or "shared.dyn"; otherwise the function + * aborts with a check failure. + * - The method sets completed_ = true before returning to avoid re-entrance. + * + * @param T LayoutInferArgs containing thread bounds and the target (used to + * select Hopper-specific layouts). + * @param level Currently unused inference detail level. + * @return LayoutMap mapping A, B, and C to their inferred layouts (or empty if + * inference was already completed). + */ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (completed_) diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index e645d0d42..e824acc16 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -10,6 +10,60 @@ #include "operator.h" namespace tvm { +/** + * Lower the GemmSP operator into a TIR statement for the given lowering + * context. + * + * Produces the TIR Stmt that implements this operator using the provided + * lowering arguments. The `analyzer` is used for arithmetic simplifications and + * may be null. + * + * @param T Lowering context and arguments. + * @returns A TIR `Stmt` implementing the lowered operator. + */ +/** + * Infer memory/layout mapping for operands and outputs of this operator. + * + * Computes a LayoutMap describing how logical tensor layouts map to physical + * buffer layouts for the given inference `level`. + * + * @param T Layout inference inputs (shapes, buffer info, etc.). + * @param level Inference granularity/level. + * @returns A LayoutMap describing inferred layouts. + */ +/** + * Compute a warp-level partitioning (rows, cols) for the given number of warps. + * + * Returns a pair (warps_per_row, warps_per_col) describing how to tile the GEMM + * across warps for the specified `target`. The optional `maybe_hopper_wgmma` + * enables target-specific adjustments (e.g., CDNA WG/MMA variants) when set. + * + * @param num_warps Total number of warps available for the tile. + * @param target Target device/architecture used to guide partitioning choices. + * @param maybe_hopper_wgmma Enable target-specific WG/MMA adjustments when + * true. + * @returns Pair of (warps_per_row, warps_per_col). + */ +/** + * Create a copy of this TileOperator node as a TileOperator reference. + * + * The returned TileOperator refers to a new node that is a copy of this node. + * + * @returns A TileOperator that is a clone of this node. + */ +/** + * Construct a GemmSP TileOperator from call arguments and a buffer map. + * + * @param args Array of PrimExpr specifying call-site arguments for the + * operator. + * @param vmap Mapping from buffer names to tir::Buffer objects for + * operands/outputs. + */ +/** + * Return the singleton Op descriptor for the GemmSP operator. + * + * @returns Reference to the operator's Op registration object. + */ namespace tl { using namespace tir; diff --git a/src/op/operator.cc b/src/op/operator.cc index ffc7cdefc..783950795 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -15,6 +15,21 @@ namespace tl { using namespace tir; +/** + * @brief Construct a TileOperator from a TIR Call using a registered builder. + * + * Looks up a builder function in the "TLOpBuilder" Op attribute map for the + * operator referenced by `call` and invokes it to produce a TileOperator. If no + * builder is registered for the operator, returns a default-constructed (empty) + * TileOperator. + * + * @param call The TIR Call whose operator and arguments will be used to build + * the TileOperator. + * @param vmap Buffer mapping passed through to the builder to resolve buffer + * references. + * @return TileOperator The constructed TileOperator, or a default (empty) + * TileOperator if no builder exists. + */ TileOperator ParseOperator(Call call, BufferMap vmap) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); @@ -26,6 +41,18 @@ TileOperator ParseOperator(Call call, BufferMap vmap) { return TileOperator(); } +/** + * @brief Parse a TileOperator from a TIR statement if it contains a call. + * + * If `stmt` is an Evaluate node whose value is a Call, delegates to + * ParseOperator(Call, BufferMap) and returns the resulting TileOperator. + * Otherwise returns a default-constructed (empty) TileOperator. + * + * @param stmt TIR statement to inspect; expected to be an Evaluate of a Call. + * @param vmap Mapping of buffer variables used when building the operator. + * @return TileOperator Parsed operator on success, or a default (empty) + * TileOperator if `stmt` is not an Evaluate(Call). + */ TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { if (stmt.as() && stmt.as()->value.as()) { auto call = stmt.as()->value.as(); @@ -34,6 +61,17 @@ TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { return TileOperator(); } +/** + * @brief Extracts the Var referenced by a `tvm_access_ptr` call expression. + * + * The function expects `expr` to be a `Call` to the builtin `tvm_access_ptr` + * and returns the `Var` found in the call's second argument (`args[1]`). The + * function performs runtime checks and will abort if `expr` is not a call, the + * call is not `tvm_access_ptr`, or the second argument is not a `Var`. + * + * @param expr A `PrimExpr` representing a `tvm_access_ptr(...)` call. + * @return tvm::Var The `Var` referenced by the `tvm_access_ptr` call. + */ Var GetVarFromAccessPtr(const PrimExpr &expr) { auto call = expr.as(); ICHECK(call); diff --git a/src/op/operator.h b/src/op/operator.h index 84692573f..8c0f8d1ea 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -11,8 +11,8 @@ #include #include #include -#include #include +#include #include "../layout/layout.h" @@ -51,32 +51,117 @@ struct LayoutInferArgs { class TileOperatorNode; class TileOperator; -class TileOperatorNode: public Object { - public: +/** + * Abstract base class for tile-level operators. + * + * Implementations must provide lowering to TIR, layout inference, and cloning. + */ + +/** + * Lower this tile operator to a TIR statement. + * + * @param T Lowering context and utilities (target, thread bounds, layout + * mappings, buffer remapping, and AddWorkspace callback for requesting + * temporary buffers). + * @param analyzer Arithmetic analyzer used during lowering. + * @return A TIR Stmt representing the lowered operator. + */ + +/** + * Infer buffer layouts for this operator. + * + * The returned LayoutMap associates input/output Buffers with inferred Layouts. + * The `level` controls how strictly layouts are determined (kFree, kCommon, + * kStrict). + * + * @param T Layout inference context (target, thread bounds, existing + * layout_map, buffer_remap). + * @param level Inference strictness level. + * @return A LayoutMap mapping Buffers to their inferred Layouts. + */ + +/** + * Create a deep copy of this TileOperator. + * + * @return A TileOperator referencing a cloned operator instance. + */ + +/** + * Reference wrapper for TileOperatorNode. + * + * Use this ObjectRef to hold and pass tile operator instances within the + * runtime. + */ + +/** + * Extract the underlying Var from an access pointer expression. + * + * If `expr` represents an access pointer that directly refers to a variable, + * returns that Var; otherwise returns a null/default Var. + * + * @param expr The pointer/access expression to inspect. + * @return The extracted Var, or a null Var if none can be found. + */ + +/** + * Parse a Call into a TileOperator using the provided buffer mapping. + * + * @param call The Call node representing a tile operator invocation. + * @param vmap Mapping from TIR Vars to Buffers for resolving buffer arguments. + * @return A TileOperator constructed from the call and buffer map. + */ + +/** + * Parse a Stmt into a TileOperator using the provided buffer mapping. + * + * @param stmt The Stmt representing a tile operator region or call. + * @param vmap Mapping from TIR Vars to Buffers for resolving buffer references. + * @return A TileOperator constructed from the statement and buffer map. + */ + +/** + * Function type for TL operator builders exposed to the FFI. + * + * Builder functions take an array of PrimExpr arguments and a BufferMap, and + * return a constructed TileOperator. + */ + +/** + * Register a TL operator and its builder with TVM's op registry. + * + * Entry should be a type providing a static `Get()` and a constructor taking + * `(Array, BufferMap)`. This macro registers the operator under the + * name "tl.OpName" and sets an FFI builder attribute that constructs + * Entry(args, vmap). + * + * Usage: TIR_REGISTER_TL_OP(MyOpEntry, MyOp) + */ +class TileOperatorNode : public Object { +public: virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; - virtual LayoutMap InferLayout(const LayoutInferArgs& T, + virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const = 0; virtual TileOperator Clone() const = 0; - static constexpr const char* _type_key = "tl.TileOperator"; + static constexpr const char *_type_key = "tl.TileOperator"; TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object); }; class TileOperator : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); +public: + TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); }; - Var GetVarFromAccessPtr(const PrimExpr &expr); TileOperator ParseOperator(Call call, BufferMap vmap); TileOperator ParseOperator(Stmt stmt, BufferMap vmap); -using OpBuilderFunc = ffi::TypedFunction, BufferMap)>; +using OpBuilderFunc = + ffi::TypedFunction, BufferMap)>; #define TIR_REGISTER_TL_OP(Entry, OpName) \ const Op &Entry::Get() { \ @@ -86,11 +171,10 @@ using OpBuilderFunc = ffi::TypedFunction, BufferMap TVM_REGISTER_OP("tl." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr("TLOpBuilder", \ - [](Array args, BufferMap vmap) { \ - return Entry(args, vmap); \ + [](Array args, BufferMap vmap) { \ + return Entry(args, vmap); \ }) - } // namespace tl } // namespace tvm diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 2c347c34f..1ef2e5f80 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -141,6 +141,19 @@ void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { StmtExprVisitor::VisitStmt_(op); } +/** + * @brief Visit a BufferLoad node and record/validate index mapping for + * fragment-local buffers. + * + * If the loaded buffer's scope is "local.fragment", this records the load + * indices in the visitor's indice_map_ when seen for the first time. If an + * entry already exists, the previously recorded indices are asserted + * structurally equal to the current indices. + * + * This ensures all accesses to the same fragment-local buffer within the + * parallel loop use a consistent index map. The function then continues + * standard expression visitation. + */ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { if (op->buffer.scope() == "local.fragment") { if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { @@ -154,42 +167,91 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { StmtExprVisitor::VisitExpr_(op); } +/** + * @brief Construct a ParallelOpNode from a parallel loop nest root. + * + * Initializes the node with the given For loop as the root of the parallel + * operator and immediately runs the internal ParallelLoopNestVisitor to collect + * loop and buffer access information from the nested body. + * + * @param root The root For node representing the parallel loop nest to be + * analyzed. + */ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { V.VisitStmt(root); } +/** + * @brief Create a copy of this ParallelOpNode wrapped as a TileOperator. + * + * Returns a new TileOperator that holds a deep copy of this ParallelOpNode. + * + * @return TileOperator A TileOperator owning a copy of this node. + */ TileOperator ParallelOpNode::Clone() const { auto op = make_object(*this); return ParallelOp(op); } +/** + * @brief No-op lowering: return the stored root statement unchanged. + * + * This implementation does not perform any transformation and returns the + * operator's original root For statement as-is. + * + * @param T Lowering arguments (unused). + * @return Stmt The original root statement held by this ParallelOpNode. + */ Stmt ParallelOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return root_; } +/** + * @brief Check whether a buffer is indexed by the loop's canonical (common) + * iteration variables. + * + * Returns true if the recorded index mapping for `buffer` is structurally equal + * to the sequence of loop iteration variables for this parallel op (i.e., the + * buffer is accessed using the common access indices of the loop nest). + * + * @param buffer The buffer to check. + * @return true if the buffer's index map equals the loop's iteration variables; + * false otherwise. + */ bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const { auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); return StructuralEqual()(indice_map_[buffer], common_indice); } -/*! \brief Infer the layout for parallel operations based on different inference - * levels +/** + * @brief Infer buffer layouts for a Parallel operator based on the chosen + * inference level. * - * The inference level controls how aggressively we try to infer and optimize - * layouts: - * - kStrict (2): Most conservative level. Only allows explicitly defined - * layouts. Returns empty layout map if loop_layout_ is not already defined. - * Used when exact layout control is required. + * Attempts to compute a consistent LayoutMap for buffers accessed by a parallel + * loop (root_) using explicit input layouts (T.layout_map), thread bounds + * (T.thread_bounds), and optional buffer remapping/vectorization information in + * T. Behavior depends on the supplied InferLevel: + * - kStrict: only accept pre-existing loop_layout_ (no inference). + * - kCommon: allow inference from explicit buffer fragments when available. + * - kFree: attempt more aggressive inference (derive loop partition from + * read/write fragments, plan partitioning from vectorization/thread bounds, and + * add predicates to constrain replication when necessary). * - * - kCommon (1): Intermediate level between strict and free. - * Allows common layout patterns while maintaining some - * constraints. + * This method may mutate the node's internal state (sets loop_layout_ when + * inferred and registers predicates via AddPredicate) and consults analyzer_ + * for symbolic proofs. * - * - kFree (0): Most permissive level. Allows maximum optimization freedom. - * Will attempt layout inference even without source buffers. - * Can generate new layouts based on vectorization and thread - * bounds. Used when maximum performance optimization is desired. + * @param T Container of auxiliary inputs used for inference (buffer_remap, + * layout_map, and thread_bounds). The function uses T.layout_map for source + * fragments and T.thread_bounds to bind thread-range information in inferred + * fragments. + * @param level Controls inference aggressiveness (kStrict, kCommon, kFree). + * @return LayoutMap A map of buffers to inferred Fragment layouts for buffers + * that did not already have layouts in T.layout_map. Returns an empty map when + * no inference was performed. + * @throws LayoutConflictException If a computed loop partition conflicts with + * an existing buffer fragment (incompatible thread mappings). */ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { @@ -368,6 +430,20 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, return results; } +/** + * @brief Retrieve the loop's thread predicate with the thread variable + * substituted. + * + * If a predicate is set for this ParallelOpNode, returns a copy of that + * predicate where the placeholder input (InputPlaceholder(0)) is replaced by + * the provided thread_var. If no predicate is defined, returns an empty + * Optional. + * + * @param thread_var The thread loop variable to substitute for the predicate's + * input placeholder. + * @return Optional The substituted predicate expression, or + * std::nullopt if none is defined. + */ Optional ParallelOpNode::GetPredicate(Var thread_var) const { if (predicate_.defined()) { return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); @@ -376,6 +452,32 @@ Optional ParallelOpNode::GetPredicate(Var thread_var) const { } } +/** + * @brief Construct the complete fragment layout for a buffer within the + * parallel loop. + * + * Given a buffer referenced inside the parallel loop, return a Fragment that + * maps the buffer's logical indices to the loop's thread space and replication + * extent. + * + * Detailed behavior: + * - Precondition: a loop layout (loop_layout_) must be defined. + * - If the buffer uses the common access indices of the loop, the loop's + * fragment is returned directly. + * - Otherwise, the function: + * - Computes the buffer's bijective index by appending the flattened + * replication expression for unused iterators. + * - Inverts that bijection to obtain the replication extent of the buffer's + * index space and combines it with the loop's replication extent to produce the + * destination replication extent. + * - Builds forward index placeholders for the buffer elements and maps them + * through the inverted layout and the loop layout to derive the thread binding. + * - Returns a Fragment with the computed thread binding and combined + * replication extent, with replicate variables condensed. + * + * @return Fragment The completed fragment describing thread binding and + * replication extent for `buffer`. + */ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ICHECK(loop_layout_.defined()); if (IsCommonAccessIndice(buffer)) { diff --git a/src/op/parallel.h b/src/op/parallel.h index fe514b43d..b03a669c1 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -12,6 +12,140 @@ #include "../layout/layout.h" #include "operator.h" +/** + * Exception representing a layout conflict detected during layout inference. + * + * Stores an explanatory message retrievable via what(). + */ + +/** + * Determine whether `small_frag` is guaranteed to be contained within + * `large_frag` under the given index mappings and using the provided arithmetic + * analyzer. + * + * @param small_frag The smaller fragment to test for containment. + * @param large_frag The larger fragment that may contain `small_frag`. + * @param small_frag_indices Index expressions mapping the small fragment into + * buffer space. + * @param large_frag_indices Index expressions mapping the large fragment into + * buffer space. + * @param analyzer_ Arithmetic analyzer used to simplify and prove index + * relations. + * @return true if containment can be proven; false otherwise. + */ + +/** + * Visitor that traverses a parallel loop nest to collect buffer access and + * loop-structure information for a ParallelOpNode. + * + * The visitor records loop variables, buffer read/write accesses, and builds + * predicates as it encounters BufferLoad/BufferStore and For nodes. + */ + +/** + * Represents a parallel for-loop operator in TileLang. + * + * Holds the root For loop, collects and exposes loop layout and access-index + * information, and provides layout inference and lowering to TIR. + * + * Public methods expose the inferred loop layout, root loop, buffer index + * mappings, and any per-thread predicate; Lower and InferLayout perform the + * operator's lowering and layout inference respectively. + */ + +/** + * Create a ParallelOpNode from a root For loop. + * + * @param root The root For node representing the parallel loop nest. + */ + +/** + * Lower this parallel operator into a TIR statement suitable for codegen. + * + * @param T Lowering arguments and context. + * @param analyzer Arithmetic analyzer for expression simplification during + * lowering. + * @return A TIR statement representing the lowered parallel loop. + */ + +/** + * Infer the layout mapping for this parallel operator at the specified level. + * + * @param T Arguments and context for layout inference. + * @param level Inference granularity level. + * @return A LayoutMap describing inferred buffer/layout relationships for the + * operator. + */ + +/** + * Copy-construct a ParallelOpNode, preserving inferred layout and predicate. + */ + +/** + * Get the inferred loop layout fragment. + * + * @return The Fragment representing the loop's inferred layout (may be lazily + * computed). + */ + +/** + * Get the root For loop of this operator. + * + * @return The root For AST node. + */ + +/** + * Get the mapping from each buffer to the array of index expressions used to + * access it within the loop nest. + * + * @return A Map from Buffer to Array of access indices. + */ + +/** + * Retrieve the predicate expression associated with a given thread variable, if + * any. + * + * @param thread_var The thread variable whose predicate is requested. + * @return An Optional containing the predicate when present. + */ + +/** + * Create a deep copy of this operator as a TileOperator handle. + * + * @return A TileOperator that references a copy of this node. + */ + +/** + * Visitor helper: complete the fragment layout for a buffer (internal). + * + * (Private helper — not part of the public API.) + */ + +/** + * Helper to check whether a buffer's access indices are the common loop indices + * (internal). + * + * (Private helper — not part of the public API.) + */ + +/** + * Add `expr` to the current predicate by logical AND; sets predicate if none + * exists. + * + * (Private helper — not part of the public API.) + */ + +/** + * Thin handle type exposing ParallelOpNode as a TileOperator. + * + * Construct from a root For loop to create and own a ParallelOpNode instance. + */ + +/** + * Construct a ParallelOp handle from a root For loop. + * + * @param root The root For node representing the parallel loop nest. + */ namespace tvm { namespace tl { diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 4fcf6c686..c1f9cd05c 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -21,6 +21,25 @@ namespace tl { using namespace tir; +/** + * @brief Construct a ReduceOp from raw TL arguments and a buffer mapping. + * + * Interprets `args` and `vmap` to populate an internal ReduceOpNode: + * - args[0]: access pointer for the source buffer + * - args[1]: access pointer for the destination buffer + * - args[2]: string literal specifying the reduce type: "sum", "abssum", + * "absmax", "max", or "min" + * - args[3]: integer literal for the reduction dimension (axis) + * - args[4]: boolean literal indicating whether to clear/init the destination + * + * The constructor resolves the access pointers via `vmap`, maps the reduce + * type string to the ReduceType enum, assigns the reduction dimension and + * clear flag, and stores the constructed node in `data_`. An invalid reduce + * type triggers a fatal check. + * + * @param args Array of TL prim-expr arguments as described above. + * @param vmap Mapping from variables (from access pointers) to Buffer objects. + */ ReduceOp::ReduceOp(Array args, BufferMap vmap) { ObjectPtr node = make_object(); node->src = vmap[GetVarFromAccessPtr(args[0])]; @@ -43,16 +62,52 @@ ReduceOp::ReduceOp(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a copy of this ReduceOpNode wrapped as a TileOperator. + * + * Returns a new TileOperator holding a freshly allocated ReduceOpNode + * constructed as a copy of this node. + * + * @return TileOperator A tile operator that owns the cloned ReduceOpNode. + */ TileOperator ReduceOpNode::Clone() const { auto op = make_object(*this); return ReduceOp(op); } +/** + * @brief Create a deep copy of this CumSum op node wrapped as a TileOperator. + * + * Returns a new TileOperator whose underlying CumSumOpNode is a copy of + * the current node. Useful for cloning operators when building or + * transforming computation graphs. + * + * @return TileOperator A TileOperator containing a copy of this node. + */ TileOperator CumSumOpNode::Clone() const { auto op = make_object(*this); return CumSumOp(op); } +/** + * @brief Create the initial accumulator value for the destination buffer based + * on reduction type. + * + * Returns the PrimExpr representing the initial value stored in the destination + * accumulator before any source elements are combined. The returned value + * depends on the destination dtype and the node's reduction type: + * - kSum, kAbsSum: zero of the destination dtype. + * - kMax: minimum representable value for signed integers, zero for unsigned + * integers, and -INFINITY for floating point. + * - kMin: maximum representable value for signed integers, all-ones (max) for + * unsigned integers, and +INFINITY for floating point. + * - kAbsMax: zero of the destination dtype. + * + * The function will abort (ICHECK failure) if the reduction type is + * unrecognized. + * + * @return PrimExpr initial value appropriate for `dst->dtype` and `type`. + */ PrimExpr ReduceOpNode::MakeInitValue() const { auto dst_dtype = dst->dtype; auto is_int = dst_dtype.is_int(); @@ -87,6 +142,24 @@ PrimExpr ReduceOpNode::MakeInitValue() const { } } +/** + * @brief Combine two scalar expressions according to this node's reduction + * type. + * + * Casts the right operand to the left operand's dtype if they differ, then + * returns the reduction of `a` and `b` using the operator specified by `type`: + * - kSum: `a + b` + * - kAbsSum: `a + max(b, -b)` + * - kMax: `max(a, b)` + * - kMin: `min(a, b)` + * - kAbsMax: `max(max(a, b), -min(a, b))` + * + * @param a Left-hand operand (result dtype drives the output dtype). + * @param b Right-hand operand (will be cast to `a`'s dtype if needed). + * @return PrimExpr The combined expression with dtype equal to `a.dtype`. + * + * @note The function DCHECKs/ICHECKs on an unknown/unsupported reduction type. + */ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { PrimExpr lhs = a, rhs = b; if (lhs->dtype != rhs->dtype) { @@ -109,6 +182,20 @@ PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { } } +/** + * @brief Map the reduction type to the codegen reducer name used by external + * ALL-Reduce/CUDA helpers. + * + * Returns the string identifier of the code-generation reducer corresponding to + * this ReduceOpNode's `type`. Mapping: + * - kSum, kAbsSum -> "tl::SumOp" + * - kMax, kAbsMax -> "tl::MaxOp" + * - kMin -> "tl::MinOp" + * + * The function terminates with a check failure if `type` is unknown. + * + * @return std::string Reducer name used by codegen extern calls. + */ std::string ReduceOpNode::MakeCodegenReducer() const { switch (type) { case ReduceType::kSum: @@ -127,6 +214,32 @@ std::string ReduceOpNode::MakeCodegenReducer() const { } } +/** + * @brief Lower the Reduce operator node to a TIR statement. + * + * Lowers a ReduceOpNode that targets fragment-local buffers into a sequence of + * TIR statements implementing: per-thread local reduction, inter-thread + * AllReduce (when needed), and final writeback (with an optional duplicate + * clear buffer to avoid in-place conflicts). Supports reduction kinds + * (sum/abs-sum/max/min/abs-max) and handles layout-driven index mapping and + * loop partitioning to thread axes. + * + * @param T Lowering context providing buffer remapping, layout map, target and + * thread bounds, and workspace allocation helper. Must contain + * fragment-local mappings for both src and dst. + * @param analyzer Symbolic analyzer used to simplify and compress iterators. + * @return Stmt The constructed TIR statement implementing the reduction. + * + * Preconditions: + * - src and dst buffers must be in "local.fragment" scope. + * - The layouts must have compatible input/output dimensions for the + * specified reduction axis. + * + * Failure modes: + * - The function uses ICHECK to enforce unsupported scopes, dimension + * mismatches, unknown reduction types, and other invariants; violations + * will trigger a fatal check failure. + */ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") @@ -296,6 +409,38 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return body; } +/** + * @brief Infer a layout mapping for the destination buffer of a Reduce + * operator. + * + * When inference level is below `kStrict`, and both source and destination + * buffers live in `local.fragment` with a known source fragment layout, this + * computes a candidate destination Fragment layout that accounts for + * replication over the reduction dimension and binds thread ranges from + * `T.thread_bounds`. + * + * Behavior: + * - Constructs a destination Fragment whose replicate extent equals + * src.shape[dim] * src_fragment.ReplicateExtent(), and whose threading is + * derived from the source fragment with the reduction dimension folded out. + * - If no layout exists for `dst` in `T.layout_map`, returns a map {dst -> + * inferred}. + * - If `dst` already has a layout, validates that the existing layout strictly + * contains the computed layout (shapes match and fragment containment holds). + * If compatible but the computed replicate extent is larger, returns the new + * layout. + * - In all other cases (strict inference level, unsupported scopes, or no src + * layout), returns an empty map. + * + * @param T Layout inference context containing `layout_map` and + * `thread_bounds`. + * @param level Inference strictness; no inference is performed at or above + * `kStrict`. + * @return LayoutMap A mapping for `dst` to an inferred Fragment layout, or + * empty. + * @throws LayoutConflictException if an existing `dst` layout conflicts with + * the computed layout (not containable or incompatible replication extents). + */ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (level >= InferLevel::kStrict) @@ -373,6 +518,22 @@ TIR_REGISTER_TL_OP(ReduceOp, reduce) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +/** + * @brief Construct a CumSumOp from a list of arguments and a buffer map. + * + * Expects args to contain exactly four PrimExprs in this order: + * 0: access pointer to source buffer (src), + * 1: access pointer to destination buffer (dst), + * 2: integer dimension to perform the cumulative sum along (dim), + * 3: boolean flag indicating whether to compute the cumsum in reverse + * (reverse). + * + * The constructor resolves src and dst from the provided BufferMap and stores + * the parsed dim and reverse values on the node. It verifies that args.size() + * == 4 and that dim is a valid axis for the source buffer shape. + * + * @param args Array of PrimExpr as described above. + */ CumSumOp::CumSumOp(Array args, BufferMap vmap) { /* CumSum arguments: @@ -391,6 +552,28 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Lower the CumSum operator to TIR. + * + * Produces a TIR statement implementing cumulative sum depending on buffer + * scopes: + * - For shared/shared.dyn scopes: emits an extern call to + * `tl::CumSum2D::run` with arguments [function_name, + * src.access_ptr(1), dst.access_ptr(3), src.shape...]. The number of threads is + * taken from `T.thread_bounds->extent`. Returns an Evaluate(Call(...)) + * statement. + * - For local.fragment scopes on both src and dst: fatal error (not + * implemented). + * - For any other scope combinations: fails with an assertion. + * + * The `analyzer` parameter is accepted for interface compatibility but is not + * used by this lowering. + * + * @param T Lowering arguments (provides thread bounds and other lowering + * context). + * @return Stmt A TIR statement representing the lowered cumulative-sum + * operation. + */ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") { @@ -417,6 +600,17 @@ Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Stmt(); } +/** + * @brief Layout inference for CumSum operator. + * + * CumSum does not perform any layout inference; this function always returns + * an empty mapping. The operator's lowering expects shared-memory semantics + * and layout decisions are handled elsewhere. + * + * @param T Layout inference inputs (buffers, existing layouts, etc.). + * @param level Inference strictness level (unused). + * @return LayoutMap Empty map indicating no inferred layouts. + */ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { return {}; diff --git a/src/op/reduce.h b/src/op/reduce.h index 2be74cf09..c78ac23d8 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -10,6 +10,146 @@ #include "operator.h" namespace tvm { +/** + * Tile operator node that performs a reduction (sum, max, min, etc.) along a + * single tensor dimension. + * + * Represents a per-instance reduce operator with explicit source/destination + * buffers, target dimension, reduction type, and a flag controlling whether the + * destination is cleared before reduction. + */ + +/** + * Lower this ReduceOpNode into a Tir Stmt suitable for code generation. + * + * Produces the TIR statement(s) that implement the configured reduction. + * + * @return A TIR `Stmt` implementing the reduce operation. + */ + +/** + * Infer input/output layouts for this reduce operator. + * + * Returns a LayoutMap describing how input and output buffer layouts relate + * for the configured reduction dimension. + * + * @param level Inference detail level that may affect how aggressively layouts + * are inferred. + * @return A LayoutMap mapping operator arguments to inferred layouts. + */ + +/** + * Retrieve the global operator descriptor for the reduce operator. + * + * @return A reference to the Op descriptor corresponding to this operator type. + */ + +/** + * Create a copy of this reduce operator as a TileOperator handle. + * + * The returned TileOperator preserves the node's configuration (buffers, dim, + * type, clear). + * + * @return A TileOperator wrapping a cloned ReduceOpNode. + */ + +/** + * Construct the initial value used by the reduction (e.g., 0 for sum, -inf for + * max). + * + * @return A PrimExpr representing the reduction's identity/init value. + */ + +/** + * Combine two partial values according to the configured reduction. + * + * Implements the binary reducer (for example, `a + b` for sum or `max(a, b)` + * for max). + * + * @return A PrimExpr representing the reduced result of `a` and `b`. + */ + +/** + * Generate a string snippet suitable for code generation of the reducer + * expression. + * + * The returned code fragment should implement the binary reduction operation in + * the target backend's code string form. + * + * @return A std::string containing the codegen expression for the reducer. + */ + +/** + * Reference wrapper for ReduceOpNode as a TileOperator. + * + * Construct a ReduceOp from explicit arguments and a buffer map. + */ + +/** + * Construct a ReduceOp TileOperator from operator arguments and a buffer + * mapping. + * + * @param args Operator arguments (typically shapes, axes, or other prim exprs). + * @param vmap Mapping from argument names to tir::Buffer instances used by the + * operator. + */ + +/** + * Tile operator node that computes a cumulative sum along a single tensor + * dimension. + * + * Contains source/destination buffers, the target dimension, and a flag to + * compute the cumulative sum in reverse order. + */ + +/** + * Lower this CumSumOpNode into a Tir Stmt suitable for code generation. + * + * Produces the TIR statement(s) that implement the configured cumulative-sum. + * + * @return A TIR `Stmt` implementing the cum-sum operation. + */ + +/** + * Infer input/output layouts for this cumulative-sum operator. + * + * Returns a LayoutMap describing how input and output buffer layouts relate + * for the configured cumulative-sum dimension. + * + * @param level Inference detail level that may affect how aggressively layouts + * are inferred. + * @return A LayoutMap mapping operator arguments to inferred layouts. + */ + +/** + * Retrieve the global operator descriptor for the cumulative-sum operator. + * + * @return A reference to the Op descriptor corresponding to this operator type. + */ + +/** + * Create a copy of this cum-sum operator as a TileOperator handle. + * + * The returned TileOperator preserves the node's configuration (buffers, dim, + * reverse). + * + * @return A TileOperator wrapping a cloned CumSumOpNode. + */ + +/** + * Reference wrapper for CumSumOpNode as a TileOperator. + * + * Construct a CumSumOp from explicit arguments and a buffer map. + */ + +/** + * Construct a CumSumOp TileOperator from operator arguments and a buffer + * mapping. + * + * @param args Operator arguments (typically shapes, axes, or other prim exprs). + * @param vmap Mapping from argument names to tir::Buffer instances used by the + * operator. + */ namespace tl { using namespace tir; diff --git a/src/op/region.cc b/src/op/region.cc index 0b74ab00f..95a0b4295 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -11,6 +11,26 @@ namespace tvm { namespace tl { using namespace tir; +/** + * @brief Construct a RegionOp from TL operator arguments. + * + * Parses the TL `region` operator call arguments to populate the RegionOpNode: + * - Expects args[0] to be a `BufferLoad` whose `indices` are the per-dimension + * minima. + * - args[1] must be a constant integer used as the access mask. + * - args[2 + i] provides the extent for dimension `i`. + * + * The constructor validates that the number of load indices equals `args.size() + * - 2` and will abort via ICHECK on mismatch or if args[0] is not a + * `BufferLoad`. + * + * Parameters: + * - args: TL operator call arguments in the form + * [BufferLoad(min_i...), access_mask, extent_0, extent_1, ..., + * extent_{n-1}] where n = number of dimensions. + * - vmap: BufferMap passed through by the caller (not documented here as a + * generic utility). + */ RegionOp::RegionOp(Array args, BufferMap vmap) { size_t n = args.size(); size_t ndim = n - 2; @@ -31,11 +51,26 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { data_ = std::move(node); } +/** + * @brief Create a copy of this RegionOpNode and return it as a TileOperator. + * + * @return TileOperator A new TileOperator that owns a copied RegionOpNode. + */ TileOperator RegionOpNode::Clone() const { auto op = make_object(*this); return RegionOp(op); } +/** + * @brief Check whether the region spans the entire underlying buffer. + * + * Returns true if for every dimension the range minimum is zero and the + * range extent is structurally equal to the corresponding buffer shape + * dimension. Otherwise returns false. + * + * @return true if the region covers the full buffer in all dimensions; false + * otherwise. + */ bool RegionOpNode::IsFullRegion() const { for (size_t i = 0; i < ranges_.size(); i++) { if (!is_zero(ranges_[i]->min)) @@ -46,10 +81,33 @@ bool RegionOpNode::IsFullRegion() const { return true; } +/** + * @brief Lower the region operator to a TIR statement. + * + * Lowers this RegionOpNode into a TIR Stmt by delegating to the operator's + * evaluation path (currently `Evaluate(0)`). + * + * @param T Lowering context (provides buffers, producers/consumers and other + * environment required for lowering). + * @param analyzer Optional arithmetic analyzer used for simplification during + * lowering. + * @return Stmt The lowered TIR statement representing this region operation. + */ Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(0); } +/** + * @brief Infers data layout for the region operator. + * + * This operator does not provide any layout inference; the function always + * returns an empty LayoutMap regardless of the provided arguments or inference + * level. + * + * @param T Layout inference arguments (ignored). + * @param level Inference granularity level (ignored). + * @return LayoutMap Empty map indicating no inferred layouts. + */ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { return {}; diff --git a/src/op/region.h b/src/op/region.h index 1d56ea47b..2e20216ca 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -13,6 +13,62 @@ #include #include +/** + * Tile operator representing a memory region (buffer + ranges) used by TL + * passes. + * + * Encapsulates the target tir::Buffer, the region extents as an Array, + * and an access mask that indicates permitted or intended accesses for lowering + * and layout inference. + */ + +/** + * Lower this RegionOp into a TIR statement representing the region access. + * + * @param T Lowering-time arguments (e.g., loop/build context and value + * mappings). + * @param analyzer Arithmetic analyzer used to simplify and reason about + * expressions. + * @return A tir::Stmt that implements the region access/mutation described by + * this operator. + */ + +/** + * Infer the layout mapping for this region operator. + * + * Produces a LayoutMap describing how loop/axis indices map to buffer axes for + * layout-aware scheduling and subsequent operators. + * + * @param T Layout inference arguments (e.g., input layouts and shapes). + * @param level The inference detail level to use. + * @return A LayoutMap describing inferred mappings for the operator. + */ + +/** + * Return true when this RegionOp represents the full buffer region (i.e., + * ranges cover the entire buffer extent). + */ + +/** + * Create a shallow copy of this operator as a TileOperator handle. + * + * @return A TileOperator that references a cloned RegionOpNode. + */ + +/** + * Construct a RegionOp from argument expressions and a buffer map. + * + * @param args Positional expressions used to instantiate the operator + * (semantics depend on how RegionOp is invoked in TL pipelines). + * @param vmap Mapping from Buffer to replacement Buffer or buffer metadata used + * during creation. + */ + +/** + * Return the global Op registration for RegionOp. + * + * @return Reference to the registered tvm::Op describing the RegionOp. + */ namespace tvm { namespace tl { diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 5654044c1..6593a8212 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -63,6 +63,37 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { BufferUseDefCollector(bool skip_thread_partition) : skip_thread_partition_(skip_thread_partition) {} + /** + * @brief Execute a single layout-inference step for the infer node at the + * given index. + * + * Runs InferLayout on the TileOperator at cur_infer_id with the provided + * InferLevel and thread bounds, applies returned buffer->layout updates into + * layout_map (respecting strict_layout_map constraints for fragment buffers), + * and optionally propagates changes to dependent infer nodes by enqueueing + * them into q and marking in_queue. + * + * The function mutates layout_map and, when update_queue is true, may modify + * q and in_queue. It performs internal sanity checks via ICHECK and will + * LOG(WARNING) for buffers that cannot be propagated; ICHECK failures abort + * execution. + * + * @param cur_infer_id Index of the infer operator in infer_list_ to run (must + * be within range). + * @param level Inference relaxation level to pass to the operator's + * InferLayout. + * @param update_queue If true, discovered layout changes will enqueue + * dependent infer nodes. + * @param layout_map Mutable map of inferred layouts that will be updated with + * returned layouts. + * @param strict_layout_map Read-only map of layouts produced in the strict + * phase; used to enforce containment checks for local.fragment buffers when + * relaxing. + * @param q BFS queue used to propagate dependent inference indices; new + * indices may be pushed. + * @param in_queue Parallel boolean vector tracking queued status; entries + * corresponding to enqueued indices will be set to true. + */ void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue, LayoutMap &layout_map, const LayoutMap &strict_layout_map, std::queue &q, std::vector &in_queue) { @@ -189,6 +220,30 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } }; + /** + * @brief Run the multi-stage layout inference and return the collected + * results. + * + * Performs layout inference over the collected TileOperator entries in three + * phases: (1) strict per-operator inference, (2) common inference via a BFS + * propagation queue, and (3) a free-mode relaxation phase that explores + * alternative root orderings within connected components to reduce register + * footprint. After inference completes, verifies that all local.fragment + * buffers have inferred layouts and collects loop (For) -> Fragment layouts + * and any per-loop predicates discovered during inference. + * + * The method consumes/permutes internal inference state (notably moves + * entries out of infer_list_) and returns a LayoutInferenceResult containing: + * - layout_map: inferred Layout for each Buffer, + * - for_map: mapping from For nodes to their inferred Fragment layout, + * - predicate_map: optional loop predicates keyed by For nodes. + * + * The function performs internal consistency checks (ICHECK) on sizes and + * required definitions; violations will terminate via ICHECK failure. + * + * @return LayoutInferenceResult A tuple-like struct with the inferred + * layout_map, for_map, and predicate_map. + */ LayoutInferenceResult Run() { // Basic consistency check: infer_list_ and thread_var_vec_ should have the // same size @@ -292,6 +347,23 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } private: + /** + * @brief Visits a Call expression to collect tile-operator-based inference + * inputs. + * + * Processes non-global function calls by parsing them into a TileOperator + * (via ParseOperator). If the parse succeeds, records: + * - buffers referenced by call arguments into the collector's use lists, + * - the call AST node into infer_list_stmt_, + * - the parsed TileOperator into infer_list_, + * - the current thread IterVar into thread_var_vec_, + * - the thread iteration bounds into thread_bounds_vec_ (uses analyzer const + * bounds when available; otherwise [0,1]). + * + * Calls to global functions (where op->op is a GlobalVar) are ignored. + * + * @param op The Call node being visited. + */ void VisitExpr_(const CallNode *op) final { IRVisitorWithAnalyzer::VisitExpr_(op); // Do not analysis the call node to the global function. @@ -344,6 +416,25 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { use_list_[buffer].push_back(infer_idx); } + /** + * @brief Handles For nodes during IR traversal. + * + * When the loop is a parallel loop (ForKind::kParallel), records it as a + * ParallelOp: + * - constructs a ParallelOp for the loop and appends it to the internal infer + * lists (infer_list_ and infer_list_stmt_), + * - registers all buffers referenced by the loop indices with use-list + * bookkeeping, + * - captures the current thread IterVar context and its compile-time extent + * (if available) into thread_var_vec_ and thread_bounds_vec_ (falls back to + * range [0,1] when unknown). + * + * For non-parallel loops, continues recursive traversal into the loop body. + * + * Side effects: + * - Mutates infer_list_, infer_list_stmt_, use_list_ (via addToUseList), + * thread_var_vec_, and thread_bounds_vec_. + */ void VisitStmt_(const ForNode *op) final { if (op->kind == ForKind::kParallel) { auto infer = ParallelOp(GetRef(op)); @@ -414,6 +505,15 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { LayoutMap annotated_layout_map_; bool skip_thread_partition_{false}; + /** + * @brief Create a deep copy of the current inference operator list. + * + * Returns a vector containing clones of each TileOperator in the collector's + * internal infer_list_. The returned list is independent of the original so + * subsequent modifications to either do not affect the other. + * + * @return std::vector Cloned copy of infer_list_. + */ std::vector BackupInferList() { std::vector back_infer_list; back_infer_list.reserve(infer_list_.size()); @@ -423,6 +523,48 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { return back_infer_list; } + /** + * @brief Explore alternative inference orders within connected components to + * relax layouts. + * + * This function performs a "free-mode" exploration that attempts different + * root operators within each connected component of the operator-use graph in + * order to find a layout assignment with lower register (fragment) usage. + * + * Detailed behavior: + * - Builds connected components of infer_list_ by unioning operators that + * share buffer uses (use_list_). + * - For each component, iterates each member operator as a candidate root: + * - Backups the current infer_list_ and uses a temporary copy of + * layout_map. + * - Runs RunInferStep and FinishInferQueue in InferLevel::kFree starting + * from the candidate root and then (as a fallback) runs the remaining members + * to try to cover the whole component. + * - If inference succeeds, computes a coarse register usage metric by + * summing the product of OutputShape dimensions for all Fragment layouts + * in the temporary layout map. + * - Tracks the candidate that yields the smallest register usage. + * - If a better plan is found for a component, replaces the global + * infer_list_ and updates layout_map with the best layout_map found. + * + * Side effects: + * - Mutates layout_map to the best-found free-mode layout assignment when a + * better plan is discovered. + * - Mutates the member infer_list_ (backed up and restored during attempts; + * finally set to the best plan if found). + * + * Notes: + * - LayoutConflictException and NormalizeIterException raised during attempts + * are caught and treated as failed attempts; they do not propagate out of + * this function. + * - The register-usage metric is a heuristic (sum of fragment output element + * counts) used to prefer less-replicated layouts. + * + * @param layout_map[in,out] The current global layout map to be updated with + * a better free-mode result if found. + * @param strict_layout_map Read-only map of layouts inferred in strict mode, + * used to constrain free-mode inference. + */ void InferInFreeMode(LayoutMap &layout_map, const LayoutMap &strict_layout_map) { // Group operators into connected components diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 0643eff5e..25e3a70f5 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -464,6 +464,32 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return var; } + /** + * @brief Handle an Evaluate node, lowering a detected tile operator to TIR. + * + * This visit implementation detects whether the Evaluate node represents a + * tile operator invocation (via ParseOperator). If no tile operator is found + * or the call targets a global function, the node is delegated to the base + * visitor. + * + * When a tile operator is present, the method: + * - Builds a workspace-allocation callback that creates a dynamic shared + * buffer named "workspace" (storage scope "shared.dyn") and returns its write + * access pointer. + * - Determines thread bounds for lowering from the analyzer's constant-int + * information for thread_var_; if unavailable, a default range [0,1) is + * used. + * - Invokes tile_op->Lower(...) with LowerArgs containing target, thread + * bounds, thread variable, the workspace callback, layout and buffer remap + * maps, and the list of GEMM-involved buffer vars; the analyzer is passed + * through for use during lowering. + * + * The lowered statement returned by the operator is then visited by the base + * IRMutatorWithAnalyzer and that result is returned. + * + * @return Stmt The (possibly transformed) statement after lowering or base + * visitor processing. + */ Stmt VisitStmt_(const EvaluateNode *op) final { // LOG(INFO) << "evaluate node: " << op->value; const CallNode *call = op->value.as();