Skip to content

Commit d867739

Browse files
committed
[Feature] Enhance GEMM operations with MMA and WGMMA support
- Added a mapping for GEMM instruction prefixes in `gemm.h`. - Renamed GEMM functions to include `mma_` prefix for clarity in `gemm_mma.h`, `gemm_sm70.h`, `gemm_sm90.h`, and `gemm_sp_sm80.h`. - Updated function signatures to improve consistency and readability. - Introduced new functions for MMA and WGMMA operations in `gemm_sm90.h` and `gemm_sp_sm90.h`. - Added a utility function for parsing arguments in `inject_fence_proxy.cc` to enhance async handling.
1 parent 4efd2d2 commit d867739

File tree

11 files changed

+98
-84
lines changed

11 files changed

+98
-84
lines changed

examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py renamed to examples/sparse_tensorcore/example_sparse_tensorcore.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import tilelang
33
from tilelang.utils.sparse import compress_sm90
44
from tilelang.layout import make_metadata_layout
5-
import tilelang.testing
5+
import tilelang.language as T
66

77

88
@tilelang.jit(out_idx=[-1])
@@ -24,7 +24,6 @@ def matmul_sp(
2424
A_shared_shape = (block_M, block_K // 2)
2525
B_shared_shape = (block_K, block_N)
2626

27-
import tilelang.language as T
2827

2928
@T.prim_func
3029
def main(

src/op/gemm.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,16 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
578578

579579
if (A.scope() == "local.fragment") {
580580
ICHECK(B.scope() != "local.fragment");
581-
op_name = "tl::gemm_rs";
581+
op_name = "gemm_rs";
582582
} else if (B.scope() == "local.fragment") {
583-
op_name = "tl::gemm_sr";
583+
op_name = "gemm_sr";
584584
} else {
585-
op_name = "tl::gemm_ss";
585+
op_name = "gemm_ss";
586586
}
587587
ICHECK(C.scope() == "local.fragment");
588588

589+
op_name = "tl::" + GemmInstPrefixMap.at(gemm_inst) + "_" + op_name;
590+
589591
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
590592
ss << warp_m << ", " << warp_n << ", ";
591593
ss << trans_A << ", " << trans_B;
@@ -600,8 +602,6 @@ Stmt GemmNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
600602
if (TargetIsCDNA(T.target)) {
601603
// for cdna gemm, we need to specify kPack
602604
ss << ", " << kPack;
603-
} else if (TargetIsHopper(T.target)) {
604-
ss << ", " << (gemm_inst == GemmInst::kWGMMA ? "true" : "false");
605605
}
606606

607607
// Emit wg_wait if necessary

src/op/gemm.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ enum class GemmWarpPolicyType : uint8_t {
2424

2525
// Target GEMM instruction
2626
enum class GemmInst : uint8_t { kMMA, kWGMMA, kTCGEN5MMA, kMFMA };
27+
const std::unordered_map<GemmInst, std::string> GemmInstPrefixMap = {
28+
{GemmInst::kMMA, "mma"},
29+
{GemmInst::kWGMMA, "wgmma"},
30+
{GemmInst::kTCGEN5MMA, "tcgen5"},
31+
{GemmInst::kMFMA, "mfma"}
32+
};
33+
34+
2735
class GemmWarpPolicyNode : public Object {
2836
public:
2937
mutable int m_warp{0};

src/op/gemm_sp.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,13 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
146146
auto block_size = *as_const_int(T.thread_bounds->extent);
147147
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
148148
(block_size / warp_size % 4 == 0);
149+
GemmInst gemm_inst = maybe_wgmma ? GemmInst::kWGMMA : GemmInst::kMMA;
149150

150151
auto [warp_m, warp_n] = policy->ComputeWarpPartition(
151152
M, N, block_size, T.target, maybe_wgmma, A->dtype.bits());
152153

153154
std::stringstream ss;
154-
std::string op_name = "tl::gemm_sp_ss";
155+
std::string op_name = "tl::" + GemmInstPrefixMap.at(gemm_inst) + "_gemm_sp_ss";
155156
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
156157
(B.scope() == "shared" || B.scope() == "shared.dyn"))
157158
<< "Only support shared.dyn scope for A and B, but received " << A.scope()
@@ -160,13 +161,11 @@ Stmt GemmSPNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
160161
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
161162
"delegated to cute implementation, found "
162163
<< E.scope();
164+
163165
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
164166
ss << warp_m << ", " << warp_n << ", ";
165167
ss << trans_A << ", " << trans_B;
166168
ss << ", " << clear_accum;
167-
if (TargetIsHopper(T.target)) {
168-
ss << ", " << (maybe_wgmma ? "true" : "false");
169-
}
170169
if (wg_wait != 0) {
171170
ss << ", " << wg_wait;
172171
}

src/tl_templates/cuda/gemm_mma.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ namespace tl::tl_mma {
449449
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
450450
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
451451
int offset_b, typename A_type, typename B_type, typename C_type>
452-
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
452+
CUTLASS_DEVICE void mma_gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
453453
using MMA =
454454
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
455455
trans_B, clear_accum, lda, ldb, offset_a,
@@ -460,7 +460,7 @@ CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
460460
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
461461
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
462462
int offset_b, typename A_type, typename B_type, typename C_type>
463-
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
463+
CUTLASS_DEVICE void mma_gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
464464
using MMA =
465465
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
466466
trans_B, clear_accum, lda, ldb, offset_a,
@@ -471,7 +471,7 @@ CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
471471
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
472472
bool trans_B, bool clear_accum, int lda, int ldb, int offset_a,
473473
int offset_b, typename A_type, typename B_type, typename C_type>
474-
CUTLASS_DEVICE void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
474+
CUTLASS_DEVICE void mma_gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
475475
using MMA =
476476
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
477477
trans_B, clear_accum, lda, ldb, offset_a,

src/tl_templates/cuda/gemm_sm70.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ namespace tl {
161161
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
162162
bool trans_B, bool clear_accum, typename A_type, typename B_type,
163163
typename C_type>
164-
CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
164+
CUTLASS_DEVICE void wmma_gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
165165
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
166166
trans_B, clear_accum, A_type, B_type, C_type>;
167167
using FragmentC = typename MMA::FragmentC;
@@ -174,7 +174,7 @@ CUTLASS_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
174174
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
175175
bool trans_B, bool clear_accum, typename A_type, typename B_type,
176176
typename C_type>
177-
CUTLASS_DEVICE void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
177+
CUTLASS_DEVICE void wmma_gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
178178
using MMA = GemmTensorOp<GemmShape<M, N, K>, num_warp_m, num_warp_n, trans_A,
179179
trans_B, clear_accum, A_type, B_type, C_type>;
180180
using FragmentA = typename MMA::FragmentA;

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -232,43 +232,45 @@ namespace tl {
232232

233233
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
234234
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
235-
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
235+
int offset_a = 0, int offset_b = 0,
236236
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
237-
TL_DEVICE void gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
238-
if constexpr (use_wgmma) {
239-
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
240-
"Hopper wgmma doesn't support custom stride for A");
241-
static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
242-
"Hopper wgmma doesn't support custom stride for B");
243-
static_assert(offset_a == 0 && offset_b == 0,
244-
"offset_a and offset_b must be zero for wgmma");
245-
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
246-
trans_A, trans_B, clear_accum,
247-
A_type, B_type, C_type>;
248-
MMA::body<wg_wait>(pA, pB, accum);
249-
} else {
250-
using MMA =
251-
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
237+
TL_DEVICE void wgmma_gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
238+
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
239+
"Hopper wgmma doesn't support custom stride for A");
240+
static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
241+
"Hopper wgmma doesn't support custom stride for B");
242+
static_assert(offset_a == 0 && offset_b == 0,
243+
"offset_a and offset_b must be zero for wgmma");
244+
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
245+
trans_A, trans_B, clear_accum,
246+
A_type, B_type, C_type>;
247+
MMA::body<wg_wait>(pA, pB, accum);
248+
}
249+
250+
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
251+
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
252+
int offset_a = 0, int offset_b = 0,
253+
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
254+
TL_DEVICE void mma_gemm_ss(A_type *pA, B_type *pB, C_type *accum) {
255+
using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
252256
trans_B, clear_accum, lda, ldb, offset_a,
253257
offset_b, A_type, B_type, C_type>;
254-
MMA::body(pA, pB, accum);
255-
}
258+
MMA::body(pA, pB, accum);
256259
}
257260

258261
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
259262
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
260-
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
263+
int offset_a = 0, int offset_b = 0,
261264
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
262265
TL_DEVICE /**
263266
* Perform a read-share (B in shared memory, A in global) tiled GEMM
264267
* and accumulate into `accum`.
265268
*
266-
* Dispatches at compile time to either the Hopper wgmma
267-
* implementation or the fallback MMA implementation depending on
268-
* `use_wgmma`. The selected GemmTensorOp::body_rs performs the
269+
* Dispatches at compile time to the Hopper wgmma
270+
* implementation. The selected GemmTensorOp::body_rs performs the
269271
* region-tiled GEMM loop and updates the accumulator in-place.
270272
*
271-
* When `use_wgmma == true`, this function enforces wgmma constraints
273+
* This function enforces wgmma constraints
272274
* at compile time:
273275
* - A's leading dimension must equal (trans_A ? M : K)
274276
* - B's leading dimension must equal (trans_B ? K : N)
@@ -281,40 +283,57 @@ TL_DEVICE /**
281283
* @param accum Pointer to the accumulator/output C buffer updated
282284
* in-place.
283285
*/
284-
void
285-
gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
286-
if constexpr (use_wgmma) {
286+
void wgmma_gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
287287
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
288288
"Hopper wgmma doesn't support custom stride for A");
289289
static_assert((trans_B && ldb == K) || (!trans_B && ldb == N),
290290
"Hopper wgmma doesn't support custom stride for B");
291291
static_assert(offset_a == 0 && offset_b == 0,
292292
"offset_a and offset_b must be zero for wgmma");
293293
using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
294-
trans_A, trans_B, clear_accum,
295-
A_type, B_type, C_type>;
294+
trans_A, trans_B, clear_accum,
295+
A_type, B_type, C_type>;
296296
MMA::body_rs<wg_wait>(pA, pB, accum);
297-
} else {
297+
}
298+
299+
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
300+
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
301+
int offset_a = 0, int offset_b = 0,
302+
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
303+
TL_DEVICE /**
304+
* Perform a read-share (B in shared memory, A in global) tiled GEMM
305+
* and accumulate into `accum`.
306+
*
307+
* Dispatches at compile time to the fallback mma
308+
* implementation. The selected GemmTensorOp::body_rs performs the
309+
* region-tiled GEMM loop and updates the accumulator in-place.
310+
*
311+
* @param pA Pointer to operand A (global memory). Layout/stride
312+
* expectations depend on template parameters.
313+
* @param pB Pointer to operand B (base for shared-memory staging).
314+
* Layout/stride expectations depend on template parameters.
315+
* @param accum Pointer to the accumulator/output C buffer updated
316+
* in-place.
317+
*/
318+
void mma_gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
298319
using MMA =
299320
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
300-
trans_B, clear_accum, lda, ldb, offset_a,
301-
offset_b, A_type, B_type, C_type>;
321+
trans_B, clear_accum, lda, ldb, offset_a,
322+
offset_b, A_type, B_type, C_type>;
302323
MMA::body_rs(pA, pB, accum);
303-
}
304324
}
305325

306326
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
307327
bool trans_B, bool clear_accum = false, int lda = 0, int ldb = 0,
308-
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
328+
int offset_a = 0, int offset_b = 0,
309329
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
310330
TL_DEVICE /**
311331
* Perform a non-wgmma tiled GEMM where A regions are staged into
312332
* shared memory and B is read directly from global memory,
313333
* accumulating into `accum`.
314334
*
315335
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr
316-
* implementation. Must be instantiated with `use_wgmma = false`
317-
* (enforced via static_assert).
336+
* implementation.
318337
*
319338
* @param pA Pointer to the A operand in global memory (source that
320339
* will be staged to shared memory).
@@ -323,14 +342,12 @@ TL_DEVICE /**
323342
* @param accum Pointer to the output accumulator matrix in global
324343
* memory.
325344
*/
326-
void
327-
gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
328-
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
329-
using MMA =
330-
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
331-
trans_B, clear_accum, lda, ldb, offset_a,
332-
offset_b, A_type, B_type, C_type>;
333-
MMA::body_sr(pA, pB, accum);
345+
void mma_gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
346+
using MMA =
347+
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
348+
trans_B, clear_accum, lda, ldb, offset_a,
349+
offset_b, A_type, B_type, C_type>;
350+
MMA::body_sr(pA, pB, accum);
334351
}
335352

336353
template <int num_mma>

src/tl_templates/cuda/gemm_sp_sm80.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class GemmTensorOp {
255255
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
256256
bool trans_B, bool clear_accum = false, typename A_type,
257257
typename B_type, typename C_type, typename E_type>
258-
TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
258+
TL_DEVICE void mma_gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
259259
using MMA =
260260
GemmTensorOp<cutlass::gemm::GemmShape<M, N, K>, num_warp_m, num_warp_n,
261261
trans_A, trans_B, clear_accum, A_type, B_type, C_type>;

src/tl_templates/cuda/gemm_sp_sm90.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -215,18 +215,13 @@ class GemmTensorOp {
215215

216216
namespace tl {
217217
template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
218-
bool trans_B, bool clear_accum = false, bool use_wgmma = true,
218+
bool trans_B, bool clear_accum = false,
219219
int wg_wait = 0, typename A_type, typename B_type, typename C_type,
220220
typename GMMA = cute::tl_wgmma_sp::GemmTensorOp<
221221
M, N, K, num_warp_m, num_warp_n, trans_A, trans_B, clear_accum,
222222
A_type, B_type, C_type>,
223223
typename E_type = typename GMMA::ElementEMma::raw_type>
224-
TL_DEVICE void gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
225-
static_assert(use_wgmma, "only wgmma is supported for now");
226-
if constexpr (use_wgmma) {
227-
GMMA::body<wg_wait>(pA, pB, accum, pE);
228-
} else {
229-
CUTE_GCC_UNREACHABLE;
230-
}
224+
TL_DEVICE void wgmma_gemm_sp_ss(A_type *pA, B_type *pB, C_type *accum, E_type *pE) {
225+
GMMA::body<wg_wait>(pA, pB, accum, pE);
231226
}
232227
} // namespace tl

src/transform/inject_fence_proxy.cc

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -79,24 +79,18 @@ bool IsAsyncIntrinsic(const CallNode *call) {
7979
}
8080

8181
// TileLang async intrinsics
82-
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col()) ||
83-
call->op.same_as(tma_store()) || call->op.same_as(tma_store_arrive()) ||
84-
call->op.same_as(tma_store_wait()) ||
85-
call->op.same_as(ptx_cp_async_barrier_noinc()) ||
82+
// NOTE(wt): We only need to inject fences before tma_store and WGMMA,
83+
// since tma_load and WGMMA contain implicit proxy fence after them
84+
if (call->op.same_as(tma_store()) ||
8685
call->op.same_as(ptx_wgmma_ss()) || call->op.same_as(ptx_wgmma_rs())) {
8786
return true;
8887
}
8988

90-
// PTX async copy intrinsics
91-
if (call->op.same_as(builtin::ptx_cp_async()) ||
92-
call->op.same_as(builtin::ptx_cp_async_barrier()) ||
93-
call->op.same_as(builtin::ptx_cp_async_bulk())) {
94-
return true;
95-
}
96-
97-
// wgmma async intrinsics
9889
if (call->op.same_as(tl_gemm()) || call->op.same_as(tl_gemm_sp())) {
99-
return true;
90+
// determine whether async wgmma is utilized
91+
std::ostringstream oss;
92+
oss << call->args[0].as<StringImmNode>()->value;
93+
return oss.str().find("wgmma") != std::string::npos;
10094
}
10195

10296
return false;
@@ -174,6 +168,7 @@ class ProxyFenceInjector : public StmtMutator {
174168

175169
private:
176170
Stmt VisitStmt_(const SeqStmtNode *op) final {
171+
// FIXME: 1st stmt cannot know the previous proxy kind
177172
Array<Stmt> seq;
178173
seq.reserve(op->seq.size());
179174

@@ -213,7 +208,8 @@ class ProxyFenceInjector : public StmtMutator {
213208
} else if (IsKnownGeneric(call)) {
214209
kind = ProxyKind::kGeneric;
215210
} else {
216-
// We can now treat extern as Generic, since gemm and gemm_sp are never
211+
// Remaining intrinsic and extern are marked as Generic.
212+
// We can now all extern as Generic, since gemm and gemm_sp are never
217213
// represented as call_extern nodes. They are call_intrin nodes and will
218214
// be handled by IsAsyncIntrinsic above.
219215
kind = ProxyKind::kGeneric;

0 commit comments

Comments
 (0)