diff --git a/3rdparty/tvm b/3rdparty/tvm index b487ec426..23bce012f 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit b487ec4267d4890fddd7a5417e75b1a1fa553c06 +Subproject commit 23bce012ffd255a24289eea6ceab74a40b94a096 diff --git a/src/layout/layout.h b/src/layout/layout.h index 2e7c8a40c..073718342 100644 --- a/src/layout/layout.h +++ b/src/layout/layout.h @@ -258,6 +258,12 @@ Layout makeQuarterBankSwizzleLayout(int stride, int continuous, namespace attr { // BlockAttr, Containing the layout for all the buffers in the block constexpr const char *kLayoutMap = "layout_map"; +// ForAttr, Containing the parallel loop layout for a parallel for loop +constexpr const char *kParallelLoopLayout = "parallel_loop_layout"; +// ForAttr, Containing the predicate for a parallel for loop +constexpr const char *kParallelLoopPredicate = "parallel_loop_predicate"; +// ForAttr, Width (in elements) for coalesced memory access +constexpr const char *kCoalescedWidth = "coalesced_width"; } // namespace attr } // namespace tl diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 5925d38fd..408895966 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -10,6 +10,7 @@ #include #include +#include "../layout/layout.h" #include "../target/utils.h" #include "../transform/atomicadd_vectorize.h" #include "../transform/common/loop_fusion_utils.h" @@ -23,23 +24,24 @@ namespace tl { using namespace tir; /** - * @brief Construct an AtomicAdd operator from call arguments and a buffer map. + * @brief Construct an AtomicAdd operator from call arguments and annotations. * * Builds the internal AtomicAddNode, extracts the source and destination * regions and their backing Buffers from the first two region-style expressions * in `args` (BufferLoad/BufferRegion), and stores them along with their - * ranges. If a third argument is provided, it is interpreted as an integer - * immediate and stored as the node's coalesced width. + * ranges. Annotations are copied directly from the Call node. * * @param args Call-style PrimExprs where: * - args[0] is the source region call, - * - args[1] is the destination region call, - * - args[2] (optional) is an IntImm specifying coalesced width. + * - args[1] is the destination region call. + * @param annotations Map containing optional keys: + * - "use_tma": whether to use TMA for memory operations + * - "memory_order": memory order for atomic operations * Notes: * - The constructor checks that args[0] and args[1] are region-compatible. * - The constructed node is stored in this->data_. */ -AtomicAdd::AtomicAdd(Array args) { +AtomicAdd::AtomicAdd(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; @@ -50,16 +52,8 @@ AtomicAdd::AtomicAdd(Array args) { } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); - if (args.size() >= 3) { - node->use_tma = Downcast(args[2]); - } - node->memory_order = IntImm(0); - if (args.size() >= 4) { - node->memory_order = Downcast(args[3]); - } - if (args.size() >= 5) { - node->coalesced_width = Downcast(args[4]); - } + // Copy annotations from the Call node + node->annotations = annotations; data_ = std::move(node); } @@ -284,7 +278,7 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { new_args.push_back(dst_ptr); new_args.push_back(src_value); - new_args.push_back(memory_order); + new_args.push_back(GetMemoryOrder()); Call atomicadd_call = tvm::tir::Call(dst->dtype, atomicadd_elem_op(), new_args); @@ -292,13 +286,14 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Stmt body = tvm::tir::Evaluate(atomicadd_call); for (int i = loop_vars.size() - 1; i >= 0; i--) { - Map annotations = {}; - if (coalesced_width.defined()) { - annotations.Set("coalesced_width", coalesced_width); + Map loop_annotations; + if (annotations.count(attr::kCoalescedWidth)) { + loop_annotations.Set(attr::kCoalescedWidth, + annotations.Get(attr::kCoalescedWidth).value()); } body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, - ForKind::kParallel, body, std::nullopt, annotations); + ForKind::kParallel, body, std::nullopt, loop_annotations); } return Downcast(body); } @@ -377,7 +372,7 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, */ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; - if (use_tma->value != 0) { + if (GetUseTMA()) { Array src_indices, dst_indices; PrimExpr src_size, dst_size; std::tie(src_indices, src_size) = ReturnIndicesAndSize(0); @@ -487,7 +482,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int sm = GetArchInt(target); auto plan = planner.Plan(loop, sm); int vec = std::max(plan.vector_size, 1); - if (auto cw = loop->annotations.Get("coalesced_width")) { + if (auto cw = loop->annotations.Get(attr::kCoalescedWidth)) { if (const auto *imm = cw->as()) { int expected = imm->value; ICHECK_GT(expected, 0); diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index c6beb70eb..56f48839f 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -19,10 +19,12 @@ class AtomicAddNode : public TileOperatorNode { public: Buffer src, dst; ///< Source and destination buffers Array src_range, - dst_range; ///< Access ranges for source and destination - IntImm use_tma; ///< Whether to use TMA for memory operations - IntImm coalesced_width; ///< Width for memory coalescing optimization - IntImm memory_order; ///< Memory order for atomic operations + dst_range; ///< Access ranges for source and destination + Map annotations; ///< Annotations for the atomic operation + // Supported annotation keys: + // - "use_tma": IntImm, whether to use TMA for memory operations + // - "coalesced_width": IntImm, width for memory coalescing optimization + // - "memory_order": IntImm, memory order for atomic operations mutable ParallelOp par_op_; ///< Associated parallel operation TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.AtomicAdd", AtomicAddNode, @@ -41,9 +43,26 @@ class AtomicAddNode : public TileOperatorNode { .def_ro("dst", &AtomicAddNode::dst) .def_ro("src_range", &AtomicAddNode::src_range) .def_ro("dst_range", &AtomicAddNode::dst_range) - .def_ro("use_tma", &AtomicAddNode::use_tma) - .def_ro("coalesced_width", &AtomicAddNode::coalesced_width) - .def_ro("memory_order", &AtomicAddNode::memory_order); + .def_ro("annotations", &AtomicAddNode::annotations); + } + + // Helper methods to get annotation values + bool GetUseTMA() const { + if (auto val = annotations.Get("use_tma")) { + if (auto int_val = val->as()) { + return int_val->value != 0; + } + } + return false; + } + + int GetMemoryOrder() const { + if (auto val = annotations.Get("memory_order")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; // default: relaxed } protected: @@ -65,7 +84,9 @@ class AtomicAdd : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(AtomicAdd, TileOperator, AtomicAddNode); - TVM_DLL AtomicAdd(Array args); + TVM_DLL + AtomicAdd(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/copy.cc b/src/op/copy.cc index bda21e9eb..23770fc95 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -99,11 +99,11 @@ template static Array ReverseArray(Array array) { return Array{array.rbegin(), array.rend()}; } -// Constructs a Copy operator node from call arguments. +// Constructs a Copy operator node from call arguments and annotations. // args[0]: source region, args[1]: destination region -// Optional: args[2] coalesced_width, args[3] disable_tma, args[4] -// eviction_policy -Copy::Copy(Array args) { +// annotations: Map containing coalesced_width, disable_tma, eviction_policy, +// etc. +Copy::Copy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); Array rgs[2]; Buffer bf[2]; @@ -114,18 +114,8 @@ Copy::Copy(Array args) { } std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); - if (args.size() >= 3) { - auto coalesced_width = Downcast(args[2]); - if (coalesced_width->value > 0) { - node->coalesced_width = coalesced_width; - } - } - if (args.size() >= 4) { - node->disable_tma = Downcast(args[3]); - } - if (args.size() >= 5) { - node->eviction_policy = args[4].as()->value; - } + // Copy annotations from the Call node + node->annotations = annotations; data_ = std::move(node); } @@ -323,12 +313,13 @@ For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return For(Var("i"), 0, 1, ForKind::kSerial, body); } for (int i = loop_vars.size() - 1; i >= 0; i--) { - Map annotations = {}; - if (coalesced_width.defined()) { - annotations.Set("coalesced_width", coalesced_width); + Map loop_annotations; + if (annotations.count(attr::kCoalescedWidth)) { + loop_annotations.Set(attr::kCoalescedWidth, + annotations.Get(attr::kCoalescedWidth).value()); } body = For(loop_vars[i]->var, 0, loop_vars[i]->dom->extent, - ForKind::kParallel, body, std::nullopt, annotations); + ForKind::kParallel, body, std::nullopt, loop_annotations); } return Downcast(body); } @@ -361,7 +352,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); - auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, + auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, T.analyzer, T.buffer_oob); // Handle tensor memory (tmem) layout inference @@ -736,7 +727,7 @@ Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { PassContext pass_ctx = PassContext::Current(); bool disable_tma_lower = pass_ctx->GetConfig(kDisableTMALower, Bool(false)).value(); - auto copy_inst = GetCopyInst(target, disable_tma_lower || disable_tma, + auto copy_inst = GetCopyInst(target, disable_tma_lower || GetDisableTMA(), T.layout_map, analyzer); if (copy_inst == CopyInst::kTMemLoad || copy_inst == CopyInst::kTMemStore) { auto tmem_copy = LowerTmemCopy(T, analyzer); @@ -783,6 +774,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, << "` may cause conflicted write."; } vectorized_thread_loop = VectorizeLoop(transformed_loop); + return vectorized_thread_loop; } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; @@ -797,17 +789,11 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, level); } auto loop_layout = par_op->GetLoopLayout(); - auto thread_var = T.thread_var; - auto thread_loop = - PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); - vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); - } - - if (par_op->GetPredicate(T.thread_var).defined()) { - return IfThenElse(par_op->GetPredicate(T.thread_var).value(), - vectorized_thread_loop); + // Use LowerParallelLoop to handle partitioning, vectorization, and + // predicate + return LowerParallelLoop(par_op->GetRoot(), loop_layout, T.thread_var, + analyzer, par_op->GetPredicate(T.thread_var)); } - return vectorized_thread_loop; } // Lowers copy to LDSM/STSM (warp-level 8x8 matrix) instructions. @@ -1452,7 +1438,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, int need_reduce = 0; if (!is_load) args.push_back(need_reduce); - args.push_back(this->eviction_policy); + args.push_back(GetEvictionPolicy()); tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, Evaluate(Call(DataType::Handle(), op, args))); } else { @@ -1464,7 +1450,7 @@ Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, int need_reduce = 0; if (!is_load) args.push_back(need_reduce); - args.push_back(this->eviction_policy); + args.push_back(GetEvictionPolicy()); tma_copy = Evaluate(Call(DataType::Handle(), op, args)); } tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); @@ -1536,13 +1522,13 @@ Stmt CopyNode::LowerBulkCopy1D(const LowerArgs &T, arith::Analyzer *analyzer, tma_copy = Evaluate( Call(DataType::Handle(), tma_load(), {shared_addr, global_addr, 0, - elements * shared_tensor->dtype.bytes(), this->eviction_policy})); + elements * shared_tensor->dtype.bytes(), GetEvictionPolicy()})); } else { int need_reduce = 0; tma_copy = Evaluate( Call(DataType::Handle(), tma_store(), {global_addr, shared_addr, elements * shared_tensor->dtype.bytes(), - need_reduce, this->eviction_policy})); + need_reduce, GetEvictionPolicy()})); } tma_copy = IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_copy); return tma_copy; @@ -1575,7 +1561,8 @@ Array TMADesc::EncodeCallArgs() const { // Constructs a Conv2DIm2ColOp node from call arguments. // args: src, dst, nhw_step, c_step, kernel, stride, dilation, padding, // eviction_policy -Conv2DIm2ColOp::Conv2DIm2ColOp(Array args) { +Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, + Map annotations) { ObjectPtr node = tvm::ffi::make_object(); node->srcRegion_ = NormalizeToBufferRegion(args[0]); diff --git a/src/op/copy.h b/src/op/copy.h index be6331c54..6009c7ce0 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -115,18 +115,15 @@ class CopyNode : public TileOperatorNode { public: Buffer src, dst; // Source and destination buffers Array src_range, dst_range; // Ranges for each dimension in src and dst - IntImm coalesced_width; // Width (in elements) for coalesced memory access - Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + Map annotations; // Annotations for the copy operation + // Supported annotation keys: + // - "coalesced_width": IntImm, width for coalesced memory access + // - "disable_tma": Bool, whether to disable TMA acceleration + // - "eviction_policy": IntImm, cache eviction policy (0=normal, 1=first, + // 2=last) mutable ParallelOp par_op_; // Optional associated parallelization operator - enum class EvictionPolicy : uint8_t { - kEvictNormal = 0, - kEvictFirst = 1, - kEvictLast = 2, - }; - - uint8_t eviction_policy; // Policy for cache eviction TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Copy", CopyNode, TileOperatorNode); static void RegisterReflection() { @@ -136,7 +133,26 @@ class CopyNode : public TileOperatorNode { .def_ro("dst", &CopyNode::dst) .def_ro("src_range", &CopyNode::src_range) .def_ro("dst_range", &CopyNode::dst_range) - .def_ro("coalesced_width", &CopyNode::coalesced_width); + .def_ro("annotations", &CopyNode::annotations); + } + + // Helper methods to get annotation values + bool GetDisableTMA() const { + if (auto val = annotations.Get("disable_tma")) { + if (auto int_val = val->as()) { + return int_val->value != 0; + } + } + return false; + } + + int GetEvictionPolicy() const { + if (auto val = annotations.Get("eviction_policy")) { + if (auto int_val = val->as()) { + return int_val->value; + } + } + return 0; // default: evict_normal } /*! @@ -326,9 +342,10 @@ class Copy : public TileOperator { /*! * \brief Constructor. * \param args Expression arguments for the copy. - * \param vmap Buffer variable mapping. + * \param annotations Annotations map from the Call node. */ - TVM_DLL Copy(Array args); + TVM_DLL Copy(Array args, + Map annotations = Map()); /*! * \brief Get the TVM Op handle corresponding to this Copy op. @@ -394,7 +411,9 @@ class Conv2DIm2ColOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Conv2DIm2ColOp, TileOperator, Conv2DIm2ColOpNode); - TVM_DLL Conv2DIm2ColOp(Array args); + TVM_DLL + Conv2DIm2ColOp(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/fill.cc b/src/op/fill.cc index 17de15445..02962d242 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -58,7 +58,7 @@ using namespace tir; * lanes) and will terminate (via CHECK/ICHECK) if inputs are unsupported or out * of bounds. */ -Fill::Fill(Array args) { +Fill::Fill(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); BufferRegion region = NormalizeToBufferRegion(args[0]); diff --git a/src/op/fill.h b/src/op/fill.h index c10a5cfb1..b5734ad56 100644 --- a/src/op/fill.h +++ b/src/op/fill.h @@ -45,7 +45,8 @@ class FillNode : public TileOperatorNode { class Fill : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fill, TileOperator, FillNode); - TVM_DLL Fill(Array args); + TVM_DLL Fill(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index 97cdbe81c..e9e2fca54 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -31,7 +31,8 @@ using namespace tir; * `args[0]` is an access pointer identifying the reducer variable * and `args[1]` is an integer encoding a `ReducerOpType` (e.g., Sum/Max/Min). */ -FinalizeReducerOp::FinalizeReducerOp(Array args) { +FinalizeReducerOp::FinalizeReducerOp(Array args, + Map annotations) { auto node = tvm::ffi::make_object(); // Normalize any supported region expression // (BufferRegion/BufferLoad/tl.region) to a BufferRegion, then take the diff --git a/src/op/finalize_reducer.h b/src/op/finalize_reducer.h index 99e1e7cbf..a3903ed14 100644 --- a/src/op/finalize_reducer.h +++ b/src/op/finalize_reducer.h @@ -48,7 +48,9 @@ class FinalizeReducerOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(FinalizeReducerOp, TileOperator, FinalizeReducerOpNode); - TVM_DLL FinalizeReducerOp(Array args); + TVM_DLL FinalizeReducerOp( + Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5a8fa3070..7ad8b8c1e 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -50,7 +50,7 @@ using namespace tir; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} -Gemm::Gemm(Array args) { +Gemm::Gemm(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); node->aRegion_ = NormalizeToBufferRegion(args[0]); diff --git a/src/op/gemm.h b/src/op/gemm.h index fb3d5c0f6..fd2733882 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -177,7 +177,8 @@ class GemmNode : public TileOperatorNode { class Gemm : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Gemm, TileOperator, GemmNode); - TVM_DLL Gemm(Array args); + TVM_DLL Gemm(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index 8f6201c64..c68861814 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -50,7 +50,7 @@ using namespace tir; * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ -GemmPy::GemmPy(Array args) { +GemmPy::GemmPy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); node->aRegion_ = NormalizeToBufferRegion(args[0]); diff --git a/src/op/gemm_py.h b/src/op/gemm_py.h index 2fe47be88..d6468a0bf 100644 --- a/src/op/gemm_py.h +++ b/src/op/gemm_py.h @@ -83,7 +83,8 @@ class GemmPyNode : public TileOperatorNode { class GemmPy : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmPy, TileOperator, GemmPyNode); - TVM_DLL GemmPy(Array args); + TVM_DLL GemmPy(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 828953460..acff1ff7b 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -83,7 +83,7 @@ std::pair GemmSPWarpPolicyNode::computeWarpPartition(int M, int N, * * @note An ICHECK failure is raised if a provided kPack is not 1 or 2. */ -GemmSP::GemmSP(Array args) { +GemmSP::GemmSP(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); node->aRegion_ = NormalizeToBufferRegion(args[0]); node->eRegion_ = NormalizeToBufferRegion(args[1]); diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index a634e922f..a00773801 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -109,7 +109,8 @@ class GemmSPNode : public TileOperatorNode { class GemmSP : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSP, TileOperator, GemmSPNode); - TVM_DLL GemmSP(Array args); + TVM_DLL GemmSP(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/gemm_sp_py.cc b/src/op/gemm_sp_py.cc index 6ad8ca9b5..f66c8506a 100644 --- a/src/op/gemm_sp_py.cc +++ b/src/op/gemm_sp_py.cc @@ -49,7 +49,7 @@ using namespace tir; * fails with an ICHECK (runtime assertion). No other validation is * performed here. */ -GemmSPPy::GemmSPPy(Array args) { +GemmSPPy::GemmSPPy(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); node->aRegion_ = NormalizeToBufferRegion(args[0]); diff --git a/src/op/gemm_sp_py.h b/src/op/gemm_sp_py.h index b23b9fc5c..59c276f16 100644 --- a/src/op/gemm_sp_py.h +++ b/src/op/gemm_sp_py.h @@ -84,7 +84,9 @@ class GemmSPPy : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(GemmSPPy, TileOperator, GemmSPPyNode); - TVM_DLL GemmSPPy(Array args); + TVM_DLL + GemmSPPy(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/operator.cc b/src/op/operator.cc index 302ee3e37..0a8f6b8b8 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -31,7 +31,7 @@ TileOperator ParseOperator(Call call) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); if (op_map.count(op)) { - auto tile_op = op_map[op](call->args); + auto tile_op = op_map[op](call->args, call->annotations); ICHECK(tile_op.defined()); return tile_op; } diff --git a/src/op/operator.h b/src/op/operator.h index 2a508e31a..ddbe1fa6b 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -95,7 +95,8 @@ Var GetVarFromAccessPtr(const PrimExpr &expr); TileOperator ParseOperator(Call call); TileOperator ParseOperator(Stmt stmt); -using OpBuilderFunc = ffi::TypedFunction)>; +using OpBuilderFunc = + ffi::TypedFunction, Map)>; #define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \ const Op &Entry::Get() { \ @@ -105,7 +106,10 @@ using OpBuilderFunc = ffi::TypedFunction)>; TVM_REGISTER_OP("tl.tileop." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr( \ - "TLOpBuilder", [](Array args) { return Entry(args); }) + "TLOpBuilder", \ + [](Array args, Map annotations) { \ + return Entry(args, annotations); \ + }) } // namespace tl } // namespace tvm diff --git a/src/op/parallel.cc b/src/op/parallel.cc index fb636eff7..6290c3361 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -19,11 +19,6 @@ namespace tl { using namespace tir; -namespace attr { -/*! \brief Mark that how the loop is vectorized. */ -constexpr const char *coalesced_width = "coalesced_width"; -} // namespace attr - // ProveFragmentContains checks whether the threads that access elements of a // smaller fragment (small_frag) are a subset of the threads that access // elements of a larger fragment (large_frag) for any given loop index. This @@ -517,7 +512,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Check if coalesced_width is defined if (auto coalesced_width = - root_->annotations.Get(tl::attr::coalesced_width)) { + root_->annotations.Get(attr::kCoalescedWidth)) { if (const auto *imm = coalesced_width->as()) { int expected = imm->value; // Verify that vector_size is divisible by expected diff --git a/src/op/reduce.cc b/src/op/reduce.cc index ecd8860bb..896a28c04 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -28,7 +28,7 @@ using namespace tir; // MakeAccessPtrFromRegion moved to src/op/utils.{h,cc} -ReduceOp::ReduceOp(Array args) { +ReduceOp::ReduceOp(Array args, Map annotations) { ObjectPtr node = tvm::ffi::make_object(); // Accept BufferRegion/BufferLoad for src/dst node->srcRegion_ = NormalizeToBufferRegion(args[0]); @@ -495,7 +495,7 @@ static BufferRegion ConvertBufferToBufferRegion(const Buffer &buf) { return BufferRegion(buf, ranges); } -CumSumOp::CumSumOp(Array args) { +CumSumOp::CumSumOp(Array args, Map annotations) { /// CumSum constructor arguments: /// - src: input buffer /// - dst: output buffer diff --git a/src/op/reduce.h b/src/op/reduce.h index cab3835e1..9d3fd8c4e 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -125,7 +125,9 @@ class ReduceOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReduceOp, TileOperator, ReduceOpNode); - TVM_DLL ReduceOp(Array args); + TVM_DLL + ReduceOp(Array args, + Map annotations = Map()); static const Op &Get(); }; @@ -163,7 +165,9 @@ class CumSumOp : public TileOperator { public: TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(CumSumOp, TileOperator, CumSumOpNode); - TVM_DLL CumSumOp(Array args); + TVM_DLL + CumSumOp(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/op/region.cc b/src/op/region.cc index 25e78eba8..4776edd55 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -18,7 +18,7 @@ namespace tvm { namespace tl { using namespace tir; -RegionOp::RegionOp(Array args) { +RegionOp::RegionOp(Array args, Map annotations) { size_t n = args.size(); size_t ndim = n - 2; auto load = args[0].as(); diff --git a/src/op/region.h b/src/op/region.h index 24399f7ab..5f013eca6 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -81,7 +81,9 @@ class RegionOp : public TileOperator { * - args[1]: Integer access mask (1=r, 2=w, 3=rw). * - args[2 + i]: Extent of axis i (supports dynamic PrimExpr). */ - TVM_DLL RegionOp(Array args); + TVM_DLL + RegionOp(Array args, + Map annotations = Map()); static const Op &Get(); }; diff --git a/src/transform/inject_pipeline.cc b/src/transform/inject_pipeline.cc index 9df8a5a1a..e106dec61 100644 --- a/src/transform/inject_pipeline.cc +++ b/src/transform/inject_pipeline.cc @@ -219,7 +219,7 @@ class PipelineBodyRewriter : public StmtExprMutator { new_args.Set(i + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } Stmt VisitStmt_(const BlockNode *op) final { diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 0fb9bb931..b9ec1e952 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -29,9 +29,7 @@ #include "common/loop_parallel_transform_utils.h" #include "common/union_find.h" #include "layout_reducer.h" -#include "loop_partition.h" -#include "loop_vectorize.h" -#include "runtime/thread_storage_scope.h" +#include "parallel_loop_layout_validator.h" #include "tir/transforms/ir_utils.h" namespace tvm { @@ -1166,16 +1164,15 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { BufferUseDefCollector collector(skip_thread_partition); collector.Collect(f); auto result = collector.Run(); - LayoutInferencer substituter(result, skip_thread_partition, &analyzer); + LayoutInferencer substituter(result, &analyzer); fptr->body = substituter.VisitStmt(f->body); return f; } private: LayoutInferencer(const LayoutInferenceResult &result, - bool skip_thread_partition, arith::Analyzer *analyzer) - : arith::IRMutatorWithAnalyzer(analyzer), result_(result), - skip_thread_partition_(skip_thread_partition) {}; + arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer), result_(result) {}; using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer; @@ -1208,170 +1205,55 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { } /** - * @brief Visit and transform For nodes according to inferred layout - * information. + * @brief Visit and transform For nodes by storing inferred layout information + * as annotations instead of expanding the loop. * - * If the For node is present in result_.for_map, this method applies - * loop-level layout-driven transformations: it optionally partitions the loop - * across the thread index, vectorizes the loop body, and wraps the loop with - * a predicate if one was inferred for the loop root. + * If the For node is present in result_.for_map, this method stores the + * inferred loop layout and predicate as annotations on the For node, rather + * than performing loop partition and vectorization. * - * Detailed behavior: - * - Reads reducer information from the For node's attr::kReducerInfo - * annotation (if present) to detect reduction targets. - * - Detects register-local buffer stores (buffers with scope "local") in the - * original loop body; if only register-local stores are present the loop is - * treated as a register-local scenario and is not partitioned across - * threads. - * - Obtains the loop layout from result_.for_map[root] and, unless the loop - * is register-local or skip_thread_partition_ is set, partitions the loop via - * PartitionLoop using thread_var_ and analyzer_. - * - Scans the transformed loop body to determine whether it accesses any - * non-local buffers (scopes other than "local" or "local.fragment"). - * - Scans the transformed loop body to detect reducers (based on - * reducer_info). If a reducer is present the loop is NOT vectorized - * (reduction axes are excluded from vectorization as a conservative - * workaround). - * - If the loop has non-local accesses and no reducer, the loop is vectorized - * via VectorizeLoop. - * - If a predicate exists in result_.predicate_map for the loop root and the - * loop was partitioned, the method returns an IfThenElse surrounding the - * (possibly partitioned/vectorized) loop with that predicate; otherwise it - * returns the transformed For. + * The stored annotations are: + * - attr::kParallelLoopLayout: The Fragment layout for the parallel loop + * - attr::kParallelLoopPredicate: The predicate expression (if any) * - * @return The possibly transformed For statement (or an IfThenElse wrapping - * it) + * @return The For statement with layout annotations attached */ Stmt VisitStmt_(const ForNode *op) final { - Map reducer_info; - if (op->annotations.count(attr::kReducerInfo)) - reducer_info = op->annotations.Get(attr::kReducerInfo) - ->as>() - .value(); if (!result_.for_map.count(tvm::ffi::GetRef(op))) { return IRMutatorWithAnalyzer::VisitStmt_(op); } - // the analyzer will be modified in PartitionLoop and VectorizeLoop - // we need to save its state to prevent conflicted bindings - auto saved_analyzer = analyzer_->Clone(); + For for_node = Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto root = tvm::ffi::GetRef(op); - // This check is a workaround to support T.Parallel for local buffers. - // For example: - // for i in T.Parallel(1024): - // A_local[i] = A_global[i] - // Here, A_local is a register-local buffer held independently by each - // thread, so explicit thread binding is not required. - bool store_into_local = false; - PostOrderVisit(root, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - if (IsLocalBuffer(store->buffer)) { - store_into_local = true; - } - // if the case is like: - // for i in T.Parallel(1024): - // A_local[i] = B_global[i] - // A_frag[i] = A_global[i] - // exception will be raise in Parallel::LayoutInference - } - }); - // This check if for the loop that only manuplates "local" buffers, - // for i in T.Parallel(1024): - // A_local[i] = B_local[i] - // Though this might be illegal - // We use PostOrderVisit to detect whether the loop only manuplates - // "local" buffers, which indicates register usage and justifies skipping - // thread binding. - bool local_register_only = true; - PostOrderVisit(root, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - if (!IsLocalBuffer(store->buffer)) { - local_register_only = false; - } - } else if (const auto *load = obj.as()) { - if (!IsLocalBuffer(load->buffer)) { - local_register_only = false; - } - } - }); auto loop_layout = result_.for_map[root]; - // FIXME: tell in-Parallel and out-of-Parallel `local`s apart - // NOTE(lei): a bit ugly, we should rethink about this part in future. - bool parallel_loop = - !skip_thread_partition_ && !local_register_only && !store_into_local; - - if (parallel_loop) { - for_node = - PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout); - } - // If none thread bindings are provided, partition the loop - bool has_non_local = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto *load = obj.as()) { - String scope = load->buffer.scope(); - if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) { - has_non_local = true; - } - } else if (const auto *store = obj.as()) { - String scope = store->buffer.scope(); - if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) { - has_non_local = true; - } - } - }); - // Workaround: if reducer is presented, don't vectorize loop - // Best solution should be isolate reduction axis out of vectorization - bool has_reducer = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (!has_reducer) - if (const auto *store = obj.as()) { - has_reducer = reducer_info.count(store->buffer->data) != 0; - } - }); - // If a cast operation exists, vectorization may still be required - bool has_cast_operations = false; - PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { - if (const auto *cast = obj.as()) { - // Check if this is a non-reducer store with Cast operation - DataType from_ty = cast->value.dtype(); - DataType target_ty = cast->dtype; - if (IsCudaVectorizableCast(from_ty, target_ty) && - TargetIsCuda(Target::Current())) { - has_cast_operations = true; - } + // Store the loop layout as an annotation on the For node + auto for_ptr = for_node.CopyOnWrite(); + for_ptr->annotations.Set(attr::kParallelLoopLayout, loop_layout); + + // Store the predicate as an annotation if it exists and is not trivially + // true + if (result_.predicate_map.count(root)) { + PrimExpr predicate = analyzer_->Simplify(result_.predicate_map[root]); + // Only store predicate if it's not trivially true + if (!is_const_int(predicate, 1)) { + for_ptr->annotations.Set(attr::kParallelLoopPredicate, predicate); } - }); - - if ((has_non_local || has_cast_operations) && !has_reducer) { - DLOG(INFO) << "Try to vectorize loop"; - for_node = VectorizeLoop(for_node, saved_analyzer.get()); } - if (result_.predicate_map.count(root) && parallel_loop) { - return IfThenElse(result_.predicate_map[root], for_node); - } else { - return for_node; - } + return for_node; } Stmt VisitStmt_(const AttrStmtNode *op) final { if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); - ICHECK_NE(iv->thread_tag.length(), 0U); - if (iv->thread_tag == "threadIdx.x") { - thread_var_ = iv; - } } return IRMutatorWithAnalyzer::VisitStmt_(op); } private: const LayoutInferenceResult result_; - IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), - IterVarType::kDataPar); - bool skip_thread_partition_{false}; }; tvm::transform::Pass LayoutInference() { @@ -1382,7 +1264,10 @@ tvm::transform::Pass LayoutInference() { collector(f->body); bool has_thread_binding = !collector.thread_binding_.empty(); bool skip_thread_partition = !has_thread_binding; - return LayoutInferencer::Substitute(std::move(f), skip_thread_partition); + f = LayoutInferencer::Substitute(std::move(f), skip_thread_partition); + // Validate parallel loop layout annotations + ParallelLoopLayoutValidator::Validate(f->body); + return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); } diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index 186af7340..c1117533a 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -29,6 +29,7 @@ #include #include "../op/utils.h" +#include "loop_vectorize.h" namespace tvm { namespace tl { @@ -269,5 +270,31 @@ For LoopPragmaUnroll(For stmt) { return unrolled; } +Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, + arith::Analyzer *analyzer, Optional predicate, + bool parallel_loop, bool should_vectorize) { + // Save analyzer state to prevent conflicted bindings during vectorization + auto saved_analyzer = analyzer->Clone(); + + For result_loop = loop; + + // Step 1: Partition the loop based on the layout (if this is a parallel loop) + if (parallel_loop) { + result_loop = PartitionLoop(result_loop, thread_var, analyzer, loop_layout); + } + + // Step 2: Vectorize the loop (if requested) + if (should_vectorize) { + result_loop = VectorizeLoop(result_loop, saved_analyzer.get()); + } + + // Step 3: Wrap with predicate if provided and this is a parallel loop + if (predicate.defined() && parallel_loop) { + return IfThenElse(predicate.value(), result_loop); + } + + return result_loop; +} + } // namespace tl } // namespace tvm diff --git a/src/transform/loop_partition.h b/src/transform/loop_partition.h index 1103e7515..844065ab3 100644 --- a/src/transform/loop_partition.h +++ b/src/transform/loop_partition.h @@ -26,6 +26,7 @@ #define TVM_TL_LOOP_PARTITION_H_ #include +#include #include "../layout/layout.h" @@ -45,6 +46,31 @@ Fragment PlanLoopPartition(const For &op, int vectorize_size, For LoopPragmaUnroll(For stmt); +/*! + * \brief Lower a parallel loop by partitioning and vectorizing it. + * + * This function combines PartitionLoop and VectorizeLoop into a single + * operation, and optionally wraps the result with an IfThenElse if a + * predicate is provided. + * + * \param loop The parallel For loop to lower. + * \param loop_layout The Fragment layout for partitioning. + * \param thread_var The thread variable for partitioning. + * \param analyzer The arithmetic analyzer. + * \param predicate Optional predicate to wrap the loop with IfThenElse. + * \param parallel_loop Whether this is a true parallel loop requiring thread + * partitioning. False for loops that only operate on local/register + * buffers. (default true) + * \param should_vectorize Whether to vectorize the loop. False when reducers + * are present or when there are no non-local buffer accesses. + * (default true) + * \return The lowered statement. + */ +Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, + arith::Analyzer *analyzer, + Optional predicate = Optional(), + bool parallel_loop = true, bool should_vectorize = true); + } // namespace tl } // namespace tvm diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index c88e05c56..303ac03ee 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -18,8 +18,11 @@ #include "../op/gemm.h" #include "../op/gemm_sp.h" #include "../op/operator.h" +#include "../op/utils.h" +#include "../target/utils.h" #include "arith/ir_mutator_with_analyzer.h" +#include "layout_reducer.h" #include "loop_partition.h" namespace tvm { @@ -304,8 +307,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (!load_expr.same_as(access_ptr_call->args[0])) { auto node = access_ptr_call.CopyOnWrite(); node->args.Set(0, load_expr); - access_ptr_call = Call(access_ptr_call->dtype, access_ptr_call->op, - {load_expr}, access_ptr_call->span); + access_ptr_call = + Call(access_ptr_call->dtype, access_ptr_call->op, {load_expr}, + access_ptr_call->annotations, access_ptr_call->span); } BufferLoad load = Downcast(access_ptr_call->args[0]); Array indices = load->indices; @@ -375,7 +379,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { } result.rewritten = true; result.expr = Call(access_ptr_call->dtype, access_ptr_call->op, new_args, - access_ptr_call->span); + access_ptr_call->annotations, access_ptr_call->span); return result; } else { LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr; @@ -472,7 +476,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { if (!load_expr.same_as(address_of_call->args[0])) { auto call_node = call.CopyOnWrite(); call_node->args.Set(5, Call(address_of_call->dtype, address_of_call->op, - {load_expr}, address_of_call->span)); + {load_expr}, address_of_call->annotations, + address_of_call->span)); address_of_call = Downcast(call->args[5]); access_ptr = call->args[5]; } @@ -664,6 +669,163 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return arith::IRMutatorWithAnalyzer::VisitStmt_(op); } + /** + * @brief Handle a Parallel For node, lowering it based on the layout + * annotation. + * + * This method checks if the For node has a parallel_loop_layout annotation. + * If the For node is a parallel loop (ForKind::kParallel): + * - It must have the parallel_loop_layout annotation, otherwise an error is + * raised. + * - The loop is partitioned and vectorized based on the annotated layout. + * - If a predicate annotation exists, the loop is wrapped with an IfThenElse. + * + * Special handling for reducers and local buffers: + * - If the loop stores into local buffers, thread partitioning is skipped. + * - If the loop only manipulates local buffers, thread partitioning is + * skipped. + * - If reducers are present, vectorization is skipped. + * - Vectorization is only applied if non-local buffers or vectorizable casts + * are present. + * + * @return Stmt The lowered statement. + */ + Stmt VisitStmt_(const ForNode *op) final { + // Extract reducer info from annotations + Map reducer_info; + if (op->annotations.count(attr::kReducerInfo)) { + reducer_info = op->annotations.Get(attr::kReducerInfo) + ->as>() + .value(); + } + + // First visit the body + For for_node = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); + + // Only process parallel loops + if (op->kind != ForKind::kParallel) { + return for_node; + } + + // For nested parallel loops, the annotation is placed on the outermost + // loop. Inner parallel loops without annotation should be skipped here - + // they will be processed as part of the outer loop's partitioning. + if (!op->annotations.count(attr::kParallelLoopLayout)) { + return for_node; + } + + auto loop_layout = Downcast( + op->annotations.Get(attr::kParallelLoopLayout).value()); + + // Get predicate if it exists + Optional predicate; + if (op->annotations.count(attr::kParallelLoopPredicate)) { + predicate = Downcast( + op->annotations.Get(attr::kParallelLoopPredicate).value()); + } + + auto root = tvm::ffi::GetRef(op); + + // Check if the loop stores into local buffers. + // For example: + // for i in T.Parallel(1024): + // A_local[i] = A_global[i] + // Here, A_local is a register-local buffer held independently by each + // thread, so explicit thread binding is not required. + bool store_into_local = false; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (IsLocalBuffer(store->buffer)) { + store_into_local = true; + } + } + }); + + // Check if the loop only manipulates "local" buffers. + // for i in T.Parallel(1024): + // A_local[i] = B_local[i] + // This indicates register usage and justifies skipping thread binding. + bool local_register_only = true; + PostOrderVisit(root, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + if (!IsLocalBuffer(store->buffer)) { + local_register_only = false; + } + } else if (const auto *load = obj.as()) { + if (!IsLocalBuffer(load->buffer)) { + local_register_only = false; + } + } + }); + + // Determine if this is a true parallel loop requiring thread partitioning. + // Skip partitioning for loops that only operate on local/register buffers. + bool parallel_loop = !local_register_only && !store_into_local; + + // Check if there are non-local buffer accesses (for vectorization decision) + bool has_non_local = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *load = obj.as()) { + if (!IsLocalBuffer(load->buffer) && !IsFragmentBuffer(load->buffer)) { + has_non_local = true; + } + } else if (const auto *store = obj.as()) { + if (!IsLocalBuffer(store->buffer) && !IsFragmentBuffer(store->buffer)) { + has_non_local = true; + } + } + }); + + // Check if reducers are present in the loop body + // Workaround: if reducer is presented, don't vectorize loop + // Best solution should be isolate reduction axis out of vectorization + // + // Note: reducer_info stores original buffer data vars, but after visiting + // the body, buffers may have been remapped via var_remap_. We need to find + // the original var to check against reducer_info. + bool has_reducer = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (!has_reducer) { + if (const auto *store = obj.as()) { + Var data_var = store->buffer->data; + // Find the original var if it was remapped + // var_remap_ maps old_var -> new_var, so we need reverse lookup + Var original_var = data_var; + for (const auto &[old_var, new_var] : var_remap_) { + if (new_var.same_as(data_var)) { + original_var = old_var; + break; + } + } + has_reducer = reducer_info.count(original_var) != 0; + } + } + }); + + // Check if vectorizable cast operations exist + bool has_cast_operations = false; + PostOrderVisit(for_node->body, [&](const ObjectRef &obj) { + if (const auto *cast = obj.as()) { + DataType from_ty = cast->value.dtype(); + DataType target_ty = cast->dtype; + if (IsCudaVectorizableCast(from_ty, target_ty) && + TargetIsCuda(Target::Current())) { + has_cast_operations = true; + } + } + }); + + // Decide whether to vectorize: + // - Only if there are non-local buffers or vectorizable casts + // - AND no reducers are present + bool should_vectorize = + (has_non_local || has_cast_operations) && !has_reducer; + + // Lower the parallel loop using the common function + return LowerParallelLoop(for_node, loop_layout, thread_var_->var, analyzer_, + predicate, parallel_loop, should_vectorize); + } + Target target_; Map buffer_data_to_buffer_; Map layout_map_; diff --git a/src/transform/multi_version_buffer_rewriter.cc b/src/transform/multi_version_buffer_rewriter.cc index 7ed9437cf..4075673ec 100644 --- a/src/transform/multi_version_buffer_rewriter.cc +++ b/src/transform/multi_version_buffer_rewriter.cc @@ -469,7 +469,7 @@ class MultiVersionBufferRewriter : public StmtExprMutator { new_args.Set(i + 1, new_index); } } - return Call(call->dtype, call->op, new_args, call->span); + return Call(call->dtype, call->op, new_args, call->annotations, call->span); } PrimExpr version_index_; diff --git a/src/transform/parallel_loop_layout_validator.h b/src/transform/parallel_loop_layout_validator.h new file mode 100644 index 000000000..c4cc2e1fc --- /dev/null +++ b/src/transform/parallel_loop_layout_validator.h @@ -0,0 +1,140 @@ +/*! + * \file parallel_loop_layout_validator.h + * \brief Validator for parallel loop layout annotations. + */ + +#ifndef TVM_TL_TRANSFORM_PARALLEL_LOOP_LAYOUT_VALIDATOR_H_ +#define TVM_TL_TRANSFORM_PARALLEL_LOOP_LAYOUT_VALIDATOR_H_ + +#include + +#include "../layout/layout.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +/*! + * \brief Count the number of consecutive nested parallel loops starting from + * the given For node. + * \param op The outermost For node to start counting from. + * \return The number of consecutive nested parallel loops. + */ +inline int CountNestedParallelLoops(const ForNode *op) { + int count = 0; + const ForNode *current = op; + while (current != nullptr && current->kind == ForKind::kParallel) { + count++; + current = current->body.as(); + } + return count; +} + +/*! + * \brief Validator that checks parallel loop layout annotations. + * + * This validator checks: + * 1. All parallel loops must have layout annotations (either directly or via + * an outer nested parallel loop). + * 2. For nested parallel loops, only the outermost parallel loop should have + * the layout annotation. + * 3. The layout's InputDim must equal the number of consecutive nested + * parallel loops. + */ +class ParallelLoopLayoutValidator : public StmtVisitor { +public: + /*! + * \brief Validate parallel loop layout annotations in the given statement. + * \param stmt The statement to validate. + */ + static void Validate(const Stmt &stmt) { + ParallelLoopLayoutValidator validator; + validator.VisitStmt(stmt); + } + +private: + void VisitStmt_(const ForNode *op) final { + // Only validate parallel loops + if (op->kind != ForKind::kParallel) { + StmtVisitor::VisitStmt_(op); + return; + } + + // Check if this parallel loop has a layout annotation + bool has_layout = op->annotations.count(attr::kParallelLoopLayout) > 0; + + // Count the number of consecutive nested parallel loops + int nested_count = CountNestedParallelLoops(op); + + if (has_layout) { + // This is the outermost parallel loop with layout annotation + auto loop_layout = Downcast( + op->annotations.Get(attr::kParallelLoopLayout).value()); + + // Validate that layout's InputDim matches the number of nested parallel + // loops + int layout_input_dim = static_cast(loop_layout->InputDim()); + ICHECK(layout_input_dim == nested_count) + << "Layout InputDim mismatch for parallel loop.\n" + << "Expected: " << nested_count + << " (number of consecutive nested parallel loops)\n" + << "Got: " << layout_input_dim << " (layout InputDim)\n" + << "Loop: " << tvm::ffi::GetRef(op) << "\n" + << "For nested parallel loops, the layout annotation should be on " + << "the outermost loop, and its InputDim should equal the total " + << "number of nested parallel loops."; + + // Validate that inner parallel loops do NOT have layout annotations + ValidateInnerParallelLoopsNoLayout(op->body, nested_count - 1); + + // Skip visiting inner parallel loops as they are part of this nested + // structure. Visit the body of the innermost parallel loop instead. + const ForNode *innermost = op; + for (int i = 1; i < nested_count; i++) { + innermost = innermost->body.as(); + } + StmtVisitor::VisitStmt(innermost->body); + } else { + // This parallel loop doesn't have a layout annotation + // This is only valid if it's an inner loop of a nested parallel structure + // But since we process from outermost to innermost, if we reach here + // without a layout annotation, it's an error. + LOG(FATAL) + << "Parallel loop missing layout annotation.\n" + << "Loop: " << tvm::ffi::GetRef(op) << "\n" + << "All parallel loops must have a layout annotation after " + << "LayoutInference pass. For nested parallel loops, the annotation " + << "should be on the outermost loop."; + } + } + + /*! + * \brief Validate that inner parallel loops do not have layout annotations. + * \param body The body to check (should be inner parallel loops). + * \param remaining_count Number of remaining inner parallel loops to check. + */ + void ValidateInnerParallelLoopsNoLayout(const Stmt &body, + int remaining_count) { + if (remaining_count <= 0) { + return; + } + + const ForNode *inner_for = body.as(); + ICHECK(inner_for != nullptr && inner_for->kind == ForKind::kParallel) + << "Expected inner parallel loop but found: " << body; + + ICHECK(!inner_for->annotations.count(attr::kParallelLoopLayout)) + << "Inner parallel loop should NOT have layout annotation.\n" + << "Loop: " << tvm::ffi::GetRef(inner_for) << "\n" + << "For nested parallel loops, only the outermost parallel loop " + << "should have the layout annotation."; + + ValidateInnerParallelLoopsNoLayout(inner_for->body, remaining_count - 1); + } +}; + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_TRANSFORM_PARALLEL_LOOP_LAYOUT_VALIDATOR_H_ diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index a801f75f4..30b5f533b 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -217,10 +217,14 @@ def get_extent(data): if return_prev: raise NotImplementedError("return_prev is not supported for tile-region-based atomic operations") - if memory_order is None: - return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) - else: - return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) + # Build annotations dict + ann = {} + if use_tma: + ann["use_tma"] = 1 + if memory_order is not None: + ann["memory_order"] = _MEMORY_ORDER_ID_MAP[memory_order] + + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, annotations=ann if ann else None) def atomic_addx2(dst: Buffer, value: PrimExpr, return_prev: bool = False) -> PrimExpr: diff --git a/tilelang/language/copy_op.py b/tilelang/language/copy_op.py index 0b55c410c..6401520fb 100644 --- a/tilelang/language/copy_op.py +++ b/tilelang/language/copy_op.py @@ -17,6 +17,7 @@ def copy( coalesced_width: int | None = None, disable_tma: bool = False, eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None, + annotations: dict | None = None, ): """Copy data between memory regions. @@ -24,6 +25,11 @@ def copy( src (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Source memory region dst (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]): Destination memory region coalesced_width (Optional[int], optional): Width for coalesced memory access. Defaults to None. + disable_tma (bool, optional): Whether to disable TMA acceleration. Defaults to False. + eviction_policy (Optional[str], optional): Cache eviction policy. Defaults to None. + annotations (Optional[dict], optional): Additional annotations dict. If provided, + coalesced_width, disable_tma, and eviction_policy can also be specified here. + Values in annotations take precedence over individual arguments. Raises: TypeError: If copy extents cannot be deduced from arguments @@ -86,13 +92,19 @@ def get_extent(data): src = to_buffer_region(src, access_type="r", extents=src_extent) dst = to_buffer_region(dst, access_type="w", extents=dst_extent) - if coalesced_width is None: - coalesced_width = -1 # PrimExpr can not be None - if eviction_policy is None: - eviction_policy = 0 - else: - eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) + # Build annotations dict + ann = annotations.copy() if annotations else {} + + # Individual arguments take lower precedence than annotations + if "coalesced_width" not in ann and coalesced_width is not None: + ann["coalesced_width"] = coalesced_width + if "disable_tma" not in ann and disable_tma: + ann["disable_tma"] = disable_tma + if "eviction_policy" not in ann and eviction_policy is not None: + eviction_policy_map = {"evict_normal": 0, "evict_first": 1, "evict_last": 2} + ann["eviction_policy"] = eviction_policy_map[eviction_policy] + + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, annotations=ann if ann else None) def c2d_im2col( diff --git a/tilelang/language/tir/op.py b/tilelang/language/tir/op.py index d622911df..20876a944 100644 --- a/tilelang/language/tir/op.py +++ b/tilelang/language/tir/op.py @@ -117,7 +117,7 @@ def call_cpacked_lowered(*args, span=None): return _tvm_op.call_cpacked_lowered(*args, span=span) -def call_intrin(dtype, func_name, *args, span=None): +def call_intrin(dtype, func_name, *args, annotations=None, span=None): """Build expression by calling an intrinsic function. Intrinsics can be overloaded with multiple data types via @@ -142,7 +142,7 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return _tvm_op.call_intrin(dtype, func_name, *args, span=span) + return _tvm_op.call_intrin(dtype, func_name, *args, annotations=annotations, span=span) def call_pure_extern(dtype, func_name, *args, span=None):