From bd39d3cdb06549f104552c27416d31a75ab5adb2 Mon Sep 17 00:00:00 2001 From: Michal Wichrowski Date: Thu, 7 May 2026 14:02:51 +0200 Subject: [PATCH] =?UTF-8?q?Enable=20`tl.dot`=20with=20TF32=20precision=20o?= =?UTF-8?q?n=20tiles=20with=20N=3D8=20and=20K=3D8=20(e.g.=20`wgmma.mma=5Fa?= =?UTF-8?q?sync.sync.aligned.m64n8k8.f32.tf32.tf32`)=20via=20the=20standar?= =?UTF-8?q?d=20`tt.dot`=20=E2=86=92=20`AccelerateMatmul`=20path=20on=20sm9?= =?UTF-8?q?0+.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/Analysis/Utility.cpp | 4 +-- python/test/unit/language/test_core.py | 15 +++++++++--- test/TritonGPU/accelerate-matmul.mlir | 34 ++++++++++++++++++++++++++ third_party/nvidia/backend/compiler.py | 2 ++ 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 99cc3682c6be..3929e50fe60f 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -1191,7 +1191,7 @@ bool supportMMA(triton::DotOp op, int version) { if (k < 256 / aElemTy.getIntOrFloatBitWidth()) return false; if (!(retShapePerCTA[rank - 2] % 64 == 0 && - retShapePerCTA[rank - 1] % 16 == 0)) + retShapePerCTA[rank - 1] % 8 == 0)) return false; if (aElemTy.isF64() || bElemTy.isF64() || retType.getElementType().isF64()) { @@ -1216,7 +1216,7 @@ bool supportMMA(triton::DotOp op, int version) { if (rank == 3) return false; if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 && - retShapePerCTA[rank - 1] % 16 == 0 && + retShapePerCTA[rank - 1] % 8 == 0 && (llvm::isa(aElemTy) || aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32()))) { diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 455df6a50772..07b75dd5371a 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3284,7 +3284,10 @@ def get_test_small_dots_cases(): if not is_cuda(): return [] return [(2, 4, 32, 1, False, False, 'None', 'ieee', 'float16', 'float32', 1, None), - (1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None)] + (1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None), + # N=8: TF32 K=8 (wgmma.m64n8k8, sm90+) and FP16 K=16 (wgmma.m64n8k16) + (64, 8, 8, 4, False, False, 'None', 'tf32', 'float32', 'float32', 1, None), + (64, 8, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None)] @pytest.mark.interpreter @@ -3315,7 +3318,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty pytest.skip(f"input_precision {input_precision} is not supported in the interpreter") else: if not is_hip() and K < 16: - if in_dtype != 'float64': + tf32_n8 = (in_dtype == 'float32' and N == 8 and K == 8 and input_precision == 'tf32') + if in_dtype != 'float64' and not tf32_n8: pytest.skip("small dots are supported only on HIP at the moment") if is_cuda(): capability = torch.cuda.get_device_capability() @@ -3517,12 +3521,15 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid # make sure ld/st are vectorized ptx = pgm.asm['ptx'] - if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4): - # XXX: skip small sizes because they are not vectorized + # XXX: skip small sizes because they are not vectorized; with runtime + # strides, v4 needs the contiguous dim >= 16 (K for loads, N for stores). + enough_work = (M * N // (num_warps * 32) >= 4) and (K > 16 or N > 16 or M > 16) + if enough_work and K >= 16: if 'float64' in in_dtype: assert 'ld.global.v2.b64' in ptx else: assert 'ld.global.v4' in ptx + if enough_work and N >= 16: if 'float8' in in_dtype: assert 'st.global.v2' in ptx elif 'float64' in in_dtype: diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index e2a78a0e5925..4a2756cb5110 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -866,3 +866,37 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ tt.return %d : tensor<128x256xf64, #blocked> } } + +// ----- + +// Verify TF32 dot with N=8, K=8 (native WGMMA tile) selects MMAv3 on sm90. +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8, 8]}> +#blocked_tf32_n8 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: tf32_n8k8_sm90 + tt.func public @tf32_n8k8_sm90( + %a: tensor<64x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked_tf32_n8}>>, + %b: tensor<8x8xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked_tf32_n8}>>) -> tensor<64x8xf32, #blocked_tf32_n8> { + %cst = arith.constant dense<0.000000e+00> : tensor<64x8xf32, #blocked_tf32_n8> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x8xf32, #[[$MMA]]> + %d = tt.dot %a, %b, %cst, inputPrecision = tf32 : tensor<64x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked_tf32_n8}>> * tensor<8x8xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked_tf32_n8}>> -> tensor<64x8xf32, #blocked_tf32_n8> + tt.return %d : tensor<64x8xf32, #blocked_tf32_n8> + } +} + +// ----- + +// Verify FP16 dot with N=8, K=16 (native WGMMA tile) selects MMAv3 on sm90. +// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8, 16]}> +#blocked_fp16_n8 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { + // CHECK-LABEL: fp16_n8k16_sm90 + tt.func public @fp16_n8k16_sm90( + %a: tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked_fp16_n8}>>, + %b: tensor<16x8xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked_fp16_n8}>>) -> tensor<64x8xf32, #blocked_fp16_n8> { + %cst = arith.constant dense<0.000000e+00> : tensor<64x8xf32, #blocked_fp16_n8> + // CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x8xf32, #[[$MMA]]> + %d = tt.dot %a, %b, %cst : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked_fp16_n8}>> * tensor<16x8xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked_fp16_n8}>> -> tensor<64x8xf32, #blocked_fp16_n8> + tt.return %d : tensor<64x8xf32, #blocked_fp16_n8> + } +} diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 132ed663b20c..aeb87b4ad980 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -27,6 +27,8 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, return (1, 1, 32) elif lhs_bitwidth == 64: return (1, 1, 4) + elif lhs_bitwidth == 32: + return (1, 1, 8) else: return (1, 1, 16)