From ee8aca159b3bf7f1ea542fe65f788e74d005c1c9 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 15 Sep 2025 13:29:14 +0800 Subject: [PATCH 1/3] [Feature] Introduce custom warp specialization attribute and enhance warp group register allocation - Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework. - Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly. - Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency. - Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations. --- src/op/builtin.h | 1 + .../annotate_warp_group_reg_alloc.cc | 19 ++++++++++++++----- src/transform/warp_specialized_rewriter.cc | 5 ++++- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 0dea72230..6ea217c91 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -25,6 +25,7 @@ namespace attr { static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kWarpSpecializationScope = "kWarpSpecializationScope"; +static constexpr const char *kCustomWarpSpecialization = "kCustomWarpSpecialization"; } // namespace attr static constexpr const char *kDebugMergeSharedMemoryAllocations = diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index 5d0f5b0af..d9ae0f2e3 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -17,6 +17,9 @@ class SetMaxNRegCollector : public StmtExprVisitor { static Array Collect(const PrimFunc &f) { SetMaxNRegCollector collector; collector(f->body); + if (collector.warp_specialized_) { + return Array({}); + } return collector.has_no_set_max_nreg_ ? Array({IntImm(DataType::Int(32), -1), IntImm(DataType::Int(32), -1)}) @@ -43,21 +46,27 @@ class SetMaxNRegCollector : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == attr::kCustomWarpSpecialization) { + warp_specialized_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + Array nreg_{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)}; bool has_no_set_max_nreg_ = false; + bool warp_specialized_ = false; }; class SetMaxNRegInjector : public StmtExprMutator { public: static PrimFunc Inject(PrimFunc f) { - bool warp_specialized = WarpSpecializedDetector::Detect(f->body); - if (warp_specialized) { - // Should handle set_max_nreg when using hand-written warp specialized - return f; - } auto T = SetMaxNRegInjector(); T.nreg_ = SetMaxNRegCollector::Collect(f); + if (T.nreg_.size() == 0) { + return f; + } f.CopyOnWrite()->body = T(f->body); return f; } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 9d4892879..625c39ddd 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -1283,8 +1283,11 @@ tvm::transform::Pass WarpSpecialized() { if (!warp_specialized) { return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, disable_shuffle_elect); + } else { + ObjectRef node = String("default"); + f.CopyOnWrite()->body = AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); + return f; } - return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); } From 76d6f4783e31cccae2ad3f5729014eac5473dd50 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 15 Sep 2025 15:36:01 +0800 Subject: [PATCH 2/3] lint --- src/op/builtin.h | 3 ++- src/transform/warp_specialized_rewriter.cc | 3 ++- tilelang/language/builtin.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/op/builtin.h b/src/op/builtin.h index 6ea217c91..6a84a190e 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -25,7 +25,8 @@ namespace attr { static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kWarpSpecializationScope = "kWarpSpecializationScope"; -static constexpr const char *kCustomWarpSpecialization = "kCustomWarpSpecialization"; +static constexpr const char *kCustomWarpSpecialization = + "kCustomWarpSpecialization"; } // namespace attr static constexpr const char *kDebugMergeSharedMemoryAllocations = diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 625c39ddd..41a778d07 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -1285,7 +1285,8 @@ tvm::transform::Pass WarpSpecialized() { disable_shuffle_elect); } else { ObjectRef node = String("default"); - f.CopyOnWrite()->body = AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); + f.CopyOnWrite()->body = + AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); return f; } }; diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 7646d0805..e1ea0c34f 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -355,4 +355,4 @@ def sync_grid(): def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ - return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) \ No newline at end of file + return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) From e3692ef64a0926ad5008ba680bdd9d0fa47dcdc2 Mon Sep 17 00:00:00 2001 From: Yu Cheng Date: Mon, 15 Sep 2025 17:18:12 +0800 Subject: [PATCH 3/3] lint --- src/transform/annotate_warp_group_reg_alloc.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index d9ae0f2e3..dd6922390 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -64,7 +64,7 @@ class SetMaxNRegInjector : public StmtExprMutator { static PrimFunc Inject(PrimFunc f) { auto T = SetMaxNRegInjector(); T.nreg_ = SetMaxNRegCollector::Collect(f); - if (T.nreg_.size() == 0) { + if (T.nreg_.empty()) { return f; } f.CopyOnWrite()->body = T(f->body);