diff --git a/src/op/builtin.h b/src/op/builtin.h index 0dea72230..6a84a190e 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -25,6 +25,8 @@ namespace attr { static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kWarpSpecializationScope = "kWarpSpecializationScope"; +static constexpr const char *kCustomWarpSpecialization = + "kCustomWarpSpecialization"; } // namespace attr static constexpr const char *kDebugMergeSharedMemoryAllocations = diff --git a/src/transform/annotate_warp_group_reg_alloc.cc b/src/transform/annotate_warp_group_reg_alloc.cc index 5d0f5b0af..dd6922390 100644 --- a/src/transform/annotate_warp_group_reg_alloc.cc +++ b/src/transform/annotate_warp_group_reg_alloc.cc @@ -17,6 +17,9 @@ class SetMaxNRegCollector : public StmtExprVisitor { static Array Collect(const PrimFunc &f) { SetMaxNRegCollector collector; collector(f->body); + if (collector.warp_specialized_) { + return Array({}); + } return collector.has_no_set_max_nreg_ ? Array({IntImm(DataType::Int(32), -1), IntImm(DataType::Int(32), -1)}) @@ -43,21 +46,27 @@ class SetMaxNRegCollector : public StmtExprVisitor { StmtExprVisitor::VisitStmt_(op); } + void VisitStmt_(const AttrStmtNode *op) final { + if (op->attr_key == attr::kCustomWarpSpecialization) { + warp_specialized_ = true; + } + StmtExprVisitor::VisitStmt_(op); + } + Array nreg_{IntImm(DataType::Int(32), 0), IntImm(DataType::Int(32), 0)}; bool has_no_set_max_nreg_ = false; + bool warp_specialized_ = false; }; class SetMaxNRegInjector : public StmtExprMutator { public: static PrimFunc Inject(PrimFunc f) { - bool warp_specialized = WarpSpecializedDetector::Detect(f->body); - if (warp_specialized) { - // Should handle set_max_nreg when using hand-written warp specialized - return f; - } auto T = SetMaxNRegInjector(); T.nreg_ = SetMaxNRegCollector::Collect(f); + if (T.nreg_.empty()) { + return f; + } f.CopyOnWrite()->body = T(f->body); return f; } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 9d4892879..41a778d07 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -1283,8 +1283,12 @@ tvm::transform::Pass WarpSpecialized() { if (!warp_specialized) { return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, disable_shuffle_elect); + } else { + ObjectRef node = String("default"); + f.CopyOnWrite()->body = + AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body); + return f; } - return f; }; return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); } diff --git a/tilelang/language/builtin.py b/tilelang/language/builtin.py index 7646d0805..e1ea0c34f 100644 --- a/tilelang/language/builtin.py +++ b/tilelang/language/builtin.py @@ -355,4 +355,4 @@ def sync_grid(): def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]): """Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc. """ - return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id) \ No newline at end of file + return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)