Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int block_size,
}

bool Gemm::CheckWGMMA() const {
if (B.scope() != "shared.dyn" && B.scope() != "shared") {
return false;
}

if (C->dtype == DataType::Float(16)) {
if (A->dtype == DataType::Float(16) && B->dtype == DataType::Float(16))
return K % 16 == 0;
Expand Down Expand Up @@ -442,7 +446,10 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
B->dtype.bits(), trans_B ? 2 : 1);
results.Set(B, ABLayout);
} else {
ICHECK(0) << "WGMMA only support B in shared.";
// ICHECK(0) << "WGMMA only support B in shared.";
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
}
} else if (TargetIsCDNA(T.target)) {
auto fragment =
Expand Down
13 changes: 13 additions & 0 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,19 @@ TL_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
}
}

template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The wg_wait template parameter is unused in this function. Consider removing it to simplify the function signature. Since this function implements a non-WGMMA path, the warp-group wait parameter is not applicable here.

TL_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
trans_B, clear_accum, lda, ldb, offset_a,
offset_b, A_type, B_type, C_type>;
MMA::body_sr(pA, pB, accum);
}

template <int num_mma> TL_DEVICE void wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}
Expand Down
19 changes: 9 additions & 10 deletions src/transform/warp_specialized_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,12 +572,11 @@ class WSCodeEmitter : public StmtMutator {
WSCodeEmitter(bool is_emitting_producer, IterVar thread_iv,
Map<Var, Buffer> buffer_data_to_buffer,
const WarpSpecializedRoleMarker &marker,
bool mbarrier_only = false)
bool mbarrier_only = false, bool only_has_wgmma = false)
: is_emitting_producer_(is_emitting_producer),
buffer_data_to_buffer_(buffer_data_to_buffer), marker_(marker),
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only) {}

bool onlyHasWgMMA() const { return only_has_wgmma_; }
thread_var_(thread_iv->var), mbarrier_only_(mbarrier_only),
only_has_wgmma_(only_has_wgmma) {}

bool hasSimtCopy() const { return has_simt_copy_; }

Expand Down Expand Up @@ -617,8 +616,6 @@ class WSCodeEmitter : public StmtMutator {

auto map = ExtractSyncPattern(op->seq);

only_has_wgmma_ = WgMMACollector::HasWgMMA(SeqStmt(op->seq));

/*
std::cout << "Print ExtractSyncPattern" << std::endl;
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Expand Down Expand Up @@ -1212,11 +1209,12 @@ class WarpSpecializedRewriter : public StmtExprMutator {
block_realize.CopyOnWrite()->block = block;
return block_realize;
}
only_has_wgmma_ = WgMMACollector::HasWgMMA(block->body);
WSCodeEmitter producer(true, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker);
WSCodeEmitter consumer(false, thread_iv_, buffer_data_to_buffer_, marker,
false, only_has_wgmma_);
Stmt producer_code = producer(block->body);
Stmt consumer_code = consumer(block->body);
bool only_has_wgmma = consumer.onlyHasWgMMA();
PrimExpr consumer_thread_extent = thread_iv_->dom->extent;
PrimExpr producer_thread_extent = thread_iv_->dom->extent;
// Need one warp-group for bulk-copy only case
Expand Down Expand Up @@ -1259,8 +1257,8 @@ class WarpSpecializedRewriter : public StmtExprMutator {
PrimExpr arrive_thread_count =
producer.released_barrier_.count(i)
? (producer.hasSimtCopy() ? producer_thread_extent : 1)
: (only_has_wgmma ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
: (only_has_wgmma_ ? FloorDiv(consumer_thread_extent, 128)
: consumer_thread_extent);
barrier_num_threads.push_back(arrive_thread_count);
}

Expand Down Expand Up @@ -1289,6 +1287,7 @@ class WarpSpecializedRewriter : public StmtExprMutator {
bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
bool only_has_wgmma_ = false;
};

class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
Expand Down
Loading