From 5992f59f56e4d0a75bb0f755205620eeb726146f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 14:18:38 +0800 Subject: [PATCH 01/13] [Feature] Implement atomic reduction operations and enhance atomic add functionality - Added support for atomic max and min operations with corresponding classes and methods. - Introduced vectorization capabilities for atomic add operations based on data type and target architecture. - Refactored atomic add intrinsic calls to improve clarity and consistency. - Enhanced layout inference for atomic operations and integrated new utility functions for target architecture checks. - Updated tests to validate the new atomic operations and their vectorization behavior. --- src/op/atomic_add.cc | 294 ++------------ src/op/atomic_add.h | 59 ++- src/op/atomic_reduce.cc | 287 +++++++++++++ src/op/atomic_reduce.h | 145 +++++++ src/op/builtin.cc | 47 ++- src/op/builtin.h | 72 +++- src/op/parallel.cc | 14 +- src/target/codegen_cuda.cc | 89 +++++ src/target/utils.cc | 7 + src/target/utils.h | 1 + src/transform/atomicadd_vectorize.cc | 376 ++++++------------ src/transform/atomicadd_vectorize.h | 58 +-- src/transform/loop_partition.cc | 6 +- src/transform/loop_vectorize.cc | 41 +- ...dd.py => test_tilelang_language_atomic.py} | 82 +++- tilelang/language/atomic.py | 130 ++++-- 16 files changed, 1062 insertions(+), 646 deletions(-) create mode 100644 src/op/atomic_reduce.cc create mode 100644 src/op/atomic_reduce.h rename testing/python/language/{test_tilelang_language_atomic_add.py => test_tilelang_language_atomic.py} (83%) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 538f59fa9..b34e71425 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -15,7 +15,6 @@ #include "../target/utils.h" #include "../transform/atomicadd_vectorize.h" #include "../transform/common/loop_fusion_utils.h" -#include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" @@ -75,73 +74,32 @@ TileOperator AtomicAddNode::Clone() const { return AtomicAdd(op); } -/** - * @brief Create data-parallel iteration variables for non-singleton dimensions - * of the source. - * - * Constructs an Array of IterVar corresponding to each dimension in `src_range` - * whose extent is not equal to 1. Each IterVar has domain Range(0, extent), a - * Var named sequentially ("i", "j", "k", ...) with the same dtype as the - * extent, and type IterVarType::kDataPar. The ordering of returned itervars - * matches the order of dimensions in `src_range`. - * - * @return Array Iteration variables for all non-singleton extents in - * `src_range`. - */ -Array AtomicAddNode::MakeIterVars() const { - Array loop_vars; - size_t idx = 0; - for (size_t i = 0; i < src_range.size(); i++) { - if (is_one(src_range[i]->extent)) - continue; - Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); - idx++; - loop_vars.push_back( - {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); - } - return loop_vars; -} +const Op &AtomicAddNode::GetElemOpStatic() { return atomic_add_elem_op(); } -// ivs: itervars returned by MakeIterVars() /** - * @brief Build index expressions for either source or destination from loop - * iter vars. + * @brief Get vectorization length based on dst dtype and target SM version. * - * Given a list of iteration variables that correspond to the non-singleton - * extents of the selected region (source when src_dst == 0, destination when - * src_dst == 1), return an array of index expressions matching the full rank of - * that region. For dimensions with extent == 1, the corresponding index is the - * range's minimum; otherwise the index is `min + ivar`. + * Returns: + * - 2 for float16/bfloat16 + * - 4 for float32 on SM >= 90 + * - 1 for all other cases * - * @param ivs Iteration variables in order for all non-singleton dimensions of - * the chosen region. - * @param src_dst Selects which region to index: 0 for source (src_range), 1 for - * destination (dst_range). - * @return Array Index expressions for every dimension of the selected - * region, in original dimension order. - * - * @note The function checks that the number of provided iter vars equals the - * number of non-singleton extents; it will abort (ICHECK) if they differ. + * @param target The target architecture to check SM version. + * @return int The vectorization length. */ -Array AtomicAddNode::MakeIndices(const Array &ivs, - int src_dst) const { - Array indices; - Array ranges = src_dst == 0 ? src_range : dst_range; - size_t idx = 0; - for (size_t i = 0; i < ranges.size(); i++) { - if (is_one(ranges[i]->extent)) - indices.push_back(ranges[i]->min); - else { - indices.push_back(ranges[i]->min + ivs[idx]->var); - idx++; - } +int AtomicAddNode::GetVectorizeLength(Target target) const { + DataType dtype = dst->dtype; + if (dtype.is_float16() || dtype.is_bfloat16()) { + return 2; + } + if (dtype.is_float() && dtype.bits() == 32 && + TargetHasSMVersionGE(target, 90)) { + return 4; } - ICHECK(idx == ivs.size()) - << "idx = " << idx << ", ivs.size() = " << ivs.size() - << "src name = " << src->name << ", dst name = " << dst->name; - return indices; + return 1; } + std::pair, PrimExpr> AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { Array indices; @@ -154,61 +112,6 @@ AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { return {indices, size}; } -/** - * @brief Build a combined bound-check predicate for indexed access. - * - * Constructs an AND'd predicate ensuring each non-singleton index (derived from - * `ivs`) stays within [0, extent) for the selected operand (source when - * `src_dst==0`, destination otherwise). For each non-unit Range in the chosen - * range list this produces two conditions: - * - range.min + iv >= 0 - * - range.min + iv < extent - * - * Conditions that the analyzer can prove (with symbolic bounds) are omitted. - * If no uncertain conditions remain, an empty PrimExpr is returned. - * - * Note: the function ICHECKs that `extents.size()` equals the number of ranges - * for the selected operand. - * - * @param ivs Iteration variables corresponding to non-singleton extents (order - * matches the non-unit ranges of the chosen operand). - * @param extents Per-dimension upper bounds to check against; must have the - * same size as the selected range list. - * @param src_dst Selects which ranges to validate: 0 => `src_range`, else - * `dst_range`. - * @return PrimExpr A conjunction of remaining (non-provable) bounds checks, or - * an empty PrimExpr when no checks are required. - */ -PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, - const Array &ivs, - Array extents, - int src_dst) const { - Array ranges = src_dst == 0 ? src_range : dst_range; - Array cond_list; - ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; - size_t idx = 0; - for (size_t i = 0; i < ranges.size(); i++) { - if (is_one(ranges[i]->extent)) - continue; - PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i]; - if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { - cond_list.push_back(cond); - } - cond = ranges[i]->min + ivs[idx]->var >= 0; - if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { - cond_list.push_back(cond); - } - idx++; - } - if (cond_list.empty()) - return {}; - else { - PrimExpr cond = cond_list[0]; - for (size_t i = 1; i < cond_list.size(); i++) - cond = And(cond, cond_list[i]); - return cond; - } -} /** * @brief Build a SIMT-style loop nest that performs element-wise atomic @@ -226,8 +129,9 @@ PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, * - Validates loop variable counts against src/dst ranges (ICHECK on mismatch). * - Computes indexed accesses and emits optional bound predicates; * out-of-bounds accesses are masked to zero when predicates are uncertain. - * - Emits an extern `call_extern("AtomicAdd", address_of(dst_value), - * src_value)` call wrapped in an Evaluate statement. + * - Emits an extern `call_intrin(op.Op.get("tl.atomic_add_elem_op"), + * address_of(dst_value), src_value), annotations)` call wrapped in an Evaluate + * statement. * - Wraps the body with a parallel For at each loop level. If `coalesced_width` * is defined it is attached as the "coalesced_width" annotation on each loop. * @@ -285,7 +189,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { auto annotations = this->annotations; annotations.erase("use_tma"); Call atomicadd_call = - tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args, annotations); + tvm::tir::Call(dst->dtype, atomic_add_elem_op(), new_args, annotations); Stmt body = tvm::tir::Evaluate(atomicadd_call); @@ -658,142 +562,24 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto transformed_loop = - Downcast(ParallelLoopTransformer::Substitute(fused_loop)); - - auto GetArchInt = [&](const Target &tgt) -> int { - int arch_int = 0; - if (auto s = tgt->GetAttr("arch")) { - std::string arch = s.value(); - if (arch.rfind("sm_", 0) == 0) - arch_int = std::stoi(arch.substr(3)); - } - return arch_int; - }; - - struct AtomicLoopNestCollector : tir::StmtExprVisitor { - Array loop_vars; - Map> indice_map; - std::unordered_set writes; - arith::Analyzer analyzer; - - void Run(const Stmt &s) { StmtExprVisitor::VisitStmt(s); } - - void VisitStmt_(const ForNode *op) final { - if (op->kind == ForKind::kParallel) { - loop_vars.push_back(IterVar(Range(op->min, op->extent), op->loop_var, - IterVarType::kDataPar)); - } - analyzer.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - StmtExprVisitor::VisitStmt_(op); - } - void VisitStmt_(const BufferStoreNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - indice_map.Set(op->buffer, op->indices); - writes.insert(op->buffer); - } - StmtExprVisitor::VisitStmt_(op); - } - void VisitExpr_(const BufferLoadNode *op) final { - if (IsFragmentBuffer(op->buffer)) { - indice_map.Set(op->buffer, op->indices); - } - StmtExprVisitor::VisitExpr_(op); - } - }; - - auto ComputeLoopLayoutFromBuffer = - [&](const Buffer &buf, const Array &indices, - const LayoutMap &layout_map, const Range &thread_bounds, - const Array &loop_vars) -> Fragment { - Fragment src = layout_map[buf].as().value(); - Var rep; - auto rep_iter = - IterVar(Range(0, src->ReplicateExtent()), rep, IterVarType::kDataPar); - PrimExpr fth = src->ForwardThread(indices, rep); - fth = analyzer->Simplify(fth); - Fragment out = Fragment(loop_vars, /*forward_index=*/{}, fth, rep_iter) - ->BindThreadRange(thread_bounds); - return out; - }; - - struct AtomicInferResult { - Fragment loop_layout; - Optional predicate; - }; - - auto AtomicAddInferLayout = - [&](const For &loop, const LayoutInferArgs &args) -> AtomicInferResult { - AtomicLoopNestCollector C; - C.Run(loop); - Optional read_src; - int best_rank = -1; - for (auto kv : C.indice_map) { - const Buffer &buf = kv.first; - if (!IsFragmentBuffer(buf)) - continue; - if (!args.layout_map.count(buf)) - continue; - int rank = static_cast(kv.second.size()); - if (rank > best_rank) { - best_rank = rank; - read_src = buf; - } - } - AtomicAddVectorizePlanner planner; - int sm = GetArchInt(target); - auto plan = planner.Plan(loop, sm); - int vec = std::max(plan.vector_size, 1); - if (auto cw = loop->annotations.Get(attr::kCoalescedWidth)) { - if (const auto *imm = cw->as()) { - int expected = imm->value; - ICHECK_GT(expected, 0); - ICHECK(vec % expected == 0) - << "vector_size " << vec << " not divisible by coalesced_width " - << expected; - vec = expected; - } else { - LOG(FATAL) << "coalesced_width should be IntImmNode."; - } - } - PrimExpr total = 1; - for (Stmt s = loop; s.as().has_value(); s = s.as().value()->body) - total = total * s.as().value()->extent; - PrimExpr denom = args.thread_bounds->extent * vec; - while (!analyzer->CanProve(floormod(total, denom) == 0) && vec > 1) { - vec >>= 1; - denom = args.thread_bounds->extent * vec; - } - if (vec < 1) - vec = 1; - Fragment loop_layout; - if (read_src) { - loop_layout = ComputeLoopLayoutFromBuffer( - read_src.value(), C.indice_map[read_src.value()], args.layout_map, - args.thread_bounds, C.loop_vars); - } else { - const For &remapped = loop; - loop_layout = PlanLoopPartition(remapped, vec, args.thread_bounds); - } - - Optional pred; - if (plan.dynamic && plan.condition.defined()) { - pred = plan.condition; - } - DLOG(INFO) << "[AtomicAddInferLayout] vec=" << vec - << " loop_layout=" << loop_layout->DebugOutput(); - return {loop_layout, pred}; - }; - - auto ret = AtomicAddInferLayout(transformed_loop, - {T.target, T.thread_bounds, T.layout_map, - analyzer, false, T.buffer_remap}); - Fragment loop_layout = ret.loop_layout; - auto thread_loop = - PartitionLoop(transformed_loop, T.thread_var, analyzer, loop_layout); - auto vectorized_thread_loop = - VectorizeAtomicAdd(thread_loop, GetArchInt(target)); - return vectorized_thread_loop; + auto par_op = ParallelOp(fused_loop); + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + // 1.give par_op a recommended vectorize size. (only works for free layout inference). + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, + analyzer, par_op->GetPredicate(T.thread_var)); + return lowered_loop; } TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd) diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index f13e827a5..1fd1b8c8d 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -6,34 +6,31 @@ #ifndef TVM_TL_OP_ATOMIC_ADD_H_ #define TVM_TL_OP_ATOMIC_ADD_H_ -#include "operator.h" -#include "parallel.h" +#include "atomic_reduce.h" namespace tvm { namespace tl { using namespace tir; -/// Node class for atomic addition operations -class AtomicAddNode : public TileOperatorNode { +/*! + * \brief Node class for atomic addition operations. + * + * Inherits from AtomicOpBaseNode and adds TMA support and vectorization. + */ +class AtomicAddNode : public AtomicOpBaseNode { public: - Buffer src, dst; ///< Source and destination buffers - Array src_range, - dst_range; ///< Access ranges for source and destination - Map annotations; ///< Annotations for the atomic operation - // Supported annotation keys: - // - "use_tma": IntImm, whether to use TMA for memory operations - // - "coalesced_width": IntImm, width for memory coalescing optimization - // - "memory_order": IntImm, memory order for atomic operations - - mutable ParallelOp par_op_; ///< Associated parallel operation TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode, TileOperatorNode); + /// Override Lower to add TMA support Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + + /// Override InferLayout to add TMA layout inference LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; static const Op &Get(); + static const Op &GetElemOpStatic(); TileOperator Clone() const; static void RegisterReflection() { @@ -46,7 +43,7 @@ class AtomicAddNode : public TileOperatorNode { .def_ro("annotations", &AtomicAddNode::annotations); } - // Helper methods to get annotation values + /// Check if TMA should be used bool GetUseTMA() const { if (auto val = annotations.Get("use_tma")) { if (auto int_val = val->as()) { @@ -56,29 +53,21 @@ class AtomicAddNode : public TileOperatorNode { return false; } - int GetMemoryOrder() const { - if (auto val = annotations.Get("memory_order")) { - if (auto int_val = val->as()) { - return int_val->value; - } - } - return 0; // default: relaxed - } + /// Get vectorization length based on dst dtype and target SM version + int GetVectorizeLength(Target target) const; protected: - /// Create SIMT-style parallel loop structure + /// Override MakeSIMTLoop to handle AtomicAdd-specific logic For MakeSIMTLoop(arith::Analyzer *analyzer) const; - /// Generate iteration variables for loop nest - Array MakeIterVars() const; - /// Generate buffer indices from iteration variables - Array MakeIndices(const Array &ivs, int src_dst) const; - /// Return buffer indices and size + + /// Return buffer indices and total size std::pair, PrimExpr> ReturnIndicesAndSize(int src_dst) const; - /// Create boundary predicate for memory safety - PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, - Array extents, int src_dst) const; + /// Compute linear layout for shared tensor (used in TMA atomic add) Layout ComputeLinearLayout(const Buffer &shared_tensor) const; + + /// Lower TMA-based atomic add + Stmt LowerTMA(const LowerArgs &T, arith::Analyzer *analyzer) const; }; /// Wrapper class for atomic addition operations @@ -92,7 +81,7 @@ class AtomicAdd : public TileOperator { static const Op &Get(); }; -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm -#endif // TVM_TL_OP_ATOMIC_ADD_H_ +#endif // TVM_TL_OP_ATOMIC_ADD_H_ diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc new file mode 100644 index 000000000..e585aa03d --- /dev/null +++ b/src/op/atomic_reduce.cc @@ -0,0 +1,287 @@ +/*! + * \file tl/op/atomic_reduce.cc + * + * Define atomic reduction operators (max/min). + */ + +#include "./atomic_reduce.h" +#include "./atomic_add.h" +#include "utils.h" +#include +#include +#include + +#include "../layout/layout.h" +#include "../target/utils.h" +#include "../transform/atomicreduce_lower.h" +#include "../transform/common/loop_fusion_utils.h" +#include "../transform/loop_partition.h" +#include "builtin.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +// ============================================================================ +// AtomicMax Implementation +// ============================================================================ + +AtomicMax::AtomicMax(Array args, Map annotations) { + ObjectPtr node = tvm::ffi::make_object(); + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->annotations = annotations; + data_ = std::move(node); +} + +TileOperator AtomicMaxNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } + return AtomicMax(op); +} + +const Op &AtomicMaxNode::GetElemOpStatic() { return atomic_max_elem_op(); } + +// ============================================================================ +// AtomicMin Implementation +// ============================================================================ + +AtomicMin::AtomicMin(Array args, Map annotations) { + ObjectPtr node = tvm::ffi::make_object(); + Array rgs[2]; + Buffer bf[2]; + for (int i = 0; i < 2; i++) { + auto region = NormalizeToBufferRegion(args[i]); + rgs[i] = region->region; + bf[i] = region->buffer; + } + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); + node->annotations = annotations; + data_ = std::move(node); +} + +TileOperator AtomicMinNode::Clone() const { + auto op = tvm::ffi::make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } + return AtomicMin(op); +} + +const Op &AtomicMinNode::GetElemOpStatic() { return atomic_min_elem_op(); } + +// ============================================================================ +// Common AtomicOpBaseNode Implementation +// ============================================================================ + +template +Array AtomicOpBaseNode::MakeIterVars() const { + Array loop_vars; + size_t idx = 0; + for (size_t i = 0; i < src_range.size(); i++) { + if (is_one(src_range[i]->extent)) + continue; + Var var = Var(std::string{char('i' + idx)}, src_range[i]->extent->dtype); + idx++; + loop_vars.push_back( + {Range(0, src_range[i]->extent), var, IterVarType::kDataPar}); + } + return loop_vars; +} + +template +Array +AtomicOpBaseNode::MakeIndices(const Array &ivs, + int src_dst) const { + Array indices; + Array ranges = src_dst == 0 ? src_range : dst_range; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + indices.push_back(ranges[i]->min); + else { + indices.push_back(ranges[i]->min + ivs[idx]->var); + idx++; + } + } + ICHECK(idx == ivs.size()) + << "idx = " << idx << ", ivs.size() = " << ivs.size() + << "src name = " << src->name << ", dst name = " << dst->name; + return indices; +} + +template +PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, + int src_dst) const { + Array ranges = src_dst == 0 ? src_range : dst_range; + Array cond_list; + ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; + size_t idx = 0; + for (size_t i = 0; i < ranges.size(); i++) { + if (is_one(ranges[i]->extent)) + continue; + PrimExpr cond = ranges[i]->min + ivs[idx]->var < extents[i]; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + cond = ranges[i]->min + ivs[idx]->var >= 0; + if (!analyzer->CanProve(cond, arith::ProofStrength::kSymbolicBound)) { + cond_list.push_back(cond); + } + idx++; + } + if (cond_list.empty()) + return {}; + else { + PrimExpr cond = cond_list[0]; + for (size_t i = 1; i < cond_list.size(); i++) + cond = And(cond, cond_list[i]); + return cond; + } +} + +template +For AtomicOpBaseNode::MakeSIMTLoop( + arith::Analyzer *analyzer) const { + Array loop_vars = MakeIterVars(); + bool is_scalar = loop_vars.empty(); + if (is_scalar) { + return For(Var("i"), 0, 1, ForKind::kSerial, + BufferStore(dst, BufferLoad(src, {0}), {0})); + } + + for (const auto &iv : loop_vars) + analyzer->Bind(iv->var, iv->dom); + + ICHECK(loop_vars.size() <= src_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", src_range.size() = " << src_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + ICHECK(loop_vars.size() <= dst_range.size()) + << "loop_vars.size() = " << loop_vars.size() + << ", dst_range.size() = " << dst_range.size() << ", src = " << src->name + << ", dst = " << dst->name; + + Array src_indices = MakeIndices(loop_vars, 0); + Array dst_indices = MakeIndices(loop_vars, 1); + + Array new_args; + + // Load source value and cast to dst dtype if needed + PrimExpr src_value = BufferLoad(src, src_indices); + if (src->dtype != dst->dtype) + src_value = Cast(dst->dtype, src_value); + + // Build a pointer to destination element using tvm_access_ptr + PrimExpr dst_ptr = Call(DataType::Handle(), builtin::address_of(), + {BufferLoad(dst, dst_indices)}); + + new_args.push_back(dst_ptr); + new_args.push_back(src_value); + new_args.push_back(static_cast(this)->GetMemoryOrder()); + + // Use the appropriate elem_op based on the derived type (via CRTP) + Call atomic_call = tvm::tir::Call(dst->dtype, GetElemOp(), new_args, annotations); + + Stmt body = tvm::tir::Evaluate(atomic_call); + + for (int i = loop_vars.size() - 1; i >= 0; i--) { + Map loop_annotations; + if (i == 0) { + if (annotations.count(attr::kCoalescedWidth)) { + loop_annotations.Set(attr::kCoalescedWidth, + annotations.Get(attr::kCoalescedWidth).value()); + } + } + + body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, + ForKind::kParallel, body, std::nullopt, loop_annotations); + } + return Downcast(body); +} + +template +LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + // For atomic reduce operations, check that src and dst have the same layout + // if both are fragments + if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { + if (T.layout_map.count(src) && T.layout_map.count(dst)) { + Layout src_layout = T.layout_map.at(src); + Layout dst_layout = T.layout_map.at(dst); + ICHECK(StructuralEqual()(src_layout, dst_layout)) + << "Atomic reduce requires src and dst to have the same layout, but " + "got " + << "src layout: " << src_layout << ", dst layout: " << dst_layout + << " for src buffer: " << src->name << ", dst buffer: " << dst->name; + } + } + return {}; +} + +template +Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { + Target target = T.target; + + auto simt_loop = MakeSIMTLoop(analyzer); + auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); + auto par_op = ParallelOp(fused_loop); + std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, + InferLevel::kFree}; + for (auto level : levels) { + par_op->InferLayout({T.target, + T.thread_bounds, + T.layout_map, + analyzer, + false, + T.buffer_remap, + {}}, + level); + } + auto loop_layout = par_op->GetLoopLayout(); + auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, + analyzer, par_op->GetPredicate(T.thread_var)); + return lowered_loop; +} + +// Explicit template instantiations +template class AtomicOpBaseNode; +template class AtomicOpBaseNode; +template class AtomicOpBaseNode; + +// ============================================================================ +// Operator Registration +// ============================================================================ + +TIR_REGISTER_TL_TILE_OP(AtomicMax, atomicmax) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_REGISTER_TL_TILE_OP(AtomicMin, atomicmin) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TVM_FFI_STATIC_INIT_BLOCK() { + AtomicMaxNode::RegisterReflection(); + AtomicMinNode::RegisterReflection(); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/atomic_reduce.h b/src/op/atomic_reduce.h new file mode 100644 index 000000000..f4c3e0eed --- /dev/null +++ b/src/op/atomic_reduce.h @@ -0,0 +1,145 @@ +/*! + * \file tl/op/atomic_reduce.h + * \brief Atomic operations base class and reduction operations (max/min) + */ + +#ifndef TVM_TL_OP_ATOMIC_REDUCE_H_ +#define TVM_TL_OP_ATOMIC_REDUCE_H_ + +#include "operator.h" +#include "parallel.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Base node class for atomic operations (add/max/min). + * + * This template base class provides common functionality for all atomic + * operations including buffer management, loop generation, and layout inference. + * + * \tparam Derived The derived class type (CRTP pattern) + */ +template +class AtomicOpBaseNode : public TileOperatorNode { +public: + Buffer src, dst; ///< Source and destination buffers + Array src_range, dst_range; ///< Access ranges for source and destination + Map annotations; ///< Annotations for the atomic operation + // Supported annotation keys: + // - "coalesced_width": IntImm, width for memory coalescing optimization + // - "memory_order": IntImm, memory order for atomic operations + + mutable ParallelOp par_op_; ///< Associated parallel operation + + /// Default Lower implementation for non-TMA atomic ops + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + + /// Default InferLayout implementation + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + + /// Get memory order from annotations (default: relaxed = 0) + int GetMemoryOrder() const { + if (auto val = annotations.Get("memory_order")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; + } + +protected: + /// Create SIMT-style parallel loop structure + For MakeSIMTLoop(arith::Analyzer *analyzer) const; + + /// Generate iteration variables for loop nest + Array MakeIterVars() const; + + /// Generate buffer indices from iteration variables + Array MakeIndices(const Array &ivs, int src_dst) const; + + /// Create boundary predicate for memory safety + PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, + Array extents, int src_dst) const; + + /// Get the element-wise operation Op (to be implemented by derived class) + /// This uses CRTP to call the derived class's static method + const Op &GetElemOp() const { + return Derived::GetElemOpStatic(); + } +}; + +// Backward compatibility alias +template +using AtomicReduceBaseNode = AtomicOpBaseNode; + +/// Node class for atomic maximum operations +class AtomicMaxNode : public AtomicOpBaseNode { +public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicMax", AtomicMaxNode, + TileOperatorNode); + + static const Op &Get(); + static const Op &GetElemOpStatic(); + TileOperator Clone() const; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &AtomicMaxNode::src) + .def_ro("dst", &AtomicMaxNode::dst) + .def_ro("src_range", &AtomicMaxNode::src_range) + .def_ro("dst_range", &AtomicMaxNode::dst_range) + .def_ro("annotations", &AtomicMaxNode::annotations); + } +}; + +/// Wrapper class for atomic maximum operations +class AtomicMax : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicMax, TileOperator, + AtomicMaxNode); + TVM_DLL + AtomicMax(Array args, + Map annotations = Map()); + static const Op &Get(); +}; + +/// Node class for atomic minimum operations +class AtomicMinNode : public AtomicOpBaseNode { +public: + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicMin", AtomicMinNode, + TileOperatorNode); + + static const Op &Get(); + static const Op &GetElemOpStatic(); + TileOperator Clone() const; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &AtomicMinNode::src) + .def_ro("dst", &AtomicMinNode::dst) + .def_ro("src_range", &AtomicMinNode::src_range) + .def_ro("dst_range", &AtomicMinNode::dst_range) + .def_ro("annotations", &AtomicMinNode::annotations); + } +}; + +/// Wrapper class for atomic minimum operations +class AtomicMin : public TileOperator { +public: + TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicMin, TileOperator, + AtomicMinNode); + TVM_DLL + AtomicMin(Array args, + Map annotations = Map()); + static const Op &Get(); +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_ATOMIC_REDUCE_H_ diff --git a/src/op/builtin.cc b/src/op/builtin.cc index 9e8bf25fb..cca6b6a37 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -351,7 +351,52 @@ TIR_DEFINE_TL_BUILTIN(increase_descriptor_offset) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_TL_BUILTIN(atomicadd_elem_op) +TIR_DEFINE_TL_BUILTIN(atomic_add_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_add_ret_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_addx2_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_addx4_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_load_elem_op) + .set_num_inputs(2) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_store_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_max_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_max_ret_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_min_elem_op) + .set_num_inputs(3) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(atomic_min_ret_elem_op) .set_num_inputs(3) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/builtin.h b/src/op/builtin.h index e35059106..1225c5ff8 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -585,7 +585,77 @@ TVM_DLL const Op &increase_descriptor_offset(); * This op is used to represent an element-wise atomic add operation in * tilelang. */ -TVM_DLL const Op &atomicadd_elem_op(); +TVM_DLL const Op &atomic_add_elem_op(); + +/*! + * \brief tilelang intrinsic for element-wise atomic addition with return value. + * + * This op is used to represent an element-wise atomic add operation in + * tilelang that returns the previous value. + */ +TVM_DLL const Op &atomic_add_ret_elem_op(); + +/*! + * \brief tilelang intrinsic for vectorized (x2) atomic addition. + * + * This op is used to represent a vectorized atomic add operation (2 elements) + * in tilelang. + */ +TVM_DLL const Op &atomic_addx2_elem_op(); + +/*! + * \brief tilelang intrinsic for vectorized (x4) atomic addition. + * + * This op is used to represent a vectorized atomic add operation (4 elements) + * in tilelang. + */ +TVM_DLL const Op &atomic_addx4_elem_op(); + +/*! + * \brief tilelang intrinsic for atomic load. + * + * This op is used to represent an atomic load operation in tilelang. + */ +TVM_DLL const Op &atomic_load_elem_op(); + +/*! + * \brief tilelang intrinsic for atomic store. + * + * This op is used to represent an atomic store operation in tilelang. + */ +TVM_DLL const Op &atomic_store_elem_op(); + +/*! + * \brief tilelang intrinsic for element-wise atomic maximum. + * + * This op is used to represent an element-wise atomic max operation in + * tilelang. + */ +TVM_DLL const Op &atomic_max_elem_op(); + +/*! + * \brief tilelang intrinsic for element-wise atomic maximum with return value. + * + * This op is used to represent an element-wise atomic max operation in + * tilelang that returns the previous value. + */ +TVM_DLL const Op &atomic_max_ret_elem_op(); + +/*! + * \brief tilelang intrinsic for element-wise atomic minimum. + * + * This op is used to represent an element-wise atomic min operation in + * tilelang. + */ +TVM_DLL const Op &atomic_min_elem_op(); + +/*! + * \brief tilelang intrinsic for element-wise atomic minimum with return value. + * + * This op is used to represent an element-wise atomic min operation in + * tilelang that returns the previous value. + */ +TVM_DLL const Op &atomic_min_ret_elem_op(); /*! * \brief tilelang intrinsic for assert on device. diff --git a/src/op/parallel.cc b/src/op/parallel.cc index a35f98e15..3b3012469 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -9,6 +9,7 @@ #include #include "../layout/layout.h" +#include "arith/int_operator.h" #include "../layout/utils.h" #include "../target/utils.h" @@ -453,17 +454,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, bool has_cross_thread_access = false; PostOrderVisit(root_, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { - // check if scope is shared or global - if (store->buffer.scope() == "shared" || - store->buffer.scope() == "shared.dyn" || - store->buffer.scope() == "global") { + if (IsSharedBuffer(store->buffer) || IsGlobalBuffer(store->buffer)) { has_cross_thread_access = true; } } else if (const auto *load = obj.as()) { - // check if scope is shared or global - if (load->buffer.scope() == "shared" || - load->buffer.scope() == "shared.dyn" || - load->buffer.scope() == "global") { + if (IsSharedBuffer(load->buffer) || IsGlobalBuffer(load->buffer)) { has_cross_thread_access = true; } } @@ -478,8 +473,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, PostOrderVisit(root_, [&](const ObjectRef &obj) { if (const auto *store = obj.as()) { auto buffer = store->buffer; - if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" || - buffer.scope() == "global") { + if (IsSharedBuffer(buffer) || IsGlobalBuffer(buffer)) { store_shared_global_buffers.emplace_back(buffer); } else if (IsFragmentBuffer(buffer)) { store_fragment_buffers.emplace_back(buffer); diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 0dfb85341..d629cf1aa 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2888,6 +2888,95 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_bitor())) { os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::atomic_add_elem_op())) { + // atomic_add_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAdd(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_add_ret_elem_op())) { + // atomic_add_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev + // value + os << "AtomicAddRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; + } else if (op->op.same_as(tl::atomic_addx2_elem_op())) { + // atomic_addx2_elem_op(dst_ptr, src_ptr[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_ptr = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAddx2(" << dst_ptr << ", " << src_ptr; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_addx4_elem_op())) { + // atomic_addx4_elem_op(dst_ptr, src_ptr[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_ptr = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAddx4(" << dst_ptr << ", " << src_ptr; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_load_elem_op())) { + // atomic_load_elem_op(src_ptr, memory_order) -> returns loaded value + os << "AtomicLoad(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::atomic_store_elem_op())) { + // atomic_store_elem_op(dst_ptr, value, memory_order) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string value = PrintExpr(op->args[1]); + std::string memory_order = PrintExpr(op->args[2]); + this->PrintIndent(); + this->stream << "AtomicStore(" << dst_ptr << ", " << value << ", " + << memory_order << ");\n"; + } else if (op->op.same_as(tl::atomic_max_elem_op())) { + // atomic_max_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicMax(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_max_ret_elem_op())) { + // atomic_max_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev + // value + os << "AtomicMaxRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; + } else if (op->op.same_as(tl::atomic_min_elem_op())) { + // atomic_min_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicMin(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_min_ret_elem_op())) { + // atomic_min_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev + // value + os << "AtomicMinRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/target/utils.cc b/src/target/utils.cc index 993590ffb..cdbcb8e45 100644 --- a/src/target/utils.cc +++ b/src/target/utils.cc @@ -127,6 +127,13 @@ bool TargetHasBulkCopy(Target target) { return arch >= 90; } +bool TargetHasSMVersionGE(Target target, int version) { + if (!TargetIsCuda(target)) + return false; + int arch = GetArchInt(target); + return arch >= version; +} + int TargetGetWarpSize(Target target) { int res = 32; if (TargetIsCDNA(target)) diff --git a/src/target/utils.h b/src/target/utils.h index 9de2d4d4f..9db147818 100644 --- a/src/target/utils.h +++ b/src/target/utils.h @@ -29,6 +29,7 @@ bool TargetHasStmatrix(Target target); bool TargetHasTmem(Target target); bool TargetHasBulkCopy(Target target); int TargetGetWarpSize(Target target); +bool TargetHasSMVersionGE(Target target, int version); bool IsCudaVectorizableFP8(DataType dtype); bool IsCudaVectorizableCast(DataType from_ty, DataType target_ty); diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index d66a538db..d09d48618 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -1,6 +1,9 @@ /*! * \file atomicadd_vectorize.cc - * \brief A tool to automatically vectorize atomic add + * \brief Automatic vectorization pass for atomic add operations. + * + * This pass detects atomic_add_elem_op inside vectorized loops and converts + * them to vectorized versions (atomic_addx2_elem_op or atomic_addx4_elem_op). */ #include "atomicadd_vectorize.h" @@ -9,300 +12,145 @@ namespace tvm { namespace tl { using namespace tir; -using arith::IRMutatorWithAnalyzer; -using arith::IRVisitorWithAnalyzer; -AtomicAddVectorizePlanner::AtomicAddVectorizePlanner() = default; +namespace { -AtomicAddVectorizePlanResult -AtomicAddVectorizePlanner::Plan(const For &node, int compute_capability) { - int vectorize_size_max = 1; - this->vector_size_ = 4; - this->dynamic_ = false; - this->condition_ = PrimExpr(); - - PostOrderVisit(node, [&](const ObjectRef &obj) { - if (const auto *call = obj.as()) { - if (call->op == atomicadd_elem_op()) { - if (call->args.size() < 2) { - // Fallback: unexpected arity - vectorize_size_max = 1; - DLOG(WARNING) << "[AtomicAddVectorizePlanner] atomicadd_elem_op " - "expects 2 args, got " - << call->args.size() << "; Fallback to no vectorize"; - return; - } - DataType dtype; - if (const auto *load = call->args[0].as()) { - dtype = load->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); - } else if (const auto *ite = call->args[0].as()) { - if (const auto *then_load = ite->then_case.as()) { - dtype = then_load->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); - } else if (const auto *else_load = - ite->else_case.as()) { - dtype = else_load->dtype; - vectorize_size_max = GetVectorizeSizeMax(compute_capability, dtype); - } else { - // fallback - vectorize_size_max = 1; - DLOG(WARNING) << "[AtomicAddVectorizePlanner] IfThenElse case " - "has no BufferLoad; Fallback to no vectorize"; - } - } else { - // fallback - vectorize_size_max = 1; - DLOG(WARNING) << "[AtomicAddVectorizePlanner] Unexpected arg1 type " - << call->args[1]->GetTypeKey() - << "; Fallback to no vectorize"; - } +/*! + * \brief Extract BufferLoad from an expression that may be wrapped in address_of. + */ +Optional ExtractBufferLoad(const PrimExpr &expr) { + if (const auto *load = expr.as()) { + return tvm::ffi::GetRef(load); + } + if (const auto *call = expr.as()) { + if (call->op.same_as(builtin::address_of()) && !call->args.empty()) { + if (const auto *load = call->args[0].as()) { + return tvm::ffi::GetRef(load); } } - }); - - if (vectorize_size_max <= 1) { - return {1, dynamic_, condition_}; } - - this->max_vector_size = vectorize_size_max; - this->operator()(node); - return {vector_size_, dynamic_, condition_}; + return Optional(); } -void AtomicAddVectorizePlanner::VisitStmt_(const ForNode *node) { - inner_for_ = node; - arith::IRVisitorWithAnalyzer::VisitStmt_(node); +/*! + * \brief Get the vectorized atomic add op based on vector size. + */ +Op GetVectorizedAtomicOp(int vector_size) { + switch (vector_size) { + case 4: + return atomic_addx4_elem_op(); + case 2: + return atomic_addx2_elem_op(); + default: + return atomic_add_elem_op(); + } } -void AtomicAddVectorizePlanner::VisitExpr_(const CallNode *node) { - if (node->op == atomicadd_elem_op() && !node->args.empty()) { - if (node->args.size() < 2) { - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); - } - const BufferLoadNode *buffer_load_dst = node->args[0].as(); - const BufferLoadNode *buffer_load_src = node->args[1].as(); - if (buffer_load_src && buffer_load_src->buffer.defined() && - buffer_load_dst && buffer_load_dst->buffer.defined()) { - Buffer dst_buffer = buffer_load_dst->buffer; - UpdateVectorSize(buffer_load_dst->indices, dst_buffer); +/*! + * \brief Rewriter that transforms atomic_add_elem_op inside vectorized loops. + * + * Strategy: Detect ForKind::kVectorized loops, use their extent as vector size, + * and convert atomic_add_elem_op to the corresponding vectorized version. + */ +class AtomicAddVectorizeRewriter : public StmtExprMutator { +public: + explicit AtomicAddVectorizeRewriter(Target target) : target_(target) {} - Buffer src_buffer = buffer_load_src->buffer; - UpdateVectorSize(buffer_load_src->indices, src_buffer); +private: + /*! + * \brief Get the max vector size supported by the given dtype. + */ + int GetMaxVectorSize(DataType dtype) const { + if (dtype.is_float16() || dtype.is_bfloat16()) { + return 2; } + if (dtype.is_float() && dtype.bits() == 32 && TargetHasSMVersionGE(target_, 90)) { + return 4; + } + return 1; } - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); -} - -int AtomicAddVectorizePlanner::GetVectorizeSizeMax(int compute_capability, - DataType dtype) { - if (dtype == DataType::Float(16)) { - return 2; - } - if (dtype == DataType::BFloat(16)) { - return compute_capability > 75 ? 2 : 1; - } - if (dtype == DataType::Float(32)) { - return compute_capability >= 90 ? 4 : 1; - } - return 1; -} - -void AtomicAddVectorizePlanner::UpdateVectorSize(const Array &indices, - const Buffer &buffer) { - if (!inner_for_) - return; - auto extent_ptr = inner_for_->extent.as(); - if (!extent_ptr) - return; - const DataType &access_type = buffer->dtype; - max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value); - - auto last_dim = buffer->shape.back(); - auto mod_set = analyzer_.modular_set(last_dim); + Stmt VisitStmt_(const ForNode *node) final { + // Check if this is a vectorized loop + if (node->kind == ForKind::kVectorized) { + auto extent_ptr = as_const_int(node->extent); + if (extent_ptr) { + int vec_size = static_cast(*extent_ptr); + // Push vectorized context + vectorized_loop_ = node; + vector_size_ = vec_size; + + Stmt body = VisitStmt(node->body); + + // If we successfully vectorized atomic ops, transform the loop + if (has_vectorized_atomic_) { + has_vectorized_atomic_ = false; + vectorized_loop_ = nullptr; + vector_size_ = 1; + + // Change loop extent to 1 since atomic op now handles all elements + return For(node->loop_var, node->min, Integer(1), node->kind, body, + node->thread_binding, node->annotations, node->step, node->span); + } - if (buffer->shape.back().as()) { - max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff); - auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base); + vectorized_loop_ = nullptr; + vector_size_ = 1; - if (gcd_base < Downcast(last_dim)->value) { - max_vector_size = gcd_base; + if (body.same_as(node->body)) { + return tvm::ffi::GetRef(node); + } + return For(node->loop_var, node->min, node->extent, node->kind, body, + node->thread_binding, node->annotations, node->step, node->span); + } } + return StmtExprMutator::VisitStmt_(node); + } - vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_); - - PrimExpr elem_offset = 0; - PrimExpr stride = 1; - for (int i = indices.size() - 1; i >= 0; --i) { - elem_offset = elem_offset + indices[i] * stride; - stride = stride * buffer->shape[i]; + PrimExpr VisitExpr_(const CallNode *node) final { + if (node->op != atomic_add_elem_op() || node->args.size() < 2) { + return StmtExprMutator::VisitExpr_(node); } - while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { - vector_size_ /= 2; + // Must be inside a vectorized loop + if (!vectorized_loop_ || vector_size_ <= 1) { + return StmtExprMutator::VisitExpr_(node); } - } else if (vector_size_ <= 4) { - dynamic_ = true; - PrimExpr offset = buffer.OffsetOf(indices).back(); - condition_ = (truncmod(offset, vector_size_) == 0); - } -} -class AtomicAddVectorizeRewriter : public StmtExprMutator { -public: - AtomicAddVectorizeRewriter(const AtomicAddVectorizePlanResult &plan) - : vector_size_(plan.vector_size), dynamic_(plan.dynamic), - condition_(plan.condition) {} + auto dst_load = ExtractBufferLoad(node->args[0]); + auto src_load = ExtractBufferLoad(node->args[1]); -private: - /** - * @brief Visits a For node and rewrites the innermost loop for atomic-add - * vectorization. - * - * If the visited For node is the recorded innermost loop, this method - * validates that the loop extent is a constant, divisible by the planned - * vector size, and has a zero minimum. When vectorization is enabled - * (dynamic_ == false) it: - * - locates the thread index variable named "tx" inside the loop body, - * - creates a new outer loop variable named "_outer", - * - substitutes occurrences of `tx` with `tx * vector_size_` and the old - * loop var with `outer_var * vector_size_` so each outer iteration maps to a - * contiguous vector-sized chunk, - * - returns a new For with extent divided by vector_size_ and the - * transformed body. - * - * If dynamic_ is true, the method returns the (possibly mutated) inner For - * unchanged. - * - * Side effects: - * - updates inner_for_ to point to the current For node during visitation. - * - performs runtime checks (ICHECK) to enforce: constant extent, extent % - * vector_size_ == 0, and zero loop minimum; violations terminate execution. - * - * @return The original or transformed For statement as a Stmt. - */ - Stmt VisitStmt_(const ForNode *node) final { - inner_for_ = node; - auto ret = StmtExprMutator::VisitStmt_(node); - if (vector_size_ == 1) - return ret; - if (inner_for_ == node) { - For fnode = ret.as().value(); - auto old_var = fnode->loop_var; - auto new_var = Var(old_var->name_hint); - auto extent_ptr = as_const_int(fnode->extent); - ICHECK(extent_ptr) << fnode->extent; - int extent = *extent_ptr; - ICHECK(extent % vector_size_ == 0) - << "extent: " << extent << " vector_size_: " << vector_size_; - ICHECK(is_zero(fnode->min)); - if (!dynamic_) { - Map vmap; - vmap.Set(old_var, new_var * vector_size_); - Stmt body = Substitute(fnode->body, vmap); - return For(new_var, 0, extent / vector_size_, fnode->kind, body, - fnode->thread_binding, fnode->annotations, fnode->step, - fnode->span); - } + if (!dst_load.defined() || !src_load.defined()) { + return StmtExprMutator::VisitExpr_(node); } - return ret; - } - PrimExpr VisitExpr_(const CallNode *node) final { - bool legal_vectorize = true; - if (dynamic_) - legal_vectorize = false; - if (!(node->op == atomicadd_elem_op())) - legal_vectorize = false; - if (node->args.size() < 2) - legal_vectorize = false; - if (legal_vectorize) { - const BufferLoadNode *temp_dst_node = node->args[0].as(); - const BufferLoadNode *temp_value_node = - node->args[1].as(); - if (!temp_dst_node || !temp_value_node) - legal_vectorize = false; + // Check if dtype supports this vector size + DataType dtype = dst_load.value()->buffer->dtype; + if (vector_size_ > GetMaxVectorSize(dtype)) { + return StmtExprMutator::VisitExpr_(node); } - if (legal_vectorize) { - const BufferLoad dst_node = Downcast(node->args[0]); - const BufferLoad value_node = Downcast(node->args[1]); - // The default memory order is relaxed - // Ref: src/tl_templates/cuda/atomic.h::AtomicAdd - const IntImm memory_order = - node->args.size() >= 3 ? Downcast(node->args[2]) : IntImm(0); - Array new_args; - Call address_of_dst = - Call(DataType::Handle(), builtin::address_of(), {dst_node}); - Call address_of_value = - Call(DataType::Handle(), builtin::address_of(), {value_node}); - if (vector_size_ == 4) { - new_args.push_back(StringImm("AtomicAddx4")); - new_args.push_back(address_of_dst); - new_args.push_back(address_of_value); - } else if (vector_size_ == 2) { - new_args.push_back(StringImm("AtomicAddx2")); - new_args.push_back(address_of_dst); - new_args.push_back(address_of_value); - } else { - // Scalar case: AtomicAdd now expects a pointer to destination. - new_args.push_back(StringImm("AtomicAdd")); - new_args.push_back(address_of_dst); - new_args.push_back(value_node); - } - new_args.push_back(memory_order); - Call new_call = - tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); + // Mark that we have vectorized an atomic op + has_vectorized_atomic_ = true; - return new_call; - } else { - Array new_args; - new_args.push_back(StringImm("AtomicAdd")); - // Ensure first argument is an address; keep value as-is. - if (!node->args.empty()) { - if (const auto *bl = node->args[0].as()) { - Call address_of_dst = Call(DataType::Handle(), builtin::address_of(), - {Downcast(node->args[0])}); - new_args.push_back(address_of_dst); - } else if (const auto *call = node->args[0].as()) { - // If it's already an address_of, forward it; otherwise, keep - // original. - if (call->op.same_as(builtin::address_of())) { - new_args.push_back(node->args[0]); - } else { - new_args.push_back(node->args[0]); - } - } else { - new_args.push_back(node->args[0]); - } - // Push remaining args unchanged (value, optional memory_order, ...) - for (size_t i = 1; i < node->args.size(); ++i) { - new_args.push_back(node->args[i]); - } - } + // Create vectorized atomic op + Call addr_dst(DataType::Handle(), builtin::address_of(), {dst_load.value()}); + Call addr_src(DataType::Handle(), builtin::address_of(), {src_load.value()}); - Call new_call = - tvm::tir::Call(node->dtype, builtin::call_extern(), new_args); - - return new_call; - } + return Call(node->dtype, GetVectorizedAtomicOp(vector_size_), {addr_dst, addr_src}); } - const ForNode *inner_for_; - const int vector_size_; - const PrimExpr condition_; - const bool dynamic_; + Target target_; + const ForNode *vectorized_loop_ = nullptr; + int vector_size_ = 1; + bool has_vectorized_atomic_ = false; }; -For VectorizeAtomicAdd(const For &for_node, int compute_capability) { - AtomicAddVectorizePlanResult res = {1, false, 0}; - AtomicAddVectorizePlanner planner; - res = planner.Plan(for_node, compute_capability); - auto rewriter = AtomicAddVectorizeRewriter(res); - return Downcast(rewriter(for_node)); +} // namespace + +For VectorizeAtomicAdd(const For &for_node) { + Target target = Target::Current(false); + return Downcast(AtomicAddVectorizeRewriter(target)(for_node)); } -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm diff --git a/src/transform/atomicadd_vectorize.h b/src/transform/atomicadd_vectorize.h index 6bd3309ae..905c8aaa3 100644 --- a/src/transform/atomicadd_vectorize.h +++ b/src/transform/atomicadd_vectorize.h @@ -1,60 +1,34 @@ /*! * \file atomicadd_vectorize.h - * \brief A tool to automatically vectorize a for atomicadd + * \brief Vectorization pass for atomic add operations. */ #ifndef TVM_TL_ATOMICADD_VECTORIZE_H_ #define TVM_TL_ATOMICADD_VECTORIZE_H_ -#include "../layout/layout.h" -#include "../layout/utils.h" #include "../op/builtin.h" -#include "arith/int_operator.h" -#include "arith/ir_visitor_with_analyzer.h" -#include "common/loop_vectorization_utils.h" -#include -#include -#include +#include "../target/utils.h" #include -#include #include -#include namespace tvm { namespace tl { using namespace tir; -For VectorizeAtomicAdd(const For &for_node, int compute_capability); - -struct AtomicAddVectorizePlanResult { - int vector_size; - bool dynamic; - PrimExpr condition; -}; - -class AtomicAddVectorizePlanner : public arith::IRVisitorWithAnalyzer { -public: - AtomicAddVectorizePlanner(); - - AtomicAddVectorizePlanResult Plan(const For &node, int compute_capability); - -private: - void VisitStmt_(const ForNode *node) final; - void VisitExpr_(const CallNode *node) final; - - int GetVectorizeSizeMax(int compute_capability, DataType dtype); - void UpdateVectorSize(const Array &indices, const Buffer &buffer); - - const ForNode *inner_for_ = nullptr; - bool has_nonlocal_memory_access_ = false; - int vector_size_ = 4; - int max_vector_size = 1; - bool dynamic_ = false; - PrimExpr condition_; -}; +/*! + * \brief Vectorize atomic add operations inside vectorized loops. + * + * This function detects atomic_add_elem_op inside ForKind::kVectorized loops + * and converts them to vectorized versions (atomic_addx2_elem_op or + * atomic_addx4_elem_op) based on the loop extent and data type. + * + * \param for_node The For loop to process. + * \return The transformed For loop. + */ +For VectorizeAtomicAdd(const For &for_node); -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm -#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ +#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 201a4e2b9..61c97d20c 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -29,6 +29,7 @@ #include #include "../op/utils.h" +#include "atomicadd_vectorize.h" #include "loop_vectorize.h" namespace tvm { @@ -296,7 +297,10 @@ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, result_loop = VectorizeLoop(result_loop, saved_analyzer.get()); } - // Step 3: Wrap with predicate if provided and this is a parallel loop + // Step 3: Vectorize atomic add operations + result_loop = VectorizeAtomicAdd(result_loop); + + // Step 4: Wrap with predicate if provided and this is a parallel loop if (predicate.defined() && parallel_loop) { return IfThenElse(predicate.value(), result_loop); } diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 5bb7c8f80..7400d3155 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -24,6 +24,7 @@ #include "loop_vectorize.h" #include "../op/builtin.h" +#include "../op/utils.h" #include "../target/utils.h" #include "arith/int_operator.h" #include "arith/ir_visitor_with_analyzer.h" @@ -118,8 +119,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } PrimExpr VisitExpr_(const BufferLoadNode *node) final { - if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || - node->buffer.scope() == "shared.dyn") + if (IsSharedBuffer(node->buffer) || IsGlobalBuffer(node->buffer)) has_nonlocal_memory_access_ = true; if (node->buffer->shape.size() == 1) { // TODO(lei): This should be improved as @@ -134,8 +134,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } Stmt VisitStmt_(const BufferStoreNode *node) final { - if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || - node->buffer.scope() == "shared.dyn") + if (IsSharedBuffer(node->buffer) || IsGlobalBuffer(node->buffer)) has_nonlocal_memory_access_ = true; UpdateVectorSize(node->indices, node->buffer, true); return arith::IRMutatorWithAnalyzer::VisitStmt_(node); @@ -149,12 +148,36 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CallNode *node) final { if (node->op == builtin::if_then_else()) { CheckConditionVectorized(node->args[0]); - } else if (node->op == builtin::call_extern()) { - // do not vectorize extern calls - vector_size_ = 1; - } else if (node->op.same_as(tl::rng_init())) { - // do not vectorize random operation + } else if (node->op == tl::atomic_add_elem_op()) { + // Assert at least 2 args (dst_ptr and src) + ICHECK(node->args.size() >= 2) + << "atomic_add_elem_op requires at least 2 args (dst and src)"; + + // Get dst dtype from args[0] (address_of call containing BufferLoad) + auto address_of_call = node->args[0].as(); + ICHECK(address_of_call && address_of_call->op == builtin::address_of()) + << "atomic_add_elem_op first arg must be address_of call"; + + auto buffer_load = address_of_call->args[0].as(); + ICHECK(buffer_load) << "address_of arg must be BufferLoad"; + + DataType dtype = buffer_load->buffer->dtype; + int vectorize_length = 1; + if (dtype.is_float16() || dtype.is_bfloat16()) { + vectorize_length = 2; + } else if (dtype.is_float() && dtype.bits() == 32 && + TargetHasSMVersionGE(Target::Current(false), 90)) { + vectorize_length = 4; + } + + vector_size_ = arith::ZeroAwareGCD(vector_size_, vectorize_length); + // Do not visit the args of atomic_add_elem_op, because pointer type + // is impossible to vectorize + return Downcast(node); + } else { + // Other calls should not be vectorized vector_size_ = 1; + return Downcast(node); } return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic.py similarity index 83% rename from testing/python/language/test_tilelang_language_atomic_add.py rename to testing/python/language/test_tilelang_language_atomic.py index 8b3253b95..fc8d5278e 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic.py @@ -4,7 +4,12 @@ import torch -@tilelang.jit +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + } +) def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): @@ -21,6 +26,7 @@ def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def run_atomic_add(K, M, N, block_M, block_N, dtype=T.float32): kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + print(kernel.get_kernel_source()) import torch def ref_program(A, B): @@ -386,6 +392,80 @@ def test_tile_atomic_add(): run_tile_atomic_add(8, 128, 128, 32, 32) +# ======================= Tile-level atomic max ======================= +@tilelang.jit +def tile_atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def tile_atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + T.atomic_max(B[bx * block_M, by * block_N], A_shared) + + return tile_atomic_max + + +def run_tile_atomic_max(K, M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) + print(kernel.get_kernel_source()) + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = max(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("-inf"), dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +def test_tile_atomic_max(): + run_tile_atomic_max(8, 128, 128, 32, 32) + + +# ======================= Tile-level atomic min ======================= +@tilelang.jit +def tile_atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def tile_atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + T.atomic_min(B[bx * block_M, by * block_N], A_shared) + + return tile_atomic_min + + +def run_tile_atomic_min(K, M, N, block_M, block_N, dtype=T.float32): + kernel = tile_atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) + print(kernel.get_kernel_source()) + + def ref_program(A, B): + for k in range(K): + for i in range(M): + for j in range(N): + B[i, j] = min(B[i, j], A[k, i, j]) + + A = torch.randn(K, M, N, dtype=getattr(torch, dtype)).cuda() + B = torch.full((M, N), float("inf"), dtype=getattr(torch, dtype)).cuda() + ref_B = B.clone() + ref_program(A, ref_B) + kernel(A, B) + torch.testing.assert_close(B, ref_B, atol=1e-3, rtol=1e-3) + + +def test_tile_atomic_min(): + run_tile_atomic_min(8, 128, 128, 32, 32) + + @tilelang.testing.requires_cuda def test_tma_atomic_add(): out = torch.zeros((16, 16), dtype=torch.float32, device="cuda") diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 30b5f533b..12533928e 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -21,7 +21,7 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re """ Perform an atomic maximum on the value stored at dst with an optional memory-order. - If memory_order is None the runtime extern "AtomicMax" is called without an explicit memory-order id; otherwise the provided memory_order string is mapped to a numeric id using the module's memory-order map and passed to the extern. + Supports scalar/addressed extern atomic max when neither argument exposes extents, or tile-region-based atomic max for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicMax` path when no extents are available — otherwise the tile-region path ignores `memory_order`. Parameters: dst (Buffer): Destination buffer/address to apply the atomic max. @@ -50,29 +50,65 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re >>> def find_max(data: T.Buffer, result: T.Buffer): >>> for i in T.thread_binding(128, "threadIdx.x"): >>> atomic_max(result, data[i]) + + >>> # Tensor-to-tensor atomic max (tile-region based) + >>> src_tensor = T.Tensor([128, 64], "float32", name="src") + >>> dst_tensor = T.Tensor([128, 64], "float32", name="dst") + >>> atomic_max(dst_tensor, src_tensor) # Max entire tensors atomically """ - func_name = "AtomicMaxRet" if return_prev else "AtomicMax" - return_type = dst.dtype if return_prev else "handle" - if memory_order is None: - return T.call_extern(return_type, func_name, T.address_of(dst), value) - else: - return T.call_extern( + def get_extent(data): + if isinstance(data, Var) and T.has_let_value(data): + data = T.get_let_value(data) + if isinstance(data, Buffer): + return data.shape + elif isinstance(data, BufferRegion): + return [x.extent for x in data.region] + else: + return None + + src_extent = get_extent(value) + dst_extent = get_extent(dst) + + if dst_extent is None and src_extent is None: + # Scalar path: use atomicmax_elem_op intrinsic + return_type = dst.dtype if return_prev else "handle" + memory_order_id = _MEMORY_ORDER_ID_MAP[memory_order] if memory_order else 0 + + return T.call_intrin( return_type, - func_name, + op.Op.get("tl.atomic_max_elem_op"), T.address_of(dst), value, - _MEMORY_ORDER_ID_MAP[memory_order], + memory_order_id, ) + if isinstance(dst, Buffer) and isinstance(value, Buffer): + ir.assert_structural_equal(dst.shape, value.shape) + + assert src_extent or dst_extent, "Can't deduce atomicmax extents from args" + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + value = to_buffer_region(value, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) + + if return_prev: + raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") + + ann = {} + if memory_order is not None: + ann["memory_order"] = _MEMORY_ORDER_ID_MAP[memory_order] + + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicmax"), value, dst, annotations=ann if ann else None) + def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False) -> PrimExpr: """ Atomically update the value at dst to the minimum of its current value and value. - If memory_order is provided, it selects the memory-order semantic used by the underlying extern call; - allowed names are "relaxed", "consume", "acquire", "release", "acq_rel", and "seq_cst" (mapped internally - to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument. + Supports scalar/addressed extern atomic min when neither argument exposes extents, or tile-region-based atomic min for Buffer/BufferRegion/BufferLoad inputs. If both arguments are plain Buffers their shapes must be structurally equal. If at least one side exposes extents, extents are aligned (missing dimensions are treated as size 1); an assertion is raised if extents cannot be deduced. The optional `memory_order` (one of "relaxed","consume","acquire","release","acq_rel","seq_cst") is used only for the direct extern `AtomicMin` path when no extents are available — otherwise the tile-region path ignores `memory_order`. Parameters: dst (Buffer): Destination buffer/address to apply the atomic min. @@ -101,21 +137,59 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None, re >>> # With relaxed memory ordering for performance >>> atomic_min(min_val, 5, memory_order="relaxed") + + >>> # Tensor-to-tensor atomic min (tile-region based) + >>> src_tensor = T.Tensor([128, 64], "float32", name="src") + >>> dst_tensor = T.Tensor([128, 64], "float32", name="dst") + >>> atomic_min(dst_tensor, src_tensor) # Min entire tensors atomically """ - func_name = "AtomicMinRet" if return_prev else "AtomicMin" - return_type = dst.dtype if return_prev else "handle" - if memory_order is None: - return T.call_extern(return_type, func_name, T.address_of(dst), value) - else: - return T.call_extern( + def get_extent(data): + if isinstance(data, Var) and T.has_let_value(data): + data = T.get_let_value(data) + if isinstance(data, Buffer): + return data.shape + elif isinstance(data, BufferRegion): + return [x.extent for x in data.region] + else: + return None + + src_extent = get_extent(value) + dst_extent = get_extent(dst) + + if dst_extent is None and src_extent is None: + # Scalar path: use atomicmin_elem_op intrinsic + return_type = dst.dtype if return_prev else "handle" + memory_order_id = _MEMORY_ORDER_ID_MAP[memory_order] if memory_order else 0 + + return T.call_intrin( return_type, - func_name, + op.Op.get("tl.atomic_min_elem_op"), T.address_of(dst), value, - _MEMORY_ORDER_ID_MAP[memory_order], + memory_order_id, ) + if isinstance(dst, Buffer) and isinstance(value, Buffer): + ir.assert_structural_equal(dst.shape, value.shape) + + assert src_extent or dst_extent, "Can't deduce atomicmin extents from args" + src_extent = list(src_extent) if src_extent else [1] * len(dst_extent) + dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent) + src_extent, dst_extent = legalize_pairwise_extents(src_extent, dst_extent) + + value = to_buffer_region(value, access_type="r", extents=src_extent) + dst = to_buffer_region(dst, access_type="w", extents=dst_extent) + + if return_prev: + raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") + + ann = {} + if memory_order is not None: + ann["memory_order"] = _MEMORY_ORDER_ID_MAP[memory_order] + + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicmin"), value, dst, annotations=ann if ann else None) + def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None, return_prev: bool = False, use_tma: bool = False) -> PrimExpr: """ @@ -186,16 +260,16 @@ def get_extent(data): dst_extent = get_extent(dst) if dst_extent is None and src_extent is None: - func_name = "AtomicAddRet" if return_prev else "AtomicAdd" + atomic_add_op = op.Op.get("tl.atomic_add_ret_elem_op") if return_prev else op.Op.get("tl.atomic_add_elem_op") return_type = dst.dtype if return_prev else "handle" # Pass destination by pointer to match device signature if memory_order is None: - return T.call_extern(return_type, func_name, T.address_of(dst), value) + return T.call_intrin(return_type, atomic_add_op, T.address_of(dst), value) else: - return T.call_extern( + return T.call_intrin( return_type, - func_name, + atomic_add_op, T.address_of(dst), value, _MEMORY_ORDER_ID_MAP[memory_order], @@ -262,7 +336,7 @@ def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri """ func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2" return_type = dst.dtype if return_prev else "handle" - return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value)) + return T.call_intrin(return_type, op.Op.get("tl.atomic_addx2_elem_op"), T.address_of(dst), T.address_of(value)) def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: @@ -300,7 +374,7 @@ def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri """ func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4" return_type = "float4" if "float" in str(dst.dtype).lower() else "handle" - return T.call_extern(return_type, func_name, T.address_of(dst), T.address_of(value)) + return T.call_intrin(return_type, op.Op.get("tl.atomic_addx4_elem_op"), T.address_of(dst), T.address_of(value)) def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: @@ -339,7 +413,7 @@ def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: >>> counter = T.Tensor([1], "int64", name="counter") >>> current_count = atomic_load(counter, memory_order="relaxed") """ - return T.call_extern(src.dtype, "AtomicLoad", T.address_of(src), _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_intrin(src.dtype, op.Op.get("tl.atomic_load_elem_op"), T.address_of(src), _MEMORY_ORDER_ID_MAP[memory_order]) def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> PrimExpr: @@ -392,4 +466,4 @@ def atomic_store(dst: Buffer, src: PrimExpr, memory_order: str = "seq_cst") -> P >>> log_counter = T.Tensor([1], "int64", name="log_counter") >>> atomic_store(log_counter, 0) # Reset counter atomically """ - return T.call_extern("handle", "AtomicStore", T.address_of(dst), src, _MEMORY_ORDER_ID_MAP[memory_order]) + return T.call_intrin("handle", op.Op.get("tl.atomic_store_elem_op"), T.address_of(dst), src, _MEMORY_ORDER_ID_MAP[memory_order]) From 79f349f0729a3867a5530a035eebd3a4181b8f02 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 14:21:29 +0800 Subject: [PATCH 02/13] Refactor atomic operation code for improved readability and consistency - Cleaned up whitespace and formatting in atomic add and reduce implementations for better code clarity. - Enhanced comments in CUDA code generation for atomic operations to improve understanding. - Updated function signatures and layout definitions for consistency across atomic operation files. - Ensured proper alignment of code for better maintainability and readability. --- src/op/atomic_add.cc | 10 +++++----- src/op/atomic_add.h | 6 +++--- src/op/atomic_reduce.cc | 24 ++++++++++++------------ src/op/atomic_reduce.h | 21 ++++++++++----------- src/target/codegen_cuda.cc | 12 ++++++------ src/transform/atomicadd_vectorize.cc | 27 +++++++++++++++++---------- src/transform/atomicadd_vectorize.h | 6 +++--- tilelang/language/atomic.py | 10 +++++----- 8 files changed, 61 insertions(+), 55 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index b34e71425..25cf29257 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -99,7 +99,6 @@ int AtomicAddNode::GetVectorizeLength(Target target) const { return 1; } - std::pair, PrimExpr> AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { Array indices; @@ -112,7 +111,6 @@ AtomicAddNode::ReturnIndicesAndSize(int src_dst) const { return {indices, size}; } - /** * @brief Build a SIMT-style loop nest that performs element-wise atomic * additions from src to dst. @@ -565,7 +563,8 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto par_op = ParallelOp(fused_loop); std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; - // 1.give par_op a recommended vectorize size. (only works for free layout inference). + // 1.give par_op a recommended vectorize size. (only works for free layout + // inference). for (auto level : levels) { par_op->InferLayout({T.target, T.thread_bounds, @@ -577,8 +576,9 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { level); } auto loop_layout = par_op->GetLoopLayout(); - auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, - analyzer, par_op->GetPredicate(T.thread_var)); + auto lowered_loop = + LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, + par_op->GetPredicate(T.thread_var)); return lowered_loop; } diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 1fd1b8c8d..524868e1b 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -81,7 +81,7 @@ class AtomicAdd : public TileOperator { static const Op &Get(); }; -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm -#endif // TVM_TL_OP_ATOMIC_ADD_H_ +#endif // TVM_TL_OP_ATOMIC_ADD_H_ diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index e585aa03d..8fab54809 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -13,7 +13,6 @@ #include "../layout/layout.h" #include "../target/utils.h" -#include "../transform/atomicreduce_lower.h" #include "../transform/common/loop_fusion_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" @@ -103,7 +102,7 @@ Array AtomicOpBaseNode::MakeIterVars() const { template Array AtomicOpBaseNode::MakeIndices(const Array &ivs, - int src_dst) const { + int src_dst) const { Array indices; Array ranges = src_dst == 0 ? src_range : dst_range; size_t idx = 0; @@ -123,9 +122,9 @@ AtomicOpBaseNode::MakeIndices(const Array &ivs, template PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, - const Array &ivs, - Array extents, - int src_dst) const { + const Array &ivs, + Array extents, + int src_dst) const { Array ranges = src_dst == 0 ? src_range : dst_range; Array cond_list; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; @@ -154,8 +153,7 @@ PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, } template -For AtomicOpBaseNode::MakeSIMTLoop( - arith::Analyzer *analyzer) const { +For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.empty(); if (is_scalar) { @@ -195,7 +193,8 @@ For AtomicOpBaseNode::MakeSIMTLoop( new_args.push_back(static_cast(this)->GetMemoryOrder()); // Use the appropriate elem_op based on the derived type (via CRTP) - Call atomic_call = tvm::tir::Call(dst->dtype, GetElemOp(), new_args, annotations); + Call atomic_call = + tvm::tir::Call(dst->dtype, GetElemOp(), new_args, annotations); Stmt body = tvm::tir::Evaluate(atomic_call); @@ -216,7 +215,7 @@ For AtomicOpBaseNode::MakeSIMTLoop( template LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { + InferLevel level) const { // For atomic reduce operations, check that src and dst have the same layout // if both are fragments if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { @@ -235,7 +234,7 @@ LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, template Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, - arith::Analyzer *analyzer) const { + arith::Analyzer *analyzer) const { Target target = T.target; auto simt_loop = MakeSIMTLoop(analyzer); @@ -254,8 +253,9 @@ Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, level); } auto loop_layout = par_op->GetLoopLayout(); - auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, - analyzer, par_op->GetPredicate(T.thread_var)); + auto lowered_loop = + LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, + par_op->GetPredicate(T.thread_var)); return lowered_loop; } diff --git a/src/op/atomic_reduce.h b/src/op/atomic_reduce.h index f4c3e0eed..144cc7914 100644 --- a/src/op/atomic_reduce.h +++ b/src/op/atomic_reduce.h @@ -18,15 +18,16 @@ using namespace tir; * \brief Base node class for atomic operations (add/max/min). * * This template base class provides common functionality for all atomic - * operations including buffer management, loop generation, and layout inference. + * operations including buffer management, loop generation, and layout + * inference. * * \tparam Derived The derived class type (CRTP pattern) */ -template -class AtomicOpBaseNode : public TileOperatorNode { +template class AtomicOpBaseNode : public TileOperatorNode { public: - Buffer src, dst; ///< Source and destination buffers - Array src_range, dst_range; ///< Access ranges for source and destination + Buffer src, dst; ///< Source and destination buffers + Array src_range, + dst_range; ///< Access ranges for source and destination Map annotations; ///< Annotations for the atomic operation // Supported annotation keys: // - "coalesced_width": IntImm, width for memory coalescing optimization @@ -66,9 +67,7 @@ class AtomicOpBaseNode : public TileOperatorNode { /// Get the element-wise operation Op (to be implemented by derived class) /// This uses CRTP to call the derived class's static method - const Op &GetElemOp() const { - return Derived::GetElemOpStatic(); - } + const Op &GetElemOp() const { return Derived::GetElemOpStatic(); } }; // Backward compatibility alias @@ -139,7 +138,7 @@ class AtomicMin : public TileOperator { static const Op &Get(); }; -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm -#endif // TVM_TL_OP_ATOMIC_REDUCE_H_ +#endif // TVM_TL_OP_ATOMIC_REDUCE_H_ diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index d629cf1aa..dd59fc48a 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2899,8 +2899,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; } else if (op->op.same_as(tl::atomic_add_ret_elem_op())) { - // atomic_add_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev - // value + // atomic_add_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value os << "AtomicAddRet(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]); if (op->args.size() > 2) { @@ -2950,8 +2950,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; } else if (op->op.same_as(tl::atomic_max_ret_elem_op())) { - // atomic_max_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev - // value + // atomic_max_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value os << "AtomicMaxRet(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]); if (op->args.size() > 2) { @@ -2969,8 +2969,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { } this->stream << ");\n"; } else if (op->op.same_as(tl::atomic_min_ret_elem_op())) { - // atomic_min_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns prev - // value + // atomic_min_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value os << "AtomicMinRet(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]); if (op->args.size() > 2) { diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index d09d48618..3e9ba673c 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -16,7 +16,8 @@ using namespace tir; namespace { /*! - * \brief Extract BufferLoad from an expression that may be wrapped in address_of. + * \brief Extract BufferLoad from an expression that may be wrapped in + * address_of. */ Optional ExtractBufferLoad(const PrimExpr &expr) { if (const auto *load = expr.as()) { @@ -64,7 +65,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { if (dtype.is_float16() || dtype.is_bfloat16()) { return 2; } - if (dtype.is_float() && dtype.bits() == 32 && TargetHasSMVersionGE(target_, 90)) { + if (dtype.is_float() && dtype.bits() == 32 && + TargetHasSMVersionGE(target_, 90)) { return 4; } return 1; @@ -90,7 +92,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { // Change loop extent to 1 since atomic op now handles all elements return For(node->loop_var, node->min, Integer(1), node->kind, body, - node->thread_binding, node->annotations, node->step, node->span); + node->thread_binding, node->annotations, node->step, + node->span); } vectorized_loop_ = nullptr; @@ -100,7 +103,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { return tvm::ffi::GetRef(node); } return For(node->loop_var, node->min, node->extent, node->kind, body, - node->thread_binding, node->annotations, node->step, node->span); + node->thread_binding, node->annotations, node->step, + node->span); } } return StmtExprMutator::VisitStmt_(node); @@ -133,10 +137,13 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { has_vectorized_atomic_ = true; // Create vectorized atomic op - Call addr_dst(DataType::Handle(), builtin::address_of(), {dst_load.value()}); - Call addr_src(DataType::Handle(), builtin::address_of(), {src_load.value()}); + Call addr_dst(DataType::Handle(), builtin::address_of(), + {dst_load.value()}); + Call addr_src(DataType::Handle(), builtin::address_of(), + {src_load.value()}); - return Call(node->dtype, GetVectorizedAtomicOp(vector_size_), {addr_dst, addr_src}); + return Call(node->dtype, GetVectorizedAtomicOp(vector_size_), + {addr_dst, addr_src}); } Target target_; @@ -145,12 +152,12 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { bool has_vectorized_atomic_ = false; }; -} // namespace +} // namespace For VectorizeAtomicAdd(const For &for_node) { Target target = Target::Current(false); return Downcast(AtomicAddVectorizeRewriter(target)(for_node)); } -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm diff --git a/src/transform/atomicadd_vectorize.h b/src/transform/atomicadd_vectorize.h index 905c8aaa3..470814a92 100644 --- a/src/transform/atomicadd_vectorize.h +++ b/src/transform/atomicadd_vectorize.h @@ -28,7 +28,7 @@ using namespace tir; */ For VectorizeAtomicAdd(const For &for_node); -} // namespace tl -} // namespace tvm +} // namespace tl +} // namespace tvm -#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ +#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 12533928e..9144b36c1 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -260,7 +260,7 @@ def get_extent(data): dst_extent = get_extent(dst) if dst_extent is None and src_extent is None: - atomic_add_op = op.Op.get("tl.atomic_add_ret_elem_op") if return_prev else op.Op.get("tl.atomic_add_elem_op") + atomic_add_op = op.Op.get("tl.atomic_add_ret_elem_op") if return_prev else op.Op.get("tl.atomic_add_elem_op") return_type = dst.dtype if return_prev else "handle" # Pass destination by pointer to match device signature @@ -334,9 +334,9 @@ def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri >>> for j in range(0, grads.shape[1], 2): # Process in pairs >>> atomic_addx2(global_grads[i, j:j+2], grads[i, j:j+2]) """ - func_name = "AtomicAddx2Ret" if return_prev else "AtomicAddx2" + atomic_addx2_op = op.Op.get("tl.atomic_addx2_elem_op") if return_prev else op.Op.get("tl.atomic_addx2_elem_op") return_type = dst.dtype if return_prev else "handle" - return T.call_intrin(return_type, op.Op.get("tl.atomic_addx2_elem_op"), T.address_of(dst), T.address_of(value)) + return T.call_intrin(return_type, atomic_addx2_op, T.address_of(dst), T.address_of(value)) def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: @@ -372,9 +372,9 @@ def atomic_addx4(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> Pri >>> rgba_add = T.Tensor([4], "float32", name="rgba_add") >>> atomic_addx4(rgba_dst, rgba_add) # Atomic blend of all 4 channels """ - func_name = "AtomicAddx4Ret" if return_prev else "AtomicAddx4" + atomic_addx4_op = op.Op.get("tl.atomic_addx4_elem_op") if return_prev else op.Op.get("tl.atomic_addx4_elem_op") return_type = "float4" if "float" in str(dst.dtype).lower() else "handle" - return T.call_intrin(return_type, op.Op.get("tl.atomic_addx4_elem_op"), T.address_of(dst), T.address_of(value)) + return T.call_intrin(return_type, atomic_addx4_op, T.address_of(dst), T.address_of(value)) def atomic_load(src: Buffer, memory_order: str = "seq_cst") -> PrimExpr: From 4596e430ca25411f47bb24c51113b29166740981 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 16:20:53 +0800 Subject: [PATCH 03/13] lint fix --- src/transform/atomicadd_vectorize.cc | 47 ++++++++++++++-------------- src/transform/loop_vectorize.cc | 6 +++- 2 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index 3e9ba673c..8b4826e2f 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -76,36 +76,35 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { // Check if this is a vectorized loop if (node->kind == ForKind::kVectorized) { auto extent_ptr = as_const_int(node->extent); - if (extent_ptr) { - int vec_size = static_cast(*extent_ptr); - // Push vectorized context - vectorized_loop_ = node; - vector_size_ = vec_size; - - Stmt body = VisitStmt(node->body); - - // If we successfully vectorized atomic ops, transform the loop - if (has_vectorized_atomic_) { - has_vectorized_atomic_ = false; - vectorized_loop_ = nullptr; - vector_size_ = 1; - - // Change loop extent to 1 since atomic op now handles all elements - return For(node->loop_var, node->min, Integer(1), node->kind, body, - node->thread_binding, node->annotations, node->step, - node->span); - } + if (!extent_ptr) { + return StmtExprMutator::VisitStmt_(node); + } + int vec_size = static_cast(*extent_ptr); + // Push vectorized context + vectorized_loop_ = node; + vector_size_ = vec_size; + Stmt body = VisitStmt(node->body); + // If we successfully vectorized atomic ops, transform the loop + if (has_vectorized_atomic_) { + has_vectorized_atomic_ = false; vectorized_loop_ = nullptr; vector_size_ = 1; - - if (body.same_as(node->body)) { - return tvm::ffi::GetRef(node); - } - return For(node->loop_var, node->min, node->extent, node->kind, body, + // Change loop extent to 1 since atomic op now handles all elements + return For(node->loop_var, node->min, Integer(1), node->kind, body, node->thread_binding, node->annotations, node->step, node->span); } + + vectorized_loop_ = nullptr; + vector_size_ = 1; + + if (body.same_as(node->body)) { + return tvm::ffi::GetRef(node); + } + return For(node->loop_var, node->min, node->extent, node->kind, body, + node->thread_binding, node->annotations, node->step, + node->span); } return StmtExprMutator::VisitStmt_(node); } diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 7400d3155..2a95d77bd 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -173,7 +173,11 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { vector_size_ = arith::ZeroAwareGCD(vector_size_, vectorize_length); // Do not visit the args of atomic_add_elem_op, because pointer type // is impossible to vectorize - return Downcast(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } else if (node->op == builtin::address_of()) { + // address_of have buffer load value so we should analysis the buffer load + // node to update vector_size_. + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else { // Other calls should not be vectorized vector_size_ = 1; From a2e63aece5cc14297161ecb30e85177c25a4d699 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 18:30:51 +0800 Subject: [PATCH 04/13] Enhance loop vectorization and layout handling across multiple operations - Updated function signatures in atomic add, atomic reduce, copy, and fill operations to include layout map parameters for improved layout handling. - Refactored vectorization logic to utilize layout maps, ensuring better performance and compatibility with various buffer layouts. - Enhanced the LowerParallelLoop function to accept layout maps, facilitating more efficient loop transformations. - Added checks for buffer contiguity in vectorization processes to ensure correctness when using layout maps. - Updated tests to validate the new layout handling and vectorization behavior. --- src/op/atomic_add.cc | 2 +- src/op/atomic_reduce.cc | 2 +- src/op/copy.cc | 5 +- src/op/fill.cc | 9 +- src/op/parallel.cc | 3 +- src/transform/loop_partition.cc | 7 +- src/transform/loop_partition.h | 2 + src/transform/loop_vectorize.cc | 93 ++++++++++++++++--- src/transform/loop_vectorize.h | 11 ++- src/transform/lower_tile_op.cc | 3 +- ...g_transform_legalize_safe_memory_access.py | 2 +- 11 files changed, 109 insertions(+), 30 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 25cf29257..bd8b52d1e 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -578,7 +578,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto loop_layout = par_op->GetLoopLayout(); auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, - par_op->GetPredicate(T.thread_var)); + T.layout_map, par_op->GetPredicate(T.thread_var)); return lowered_loop; } diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index 8fab54809..370509bfe 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -255,7 +255,7 @@ Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, auto loop_layout = par_op->GetLoopLayout(); auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, - par_op->GetPredicate(T.thread_var)); + T.layout_map, par_op->GetPredicate(T.thread_var)); return lowered_loop; } diff --git a/src/op/copy.cc b/src/op/copy.cc index 7f91d4c38..05577e7e9 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -728,7 +728,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, << dst.scope() << " buffer `" << dst->name << "` may cause conflicted write."; } - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(transformed_loop, T.layout_map); return vectorized_thread_loop; } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, @@ -747,7 +747,8 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, // Use LowerParallelLoop to handle partitioning, vectorization, and // predicate return LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var, - analyzer, par_op->GetPredicate(T.thread_var)); + analyzer, T.layout_map, + par_op->GetPredicate(T.thread_var)); } } diff --git a/src/op/fill.cc b/src/op/fill.cc index 02962d242..9e036db37 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -168,7 +168,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); + auto vectorized_thread_loop = + VectorizeLoop(thread_loop, analyzer, T.layout_map); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); @@ -176,7 +177,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } else if (IsLocalBuffer(dst) || IsLocalVarBuffer(dst)) { auto init_loop = MakeSIMTLoop(analyzer); - auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); + auto vectorized_thread_loop = + VectorizeLoop(init_loop, analyzer, T.layout_map); return vectorized_thread_loop; } else if (IsSharedBuffer(dst) || IsGlobalBuffer(dst)) { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); @@ -190,7 +192,8 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); + auto vectorized_thread_loop = + VectorizeLoop(thread_loop, analyzer, T.layout_map); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 3b3012469..0e5084f8b 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -708,7 +708,8 @@ Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { // As the pass will do post processing to the layout auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); - int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); + int vector_size = + GetVectorizeSize(maybe_remapped_root_, T.analyzer, T.layout_map); DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; PrimExpr loop_total_size = 1; diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 61c97d20c..b7b15e4b5 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -272,8 +272,9 @@ For LoopPragmaUnroll(For stmt) { } Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, - arith::Analyzer *analyzer, Optional predicate, - bool parallel_loop, bool should_vectorize) { + arith::Analyzer *analyzer, const LayoutMap &layout_map, + Optional predicate, bool parallel_loop, + bool should_vectorize) { // Save analyzer state to prevent conflicted bindings during vectorization auto saved_analyzer = analyzer->Clone(); @@ -294,7 +295,7 @@ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, // Step 2: Vectorize the loop (if requested) if (should_vectorize) { - result_loop = VectorizeLoop(result_loop, saved_analyzer.get()); + result_loop = VectorizeLoop(result_loop, saved_analyzer.get(), layout_map); } // Step 3: Vectorize atomic add operations diff --git a/src/transform/loop_partition.h b/src/transform/loop_partition.h index 844065ab3..ffc32ec45 100644 --- a/src/transform/loop_partition.h +++ b/src/transform/loop_partition.h @@ -29,6 +29,7 @@ #include #include "../layout/layout.h" +#include "../op/operator.h" namespace tvm { namespace tl { @@ -68,6 +69,7 @@ For LoopPragmaUnroll(For stmt); */ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, arith::Analyzer *analyzer, + const LayoutMap &layout_map = {}, Optional predicate = Optional(), bool parallel_loop = true, bool should_vectorize = true); diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 2a95d77bd..4c941181a 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -40,6 +40,36 @@ namespace tl { using namespace tir; +/*! + * \brief Check if buffer strides represent a contiguous (row-major) layout. + * \param buffer The buffer to check. + * \param analyzer The analyzer for symbolic comparison. + * \return True if strides are empty (implicitly contiguous) or match row-major + * layout. + */ +bool IsBufferContiguous(const Buffer &buffer, arith::Analyzer *analyzer) { + if (buffer->strides.empty()) { + return true; + } + if (buffer->strides.size() != buffer->shape.size()) { + return false; + } + // For row-major layout: + // strides[n-1] = 1 + // strides[i] = strides[i+1] * shape[i+1] + int n = buffer->shape.size(); + PrimExpr expected_stride = make_const(buffer->shape[0].dtype(), 1); + for (int i = n - 1; i >= 0; --i) { + if (!analyzer->CanProveEqual(buffer->strides[i], expected_stride)) { + return false; + } + if (i > 0) { + expected_stride = expected_stride * buffer->shape[i]; + } + } + return true; +} + struct VectorizePlanResult { int vector_size; bool dynamic; @@ -73,8 +103,9 @@ class VectorizeFindGlobalAccess : public StmtExprVisitor { class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - explicit VectorizePlanner(arith::Analyzer *analyzer) - : arith::IRMutatorWithAnalyzer(analyzer) {} + explicit VectorizePlanner(arith::Analyzer *analyzer, + const LayoutMap &layout_map = {}) + : arith::IRMutatorWithAnalyzer(analyzer), layout_map_(layout_map) {} int Plan(const For &node) { tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); @@ -200,19 +231,52 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { bool is_store) { if (!inner_for_) return; + auto transformed_indices = indices; + if (layout_map_.defined() && layout_map_.count(buffer)) { + ICHECK(IsBufferContiguous(buffer, analyzer_)) + << buffer + << " has non-contiguous strides, but layout map is provided."; + // forward indices + auto layout = layout_map_[buffer]; + transformed_indices = layout->Forward(indices); + + // Reshape transformed_indices to match buffer->shape dimensions if needed + if (transformed_indices.size() != buffer->shape.size()) { + // Step 1: Compute linear offset using layout->OutputShape() + auto output_shape = layout->OutputShape(); + ICHECK_EQ(transformed_indices.size(), output_shape.size()) + << "Forward indices size " << transformed_indices.size() + << " != OutputShape size " << output_shape.size(); + PrimExpr linear_offset = 0; + PrimExpr stride = 1; + for (int i = output_shape.size() - 1; i >= 0; --i) { + linear_offset = linear_offset + transformed_indices[i] * stride; + stride = stride * output_shape[i]; + } + // Step 2: Decompose linear_offset into buffer->shape dimensions + Array new_indices; + for (int i = buffer->shape.size() - 1; i >= 0; --i) { + new_indices.push_back(FloorMod(linear_offset, buffer->shape[i])); + linear_offset = FloorDiv(linear_offset, buffer->shape[i]); + } + transformed_indices = + Array{new_indices.rbegin(), new_indices.rend()}; + } + } + // 1. Compute raw element offset auto strides = buffer->strides; if (buffer->strides.empty()) { PrimExpr stride = 1; - for (int i = indices.size() - 1; i >= 0; --i) { + for (int i = transformed_indices.size() - 1; i >= 0; --i) { strides.push_back(stride); stride = stride * buffer->shape[i]; } strides = Array{strides.rbegin(), strides.rend()}; } PrimExpr elem_offset = 0; - for (int i = 0; i < indices.size(); ++i) { - elem_offset += indices[i] * strides[i]; + for (int i = 0; i < transformed_indices.size(); ++i) { + elem_offset += transformed_indices[i] * strides[i]; } // 2. If element offset is independent with loop_var, ignore it. if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) { @@ -271,6 +335,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { const ForNode *inner_for_{}; bool has_nonlocal_memory_access_ = false; int vector_size_ = 128; + LayoutMap layout_map_; }; class VectorizeRewriter : public StmtExprMutator { @@ -314,13 +379,14 @@ class VectorizeRewriter : public StmtExprMutator { const int vector_size_; }; -int GetVectorizeSize(const For &loop) { +int GetVectorizeSize(const For &loop, const LayoutMap &layout_map) { arith::Analyzer analyzer; - return VectorizePlanner(&analyzer).Plan(loop); + return VectorizePlanner(&analyzer, layout_map).Plan(loop); } -int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { - return VectorizePlanner(analyzer).Plan(loop); +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer, + const LayoutMap &layout_map) { + return VectorizePlanner(analyzer, layout_map).Plan(loop); } bool CanProveIndependent(const PrimExpr &expr, Var var, @@ -420,10 +486,11 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, } } -For VectorizeLoop(const For &loop, int vectorize_hint) { +For VectorizeLoop(const For &loop, const LayoutMap &layout_map, + int vectorize_hint) { if (vectorize_hint <= 0) { arith::Analyzer analyzer; - VectorizePlanner planner(&analyzer); + VectorizePlanner planner(&analyzer, layout_map); vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) @@ -433,9 +500,9 @@ For VectorizeLoop(const For &loop, int vectorize_hint) { } For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, - int vectorize_hint) { + const LayoutMap &layout_map, int vectorize_hint) { if (vectorize_hint <= 0) { - VectorizePlanner planner(analyzer); + VectorizePlanner planner(analyzer, layout_map); vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1) diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 92a756228..591f047e0 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -25,6 +25,7 @@ #ifndef TVM_TL_LOOP_VECTORIZE_H_ #define TVM_TL_LOOP_VECTORIZE_H_ +#include "../op/operator.h" #include #include @@ -33,14 +34,16 @@ namespace tl { using namespace tir; -int GetVectorizeSize(const For &loop); +int GetVectorizeSize(const For &loop, const LayoutMap &layout_map = {}); -int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer); +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer, + const LayoutMap &layout_map = {}); -For VectorizeLoop(const For &loop, int vectorize_hint = -1); +For VectorizeLoop(const For &loop, const LayoutMap &layout_map = {}, + int vectorize_hint = -1); For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, - int vectorize_hint = -1); + const LayoutMap &layout_map = {}, int vectorize_hint = -1); // Can prove expr is independent with var, i.e. the value of expr doesn't change // when var changes diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 01fa84586..5f2cf1a4c 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -827,7 +827,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { // Lower the parallel loop using the common function return LowerParallelLoop(for_node, loop_layout, thread_var_->var, analyzer_, - predicate, parallel_loop, should_vectorize); + layout_map_, predicate, parallel_loop, + should_vectorize); } Target target_; diff --git a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py index 4f75fa05d..37eb3482f 100644 --- a/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py +++ b/testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py @@ -71,7 +71,7 @@ def expected( # Nest if-then-else is expected, do not flatten it to pass structural equal check if j + N_offset < N: # noqa: SIM102 if tid + M_offset < M: - T.call_extern("handle", "AtomicAdd", T.address_of(A[tid + M_offset, j + N_offset]), 1) + T.atomic_add(A[tid + M_offset, j + N_offset], 1) return main, expected From c1ef4c9ce4aa17cf806c670c9083d49e6f5ce2bd Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 18:55:53 +0800 Subject: [PATCH 05/13] Add atomic operation checks in legalize_safe_memory_access - Introduced a new method to identify atomic operations within the legalize_safe_memory_access transformation. - Updated the VisitStmt_ function to handle both CallExtern and atomic operations, ensuring recursive condition collection for these cases. - Enhanced comments for clarity on the handling of atomic operations in the context of memory access legality. --- src/transform/legalize_safe_memory_access.cc | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/transform/legalize_safe_memory_access.cc b/src/transform/legalize_safe_memory_access.cc index a6f31da7d..2838c5b8b 100644 --- a/src/transform/legalize_safe_memory_access.cc +++ b/src/transform/legalize_safe_memory_access.cc @@ -215,15 +215,25 @@ class SafeMemorysRewriter : public IRMutatorWithAnalyzer { // current statement. The current solution adopts a simplified approach: // directly applying the boundary constraints of all parameters to the // statement. While not entirely precise, it addresses most common scenarios. + // Check if the call is an atomic operation + bool IsAtomicOp(const Op &op) { + return op == atomic_add_elem_op() || op == atomic_add_ret_elem_op() || + op == atomic_addx2_elem_op() || op == atomic_addx4_elem_op() || + op == atomic_load_elem_op() || op == atomic_store_elem_op() || + op == atomic_max_elem_op() || op == atomic_max_ret_elem_op() || + op == atomic_min_elem_op() || op == atomic_min_ret_elem_op(); + } + Stmt VisitStmt_(const EvaluateNode *op) final { auto evaluate = Downcast(op); if (const CallNode *call_op = op->value.as()) { auto call = Downcast(op->value); - if (call->op == builtin::call_extern()) { - // For CallExtern, we recursively collect conditions from all children. - // Since we cannot rewrite any BufferLoad in its children (Rewrite will - // cause potential Nullptr exception). + if (call->op == builtin::call_extern() || + (call->op.as() && IsAtomicOp(Downcast(call->op)))) { + // For CallExtern and atomic ops, we recursively collect conditions + // from all children. Since we cannot rewrite any BufferLoad in its + // children (Rewrite will cause potential Nullptr exception). GlobalMemChecker checker(analyzer_, /*recursively_collect_conds=*/true); checker(call); Array conditions = checker.GetConditions(); From 8eb2a2964ad22dd8af95a860f0de932e2860b232 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 19:22:54 +0800 Subject: [PATCH 06/13] Add atomic operations for HIP code generation - Implemented support for various atomic operations including atomic add, atomic load, atomic store, atomic max, and atomic min in the HIP code generation. - Enhanced the handling of atomic operations to include optional memory order parameters. - Improved code readability with added comments explaining the purpose of each atomic operation. - Ensured consistency with existing atomic operation implementations in the codebase. --- src/op/atomic_reduce.cc | 3 +- src/target/codegen_hip.cc | 89 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index 370509bfe..f722e32e0 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -13,6 +13,7 @@ #include "../layout/layout.h" #include "../target/utils.h" + #include "../transform/common/loop_fusion_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" @@ -255,7 +256,7 @@ Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, auto loop_layout = par_op->GetLoopLayout(); auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, - T.layout_map, par_op->GetPredicate(T.thread_var)); + par_op->GetPredicate(T.thread_var)); return lowered_loop; } diff --git a/src/target/codegen_hip.cc b/src/target/codegen_hip.cc index ce904307a..4e1a6e58e 100644 --- a/src/target/codegen_hip.cc +++ b/src/target/codegen_hip.cc @@ -994,6 +994,95 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) { os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; } else if (op->op.same_as(tl::warp_reduce_bitor())) { os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::atomic_add_elem_op())) { + // atomic_add_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAdd(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_add_ret_elem_op())) { + // atomic_add_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value + os << "AtomicAddRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; + } else if (op->op.same_as(tl::atomic_addx2_elem_op())) { + // atomic_addx2_elem_op(dst_ptr, src_ptr[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_ptr = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAddx2(" << dst_ptr << ", " << src_ptr; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_addx4_elem_op())) { + // atomic_addx4_elem_op(dst_ptr, src_ptr[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_ptr = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicAddx4(" << dst_ptr << ", " << src_ptr; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_load_elem_op())) { + // atomic_load_elem_op(src_ptr, memory_order) -> returns loaded value + os << "AtomicLoad(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::atomic_store_elem_op())) { + // atomic_store_elem_op(dst_ptr, value, memory_order) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string value = PrintExpr(op->args[1]); + std::string memory_order = PrintExpr(op->args[2]); + this->PrintIndent(); + this->stream << "AtomicStore(" << dst_ptr << ", " << value << ", " + << memory_order << ");\n"; + } else if (op->op.same_as(tl::atomic_max_elem_op())) { + // atomic_max_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicMax(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_max_ret_elem_op())) { + // atomic_max_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value + os << "AtomicMaxRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; + } else if (op->op.same_as(tl::atomic_min_elem_op())) { + // atomic_min_elem_op(dst_ptr, src_value[, memory_order]) + std::string dst_ptr = PrintExpr(op->args[0]); + std::string src_value = PrintExpr(op->args[1]); + this->PrintIndent(); + this->stream << "AtomicMin(" << dst_ptr << ", " << src_value; + if (op->args.size() > 2) { + this->stream << ", " << PrintExpr(op->args[2]); + } + this->stream << ");\n"; + } else if (op->op.same_as(tl::atomic_min_ret_elem_op())) { + // atomic_min_ret_elem_op(dst_ptr, src_value[, memory_order]) -> returns + // prev value + os << "AtomicMinRet(" << PrintExpr(op->args[0]) << ", " + << PrintExpr(op->args[1]); + if (op->args.size() > 2) { + os << ", " << PrintExpr(op->args[2]); + } + os << ")"; } else { CodeGenC::VisitExpr_(op, os); } From 70d47130fd702052a1dbab17380de30490bd4033 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 19:36:49 +0800 Subject: [PATCH 07/13] Refactor AtomicOpBaseNode and related classes for improved clarity - Changed the GetElemOpStatic method to a virtual GetElemOp method in AtomicOpBaseNode, enhancing polymorphism. - Updated AtomicAddNode, AtomicMaxNode, and AtomicMinNode to override the new GetElemOp method. - Removed unnecessary template parameters from AtomicOpBaseNode, simplifying the class structure. - Cleaned up includes in atomic_reduce.cc to remove unused dependencies, improving code organization. --- src/op/atomic_add.cc | 2 +- src/op/atomic_add.h | 4 ++-- src/op/atomic_reduce.cc | 47 +++++++++++++++-------------------------- src/op/atomic_reduce.h | 25 ++++++++-------------- 4 files changed, 29 insertions(+), 49 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index bd8b52d1e..1879c21db 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -74,7 +74,7 @@ TileOperator AtomicAddNode::Clone() const { return AtomicAdd(op); } -const Op &AtomicAddNode::GetElemOpStatic() { return atomic_add_elem_op(); } +const Op &AtomicAddNode::GetElemOp() const { return atomic_add_elem_op(); } /** * @brief Get vectorization length based on dst dtype and target SM version. diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 524868e1b..9cff9f278 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -18,7 +18,7 @@ using namespace tir; * * Inherits from AtomicOpBaseNode and adds TMA support and vectorization. */ -class AtomicAddNode : public AtomicOpBaseNode { +class AtomicAddNode : public AtomicOpBaseNode { public: TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode, TileOperatorNode); @@ -30,7 +30,7 @@ class AtomicAddNode : public AtomicOpBaseNode { LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; static const Op &Get(); - static const Op &GetElemOpStatic(); + const Op &GetElemOp() const override; TileOperator Clone() const; static void RegisterReflection() { diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index f722e32e0..ee18b1b2d 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -5,7 +5,6 @@ */ #include "./atomic_reduce.h" -#include "./atomic_add.h" #include "utils.h" #include #include @@ -50,7 +49,7 @@ TileOperator AtomicMaxNode::Clone() const { return AtomicMax(op); } -const Op &AtomicMaxNode::GetElemOpStatic() { return atomic_max_elem_op(); } +const Op &AtomicMaxNode::GetElemOp() const { return atomic_max_elem_op(); } // ============================================================================ // AtomicMin Implementation @@ -79,14 +78,13 @@ TileOperator AtomicMinNode::Clone() const { return AtomicMin(op); } -const Op &AtomicMinNode::GetElemOpStatic() { return atomic_min_elem_op(); } +const Op &AtomicMinNode::GetElemOp() const { return atomic_min_elem_op(); } // ============================================================================ // Common AtomicOpBaseNode Implementation // ============================================================================ -template -Array AtomicOpBaseNode::MakeIterVars() const { +Array AtomicOpBaseNode::MakeIterVars() const { Array loop_vars; size_t idx = 0; for (size_t i = 0; i < src_range.size(); i++) { @@ -100,10 +98,8 @@ Array AtomicOpBaseNode::MakeIterVars() const { return loop_vars; } -template -Array -AtomicOpBaseNode::MakeIndices(const Array &ivs, - int src_dst) const { +Array AtomicOpBaseNode::MakeIndices(const Array &ivs, + int src_dst) const { Array indices; Array ranges = src_dst == 0 ? src_range : dst_range; size_t idx = 0; @@ -121,11 +117,10 @@ AtomicOpBaseNode::MakeIndices(const Array &ivs, return indices; } -template -PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, - const Array &ivs, - Array extents, - int src_dst) const { +PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, + int src_dst) const { Array ranges = src_dst == 0 ? src_range : dst_range; Array cond_list; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; @@ -153,8 +148,7 @@ PrimExpr AtomicOpBaseNode::MakePredicate(arith::Analyzer *analyzer, } } -template -For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { +For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.empty(); if (is_scalar) { @@ -191,9 +185,9 @@ For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { new_args.push_back(dst_ptr); new_args.push_back(src_value); - new_args.push_back(static_cast(this)->GetMemoryOrder()); + new_args.push_back(GetMemoryOrder()); - // Use the appropriate elem_op based on the derived type (via CRTP) + // Use the appropriate elem_op based on the derived type (via virtual call) Call atomic_call = tvm::tir::Call(dst->dtype, GetElemOp(), new_args, annotations); @@ -214,9 +208,8 @@ For AtomicOpBaseNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } -template -LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { // For atomic reduce operations, check that src and dst have the same layout // if both are fragments if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { @@ -233,9 +226,8 @@ LayoutMap AtomicOpBaseNode::InferLayout(const LayoutInferArgs &T, return {}; } -template -Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, - arith::Analyzer *analyzer) const { +Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { Target target = T.target; auto simt_loop = MakeSIMTLoop(analyzer); @@ -256,15 +248,10 @@ Stmt AtomicOpBaseNode::Lower(const LowerArgs &T, auto loop_layout = par_op->GetLoopLayout(); auto lowered_loop = LowerParallelLoop(fused_loop, loop_layout, T.thread_var, analyzer, - par_op->GetPredicate(T.thread_var)); + T.layout_map, par_op->GetPredicate(T.thread_var)); return lowered_loop; } -// Explicit template instantiations -template class AtomicOpBaseNode; -template class AtomicOpBaseNode; -template class AtomicOpBaseNode; - // ============================================================================ // Operator Registration // ============================================================================ diff --git a/src/op/atomic_reduce.h b/src/op/atomic_reduce.h index 144cc7914..bdfb12ca8 100644 --- a/src/op/atomic_reduce.h +++ b/src/op/atomic_reduce.h @@ -17,13 +17,11 @@ using namespace tir; /*! * \brief Base node class for atomic operations (add/max/min). * - * This template base class provides common functionality for all atomic + * This base class provides common functionality for all atomic * operations including buffer management, loop generation, and layout * inference. - * - * \tparam Derived The derived class type (CRTP pattern) */ -template class AtomicOpBaseNode : public TileOperatorNode { +class AtomicOpBaseNode : public TileOperatorNode { public: Buffer src, dst; ///< Source and destination buffers Array src_range, @@ -51,6 +49,9 @@ template class AtomicOpBaseNode : public TileOperatorNode { return 0; } + /// Get the element-wise operation Op (pure virtual, implemented by derived) + virtual const Op &GetElemOp() const = 0; + protected: /// Create SIMT-style parallel loop structure For MakeSIMTLoop(arith::Analyzer *analyzer) const; @@ -64,24 +65,16 @@ template class AtomicOpBaseNode : public TileOperatorNode { /// Create boundary predicate for memory safety PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const; - - /// Get the element-wise operation Op (to be implemented by derived class) - /// This uses CRTP to call the derived class's static method - const Op &GetElemOp() const { return Derived::GetElemOpStatic(); } }; -// Backward compatibility alias -template -using AtomicReduceBaseNode = AtomicOpBaseNode; - /// Node class for atomic maximum operations -class AtomicMaxNode : public AtomicOpBaseNode { +class AtomicMaxNode : public AtomicOpBaseNode { public: TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicMax", AtomicMaxNode, TileOperatorNode); static const Op &Get(); - static const Op &GetElemOpStatic(); + const Op &GetElemOp() const override; TileOperator Clone() const; static void RegisterReflection() { @@ -107,13 +100,13 @@ class AtomicMax : public TileOperator { }; /// Node class for atomic minimum operations -class AtomicMinNode : public AtomicOpBaseNode { +class AtomicMinNode : public AtomicOpBaseNode { public: TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicMin", AtomicMinNode, TileOperatorNode); static const Op &Get(); - static const Op &GetElemOpStatic(); + const Op &GetElemOp() const override; TileOperator Clone() const; static void RegisterReflection() { From 3f1aa47c32f0fa5ab3250086fceb9d7c02a1bcee Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 20:28:14 +0800 Subject: [PATCH 08/13] Enhance argument validation for atomic operations - Added checks in AtomicAdd, AtomicMax, and AtomicMin constructors to ensure at least two arguments are provided, improving error handling and user feedback. - Removed the unused LowerTMA method declaration from the AtomicAdd class, streamlining the codebase. --- src/op/atomic_add.cc | 1 + src/op/atomic_add.h | 2 -- src/op/atomic_reduce.cc | 2 ++ src/transform/loop_vectorize.cc | 2 +- 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 1879c21db..7aa0fb6b6 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -42,6 +42,7 @@ using namespace tir; * - The constructed node is stored in this->data_. */ AtomicAdd::AtomicAdd(Array args, Map annotations) { + ICHECK(args.size() >= 2) << "AtomicAdd expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 9cff9f278..824f6936a 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -66,8 +66,6 @@ class AtomicAddNode : public AtomicOpBaseNode { /// Compute linear layout for shared tensor (used in TMA atomic add) Layout ComputeLinearLayout(const Buffer &shared_tensor) const; - /// Lower TMA-based atomic add - Stmt LowerTMA(const LowerArgs &T, arith::Analyzer *analyzer) const; }; /// Wrapper class for atomic addition operations diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index ee18b1b2d..df4cc6996 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -27,6 +27,7 @@ using namespace tir; // ============================================================================ AtomicMax::AtomicMax(Array args, Map annotations) { + ICHECK(args.size() >= 2) << "AtomicMax expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; @@ -56,6 +57,7 @@ const Op &AtomicMaxNode::GetElemOp() const { return atomic_max_elem_op(); } // ============================================================================ AtomicMin::AtomicMin(Array args, Map annotations) { + ICHECK(args.size() >= 2) << "AtomicMin expects at least 2 arguments (src, dst), got " << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 4c941181a..bdbc51513 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -212,7 +212,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } else { // Other calls should not be vectorized vector_size_ = 1; - return Downcast(node); + return ffi::GetRef(node); } return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } From 511ddef14fc71ab2baf7864fe186708bd13eb5cc Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 20:28:50 +0800 Subject: [PATCH 09/13] lint fix --- src/op/atomic_add.cc | 4 +++- src/op/atomic_add.h | 1 - src/op/atomic_reduce.cc | 8 ++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 7aa0fb6b6..02d93cad1 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -42,7 +42,9 @@ using namespace tir; * - The constructed node is stored in this->data_. */ AtomicAdd::AtomicAdd(Array args, Map annotations) { - ICHECK(args.size() >= 2) << "AtomicAdd expects at least 2 arguments (src, dst), got " << args.size(); + ICHECK(args.size() >= 2) + << "AtomicAdd expects at least 2 arguments (src, dst), got " + << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 824f6936a..1b8752828 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -65,7 +65,6 @@ class AtomicAddNode : public AtomicOpBaseNode { /// Compute linear layout for shared tensor (used in TMA atomic add) Layout ComputeLinearLayout(const Buffer &shared_tensor) const; - }; /// Wrapper class for atomic addition operations diff --git a/src/op/atomic_reduce.cc b/src/op/atomic_reduce.cc index df4cc6996..1bafa6f0a 100644 --- a/src/op/atomic_reduce.cc +++ b/src/op/atomic_reduce.cc @@ -27,7 +27,9 @@ using namespace tir; // ============================================================================ AtomicMax::AtomicMax(Array args, Map annotations) { - ICHECK(args.size() >= 2) << "AtomicMax expects at least 2 arguments (src, dst), got " << args.size(); + ICHECK(args.size() >= 2) + << "AtomicMax expects at least 2 arguments (src, dst), got " + << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; @@ -57,7 +59,9 @@ const Op &AtomicMaxNode::GetElemOp() const { return atomic_max_elem_op(); } // ============================================================================ AtomicMin::AtomicMin(Array args, Map annotations) { - ICHECK(args.size() >= 2) << "AtomicMin expects at least 2 arguments (src, dst), got " << args.size(); + ICHECK(args.size() >= 2) + << "AtomicMin expects at least 2 arguments (src, dst), got " + << args.size(); ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; From 65cfad7153f110a7b909be901763d9bb038b10a5 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 20:44:43 +0800 Subject: [PATCH 10/13] fix --- src/op/copy.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op/copy.cc b/src/op/copy.cc index fd85bbb77..711d87afc 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -724,7 +724,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, << dst.scope() << " buffer `" << dst->name << "` may cause conflicted write."; } - vectorized_thread_loop = VectorizeLoop(transformed_loop, T.layout_map); + vectorized_thread_loop = VectorizeLoop(fused_loop, T.layout_map); return vectorized_thread_loop; } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, From 351003a648d047b578d4f88f10c0f5651ab656ee Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 21:22:24 +0800 Subject: [PATCH 11/13] test fix --- testing/python/language/test_tilelang_language_atomic.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/testing/python/language/test_tilelang_language_atomic.py b/testing/python/language/test_tilelang_language_atomic.py index fc8d5278e..a96d0ac6f 100644 --- a/testing/python/language/test_tilelang_language_atomic.py +++ b/testing/python/language/test_tilelang_language_atomic.py @@ -4,12 +4,7 @@ import torch -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32): @T.prim_func def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): From 116919991e91b00faeaa3faa39cad5076c0c7003 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 15 Jan 2026 23:27:39 +0800 Subject: [PATCH 12/13] performance fix --- src/transform/loop_vectorize.cc | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index bdbc51513..e87b912ab 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -209,6 +209,14 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { // address_of have buffer load value so we should analysis the buffer load // node to update vector_size_. return arith::IRMutatorWithAnalyzer::VisitExpr_(node); + } else if (node->op.same_as(tir::builtin::bitwise_and()) || + node->op.same_as(tir::builtin::bitwise_or()) || + node->op.same_as(tir::builtin::bitwise_xor()) || + node->op.same_as(tir::builtin::bitwise_not()) || + node->op.same_as(tir::builtin::shift_left()) || + node->op.same_as(tir::builtin::shift_right())) { + // Bitwise operations can be vectorized + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else { // Other calls should not be vectorized vector_size_ = 1; @@ -239,7 +247,6 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { // forward indices auto layout = layout_map_[buffer]; transformed_indices = layout->Forward(indices); - // Reshape transformed_indices to match buffer->shape dimensions if needed if (transformed_indices.size() != buffer->shape.size()) { // Step 1: Compute linear offset using layout->OutputShape() From ec37e0dd2d2eabeb5f82a588f47c2010f5520520 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 16 Jan 2026 00:01:36 +0800 Subject: [PATCH 13/13] remove useless comments --- src/transform/loop_vectorize.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index e87b912ab..8b29282c6 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -202,8 +202,6 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { } vector_size_ = arith::ZeroAwareGCD(vector_size_, vectorize_length); - // Do not visit the args of atomic_add_elem_op, because pointer type - // is impossible to vectorize return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } else if (node->op == builtin::address_of()) { // address_of have buffer load value so we should analysis the buffer load