From c872bcb546849960f68a806ca48690ec50979130 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 26 Aug 2025 21:30:17 +0800 Subject: [PATCH 1/5] Refactor operator classes to inherit from TileOperator and update layout inference methods - Changed base class of several operator classes (AtomicAdd, Copy, Gemm, etc.) from Operator to TileOperator for better alignment with tile operations. - Updated InferLayout and Lower methods to use 'override' specifier for clarity and consistency. - Adjusted header inclusions to replace "op.h" with "operator.h" across multiple files for improved organization. - Added missing layout inference implementations for Fill and Conv2DIm2ColOp. - Removed deprecated op.cc and op.h files to streamline the codebase. --- src/op/atomic_add.cc | 7 +-- src/op/atomic_add.h | 13 ++--- src/op/builtin.h | 2 +- src/op/copy.cc | 8 ++- src/op/copy.h | 25 +++++---- src/op/elem.cc | 4 ++ src/op/elem.h | 10 ++-- src/op/gemm.cc | 23 +++++--- src/op/gemm.h | 13 ++--- src/op/gemm_sp.cc | 3 +- src/op/gemm_sp.h | 13 ++--- src/op/op.cc | 87 ------------------------------- src/op/operator.cc | 48 +++++++++++++++++ src/op/{op.h => operator.h} | 44 +++++++--------- src/op/parallel.cc | 9 +++- src/op/parallel.h | 24 +++++---- src/op/reduce.cc | 6 ++- src/op/reduce.h | 20 +++---- src/op/region.cc | 55 +++++++++++++++++++ src/op/region.h | 51 ++++++++++++++++++ src/transform/layout_inference.cc | 11 ++-- src/transform/lower_tile_op.cc | 2 +- 22 files changed, 291 insertions(+), 187 deletions(-) delete mode 100644 src/op/op.cc create mode 100644 src/op/operator.cc rename src/op/{op.h => operator.h} (67%) create mode 100644 src/op/region.cc create mode 100644 src/op/region.h diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index e68cf41db..079971f98 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -4,8 +4,8 @@ * Define elment-wise operators. */ -#include "atomic_add.h" - +#include "./atomic_add.h" +#include "./region.h" #include #include #include @@ -210,7 +210,8 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } -LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (par_op_ == nullptr) { arith::Analyzer analyzer; par_op_ = std::make_unique(MakeSIMTLoop(&analyzer)); diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index b8bb0dd97..684cd4239 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -7,7 +7,7 @@ #ifndef TVM_TL_OP_ATOMIC_ADD_H_ #define TVM_TL_OP_ATOMIC_ADD_H_ -#include "op.h" +#include "operator.h" #include "parallel.h" namespace tvm { @@ -15,11 +15,12 @@ namespace tl { using namespace tir; -class AtomicAdd : public Operator { +class AtomicAdd : public TileOperator { public: AtomicAdd(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); @@ -32,7 +33,7 @@ class AtomicAdd : public Operator { par_op_ = std::unique_ptr( static_cast(other.par_op_->Clone().release())); } - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -53,7 +54,7 @@ class AtomicAdd : public Operator { Array src_range, dst_range; IntImm coalesced_width; - std::unique_ptr par_op_; + mutable std::unique_ptr par_op_; }; } // namespace tl diff --git a/src/op/builtin.h b/src/op/builtin.h index f48cd9851..59dc55901 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -7,7 +7,7 @@ #ifndef TVM_TL_OP_BUILTIN_H_ #define TVM_TL_OP_BUILTIN_H_ -#include "op.h" +#include "operator.h" #include namespace tvm { diff --git a/src/op/copy.cc b/src/op/copy.cc index 908f5f90c..6bd04c773 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -15,6 +15,7 @@ #include "../transform/common/loop_parallel_transform_utils.h" #include "../transform/loop_partition.h" #include "../transform/loop_vectorize.h" +#include "region.h" #include "../target/cuda.h" #include "../target/utils.h" @@ -316,7 +317,7 @@ Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const { * indicating the level of layout inference. \return LayoutMap containing the * inferred layout. */ -LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) const { auto target = T.target; using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); @@ -1228,6 +1229,11 @@ TIR_REGISTER_TL_OP(Copy, copy) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +LayoutMap Conv2DIm2ColOp::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + return {}; +} + // Register the Conv2DIm2Col operation with TVM's TIR system // This operation performs im2col transformation for 2D convolutions using TMA // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, diff --git a/src/op/copy.h b/src/op/copy.h index b4482e206..33581b7d0 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -11,7 +11,7 @@ #ifndef TVM_TL_OP_COPY_H_ #define TVM_TL_OP_COPY_H_ -#include "op.h" +#include "operator.h" #include "parallel.h" namespace tvm { @@ -83,7 +83,7 @@ struct TMAIm2ColDesc { * block-wise or element-wise data transfer, possibly optimized with * parallelization or TMA hardware acceleration. */ -class Copy : public Operator { +class Copy : public TileOperator { public: /*! * \brief Constructor. @@ -97,14 +97,15 @@ class Copy : public Operator { * \param T Arguments for lowering. * \param analyzer Analyzer for simplification and bounds checks. */ - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; /*! * \brief Infer buffer layouts after applying this operator. * \param T Arguments for layout inference. * \param level Level of inference (basic or detailed). */ - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; /*! * \brief Get the TVM Op handle corresponding to this Copy op. @@ -163,7 +164,7 @@ class Copy : public Operator { /*! * \brief Clone this copy operator. */ - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -225,7 +226,7 @@ class Copy : public Operator { IntImm coalesced_width; // Width (in elements) for coalesced memory access Bool disable_tma = Bool(false); // Whether to disable TMA acceleration - std::unique_ptr + mutable std::unique_ptr par_op_; // Optional associated parallelization operator enum class EvictionPolicy { @@ -243,7 +244,7 @@ class Copy : public Operator { * This operator converts input image layout into columnar format suitable * for matrix multiplication-based convolution lowering. */ -class Conv2DIm2ColOp : public Operator { +class Conv2DIm2ColOp : public TileOperator { public: /*! * \brief Constructor. @@ -255,7 +256,13 @@ class Conv2DIm2ColOp : public Operator { /*! * \brief Lower to TIR statement. */ - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + + /*! + * \brief Infer layout for this operator. + */ + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; /*! * \brief Get TVM Op handle. @@ -265,7 +272,7 @@ class Conv2DIm2ColOp : public Operator { /*! * \brief Clone this operator. */ - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } diff --git a/src/op/elem.cc b/src/op/elem.cc index d3d7290ed..228f05d24 100644 --- a/src/op/elem.cc +++ b/src/op/elem.cc @@ -122,6 +122,10 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } +LayoutMap Fill::InferLayout(const LayoutInferArgs &T, InferLevel level) const { + return {}; +} + TIR_REGISTER_TL_OP(Fill, fill) .set_num_inputs(2) .set_attr("TCallEffectKind", diff --git a/src/op/elem.h b/src/op/elem.h index b3d682398..fcb16547f 100644 --- a/src/op/elem.h +++ b/src/op/elem.h @@ -7,7 +7,7 @@ #ifndef TVM_TL_OP_ELEM_H_ #define TVM_TL_OP_ELEM_H_ -#include "op.h" +#include "operator.h" #include "parallel.h" namespace tvm { @@ -15,13 +15,15 @@ namespace tl { using namespace tir; -class Fill : public Operator { +class Fill : public TileOperator { public: Fill(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 065e664e5..6bfc1b733 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -87,10 +87,13 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { * per-warp tile sizes) and adapts the partition according to the configured * GemmWarpPolicy (FullRow, FullCol, Square). * - * @param block_size Total number of threads in the block (used to derive num_warps). + * @param block_size Total number of threads in the block (used to derive + * num_warps). * @param gemm_inst The chosen GEMM implementation (e.g., kWGMMA, kMFMA, kMMA). - * @param target Target device information (used for warp size and target-specific rules). - * @return std::pair {m_warp, n_warp} where m_warp * n_warp == num_warps. + * @param target Target device information (used for warp size and + * target-specific rules). + * @return std::pair {m_warp, n_warp} where m_warp * n_warp == + * num_warps. * * Constraints and behavior: * - Each warp is assumed to cover 16 rows (M) and 8 columns (N). The function @@ -100,7 +103,8 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { * - num_warps must be a multiple of 4 (warp-groups of 4). * - m_warp is always a multiple of 4. * - The warp partition respects the GemmWarpPolicy: - * - FullRow: maximize warps on M (in multiples of 4) while keeping divisibility. + * - FullRow: maximize warps on M (in multiples of 4) while keeping + * divisibility. * - FullCol: maximize warps on N, but if N is not evenly divisible, move * whole warp-groups to M to achieve feasibility. * - Square: choose a multiple-of-4 m_warp that best balances per-warp work @@ -296,14 +300,16 @@ std::pair Gemm::ComputeWarpPartition(int block_size, * Supported combinations and constraints: * - C=float16: * - A=float16, B=float16: K % 16 == 0 - * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % 32 == 0 + * - Various float8 mixes (e4m3/e5m2): require (!trans_A && trans_B) and K % + * 32 == 0 * - C=float32: * - A=float16, B=float16: K % 16 == 0 * - A=bfloat16, B=bfloat16: K % 16 == 0 * - A=float32, B=float32: require (!trans_A && trans_B) and K % 8 == 0 * - Various float8 mixes: require (!trans_A && trans_B) and K % 32 == 0 * - C=int32: - * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) and K % 32 == 0 + * - 8-bit integer combinations (Int8/UInt8): require (!trans_A && trans_B) + * and K % 32 == 0 * * @return true if WGMMA is supported for the current buffers, dtypes, and * transpose/shape constraints; false otherwise. @@ -425,7 +431,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { * - C.scope() must be "local.fragment". * * Postconditions / side effects: - * - Marks the operator's layout inference as completed (sets completed_ = true). + * - Marks the operator's layout inference as completed (sets completed_ = + * true). * - May abort via ICHECK on unsupported targets, invalid buffer scopes, or * incompatible shape constraints. * @@ -433,7 +440,7 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { * @param level Inference level (unused for side effects but retained for API). * @return LayoutMap mapping each of A, B, and C to their inferred layouts. */ -LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) const { if (completed_) return {}; LayoutMap results; diff --git a/src/op/gemm.h b/src/op/gemm.h index 55e42b771..3cb3cf0d5 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -7,18 +7,19 @@ #ifndef TVM_TL_OP_GEMM_H_ #define TVM_TL_OP_GEMM_H_ -#include "op.h" +#include "operator.h" namespace tvm { namespace tl { using namespace tir; -class Gemm : public Operator { +class Gemm : public TileOperator { public: Gemm(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); enum class GemmWarpPolicy { kSquare = 0, @@ -26,7 +27,7 @@ class Gemm : public Operator { kFullCol = 2, } policy; - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -52,7 +53,7 @@ class Gemm : public Operator { // only will be enabled under cdna mfma instructions int kPack = 1; int wg_wait = 0; - bool completed_ = false; + mutable bool completed_ = false; }; } // namespace tl diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index 9405c8631..b642a8cbe 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -256,7 +256,8 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(new_call); } -LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (completed_) return {}; LayoutMap results; diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index 4488e4612..9a14f17e9 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -7,18 +7,19 @@ #ifndef TVM_TL_OP_GEMM_SP_H_ #define TVM_TL_OP_GEMM_SP_H_ -#include "op.h" +#include "operator.h" namespace tvm { namespace tl { using namespace tir; -class GemmSP : public Operator { +class GemmSP : public TileOperator { public: GemmSP(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); enum class GemmWarpPolicy { kSquare = 0, @@ -26,7 +27,7 @@ class GemmSP : public Operator { kFullCol = 2, } policy; - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -44,7 +45,7 @@ class GemmSP : public Operator { // only will be enabled under cdna mfma instructions int kPack = 1; int wg_wait = 0; - bool completed_ = false; + mutable bool completed_ = false; }; } // namespace tl diff --git a/src/op/op.cc b/src/op/op.cc deleted file mode 100644 index 69cd59227..000000000 --- a/src/op/op.cc +++ /dev/null @@ -1,87 +0,0 @@ -/*! - * \file tl/op/op.cc - * - * Define operators usd in tile library. - */ - -#include "op.h" - -#include -#include -#include - -namespace tvm { -namespace tl { - -using namespace tir; - -TIR_REGISTER_TL_OP(RegionOp, region) - .set_num_inputs(-1) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); - -std::unique_ptr ParseOperator(Call call, BufferMap vmap) { - auto op_map = Op::GetAttrMap("TLOpBuilder"); - Op op = call->op.as().value(); - if (op_map.count(op)) { - Operator *ptr = static_cast(op_map[op](call->args, vmap)); - ICHECK(ptr != nullptr); - return std::unique_ptr(ptr); - } - return nullptr; -} - -std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap) { - if (stmt.as() && stmt.as()->value.as()) { - auto call = stmt.as()->value.as(); - return ParseOperator(GetRef(call), vmap); - } - return nullptr; -} - -Var GetVarFromAccessPtr(const PrimExpr &expr) { - auto call = expr.as(); - ICHECK(call); - ICHECK(call->op.same_as(builtin::tvm_access_ptr())); - auto var = call->args[1].as(); - ICHECK(var); - return GetRef(var); -} - -RegionOp::RegionOp(Array args, BufferMap vmap) { - size_t n = args.size(); - size_t ndim = n - 2; - auto load = args[0].as(); - ICHECK(load); - ICHECK(load->indices.size() == ndim) - << "load->indices.size() = " << load->indices << " ndim = " << ndim; - buffer_ = load->buffer; - access_mask_ = static_cast(*as_const_int(args[1])); - for (size_t i = 0; i < ndim; i++) { - PrimExpr min = load->indices[i]; - PrimExpr extent = args[2 + i]; - ranges_.push_back(Range::FromMinExtent(min, extent)); - } -} - -bool RegionOp::IsFullRegion() const { - for (size_t i = 0; i < ranges_.size(); i++) { - if (!is_zero(ranges_[i]->min)) - return false; - if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) - return false; - } - return true; -} - -Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - ICHECK(0) << "Not Implemented Lower method."; - return Evaluate(0); -} - -LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) { - return {}; -} - -} // namespace tl -} // namespace tvm diff --git a/src/op/operator.cc b/src/op/operator.cc new file mode 100644 index 000000000..80be1589c --- /dev/null +++ b/src/op/operator.cc @@ -0,0 +1,48 @@ +/*! + * \file tl/op/op.cc + * + * Define operators usd in tile library. + */ + +#include "operator.h" + +#include +#include +#include + +namespace tvm { +namespace tl { + +using namespace tir; + + +std::unique_ptr ParseOperator(Call call, BufferMap vmap) { + auto op_map = Op::GetAttrMap("TLOpBuilder"); + Op op = call->op.as().value(); + if (op_map.count(op)) { + TileOperator *ptr = static_cast(op_map[op](call->args, vmap)); + ICHECK(ptr != nullptr); + return std::unique_ptr(ptr); + } + return nullptr; +} + +std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap) { + if (stmt.as() && stmt.as()->value.as()) { + auto call = stmt.as()->value.as(); + return ParseOperator(GetRef(call), vmap); + } + return nullptr; +} + +Var GetVarFromAccessPtr(const PrimExpr &expr) { + auto call = expr.as(); + ICHECK(call); + ICHECK(call->op.same_as(builtin::tvm_access_ptr())); + auto var = call->args[1].as(); + ICHECK(var); + return GetRef(var); +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/op.h b/src/op/operator.h similarity index 67% rename from src/op/op.h rename to src/op/operator.h index 1dc21c2bc..306b829fc 100644 --- a/src/op/op.h +++ b/src/op/operator.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include "../layout/layout.h" @@ -58,38 +60,30 @@ struct LayoutInferArgs { Map buffer_remap; }; -class Operator { -public: - virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; - virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level); - virtual ~Operator() = default; - virtual std::unique_ptr Clone() const = 0; -}; - -class RegionOp : public Operator { -public: - RegionOp(Array args, BufferMap vmap); - static const Op &Get(); - - std::unique_ptr Clone() const final { - return std::make_unique(*this); +class TileOperator { + public: + // Lower 接口 + virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { + ICHECK(0) << "Not Implemented Lower method."; + return Evaluate(0); } - const Buffer &GetBuffer() const { return buffer_; } - const Array &GetRanges() const { return ranges_; } - int GetAccessMask() const { return access_mask_; } - bool IsFullRegion() const; + // InferLayout 接口 + virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const { + return {}; + } -private: - Buffer buffer_; - Array ranges_; - int access_mask_; + // Clone 接口 + virtual std::unique_ptr Clone() const = 0; + + // 虚析构函数 + virtual ~TileOperator() = default; }; Var GetVarFromAccessPtr(const PrimExpr &expr); -std::unique_ptr ParseOperator(Call call, BufferMap vmap); -std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); +std::unique_ptr ParseOperator(Call call, BufferMap vmap); +std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); } // namespace tl } // namespace tvm diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 33ceb7de8..a2d622f58 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -156,6 +156,10 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } +Stmt ParallelOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + return root_; +} + bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); return StructuralEqual()(indice_map_[buffer], common_indice); @@ -179,7 +183,8 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { * Can generate new layouts based on vectorization and thread * bounds. Used when maximum performance optimization is desired. */ -LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (loop_layout_.defined()) return {}; if (level == InferLevel::kStrict) @@ -363,7 +368,7 @@ Optional ParallelOp::GetPredicate(Var thread_var) const { } } -Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) { +Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) const { ICHECK(loop_layout_.defined()); if (IsCommonAccessIndice(buffer)) { return loop_layout_; diff --git a/src/op/parallel.h b/src/op/parallel.h index fd49acfe9..5a80b8b7a 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -10,7 +10,7 @@ #include #include "../layout/layout.h" -#include "op.h" +#include "operator.h" namespace tvm { namespace tl { @@ -36,25 +36,27 @@ class ParallelOp; class ParallelLoopNestVisitor : public StmtExprVisitor { private: ParallelLoopNestVisitor(ParallelOp *op) : p(op){}; - void VisitStmt_(const ForNode *op) final; - void VisitStmt_(const BufferStoreNode *op) final; - void VisitExpr_(const BufferLoadNode *op) final; + void VisitStmt_(const ForNode *op) override; + void VisitStmt_(const BufferStoreNode *op) override; + void VisitExpr_(const BufferLoadNode *op) override; ParallelOp *p; friend class ParallelOp; }; -class ParallelOp : public Operator { +class ParallelOp : public TileOperator { public: ParallelOp(For root); - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) { loop_layout_ = other.loop_layout_; predicate_ = other.predicate_; } - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -64,9 +66,9 @@ class ParallelOp : public Operator { Optional GetPredicate(Var thread_var) const; private: - Fragment CompleteBufferFragment(const Buffer &buffer); + Fragment CompleteBufferFragment(const Buffer &buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const; - void AddPredicate(PrimExpr expr) { + void AddPredicate(PrimExpr expr) const { predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; } @@ -78,9 +80,9 @@ class ParallelOp : public Operator { std::unordered_set buffer_is_write_; Array loop_vars_; - Fragment loop_layout_; + mutable Fragment loop_layout_; mutable arith::Analyzer analyzer_; - Optional predicate_; + mutable Optional predicate_; friend class ParallelLoopNestVisitor; }; diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 79ce193ba..a8fa09e5f 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -284,7 +284,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return body; } -LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (level >= InferLevel::kStrict) return {}; if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && @@ -402,7 +403,8 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Stmt(); } -LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { +LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } diff --git a/src/op/reduce.h b/src/op/reduce.h index 64954ea43..5303cb147 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -7,21 +7,22 @@ #ifndef TVM_TL_OP_REDUCE_H_ #define TVM_TL_OP_REDUCE_H_ -#include "op.h" +#include "operator.h" namespace tvm { namespace tl { using namespace tir; -class ReduceOp : public Operator { +class ReduceOp : public TileOperator { public: ReduceOp(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } @@ -42,14 +43,15 @@ class ReduceOp : public Operator { std::string MakeCodegenReducer() const; }; -class CumSumOp : public Operator { +class CumSumOp : public TileOperator { public: CumSumOp(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); - std::unique_ptr Clone() const final { + std::unique_ptr Clone() const override { return std::make_unique(*this); } diff --git a/src/op/region.cc b/src/op/region.cc new file mode 100644 index 000000000..9f899cd34 --- /dev/null +++ b/src/op/region.cc @@ -0,0 +1,55 @@ +/*! + * \file tl/op/region.cc + * \brief Define region operator. + * + */ + +#include "region.h" +#include + +namespace tvm { +namespace tl { +using namespace tir; + +TIR_REGISTER_TL_OP(RegionOp, region) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + + +RegionOp::RegionOp(Array args, BufferMap vmap) { + size_t n = args.size(); + size_t ndim = n - 2; + auto load = args[0].as(); + ICHECK(load); + ICHECK(load->indices.size() == ndim) + << "load->indices.size() = " << load->indices << " ndim = " << ndim; + buffer_ = load->buffer; + access_mask_ = static_cast(*as_const_int(args[1])); + for (size_t i = 0; i < ndim; i++) { + PrimExpr min = load->indices[i]; + PrimExpr extent = args[2 + i]; + ranges_.push_back(Range::FromMinExtent(min, extent)); + } +} + +bool RegionOp::IsFullRegion() const { + for (size_t i = 0; i < ranges_.size(); i++) { + if (!is_zero(ranges_[i]->min)) + return false; + if (!StructuralEqual()(ranges_[i]->extent, buffer_->shape[i])) + return false; + } + return true; +} + +Stmt RegionOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + return Evaluate(0); +} + +LayoutMap RegionOp::InferLayout(const LayoutInferArgs &T, InferLevel level) const { + return {}; +} + +} // namespace tl +} // namespace tvm diff --git a/src/op/region.h b/src/op/region.h new file mode 100644 index 000000000..59e86ea91 --- /dev/null +++ b/src/op/region.h @@ -0,0 +1,51 @@ +/*! + * \file tl/op/op.h + * \brief Tile library operations. + * + */ + +#ifndef TVM_TL_OP_REGION_H_ +#define TVM_TL_OP_REGION_H_ + +#include +#include +#include +#include +#include "./operator.h" + +namespace tvm { +namespace tl { + +using namespace tir; + +class RegionOp : public TileOperator { +public: + RegionOp(Array args, BufferMap vmap); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + static const Op &Get(); + + std::unique_ptr Clone() const override { + return std::make_unique(*this); + } + + const Buffer &GetBuffer() const { return buffer_; } + const Array &GetRanges() const { return ranges_; } + int GetAccessMask() const { return access_mask_; } + bool IsFullRegion() const; + +private: + Buffer buffer_; + Array ranges_; + int access_mask_; +}; + +Var GetVarFromAccessPtr(const PrimExpr &expr); + +std::unique_ptr ParseOperator(Call call, BufferMap vmap); +std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); + +} // namespace tl +} // namespace tvm + +#endif // TVM_TL_OP_REGION_H_ diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index fdbe6b861..138240e83 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -15,6 +15,7 @@ #include "../layout/utils.h" #include "../op/parallel.h" +#include "../op/region.h" #include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h" #include "common/loop_fusion_utils.h" @@ -112,7 +113,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { level != InferLevel::kStrict && !strict_layout_map.count(buffer)) { // Actually this test has been done in ParallelOp::InferLayout // already. Just do it again to avoid missing implementations in other - // `Operator`s. + // `TileOperator`s. auto dst_layout = layout.as().value(); auto src_layout = layout_map[buffer].as().value(); ICHECK(dst_layout->InputDim() == src_layout->InputDim()); @@ -253,7 +254,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK(infer_list_.size() == thread_var_vec_.size()) << "infer_list_ and thread_var_vec_ size mismatch"; for (int i = 0; i < infer_list_.size(); i++) { - std::unique_ptr base_infer = std::move(infer_list_[i]); + std::unique_ptr base_infer = std::move(infer_list_[i]); auto thread_var = thread_var_vec_[i]; // Check if base_infer is valid @@ -399,7 +400,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { Map buffer_data_to_buffer_; std::vector infer_list_stmt_; - std::vector> infer_list_; + std::vector> infer_list_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> use_list_; // This is a workaround for cpu backend, @@ -412,8 +413,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { LayoutMap annotated_layout_map_; bool skip_thread_partition_{false}; - std::vector> BackupInferList() { - std::vector> back_infer_list; + std::vector> BackupInferList() { + std::vector> back_infer_list; back_infer_list.reserve(infer_list_.size()); for (auto &&p : infer_list_) { back_infer_list.push_back(p->Clone()); diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 76da0ff61..1e4ab6cdc 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -12,7 +12,7 @@ #include "../layout/layout.h" #include "../layout/utils.h" #include "../op/builtin.h" -#include "../op/op.h" +#include "../op/operator.h" #include "arith/ir_mutator_with_analyzer.h" #include "loop_partition.h" From 2f1c5ea5f86ae73948c7af3e636a55b9a3f04f60 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 26 Aug 2025 21:30:32 +0800 Subject: [PATCH 2/5] lint fix --- src/op/operator.cc | 4 ++-- src/op/region.cc | 4 ++-- src/op/region.h | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/op/operator.cc b/src/op/operator.cc index 80be1589c..ecdc01ea3 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -15,12 +15,12 @@ namespace tl { using namespace tir; - std::unique_ptr ParseOperator(Call call, BufferMap vmap) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); if (op_map.count(op)) { - TileOperator *ptr = static_cast(op_map[op](call->args, vmap)); + TileOperator *ptr = + static_cast(op_map[op](call->args, vmap)); ICHECK(ptr != nullptr); return std::unique_ptr(ptr); } diff --git a/src/op/region.cc b/src/op/region.cc index 9f899cd34..6d61c78b6 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -16,7 +16,6 @@ TIR_REGISTER_TL_OP(RegionOp, region) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); - RegionOp::RegionOp(Array args, BufferMap vmap) { size_t n = args.size(); size_t ndim = n - 2; @@ -47,7 +46,8 @@ Stmt RegionOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(0); } -LayoutMap RegionOp::InferLayout(const LayoutInferArgs &T, InferLevel level) const { +LayoutMap RegionOp::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } diff --git a/src/op/region.h b/src/op/region.h index 59e86ea91..9d62b2641 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -7,11 +7,11 @@ #ifndef TVM_TL_OP_REGION_H_ #define TVM_TL_OP_REGION_H_ +#include "./operator.h" #include #include #include #include -#include "./operator.h" namespace tvm { namespace tl { @@ -22,7 +22,8 @@ class RegionOp : public TileOperator { public: RegionOp(Array args, BufferMap vmap); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; static const Op &Get(); std::unique_ptr Clone() const override { From 75abbfc221e43ec4ac45d6131273e97943668f29 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 27 Aug 2025 18:33:05 +0800 Subject: [PATCH 3/5] Refactor operator classes to use Node pattern and improve memory management - Updated several operator classes (AtomicAdd, Copy, Gemm, etc.) to utilize the Node pattern for better memory management and encapsulation. - Changed constructors to initialize member variables through a node object, enhancing clarity and reducing direct member access. - Updated Clone methods to return TileOperator instances instead of unique pointers, aligning with the new design. - Refactored InferLayout and Lower methods to ensure consistency across operator implementations. - Adjusted header files to reflect the new class structure and removed deprecated code for a cleaner codebase. --- src/op/atomic_add.cc | 53 ++++++----- src/op/atomic_add.h | 45 ++++----- src/op/copy.cc | 105 ++++++++++++--------- src/op/copy.h | 148 ++++++++++++++---------------- src/op/elem.cc | 60 +++++++----- src/op/elem.h | 28 +++--- src/op/gemm.cc | 64 +++++++------ src/op/gemm.h | 57 +++++++----- src/op/gemm_sp.cc | 47 ++++++---- src/op/gemm_sp.h | 25 ++--- src/op/operator.cc | 15 ++- src/op/operator.h | 69 +++++++------- src/op/parallel.cc | 22 +++-- src/op/parallel.h | 32 +++++-- src/op/reduce.cc | 62 ++++++++----- src/op/reduce.h | 66 +++++++------ src/op/region.cc | 33 ++++--- src/op/region.h | 29 +++--- src/transform/layout_inference.cc | 24 ++--- src/transform/lower_tile_op.cc | 2 +- 20 files changed, 548 insertions(+), 438 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 079971f98..a2d622a11 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -34,7 +34,8 @@ static int GetArchInt(Target target) { return arch_int; } -AtomicAdd::AtomicAdd(Array args, BufferMap vmap) : args_(args) { +AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { @@ -42,17 +43,23 @@ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) : args_(args) { auto call = expr.as(); ICHECK(call); auto region = RegionOp(call->args, vmap); - rgs[i] = region.GetRanges(); - bf[i] = region.GetBuffer(); + rgs[i] = region->GetRanges(); + bf[i] = region->GetBuffer(); } - std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); - std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); if (args.size() >= 3) { - coalesced_width = Downcast(args[2]); + node->coalesced_width = Downcast(args[2]); } + data_ = std::move(node); } -Array AtomicAdd::MakeIterVars() const { +TileOperator AtomicAddNode::Clone() const { + auto op = make_object(*this); + return AtomicAdd(op); +} + +Array AtomicAddNode::MakeIterVars() const { Array loop_vars; size_t idx = 0; for (size_t i = 0; i < src_range.size(); i++) { @@ -68,8 +75,8 @@ Array AtomicAdd::MakeIterVars() const { // ivs: itervars returned by MakeIterVars() // src_dst: 0 for src_indices, 1 for dst_indices -Array AtomicAdd::MakeIndices(const Array &ivs, - int src_dst) const { +Array AtomicAddNode::MakeIndices(const Array &ivs, + int src_dst) const { Array indices; Array ranges = src_dst == 0 ? src_range : dst_range; size_t idx = 0; @@ -87,9 +94,10 @@ Array AtomicAdd::MakeIndices(const Array &ivs, return indices; } -PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer, - const Array &ivs, - Array extents, int src_dst) const { +PrimExpr AtomicAddNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, + int src_dst) const { Array ranges = src_dst == 0 ? src_range : dst_range; Array cond_list; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; @@ -117,7 +125,7 @@ PrimExpr AtomicAdd::MakePredicate(arith::Analyzer *analyzer, } } -For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const { +For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.size() == 0; if (is_scalar) { @@ -180,16 +188,16 @@ For AtomicAdd::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } -Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); - auto par_op = std::make_unique(fused_loop); + auto par_op = ParallelOp(fused_loop); std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; for (auto level : levels) { - par_op->InferLayout( + (par_op)->InferLayout( {T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level); } auto loop_layout = par_op->GetLoopLayout(); @@ -210,11 +218,11 @@ Stmt AtomicAdd::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } -LayoutMap AtomicAdd::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { - if (par_op_ == nullptr) { +LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { + if (!par_op_.defined()) { arith::Analyzer analyzer; - par_op_ = std::make_unique(MakeSIMTLoop(&analyzer)); + par_op_ = ParallelOp(MakeSIMTLoop(&analyzer)); } if (T.layout_map.count(src) && T.layout_map.count(dst)) { if (src.scope() == "local.fragment" && dst.scope() == "local.fragment") { @@ -237,10 +245,5 @@ TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -// TVM_REGISTER_OP("tl.atomicadd") -// .set_num_inputs(2) -// .add_argument("ref", "Buffer", "The destination buffer") -// .add_argument("val", "Expr", "The value to be added atomically"); - } // namespace tl } // namespace tvm \ No newline at end of file diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 684cd4239..678d62e55 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -15,27 +15,23 @@ namespace tl { using namespace tir; -class AtomicAdd : public TileOperator { +class AtomicAddNode : public TileOperatorNode { public: - AtomicAdd(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; + Array args_; - static const Op &Get(); + Buffer src, dst; + Array src_range, dst_range; + IntImm coalesced_width; + + mutable ParallelOp par_op_; + static constexpr const char *_type_key = "tl.AtomicAdd"; + TVM_DECLARE_FINAL_OBJECT_INFO(AtomicAddNode, TileOperatorNode); - AtomicAdd(const AtomicAdd &other) - : args_(other.args_), src(other.src), dst(other.dst), - src_range(other.src_range), dst_range(other.dst_range), - coalesced_width(other.coalesced_width) { - // No clone nullptr - if (other.par_op_) - par_op_ = std::unique_ptr( - static_cast(other.par_op_->Clone().release())); - } - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; + + static const Op &Get(); + TileOperator Clone() const; protected: For MakeSIMTLoop(arith::Analyzer *analyzer) const; @@ -47,14 +43,13 @@ class AtomicAdd : public TileOperator { PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const; +}; - Array args_; - - Buffer src, dst; - Array src_range, dst_range; - IntImm coalesced_width; - - mutable std::unique_ptr par_op_; +class AtomicAdd : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(AtomicAdd, TileOperator, AtomicAddNode); + TVM_DLL AtomicAdd(Array args, BufferMap vmap); + static const Op &Get(); }; } // namespace tl diff --git a/src/op/copy.cc b/src/op/copy.cc index 6bd04c773..cf6a2778d 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -112,7 +112,8 @@ template static Array ReverseArray(Array array) { * operation. \param vmap BufferMap mapping original buffer names to new buffer * names. */ -Copy::Copy(Array args, BufferMap vmap) : args_(args) { +Copy::Copy(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); Array rgs[2]; Buffer bf[2]; for (int i = 0; i < 2; i++) { @@ -120,23 +121,29 @@ Copy::Copy(Array args, BufferMap vmap) : args_(args) { auto call = expr.as(); ICHECK(call); auto region = RegionOp(call->args, vmap); - rgs[i] = region.GetRanges(); - bf[i] = region.GetBuffer(); + rgs[i] = region->GetRanges(); + bf[i] = region->GetBuffer(); } - std::tie(this->src, this->dst) = std::tie(bf[0], bf[1]); - std::tie(this->src_range, this->dst_range) = std::tie(rgs[0], rgs[1]); + std::tie(node->src, node->dst) = std::tie(bf[0], bf[1]); + std::tie(node->src_range, node->dst_range) = std::tie(rgs[0], rgs[1]); if (args.size() >= 3) { auto coalesced_width = Downcast(args[2]); if (coalesced_width->value > 0) { - this->coalesced_width = coalesced_width; + node->coalesced_width = coalesced_width; } } if (args.size() >= 4) { - this->disable_tma = Downcast(args[3]); + node->disable_tma = Downcast(args[3]); } if (args.size() >= 5) { - this->eviction_policy = args[4].as()->value; + node->eviction_policy = args[4].as()->value; } + data_ = std::move(node); +} + +TileOperator CopyNode::Clone() const { + auto op = make_object(*this); + return Copy(op); } /*! @@ -145,7 +152,7 @@ Copy::Copy(Array args, BufferMap vmap) : args_(args) { * > 1. \return Array of IterVar representing the iterator variables for the * copy operation. */ -Array Copy::MakeIterVars() const { +Array CopyNode::MakeIterVars() const { Array loop_vars; size_t idx = 0; for (size_t i = 0; i < src_range.size(); i++) { @@ -168,8 +175,8 @@ Array Copy::MakeIterVars() const { * dst_indices. \return Array of PrimExpr representing the indices for the copy * operation. */ -Array Copy::MakeIndices(const Array &ivs, - int src_dst) const { +Array CopyNode::MakeIndices(const Array &ivs, + int src_dst) const { Array indices; Array ranges = src_dst == 0 ? src_range : dst_range; size_t idx = 0; @@ -196,9 +203,9 @@ Array Copy::MakeIndices(const Array &ivs, * of the copy operation. \param src_dst 0 for src_indices, 1 for dst_indices. * \return PrimExpr representing the predicate for the copy operation. */ -PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer, - const Array &ivs, Array extents, - int src_dst) const { +PrimExpr CopyNode::MakePredicate(arith::Analyzer *analyzer, + const Array &ivs, + Array extents, int src_dst) const { Array ranges = src_dst == 0 ? src_range : dst_range; Array cond_list; ICHECK(extents.size() == ranges.size()) << extents << " " << ranges; @@ -234,7 +241,7 @@ PrimExpr Copy::MakePredicate(arith::Analyzer *analyzer, * simplification. \return For representing the SIMT loop for the copy * operation. */ -For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const { +For CopyNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { Array loop_vars = MakeIterVars(); bool is_scalar = loop_vars.size() == 0; if (is_scalar) { @@ -290,7 +297,7 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const { * shared tensor. \return Layout representing the linear layout for the TMA * copy. */ -Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const { +Layout CopyNode::ComputeLinearLayout(const Buffer &shared_tensor) const { Array input_size = shared_tensor->shape; Array forward_vars; for (size_t i = 0; i < input_size.size(); i++) { @@ -317,7 +324,8 @@ Layout Copy::ComputeLinearLayout(const Buffer &shared_tensor) const { * indicating the level of layout inference. \return LayoutMap containing the * inferred layout. */ -LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) const { +LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { auto target = T.target; using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); @@ -345,9 +353,9 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) const { // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout - if (!par_op_) { + if (!par_op_.defined()) { arith::Analyzer analyzer; - par_op_ = std::make_unique(MakeSIMTLoop(&analyzer)); + par_op_ = ParallelOp((MakeSIMTLoop(&analyzer))); } return par_op_->InferLayout(T, level); } @@ -360,7 +368,7 @@ LayoutMap Copy::InferLayout(const LayoutInferArgs &T, InferLevel level) const { * same data type. \param target Target device. \return True if the copy * operation is a bulk load, false otherwise. */ -bool Copy::CheckBulkLoad(Target target) const { +bool CopyNode::CheckBulkLoad(Target target) const { // 1. arch must have bulk copy support if (!TargetHasBulkCopy(target)) return false; @@ -388,7 +396,7 @@ bool Copy::CheckBulkLoad(Target target) const { * same data type. \param target Target device. \return True if the copy * operation is a bulk store, false otherwise. */ -bool Copy::CheckBulkStore(Target target) const { +bool CopyNode::CheckBulkStore(Target target) const { // 1. arch must have bulk copy support if (!TargetHasBulkCopy(target)) return false; @@ -416,7 +424,7 @@ bool Copy::CheckBulkStore(Target target) const { * Target device. \return True if the copy operation is a LDSM copy, false * otherwise. */ -bool Copy::CheckLDSMCopy(Target target) const { +bool CopyNode::CheckLDSMCopy(Target target) const { return TargetHasLdmatrix(target) && (src.scope() == "shared.dyn" || src.scope() == "shared") && dst.scope() == "local.fragment"; @@ -430,7 +438,7 @@ bool Copy::CheckLDSMCopy(Target target) const { * Target device. \return True if the copy operation is a STSM copy, false * otherwise. */ -bool Copy::CheckSTSMCopy(Target target) const { +bool CopyNode::CheckSTSMCopy(Target target) const { return TargetHasStmatrix(target) && src.scope() == "local.fragment" && (dst.scope() == "shared.dyn" || dst.scope() == "shared"); } @@ -443,7 +451,7 @@ bool Copy::CheckSTSMCopy(Target target) const { * copy if no specialized instruction is applicable. \param target Target * device. \return CopyInst representing the copy instruction type. */ -Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const { +CopyInst CopyNode::GetCopyInst(Target target, bool disable_tma_lower) const { // disable_tma_lower is from pass_configs // when tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER is True, // we will not use tma for bulk load/store @@ -472,7 +480,7 @@ Copy::CopyInst Copy::GetCopyInst(Target target, bool disable_tma_lower) const { * \param analyzer Arithmetic analyzer for simplification. * \return Stmt representing the PTX code for the copy operation. */ -Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt CopyNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; using namespace tvm::transform; PassContext pass_ctx = PassContext::Current(); @@ -503,8 +511,8 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { * map. \param analyzer Arithmetic analyzer for simplification. \return Stmt * representing the normal copy code. */ -Stmt Copy::LowerNormalCopy(const LowerArgs &T, - arith::Analyzer *analyzer) const { +Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, + arith::Analyzer *analyzer) const { bool is_cpu_target = T.target->GetTargetDeviceType() == kDLCPU; auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); @@ -513,7 +521,7 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T, Downcast(ParallelLoopTransformer::Substitute(fused_loop)); For vectorized_thread_loop; - auto par_op = std::make_unique(transformed_loop); + auto par_op = ParallelOp(transformed_loop); if (is_cpu_target) { vectorized_thread_loop = VectorizeLoop(transformed_loop); @@ -549,8 +557,8 @@ Stmt Copy::LowerNormalCopy(const LowerArgs &T, * \param copy_inst CopyInst representing the copy instruction type. * \return Stmt representing the LDSM/STSM copy code. */ -Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, - CopyInst copy_inst) const { +Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { ICHECK(copy_inst == CopyInst::kLDSM || copy_inst == CopyInst::kSTSM) << "Invalid copy inst " << static_cast(copy_inst); bool is_ldmatrix = copy_inst == CopyInst::kLDSM; @@ -742,8 +750,8 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer, * copy_inst CopyInst representing the copy instruction type. \return Stmt * representing the bulk copy code. */ -Stmt Copy::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, - CopyInst copy_inst) const { +Stmt CopyNode::LowerBulkCopy(const LowerArgs &T, arith::Analyzer *analyzer, + CopyInst copy_inst) const { ICHECK(copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkStore) << "Invalid copy inst " << static_cast(copy_inst); bool is_load = copy_inst == CopyInst::kBulkLoad; @@ -1040,15 +1048,22 @@ Array TMADesc::EncodeCallArgs() const { * buffer names to new buffer names. */ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { - src = vmap[GetVarFromAccessPtr(args[0])]; - dst = vmap[GetVarFromAccessPtr(args[1])]; - nhw_step = args[2]; - c_step = args[3]; - kernel = args[4].as().value()->value; - stride = args[5].as().value()->value; - dilation = args[6].as().value()->value; - padding = args[7].as().value()->value; - eviction_policy = args[8].as().value()->value; + ObjectPtr node = make_object(); + node->src = vmap[GetVarFromAccessPtr(args[0])]; + node->dst = vmap[GetVarFromAccessPtr(args[1])]; + node->nhw_step = args[2]; + node->c_step = args[3]; + node->kernel = args[4].as().value()->value; + node->stride = args[5].as().value()->value; + node->dilation = args[6].as().value()->value; + node->padding = args[7].as().value()->value; + node->eviction_policy = args[8].as().value()->value; + data_ = std::move(node); +} + +TileOperator Conv2DIm2ColOpNode::Clone() const { + auto op = make_object(*this); + return Conv2DIm2ColOp(op); } /*! @@ -1061,8 +1076,8 @@ Conv2DIm2ColOp::Conv2DIm2ColOp(Array args, BufferMap vmap) { * \param analyzer Arithmetic analyzer for simplification. * \return Stmt representing the PTX code for the Conv2DIm2ColOp. */ -Stmt Conv2DIm2ColOp::Lower(const LowerArgs &T, - arith::Analyzer *analyzer) const { +Stmt Conv2DIm2ColOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { ICHECK(TargetIsHopper(T.target)); ICHECK(src.scope() == "global" && (dst.scope() == "shared.dyn" || dst.scope() == "shared")); @@ -1229,8 +1244,8 @@ TIR_REGISTER_TL_OP(Copy, copy) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -LayoutMap Conv2DIm2ColOp::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } diff --git a/src/op/copy.h b/src/op/copy.h index 33581b7d0..2b9f2d855 100644 --- a/src/op/copy.h +++ b/src/op/copy.h @@ -18,6 +18,17 @@ namespace tvm { namespace tl { using namespace tir; +/*! + * \brief Copy instruction type. + */ +enum class CopyInst { + kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy + kLDSM = 1, // ldmatrix memory copy + kSTSM = 2, // stmatrix memory copy + kBulkLoad = 3, // utilize tma load + kBulkStore = 4, // utilize tma store +}; + /*! * \brief Descriptor for Tensor Memory Access (TMA) copy operations. * @@ -83,14 +94,26 @@ struct TMAIm2ColDesc { * block-wise or element-wise data transfer, possibly optimized with * parallelization or TMA hardware acceleration. */ -class Copy : public TileOperator { +class CopyNode : public TileOperatorNode { public: - /*! - * \brief Constructor. - * \param args Expression arguments for the copy. - * \param vmap Buffer variable mapping. - */ - Copy(Array args, BufferMap vmap); + Array args_; // Copy parameters (indices, sizes, etc.) + + Buffer src, dst; // Source and destination buffers + Array src_range, dst_range; // Ranges for each dimension in src and dst + IntImm coalesced_width; // Width (in elements) for coalesced memory access + Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + + mutable ParallelOp par_op_; // Optional associated parallelization operator + + enum class EvictionPolicy { + kEvictNormal = 0, + kEvictFirst = 1, + kEvictLast = 2, + }; + + int eviction_policy; // Policy for cache eviction + static constexpr const char *_type_key = "tl.Copy"; + TVM_DECLARE_FINAL_OBJECT_INFO(CopyNode, TileOperatorNode); /*! * \brief Lower the copy operator to a TIR statement. @@ -104,24 +127,7 @@ class Copy : public TileOperator { * \param T Arguments for layout inference. * \param level Level of inference (basic or detailed). */ - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - - /*! - * \brief Get the TVM Op handle corresponding to this Copy op. - */ - static const Op &Get(); - - /*! - * \brief Copy instruction type. - */ - enum class CopyInst { - kNormal = 0, // utilize ldg/stg or cpasync or any buffer copy - kLDSM = 1, // ldmatrix memory copy - kSTSM = 2, // stmatrix memory copy - kBulkLoad = 3, // utilize tma load - kBulkStore = 4, // utilize tma store - }; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; /*! * \brief Check if bulk copy is supported. @@ -148,26 +154,9 @@ class Copy : public TileOperator { */ CopyInst GetCopyInst(Target target, bool disable_tma_lower) const; - /*! - * \brief Copy constructor (deep clones ParallelOp if present). - */ - Copy(const Copy &other) - : args_(other.args_), src(other.src), dst(other.dst), - src_range(other.src_range), dst_range(other.dst_range), - coalesced_width(other.coalesced_width), disable_tma(other.disable_tma) { - // Deep copy ParallelOp if it exists - if (other.par_op_) - par_op_ = std::unique_ptr( - static_cast(other.par_op_->Clone().release())); - } - /*! * \brief Clone this copy operator. */ - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } - protected: /*! * \brief Generate lowering for bulk/global-to-shared copy. @@ -219,23 +208,24 @@ class Copy : public TileOperator { PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const; - Array args_; // Copy parameters (indices, sizes, etc.) - - Buffer src, dst; // Source and destination buffers - Array src_range, dst_range; // Ranges for each dimension in src and dst - IntImm coalesced_width; // Width (in elements) for coalesced memory access - Bool disable_tma = Bool(false); // Whether to disable TMA acceleration + TileOperator Clone() const; +}; - mutable std::unique_ptr - par_op_; // Optional associated parallelization operator +class Copy : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(Copy, TileOperator, CopyNode); - enum class EvictionPolicy { - kEvictNormal = 0, - kEvictFirst = 1, - kEvictLast = 2, - }; + /*! + * \brief Constructor. + * \param args Expression arguments for the copy. + * \param vmap Buffer variable mapping. + */ + TVM_DLL Copy(Array args, BufferMap vmap); - int eviction_policy; // Policy for cache eviction + /*! + * \brief Get the TVM Op handle corresponding to this Copy op. + */ + static const Op &Get(); }; /*! @@ -244,14 +234,19 @@ class Copy : public TileOperator { * This operator converts input image layout into columnar format suitable * for matrix multiplication-based convolution lowering. */ -class Conv2DIm2ColOp : public TileOperator { +class Conv2DIm2ColOpNode : public TileOperatorNode { public: - /*! - * \brief Constructor. - * \param args Op arguments (convolution parameters, shapes, etc.) - * \param vmap Variable buffer mapping. - */ - Conv2DIm2ColOp(Array args, BufferMap vmap); + Buffer src, dst; // Source (input feature map) and destination (im2col matrix) + int stride; // Stride for convolution + int padding; // Padding amount + int dilation; // Dilation factor + int kernel; // Kernel size + int eviction_policy; // Cache eviction policy + PrimExpr nhw_step; // Step size in NHW dimensions + PrimExpr c_step; // Step size in channel dimension + + static constexpr const char *_type_key = "tl.Conv2DIm2Col"; + TVM_DECLARE_FINAL_OBJECT_INFO(Conv2DIm2ColOpNode, TileOperatorNode); /*! * \brief Lower to TIR statement. @@ -261,30 +256,21 @@ class Conv2DIm2ColOp : public TileOperator { /*! * \brief Infer layout for this operator. */ - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; /*! * \brief Get TVM Op handle. */ static const Op &Get(); + TileOperator Clone() const; +}; - /*! - * \brief Clone this operator. - */ - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } - -private: - Buffer src, dst; // Source (input feature map) and destination (im2col matrix) - int stride; // Stride for convolution - int padding; // Padding amount - int dilation; // Dilation factor - int kernel; // Kernel size - int eviction_policy; // Cache eviction policy - PrimExpr nhw_step; // Step size in NHW dimensions - PrimExpr c_step; // Step size in channel dimension +class Conv2DIm2ColOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(Conv2DIm2ColOp, TileOperator, + Conv2DIm2ColOpNode); + TVM_DLL Conv2DIm2ColOp(Array args, BufferMap vmap); + static const Op &Get(); }; } // namespace tl diff --git a/src/op/elem.cc b/src/op/elem.cc index 228f05d24..7aec6c3d8 100644 --- a/src/op/elem.cc +++ b/src/op/elem.cc @@ -23,6 +23,7 @@ namespace tl { using namespace tir; Fill::Fill(Array args, BufferMap vmap) { + ObjectPtr node = make_object(); if (args[0]->IsInstance()) { auto buffer_load = Downcast(args[0]); @@ -33,42 +34,49 @@ Fill::Fill(Array args, BufferMap vmap) { const auto *lanes = ramp->lanes.as(); CHECK(lanes) << "Scalable vectors not supported in BufferRegion conversion"; - region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); + node->region.push_back(Range::FromMinExtent(ramp->base, ramp->lanes)); } else { - region.push_back(Range::FromMinExtent(index, 1)); + node->region.push_back(Range::FromMinExtent(index, 1)); } } - dst = buffer_load->buffer; + node->dst = buffer_load->buffer; } else { - dst = vmap[GetVarFromAccessPtr(args[0])]; - for (int i = 0; i < dst->shape.size(); i++) { - region.push_back(Range(0, dst->shape[i])); + node->dst = vmap[GetVarFromAccessPtr(args[0])]; + for (int i = 0; i < node->dst->shape.size(); i++) { + node->region.push_back(Range(0, node->dst->shape[i])); } } - if (args[1]->dtype != dst->dtype) { - value = Cast(dst->dtype, args[1]); + if (args[1]->dtype != node->dst->dtype) { + node->value = Cast(node->dst->dtype, args[1]); } else { - value = args[1]; + node->value = args[1]; } - ICHECK(region.size() == dst->shape.size()) - << "region size = " << region.size() << " != " << dst->shape.size(); - for (int i = 0; i < region.size(); i++) { + ICHECK(node->region.size() == node->dst->shape.size()) + << "region size = " << node->region.size() + << " != " << node->dst->shape.size(); + for (int i = 0; i < node->region.size(); i++) { // bound check if region is static - if (region[i]->min.as()) { - int64_t min = Downcast(region[i]->min)->value; + if (node->region[i]->min.as()) { + int64_t min = Downcast(node->region[i]->min)->value; ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0"; } - if (region[i]->extent.as()) { - int64_t extent = Downcast(region[i]->extent)->value; - ICHECK_LE(extent, Downcast(dst->shape[i])->value) - << "region[" << i << "] = " << extent << " > " << dst->shape[i]; + if (node->region[i]->extent.as()) { + int64_t extent = Downcast(node->region[i]->extent)->value; + ICHECK_LE(extent, Downcast(node->dst->shape[i])->value) + << "region[" << i << "] = " << extent << " > " << node->dst->shape[i]; } } + data_ = std::move(node); } -For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { +TileOperator FillNode::Clone() const { + auto op = make_object(*this); + return Fill(op); +} + +For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { int ndim = dst->shape.size(); Array loop_vars; Array dst_indices; @@ -85,10 +93,15 @@ For Fill::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } -Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { + // get global func "tl.fill.lower" + if (const auto f = ffi::Function::GetGlobal("tl.fill.lower")) { + auto stmt = (*f)(dst, value); + return Downcast(stmt); + } if (dst.scope() == "local.fragment") { - auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); + auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, @@ -106,7 +119,7 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto vectorized_thread_loop = VectorizeLoop(init_loop); return vectorized_thread_loop; } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") { - auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); + auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, @@ -122,7 +135,8 @@ Stmt Fill::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } -LayoutMap Fill::InferLayout(const LayoutInferArgs &T, InferLevel level) const { +LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } diff --git a/src/op/elem.h b/src/op/elem.h index fcb16547f..a3efb3f92 100644 --- a/src/op/elem.h +++ b/src/op/elem.h @@ -15,23 +15,29 @@ namespace tl { using namespace tir; -class Fill : public TileOperator { +class FillNode : public TileOperatorNode { public: - Fill(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; + tir::Buffer dst; + PrimExpr value; + Array region; + static constexpr const char *_type_key = "tl.Fill"; + TVM_DECLARE_FINAL_OBJECT_INFO(FillNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; static const Op &Get(); - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } + TileOperator Clone() const; private: For MakeSIMTLoop(arith::Analyzer *analyzer) const; - tir::Buffer dst; - PrimExpr value; - Array region; +}; + +class Fill : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(Fill, TileOperator, FillNode); + TVM_DLL Fill(Array args, BufferMap vmap); + static const Op &Get(); }; } // namespace tl diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 6bfc1b733..c308dc5a1 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -34,35 +34,44 @@ static std::vector toPrimeFactors(int x) { } Gemm::Gemm(Array args, BufferMap vmap) { - Aptr = args[0]; - Bptr = args[1]; - Cptr = args[2]; - A = vmap[GetVarFromAccessPtr(Aptr)]; - B = vmap[GetVarFromAccessPtr(Bptr)]; - C = vmap[GetVarFromAccessPtr(Cptr)]; - trans_A = args[3].as().value(); - trans_B = args[4].as().value(); - M = args[5].as().value()->value; - N = args[6].as().value()->value; - K = args[7].as().value()->value; - policy = static_cast(args[8].as().value()->value); - clear_accum = args[9].as().value(); - stride_A = args[10].as().value()->value; - stride_B = args[11].as().value()->value; - offset_A = args[12].as().value()->value; - offset_B = args[13].as().value()->value; + ObjectPtr node = make_object(); + + node->Aptr = args[0]; + node->Bptr = args[1]; + node->Cptr = args[2]; + node->A = vmap[GetVarFromAccessPtr(node->Aptr)]; + node->B = vmap[GetVarFromAccessPtr(node->Bptr)]; + node->C = vmap[GetVarFromAccessPtr(node->Cptr)]; + node->trans_A = args[3].as().value(); + node->trans_B = args[4].as().value(); + node->M = args[5].as().value()->value; + node->N = args[6].as().value()->value; + node->K = args[7].as().value()->value; + node->policy = + static_cast(args[8].as().value()->value); + node->clear_accum = args[9].as().value(); + node->stride_A = args[10].as().value()->value; + node->stride_B = args[11].as().value()->value; + node->offset_A = args[12].as().value()->value; + node->offset_B = args[13].as().value()->value; if (args.size() > 14) { - kPack = args[14].as().value()->value; - if (kPack != 1 && kPack != 2) { + node->kPack = args[14].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { ICHECK(false) << "kPack must be 1 or 2"; } } if (args.size() > 15) { - wg_wait = args[15].as().value()->value; + node->wg_wait = args[15].as().value()->value; } + data_ = std::move(node); } -Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { +TileOperator GemmNode::Clone() const { + auto op = make_object(*this); + return Gemm(op); +} + +GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { int warp_size = TargetGetWarpSize(target); int num_warps = block_size / warp_size; bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && @@ -122,9 +131,9 @@ Gemm::GemmInst Gemm::GetGemmInst(int block_size, Target target) const { * divisibility or policy conditions are not met (e.g., M/N tile divisibility, * invalid policy, or WGMMA-specific warp-group requirements). */ -std::pair Gemm::ComputeWarpPartition(int block_size, - GemmInst gemm_inst, - Target target) const { +std::pair GemmNode::ComputeWarpPartition(int block_size, + GemmInst gemm_inst, + Target target) const { int num_warps = block_size / TargetGetWarpSize(target); int m_warp = 1, n_warp = 1; constexpr int kMPerWarp = 16; // Rows processed by a single warp @@ -314,7 +323,7 @@ std::pair Gemm::ComputeWarpPartition(int block_size, * @return true if WGMMA is supported for the current buffers, dtypes, and * transpose/shape constraints; false otherwise. */ -bool Gemm::CheckWGMMA() const { +bool GemmNode::CheckWGMMA() const { if (B.scope() != "shared.dyn" && B.scope() != "shared") { return false; } @@ -379,7 +388,7 @@ static int GetArchInt(Target target) { return arch_int; } -Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto block_size = *as_const_int(T.thread_bounds->extent); GemmInst gemm_inst = GetGemmInst(block_size, T.target); auto [warp_m, warp_n] = ComputeWarpPartition(block_size, gemm_inst, T.target); @@ -440,7 +449,8 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { * @param level Inference level (unused for side effects but retained for API). * @return LayoutMap mapping each of A, B, and C to their inferred layouts. */ -LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) const { +LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (completed_) return {}; LayoutMap results; diff --git a/src/op/gemm.h b/src/op/gemm.h index 3cb3cf0d5..15199b2f3 100644 --- a/src/op/gemm.h +++ b/src/op/gemm.h @@ -14,31 +14,14 @@ namespace tl { using namespace tir; -class Gemm : public TileOperator { -public: - Gemm(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - static const Op &Get(); - enum class GemmWarpPolicy { - kSquare = 0, - kFullRow = 1, - kFullCol = 2, - } policy; - - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } - -private: - // Target GEMM instruction - enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA }; - GemmInst GetGemmInst(int block_size, Target target) const; - - std::pair ComputeWarpPartition(int num_warps, GemmInst gemm_inst, - Target target) const; +enum class GemmWarpPolicy { + kSquare = 0, + kFullRow = 1, + kFullCol = 2, +}; +class GemmNode : public TileOperatorNode { +public: bool CheckWGMMA() const; Array call_args; tir::Buffer A, B, C; @@ -53,9 +36,35 @@ class Gemm : public TileOperator { // only will be enabled under cdna mfma instructions int kPack = 1; int wg_wait = 0; + GemmWarpPolicy policy; + + static constexpr const char *_type_key = "tl.Gemm"; + TVM_DECLARE_FINAL_OBJECT_INFO(GemmNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + + TileOperator Clone() const; + +private: + // Target GEMM instruction + enum class GemmInst { kMMA, kWGMMA, kUTCMMA, kMFMA }; + GemmInst GetGemmInst(int block_size, Target target) const; + + std::pair ComputeWarpPartition(int num_warps, GemmInst gemm_inst, + Target target) const; + mutable bool completed_ = false; }; +class Gemm : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(Gemm, TileOperator, GemmNode); + TVM_DLL Gemm(Array args, BufferMap vmap); + static const Op &Get(); +}; + } // namespace tl } // namespace tvm diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index b642a8cbe..2b4b1c064 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -32,31 +32,39 @@ static std::vector toPrimeFactors(int x) { } GemmSP::GemmSP(Array args, BufferMap vmap) { - A = vmap[GetVarFromAccessPtr(args[0])]; - E = vmap[GetVarFromAccessPtr(args[1])]; - B = vmap[GetVarFromAccessPtr(args[2])]; - C = vmap[GetVarFromAccessPtr(args[3])]; - trans_A = args[4].as().value(); - trans_B = args[5].as().value(); - M = args[6].as().value()->value; - N = args[7].as().value()->value; - K = args[8].as().value()->value; - policy = static_cast(args[9].as().value()->value); - clear_accum = args[10].as().value(); + ObjectPtr node = make_object(); + node->A = vmap[GetVarFromAccessPtr(args[0])]; + node->E = vmap[GetVarFromAccessPtr(args[1])]; + node->B = vmap[GetVarFromAccessPtr(args[2])]; + node->C = vmap[GetVarFromAccessPtr(args[3])]; + node->trans_A = args[4].as().value(); + node->trans_B = args[5].as().value(); + node->M = args[6].as().value()->value; + node->N = args[7].as().value()->value; + node->K = args[8].as().value()->value; + node->policy = static_cast( + args[9].as().value()->value); + node->clear_accum = args[10].as().value(); if (args.size() > 11) { - kPack = args[11].as().value()->value; - if (kPack != 1 && kPack != 2) { + node->kPack = args[11].as().value()->value; + if (node->kPack != 1 && node->kPack != 2) { ICHECK(false) << "kPack must be 1 or 2"; } } if (args.size() > 12) { - wg_wait = args[12].as().value()->value; + node->wg_wait = args[12].as().value()->value; } + data_ = std::move(node); +} + +TileOperator GemmSPNode::Clone() const { + auto op = make_object(*this); + return GemmSP(op); } std::pair -GemmSP::ComputeWarpPartition(int num_warps, Target target, - bool maybe_hopper_wgmma) const { +GemmSPNode::ComputeWarpPartition(int num_warps, Target target, + bool maybe_hopper_wgmma) const { int m_warp = 1, n_warp = 1; constexpr int kMPerWarp = 16; // Rows processed by a single warp constexpr int kNPerWarp = 8; // Columns processed by a single warp @@ -212,7 +220,7 @@ GemmSP::ComputeWarpPartition(int num_warps, Target target, return {m_warp, n_warp}; } -Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { int warp_size = 32; auto block_size = *as_const_int(T.thread_bounds->extent); @@ -256,8 +264,8 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(new_call); } -LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (completed_) return {}; LayoutMap results; @@ -309,6 +317,7 @@ LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, completed_ = true; return results; } + TIR_REGISTER_TL_OP(GemmSP, gemm_sp) .set_num_inputs(5) .set_attr("TCallEffectKind", diff --git a/src/op/gemm_sp.h b/src/op/gemm_sp.h index 9a14f17e9..e645d0d42 100644 --- a/src/op/gemm_sp.h +++ b/src/op/gemm_sp.h @@ -14,24 +14,16 @@ namespace tl { using namespace tir; -class GemmSP : public TileOperator { +class GemmSPNode : public TileOperatorNode { public: - GemmSP(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; - static const Op &Get(); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const; + LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const; enum class GemmWarpPolicy { kSquare = 0, kFullRow = 1, kFullCol = 2, } policy; - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } - -private: std::pair ComputeWarpPartition(int num_warps, Target target, bool maybe_hopper_wgmma = true) const; @@ -45,9 +37,20 @@ class GemmSP : public TileOperator { // only will be enabled under cdna mfma instructions int kPack = 1; int wg_wait = 0; + + TileOperator Clone() const; + +private: mutable bool completed_ = false; }; +class GemmSP : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(GemmSP, TileOperator, GemmSPNode); + TVM_DLL GemmSP(Array args, BufferMap vmap); + static const Op &Get(); +}; + } // namespace tl } // namespace tvm diff --git a/src/op/operator.cc b/src/op/operator.cc index ecdc01ea3..ffc7cdefc 100644 --- a/src/op/operator.cc +++ b/src/op/operator.cc @@ -15,24 +15,23 @@ namespace tl { using namespace tir; -std::unique_ptr ParseOperator(Call call, BufferMap vmap) { +TileOperator ParseOperator(Call call, BufferMap vmap) { auto op_map = Op::GetAttrMap("TLOpBuilder"); Op op = call->op.as().value(); if (op_map.count(op)) { - TileOperator *ptr = - static_cast(op_map[op](call->args, vmap)); - ICHECK(ptr != nullptr); - return std::unique_ptr(ptr); + auto tile_op = op_map[op](call->args, vmap); + ICHECK(tile_op.defined()); + return tile_op; } - return nullptr; + return TileOperator(); } -std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap) { +TileOperator ParseOperator(Stmt stmt, BufferMap vmap) { if (stmt.as() && stmt.as()->value.as()) { auto call = stmt.as()->value.as(); return ParseOperator(GetRef(call), vmap); } - return nullptr; + return TileOperator(); } Var GetVarFromAccessPtr(const PrimExpr &expr) { diff --git a/src/op/operator.h b/src/op/operator.h index 306b829fc..91f7c6bec 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -24,19 +24,6 @@ using namespace tir; using AddWorkspaceCallback = std::function; using LayoutMap = Map; using BufferMap = Map; -using OpBuilderFunc = ffi::TypedFunction, BufferMap)>; - -#define TIR_REGISTER_TL_OP(Entry, OpName) \ - const Op &Entry::Get() { \ - static const Op &op = Op::Get("tl." #OpName); \ - return op; \ - } \ - TVM_REGISTER_OP("tl." #OpName) \ - .set_attr("TScriptPrinterName", #OpName) \ - .set_attr("TLOpBuilder", \ - [](Array a, BufferMap b) { \ - return (void *)(new Entry(a, b)); \ - }) enum class InferLevel { kFree = 0, @@ -60,30 +47,48 @@ struct LayoutInferArgs { Map buffer_remap; }; -class TileOperator { +class TileOperatorNode; +class TileOperator; + +class TileOperatorNode: public Object { public: - // Lower 接口 - virtual Stmt Lower(const LowerArgs& T, arith::Analyzer* analyzer) const { - ICHECK(0) << "Not Implemented Lower method."; - return Evaluate(0); - } - - // InferLayout 接口 - virtual LayoutMap InferLayout(const LayoutInferArgs& T, InferLevel level) const { - return {}; - } - - // Clone 接口 - virtual std::unique_ptr Clone() const = 0; - - // 虚析构函数 - virtual ~TileOperator() = default; + virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const = 0; + + virtual LayoutMap InferLayout(const LayoutInferArgs& T, + InferLevel level) const = 0; + + virtual TileOperator Clone() const = 0; + + static constexpr const char* _type_key = "tl.TileOperator"; + + TVM_DECLARE_BASE_OBJECT_INFO(TileOperatorNode, Object); }; +class TileOperator : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TileOperator, ObjectRef, TileOperatorNode); +}; + + Var GetVarFromAccessPtr(const PrimExpr &expr); -std::unique_ptr ParseOperator(Call call, BufferMap vmap); -std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); +TileOperator ParseOperator(Call call, BufferMap vmap); +TileOperator ParseOperator(Stmt stmt, BufferMap vmap); + +using OpBuilderFunc = ffi::TypedFunction, BufferMap)>; + +#define TIR_REGISTER_TL_OP(Entry, OpName) \ + const Op &Entry::Get() { \ + static const Op &op = Op::Get("tl." #OpName); \ + return op; \ + } \ + TVM_REGISTER_OP("tl." #OpName) \ + .set_attr("TScriptPrinterName", #OpName) \ + .set_attr("TLOpBuilder", \ + [](Array args, BufferMap vmap) { \ + return Entry(args, vmap); \ + }) + } // namespace tl } // namespace tvm diff --git a/src/op/parallel.cc b/src/op/parallel.cc index a2d622f58..2c347c34f 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -154,13 +154,21 @@ void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { StmtExprVisitor::VisitExpr_(op); } -ParallelOp::ParallelOp(For root) : root_(root), V(this) { V.VisitStmt(root); } +ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { + V.VisitStmt(root); +} + +TileOperator ParallelOpNode::Clone() const { + auto op = make_object(*this); + return ParallelOp(op); +} -Stmt ParallelOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt ParallelOpNode::Lower(const LowerArgs &T, + arith::Analyzer *analyzer) const { return root_; } -bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { +bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const { auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); return StructuralEqual()(indice_map_[buffer], common_indice); } @@ -183,8 +191,8 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const { * Can generate new layouts based on vectorization and thread * bounds. Used when maximum performance optimization is desired. */ -LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (loop_layout_.defined()) return {}; if (level == InferLevel::kStrict) @@ -360,7 +368,7 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, return results; } -Optional ParallelOp::GetPredicate(Var thread_var) const { +Optional ParallelOpNode::GetPredicate(Var thread_var) const { if (predicate_.defined()) { return Substitute(predicate_.value(), {{InputPlaceholder(0), thread_var}}); } else { @@ -368,7 +376,7 @@ Optional ParallelOp::GetPredicate(Var thread_var) const { } } -Fragment ParallelOp::CompleteBufferFragment(const Buffer &buffer) const { +Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ICHECK(loop_layout_.defined()); if (IsCommonAccessIndice(buffer)) { return loop_layout_; diff --git a/src/op/parallel.h b/src/op/parallel.h index 5a80b8b7a..addbd49d8 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -31,40 +31,42 @@ bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, Array large_frag_indices, arith::Analyzer &analyzer_); -class ParallelOp; +class ParallelOpNode; class ParallelLoopNestVisitor : public StmtExprVisitor { private: - ParallelLoopNestVisitor(ParallelOp *op) : p(op){}; + ParallelLoopNestVisitor(ParallelOpNode *op) : p(op){}; void VisitStmt_(const ForNode *op) override; void VisitStmt_(const BufferStoreNode *op) override; void VisitExpr_(const BufferLoadNode *op) override; - ParallelOp *p; + ParallelOpNode *p; - friend class ParallelOp; + friend class ParallelOpNode; }; -class ParallelOp : public TileOperator { +class ParallelOpNode : public TileOperatorNode { public: - ParallelOp(For root); + static constexpr const char *_type_key = "tl.ParallelOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode); + + ParallelOpNode(For root); Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; - ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) { + ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) { loop_layout_ = other.loop_layout_; predicate_ = other.predicate_; } - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } Fragment GetLoopLayout() const { return loop_layout_; } For GetRoot() const { return root_; } Map> GetIndiceMap() const { return indice_map_; } Optional GetPredicate(Var thread_var) const; + TileOperator Clone() const; + private: Fragment CompleteBufferFragment(const Buffer &buffer) const; bool IsCommonAccessIndice(const Buffer &buffer) const; @@ -87,6 +89,16 @@ class ParallelOp : public TileOperator { friend class ParallelLoopNestVisitor; }; +class ParallelOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(ParallelOp, TileOperator, ParallelOpNode); + + ParallelOp(For root) { + auto op = make_object(root); + data_ = std::move(op); + } +}; + } // namespace tl } // namespace tvm diff --git a/src/op/reduce.cc b/src/op/reduce.cc index a8fa09e5f..4fcf6c686 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -22,26 +22,38 @@ namespace tl { using namespace tir; ReduceOp::ReduceOp(Array args, BufferMap vmap) { - src = vmap[GetVarFromAccessPtr(args[0])]; - dst = vmap[GetVarFromAccessPtr(args[1])]; - String reduce_type = args[2].as().value()->value; - dim = args[3].as().value()->value; + ObjectPtr node = make_object(); + node->src = vmap[GetVarFromAccessPtr(args[0])]; + node->dst = vmap[GetVarFromAccessPtr(args[1])]; + std::string reduce_type = args[2].as().value()->value; + node->dim = args[3].as().value()->value; if (reduce_type == "sum") - type = ReduceType::kSum; + node->type = ReduceType::kSum; else if (reduce_type == "abssum") - type = ReduceType::kAbsSum; + node->type = ReduceType::kAbsSum; else if (reduce_type == "absmax") - type = ReduceType::kAbsMax; + node->type = ReduceType::kAbsMax; else if (reduce_type == "max") - type = ReduceType::kMax; + node->type = ReduceType::kMax; else if (reduce_type == "min") - type = ReduceType::kMin; + node->type = ReduceType::kMin; else ICHECK(0) << "Unknown reduce type: " << reduce_type; - clear = args[4].as().value(); + node->clear = args[4].as().value(); + data_ = std::move(node); } -PrimExpr ReduceOp::MakeInitValue() const { +TileOperator ReduceOpNode::Clone() const { + auto op = make_object(*this); + return ReduceOp(op); +} + +TileOperator CumSumOpNode::Clone() const { + auto op = make_object(*this); + return CumSumOp(op); +} + +PrimExpr ReduceOpNode::MakeInitValue() const { auto dst_dtype = dst->dtype; auto is_int = dst_dtype.is_int(); bool is_uint = dst_dtype.is_uint(); @@ -75,7 +87,7 @@ PrimExpr ReduceOp::MakeInitValue() const { } } -PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { +PrimExpr ReduceOpNode::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { PrimExpr lhs = a, rhs = b; if (lhs->dtype != rhs->dtype) { rhs = Cast(lhs->dtype, rhs); @@ -97,7 +109,7 @@ PrimExpr ReduceOp::MakeReduce(const PrimExpr &a, const PrimExpr &b) const { } } -std::string ReduceOp::MakeCodegenReducer() const { +std::string ReduceOpNode::MakeCodegenReducer() const { switch (type) { case ReduceType::kSum: return "tl::SumOp"; @@ -115,7 +127,7 @@ std::string ReduceOp::MakeCodegenReducer() const { } } -Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ICHECK(this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") << "Reduce for shared memory not implemented."; @@ -284,8 +296,8 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return body; } -LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { if (level >= InferLevel::kStrict) return {}; if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && @@ -370,14 +382,16 @@ CumSumOp::CumSumOp(Array args, BufferMap vmap) { reverse: whether to cumsum in reverse order */ CHECK_EQ(args.size(), 4); - src = vmap[GetVarFromAccessPtr(args[0])]; - dst = vmap[GetVarFromAccessPtr(args[1])]; - dim = args[2].as().value()->value; - reverse = args[3].as().value(); - CHECK_LT(dim, static_cast(src->shape.size())); + ObjectPtr node = make_object(); + node->src = vmap[GetVarFromAccessPtr(args[0])]; + node->dst = vmap[GetVarFromAccessPtr(args[1])]; + node->dim = args[2].as().value()->value; + node->reverse = args[3].as().value(); + CHECK_LT(node->dim, static_cast(node->src->shape.size())); + data_ = std::move(node); } -Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt CumSumOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (this->src.scope() == "local.fragment" && this->dst.scope() == "local.fragment") { LOG(FATAL) << "CumSum for fragment not implemented, please raise an issue " @@ -403,8 +417,8 @@ Stmt CumSumOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Stmt(); } -LayoutMap CumSumOp::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } diff --git a/src/op/reduce.h b/src/op/reduce.h index 5303cb147..2be74cf09 100644 --- a/src/op/reduce.h +++ b/src/op/reduce.h @@ -14,51 +14,63 @@ namespace tl { using namespace tir; -class ReduceOp : public TileOperator { +enum class ReduceType { + kSum, + kAbsSum, + kMax, + kMin, + kAbsMax, +}; + +class ReduceOpNode : public TileOperatorNode { public: - ReduceOp(Array args, BufferMap vmap); + tir::Buffer src, dst; + int dim; + ReduceType type; + bool clear; + + static constexpr const char *_type_key = "tl.ReduceOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(ReduceOpNode, TileOperatorNode); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; static const Op &Get(); - - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } + TileOperator Clone() const; private: - tir::Buffer src, dst; - int dim; - enum class ReduceType { - kSum, - kAbsSum, - kMax, - kMin, - kAbsMax, - } type; - bool clear; - PrimExpr MakeInitValue() const; PrimExpr MakeReduce(const PrimExpr &a, const PrimExpr &b) const; std::string MakeCodegenReducer() const; }; -class CumSumOp : public TileOperator { +class ReduceOp : public TileOperator { public: - CumSumOp(Array args, BufferMap vmap); - Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; - LayoutMap InferLayout(const LayoutInferArgs &T, - InferLevel level) const override; + TVM_DEFINE_OBJECT_REF_METHODS(ReduceOp, TileOperator, ReduceOpNode); + TVM_DLL ReduceOp(Array args, BufferMap vmap); static const Op &Get(); +}; - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } - -private: +class CumSumOpNode : public TileOperatorNode { +public: tir::Buffer src, dst; int dim; bool reverse; + static constexpr const char *_type_key = "tl.CumSumOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(CumSumOpNode, TileOperatorNode); + + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + LayoutMap InferLayout(const LayoutInferArgs &T, + InferLevel level) const override; + static const Op &Get(); + TileOperator Clone() const; +}; + +class CumSumOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(CumSumOp, TileOperator, CumSumOpNode); + TVM_DLL CumSumOp(Array args, BufferMap vmap); + static const Op &Get(); }; } // namespace tl diff --git a/src/op/region.cc b/src/op/region.cc index 6d61c78b6..0b74ab00f 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -11,11 +11,6 @@ namespace tvm { namespace tl { using namespace tir; -TIR_REGISTER_TL_OP(RegionOp, region) - .set_num_inputs(-1) - .set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); - RegionOp::RegionOp(Array args, BufferMap vmap) { size_t n = args.size(); size_t ndim = n - 2; @@ -23,16 +18,25 @@ RegionOp::RegionOp(Array args, BufferMap vmap) { ICHECK(load); ICHECK(load->indices.size() == ndim) << "load->indices.size() = " << load->indices << " ndim = " << ndim; - buffer_ = load->buffer; - access_mask_ = static_cast(*as_const_int(args[1])); + Array ranges; for (size_t i = 0; i < ndim; i++) { PrimExpr min = load->indices[i]; PrimExpr extent = args[2 + i]; - ranges_.push_back(Range::FromMinExtent(min, extent)); + ranges.push_back(Range::FromMinExtent(min, extent)); } + ObjectPtr node = make_object(); + node->buffer_ = load->buffer; + node->access_mask_ = static_cast(*as_const_int(args[1])); + node->ranges_ = ranges; + data_ = std::move(node); +} + +TileOperator RegionOpNode::Clone() const { + auto op = make_object(*this); + return RegionOp(op); } -bool RegionOp::IsFullRegion() const { +bool RegionOpNode::IsFullRegion() const { for (size_t i = 0; i < ranges_.size(); i++) { if (!is_zero(ranges_[i]->min)) return false; @@ -42,14 +46,19 @@ bool RegionOp::IsFullRegion() const { return true; } -Stmt RegionOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { +Stmt RegionOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return Evaluate(0); } -LayoutMap RegionOp::InferLayout(const LayoutInferArgs &T, - InferLevel level) const { +LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, + InferLevel level) const { return {}; } +TIR_REGISTER_TL_OP(RegionOp, region) + .set_num_inputs(-1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kPure)); + } // namespace tl } // namespace tvm diff --git a/src/op/region.h b/src/op/region.h index 9d62b2641..1d56ea47b 100644 --- a/src/op/region.h +++ b/src/op/region.h @@ -18,33 +18,34 @@ namespace tl { using namespace tir; -class RegionOp : public TileOperator { +class RegionOpNode : public TileOperatorNode { public: - RegionOp(Array args, BufferMap vmap); + Buffer buffer_; + Array ranges_; + int access_mask_; + + static constexpr const char *_type_key = "tl.RegionOp"; + TVM_DECLARE_FINAL_OBJECT_INFO(RegionOpNode, TileOperatorNode); + Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; - static const Op &Get(); - - std::unique_ptr Clone() const override { - return std::make_unique(*this); - } const Buffer &GetBuffer() const { return buffer_; } const Array &GetRanges() const { return ranges_; } int GetAccessMask() const { return access_mask_; } bool IsFullRegion() const; -private: - Buffer buffer_; - Array ranges_; - int access_mask_; + TileOperator Clone() const; }; -Var GetVarFromAccessPtr(const PrimExpr &expr); +class RegionOp : public TileOperator { +public: + TVM_DEFINE_OBJECT_REF_METHODS(RegionOp, TileOperator, RegionOpNode); + TVM_DLL RegionOp(Array args, BufferMap vmap); -std::unique_ptr ParseOperator(Call call, BufferMap vmap); -std::unique_ptr ParseOperator(Stmt stmt, BufferMap vmap); + static const Op &Get(); +}; } // namespace tl } // namespace tvm diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 138240e83..4e5dfa40b 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -80,8 +80,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { auto iter_var = thread_var_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id]; // Double-check that 'next' is valid - ICHECK(next != nullptr) - << "infer_list_[" << cur_infer_id << "] is null inside run_infer_step."; + ICHECK(next.defined()) << "infer_list_[" << cur_infer_id + << "] is null inside run_infer_step."; // Check iter_var->dom and dom->extent ICHECK(iter_var.defined()) @@ -211,7 +211,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { std::vector in_queue(num_infer, true); for (int i = 0; i < num_infer; i++) { // Check that each infer_list_ entry is valid - ICHECK(infer_list_[i] != nullptr) + ICHECK(infer_list_[i].defined()) << "infer_list_[" << i << "] is null. The inference object is not allocated properly."; @@ -254,13 +254,13 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK(infer_list_.size() == thread_var_vec_.size()) << "infer_list_ and thread_var_vec_ size mismatch"; for (int i = 0; i < infer_list_.size(); i++) { - std::unique_ptr base_infer = std::move(infer_list_[i]); + TileOperator base_infer = std::move(infer_list_[i]); auto thread_var = thread_var_vec_[i]; // Check if base_infer is valid - ICHECK(base_infer != nullptr) << "Null pointer encountered in " - "infer_list_ while collecting for_map."; - if (auto for_infer = dynamic_cast(base_infer.get())) { + ICHECK(base_infer.defined()) << "Null pointer encountered in " + "infer_list_ while collecting for_map."; + if (auto for_infer = base_infer.as()) { // Check that the loop layout is defined ICHECK(for_infer->GetLoopLayout().defined()) << "The Layout for Parallel for cannot be inferred correctly:\n" @@ -298,7 +298,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { return; auto p = ParseOperator(GetRef(op), buffer_data_to_buffer_); - if (p != nullptr) { + if (p.defined()) { for (const auto &arg : op->args) { if (auto buffer = getBufferFromAccessPtr(arg)) { addToUseList(buffer.value()); @@ -345,7 +345,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { void VisitStmt_(const ForNode *op) final { if (op->kind == ForKind::kParallel) { - auto infer = std::make_unique(GetRef(op)); + auto infer = ParallelOp(GetRef(op)); for (const auto &[buffer, _] : infer->GetIndiceMap()) { addToUseList(buffer); } @@ -400,7 +400,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { Map buffer_data_to_buffer_; std::vector infer_list_stmt_; - std::vector> infer_list_; + std::vector infer_list_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> use_list_; // This is a workaround for cpu backend, @@ -413,8 +413,8 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { LayoutMap annotated_layout_map_; bool skip_thread_partition_{false}; - std::vector> BackupInferList() { - std::vector> back_infer_list; + std::vector BackupInferList() { + std::vector back_infer_list; back_infer_list.reserve(infer_list_.size()); for (auto &&p : infer_list_) { back_infer_list.push_back(p->Clone()); diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 1e4ab6cdc..4f8437fba 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -421,7 +421,7 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { return Downcast(IRMutatorWithAnalyzer::VisitStmt_(op)); auto tile_op = ParseOperator(GetRef(op), buffer_data_to_buffer_); - if (tile_op == nullptr) + if (!tile_op.defined()) return IRMutatorWithAnalyzer::VisitStmt_(op); AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { auto workspace = From 67e124140567829ebc044ebdf16b5d7991ef4089 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 29 Aug 2025 22:14:53 +0800 Subject: [PATCH 4/5] Enhance Clone methods in AtomicAdd and Copy classes to support parallel operation cloning - Updated the Clone methods in AtomicAddNode and CopyNode to ensure that the parallel operation (par_op_) is properly cloned when defined, improving the integrity of cloned objects. - Refactored the FillNode class to use ParallelOp directly instead of std::make_unique, streamlining the creation of parallel operations. - Made minor adjustments in layout inference and other related methods for consistency and clarity. --- src/op/atomic_add.cc | 3 +++ src/op/copy.cc | 5 ++-- src/op/elem.cc | 4 ++-- src/op/parallel.h | 39 +++++++++++++++++++++++++------ src/transform/layout_inference.cc | 7 +++++- 5 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index a2d622a11..acc54e9e0 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -56,6 +56,9 @@ AtomicAdd::AtomicAdd(Array args, BufferMap vmap) { TileOperator AtomicAddNode::Clone() const { auto op = make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } return AtomicAdd(op); } diff --git a/src/op/copy.cc b/src/op/copy.cc index cf6a2778d..f2d7cf3f0 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -143,6 +143,9 @@ Copy::Copy(Array args, BufferMap vmap) { TileOperator CopyNode::Clone() const { auto op = make_object(*this); + if (par_op_.defined()) { + op->par_op_ = Downcast(par_op_->Clone()); + } return Copy(op); } @@ -349,7 +352,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, return Map({{shared_tensor, linear_layout}}); } } - // for LDSM/STSM, the layout was deduced from register layout // so we can directly apply the layout of normal copy // Use parallel op to infer the layout @@ -359,7 +361,6 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, } return par_op_->InferLayout(T, level); } - /*! * \brief Check if the copy operation is a bulk load. * This function verifies if the copy operation can be implemented using CUDA's diff --git a/src/op/elem.cc b/src/op/elem.cc index 7aec6c3d8..ccc92595f 100644 --- a/src/op/elem.cc +++ b/src/op/elem.cc @@ -101,7 +101,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } if (dst.scope() == "local.fragment") { - auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); + auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, @@ -119,7 +119,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { auto vectorized_thread_loop = VectorizeLoop(init_loop); return vectorized_thread_loop; } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared") { - auto par_op = std::make_unique(MakeSIMTLoop(analyzer)); + auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map}, InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, diff --git a/src/op/parallel.h b/src/op/parallel.h index addbd49d8..fe514b43d 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -45,48 +45,73 @@ class ParallelLoopNestVisitor : public StmtExprVisitor { friend class ParallelOpNode; }; +// ParallelOpNode represents a parallel for loop operator in TileLang. +// It is responsible for inferring layouts, holding loop structure, and managing +// predicates. class ParallelOpNode : public TileOperatorNode { public: + // The inferred layout for the loop, mutable to allow lazy inference. + mutable Fragment loop_layout_; + // The predicate expression for the loop, if any, mutable for lazy + // construction. + mutable Optional predicate_; + + // Type key for TVM object system. static constexpr const char *_type_key = "tl.ParallelOp"; TVM_DECLARE_FINAL_OBJECT_INFO(ParallelOpNode, TileOperatorNode); + // Construct from a root For loop. ParallelOpNode(For root); + + // Lower the operator to a TIR statement. Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const override; + + // Infer the layout for this parallel operator. LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) const override; + // Copy constructor for ParallelOpNode. ParallelOpNode(const ParallelOpNode &other) : ParallelOpNode(other.root_) { loop_layout_ = other.loop_layout_; predicate_ = other.predicate_; } + // Get the inferred loop layout. Fragment GetLoopLayout() const { return loop_layout_; } + // Get the root For loop. For GetRoot() const { return root_; } + // Get the mapping from buffer to access indices. Map> GetIndiceMap() const { return indice_map_; } + // Get the predicate for a given thread variable. Optional GetPredicate(Var thread_var) const; + // Clone this operator. TileOperator Clone() const; private: + // Complete the fragment layout for a given buffer. Fragment CompleteBufferFragment(const Buffer &buffer) const; + // Check if the buffer is accessed with common indices (i.e., loop variables). bool IsCommonAccessIndice(const Buffer &buffer) const; + // Add a predicate to the current predicate expression. void AddPredicate(PrimExpr expr) const { predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; } + // Allow ParallelLoopNestVisitor to access private members. + friend class ParallelLoopNestVisitor; + // The root For loop node. For root_; - + // Visitor for collecting loop nest information. ParallelLoopNestVisitor V; - + // Mapping from buffer to their access indices in the loop. Map> indice_map_; + // Set of buffers that are written to in the loop. std::unordered_set buffer_is_write_; + // The loop variables for the parallel loop nest. Array loop_vars_; - - mutable Fragment loop_layout_; + // Analyzer for simplifying and analyzing expressions, mutable for lazy use. mutable arith::Analyzer analyzer_; - mutable Optional predicate_; - - friend class ParallelLoopNestVisitor; }; class ParallelOp : public TileOperator { diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index 4e5dfa40b..5654044c1 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -101,6 +101,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Run InferLayout auto updates = next->InferLayout( LayoutInferArgs{target_, thread_bounds, layout_map}, level); + // Process the returned updates for (const auto &[buffer, layout] : updates) { // Basic validity checks @@ -444,20 +445,25 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { int root = uf.Find(i); components[root].push_back(i); } + // Create a map from root to buffers std::unordered_map> components_buffers; for (const auto &[buffer, infer_indices] : use_list_) { int root = uf.Find(infer_indices[0]); components_buffers[root].push_back(buffer); } + // Keep components_buffers for debug purpose + (void)components_buffers; // For each component, try each op as root, and determine the least // replicated one std::queue q; std::vector in_queue(infer_list_.size(), false); + for (auto &&[root, members] : components) { decltype(infer_list_) best_infer_list; LayoutMap best_layout_map; int64_t min_reg_num = INT64_MAX; + for (int attempt_infer_root : members) { // backup infer_list_ in class member auto back_infer_list = BackupInferList(); @@ -471,7 +477,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { tmp_layout_map, strict_layout_map, q, in_queue); FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, q, in_queue); - // Silly workaround: we have no clue if single root will iterate over // the entire component, since the InferLayout implementations have // complicated conditioning inside and we know nothing about it. From d02ec6908e6d98ea580da74b2e92efdea2edfdf7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 29 Aug 2025 22:20:15 +0800 Subject: [PATCH 5/5] Refactor FillNode::Lower method to remove unused global function call - Eliminated the call to the global function "tl.fill.lower" in the FillNode::Lower method, streamlining the code and improving clarity. - Retained the core functionality of the method while enhancing maintainability by reducing unnecessary dependencies. --- src/op/elem.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/op/elem.cc b/src/op/elem.cc index ccc92595f..a3b5b469e 100644 --- a/src/op/elem.cc +++ b/src/op/elem.cc @@ -94,12 +94,6 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { } Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { - // get global func "tl.fill.lower" - if (const auto f = ffi::Function::GetGlobal("tl.fill.lower")) { - auto stmt = (*f)(dst, value); - return Downcast(stmt); - } - if (dst.scope() == "local.fragment") { auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); par_op->InferLayout({T.target, T.thread_bounds, T.layout_map},