diff --git a/src/op/gemm.cc b/src/op/gemm.cc index d67317dad..45df6c2c9 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -241,6 +241,10 @@ std::pair Gemm::ComputeWarpPartition(int block_size, } bool Gemm::CheckWGMMA() const { + if (B.scope() != "shared.dyn" && B.scope() != "shared") { + return false; + } + if (C->dtype == DataType::Float(16)) { if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16)) return K % 16 == 0; @@ -443,7 +447,9 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) { B->dtype.bits(), trans_B ? 2 : 1); results.Set(B, ABLayout); } else { - ICHECK(0) << "WGMMA only support B in shared."; + auto fragment = + makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B); + results.Set(B, fragment->BindThreadRange(thread_range)); } } else if (TargetIsCDNA(T.target)) { auto fragment = @@ -490,4 +496,4 @@ TIR_REGISTER_TL_OP(Gemm, gemm) Integer(CallEffectKind::kOpaque)); } // namespace tl -} // namespace tvm \ No newline at end of file +} // namespace tvm diff --git a/src/tl_templates/cuda/gemm_sm90.h b/src/tl_templates/cuda/gemm_sm90.h index f2579a7d4..2f855d307 100644 --- a/src/tl_templates/cuda/gemm_sm90.h +++ b/src/tl_templates/cuda/gemm_sm90.h @@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) { } } +template +TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) { + static_assert(!use_wgmma, "wgmma doesn't support gemm_sr"); + using MMA = + cute::tl_mma::GemmTensorOp; + MMA::body_sr(pA, pB, accum); +} + template TL_DEVICE void wait_wgmma() { cute::warpgroup_wait(); } diff --git a/src/transform/warp_specialized_rewriter.cc b/src/transform/warp_specialized_rewriter.cc index 1ea14ad5b..2353f7fc0 100644 --- a/src/transform/warp_specialized_rewriter.cc +++ b/src/transform/warp_specialized_rewriter.cc @@ -572,12 +572,11 @@ class WSCodeEmitter : public StmtMutator { WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv, Map buffer_data_to_buffer, const WarpSpecializedRoleMarker &marker, - bool mbarrier_only = false) + bool mbarrier_only = false, bool only_has_wgmma = false) : is_emitting_producer_(is_emitting_producer), buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker), - thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {} - - bool onlyHasWgMMA() const { return only_has_wgmma_; } + thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only), + only_has_wgmma_(only_has_wgmma) {} bool hasSimtCopy() const { return has_simt_copy_; } @@ -617,8 +616,6 @@ class WSCodeEmitter : public StmtMutator { auto map = ExtractSyncPattern(op->seq); - only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq)); - /* std::cout << "Print ExtractSyncPattern" << std::endl; for (int i = 0; i < static_cast(op->seq.size()); i++) { @@ -1212,11 +1209,12 @@ class WarpSpecializedRewriter : public StmtExprMutator { block_realize.CopyOnWrite()->block = block; return block_realize; } + only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body); WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker); - WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker); + WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker, + false, only_has_wgmma_); Stmt producer_code = producer(block->body); Stmt consumer_code = consumer(block->body); - bool only_has_wgmma = consumer.onlyHasWgMMA(); PrimExpr consumer_thread_extent = thread_iv_->dom->extent; PrimExpr producer_thread_extent = thread_iv_->dom->extent; // Need one warp-group for bulk-copy only case @@ -1259,8 +1257,8 @@ class WarpSpecializedRewriter : public StmtExprMutator { PrimExpr arrive_thread_count = producer.released_barrier_.count(i) ? (producer.hasSimtCopy() ? producer_thread_extent : 1) - : (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128) - : consumer_thread_extent); + : (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128) + : consumer_thread_extent); barrier_num_threads.push_back(arrive_thread_count); } @@ -1289,6 +1287,7 @@ class WarpSpecializedRewriter : public StmtExprMutator { bool disable_warp_specialized_ = false; bool disable_shuffle_elect_ = false; Array nreg_; + bool only_has_wgmma_ = false; }; class WarpSpecializedDetector : public IRVisitorWithAnalyzer {