@@ -232,43 +232,45 @@ namespace tl {
232232
233233template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
234234 bool trans_B, bool clear_accum = false , int lda = 0 , int ldb = 0 ,
235- int offset_a = 0 , int offset_b = 0 , bool use_wgmma = true ,
235+ int offset_a = 0 , int offset_b = 0 ,
236236 int wg_wait = 0 , typename A_type, typename B_type, typename C_type>
237- TL_DEVICE void gemm_ss (A_type *pA, B_type *pB, C_type *accum) {
238- if constexpr (use_wgmma) {
239- static_assert ((trans_A && lda == M) || (!trans_A && lda == K),
240- " Hopper wgmma doesn't support custom stride for A" );
241- static_assert ((trans_B && ldb == K) || (!trans_B && ldb == N),
242- " Hopper wgmma doesn't support custom stride for B" );
243- static_assert (offset_a == 0 && offset_b == 0 ,
244- " offset_a and offset_b must be zero for wgmma" );
245- using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
246- trans_A, trans_B, clear_accum,
247- A_type, B_type, C_type>;
248- MMA::body<wg_wait>(pA, pB, accum);
249- } else {
250- using MMA =
251- cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
237+ TL_DEVICE void wgmma_gemm_ss (A_type *pA, B_type *pB, C_type *accum) {
238+ static_assert ((trans_A && lda == M) || (!trans_A && lda == K),
239+ " Hopper wgmma doesn't support custom stride for A" );
240+ static_assert ((trans_B && ldb == K) || (!trans_B && ldb == N),
241+ " Hopper wgmma doesn't support custom stride for B" );
242+ static_assert (offset_a == 0 && offset_b == 0 ,
243+ " offset_a and offset_b must be zero for wgmma" );
244+ using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
245+ trans_A, trans_B, clear_accum,
246+ A_type, B_type, C_type>;
247+ MMA::body<wg_wait>(pA, pB, accum);
248+ }
249+
250+ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
251+ bool trans_B, bool clear_accum = false , int lda = 0 , int ldb = 0 ,
252+ int offset_a = 0 , int offset_b = 0 ,
253+ int wg_wait = 0 , typename A_type, typename B_type, typename C_type>
254+ TL_DEVICE void mma_gemm_ss (A_type *pA, B_type *pB, C_type *accum) {
255+ using MMA = cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
252256 trans_B, clear_accum, lda, ldb, offset_a,
253257 offset_b, A_type, B_type, C_type>;
254- MMA::body (pA, pB, accum);
255- }
258+ MMA::body (pA, pB, accum);
256259}
257260
258261template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
259262 bool trans_B, bool clear_accum = false , int lda = 0 , int ldb = 0 ,
260- int offset_a = 0 , int offset_b = 0 , bool use_wgmma = true ,
263+ int offset_a = 0 , int offset_b = 0 ,
261264 int wg_wait = 0 , typename A_type, typename B_type, typename C_type>
262265TL_DEVICE /* *
263266 * Perform a read-share (B in shared memory, A in global) tiled GEMM
264267 * and accumulate into `accum`.
265268 *
266- * Dispatches at compile time to either the Hopper wgmma
267- * implementation or the fallback MMA implementation depending on
268- * `use_wgmma`. The selected GemmTensorOp::body_rs performs the
269+ * Dispatches at compile time to the Hopper wgmma
270+ * implementation. The selected GemmTensorOp::body_rs performs the
269271 * region-tiled GEMM loop and updates the accumulator in-place.
270272 *
271- * When `use_wgmma == true`, this function enforces wgmma constraints
273+ * This function enforces wgmma constraints
272274 * at compile time:
273275 * - A's leading dimension must equal (trans_A ? M : K)
274276 * - B's leading dimension must equal (trans_B ? K : N)
@@ -281,40 +283,57 @@ TL_DEVICE /**
281283 * @param accum Pointer to the accumulator/output C buffer updated
282284 * in-place.
283285 */
284- void
285- gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
286- if constexpr (use_wgmma) {
286+ void wgmma_gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
287287 static_assert ((trans_A && lda == M) || (!trans_A && lda == K),
288288 " Hopper wgmma doesn't support custom stride for A" );
289289 static_assert ((trans_B && ldb == K) || (!trans_B && ldb == N),
290290 " Hopper wgmma doesn't support custom stride for B" );
291291 static_assert (offset_a == 0 && offset_b == 0 ,
292292 " offset_a and offset_b must be zero for wgmma" );
293293 using MMA = cute::tl_wgmma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n,
294- trans_A, trans_B, clear_accum,
295- A_type, B_type, C_type>;
294+ trans_A, trans_B, clear_accum,
295+ A_type, B_type, C_type>;
296296 MMA::body_rs<wg_wait>(pA, pB, accum);
297- } else {
297+ }
298+
299+ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
300+ bool trans_B, bool clear_accum = false , int lda = 0 , int ldb = 0 ,
301+ int offset_a = 0 , int offset_b = 0 ,
302+ int wg_wait = 0 , typename A_type, typename B_type, typename C_type>
303+ TL_DEVICE /* *
304+ * Perform a read-share (B in shared memory, A in global) tiled GEMM
305+ * and accumulate into `accum`.
306+ *
307+ * Dispatches at compile time to the fallback mma
308+ * implementation. The selected GemmTensorOp::body_rs performs the
309+ * region-tiled GEMM loop and updates the accumulator in-place.
310+ *
311+ * @param pA Pointer to operand A (global memory). Layout/stride
312+ * expectations depend on template parameters.
313+ * @param pB Pointer to operand B (base for shared-memory staging).
314+ * Layout/stride expectations depend on template parameters.
315+ * @param accum Pointer to the accumulator/output C buffer updated
316+ * in-place.
317+ */
318+ void mma_gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
298319 using MMA =
299320 cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
300- trans_B, clear_accum, lda, ldb, offset_a,
301- offset_b, A_type, B_type, C_type>;
321+ trans_B, clear_accum, lda, ldb, offset_a,
322+ offset_b, A_type, B_type, C_type>;
302323 MMA::body_rs (pA, pB, accum);
303- }
304324}
305325
306326template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
307327 bool trans_B, bool clear_accum = false , int lda = 0 , int ldb = 0 ,
308- int offset_a = 0 , int offset_b = 0 , bool use_wgmma = true ,
328+ int offset_a = 0 , int offset_b = 0 ,
309329 int wg_wait = 0 , typename A_type, typename B_type, typename C_type>
310330TL_DEVICE /* *
311331 * Perform a non-wgmma tiled GEMM where A regions are staged into
312332 * shared memory and B is read directly from global memory,
313333 * accumulating into `accum`.
314334 *
315335 * This overload dispatches to the tl_mma::GemmTensorOp::body_sr
316- * implementation. Must be instantiated with `use_wgmma = false`
317- * (enforced via static_assert).
336+ * implementation.
318337 *
319338 * @param pA Pointer to the A operand in global memory (source that
320339 * will be staged to shared memory).
@@ -323,14 +342,12 @@ TL_DEVICE /**
323342 * @param accum Pointer to the output accumulator matrix in global
324343 * memory.
325344 */
326- void
327- gemm_sr (A_type *pA, B_type *pB, C_type *accum) {
328- static_assert (!use_wgmma, " wgmma doesn't support gemm_sr" );
329- using MMA =
330- cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
331- trans_B, clear_accum, lda, ldb, offset_a,
332- offset_b, A_type, B_type, C_type>;
333- MMA::body_sr (pA, pB, accum);
345+ void mma_gemm_sr (A_type *pA, B_type *pB, C_type *accum) {
346+ using MMA =
347+ cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
348+ trans_B, clear_accum, lda, ldb, offset_a,
349+ offset_b, A_type, B_type, C_type>;
350+ MMA::body_sr (pA, pB, accum);
334351}
335352
336353template <int num_mma>
0 commit comments