Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ bool supportMMA(triton::DotOp op, int version) {
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
return false;
if (!(retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 16 == 0))
Comment thread
ThomasRaoux marked this conversation as resolved.
retShapePerCTA[rank - 1] % 8 == 0))
return false;
if (aElemTy.isF64() || bElemTy.isF64() ||
retType.getElementType().isF64()) {
Expand All @@ -1216,7 +1216,7 @@ bool supportMMA(triton::DotOp op, int version) {
if (rank == 3)
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 16 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
Expand Down
15 changes: 11 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3284,7 +3284,10 @@ def get_test_small_dots_cases():
if not is_cuda():
return []
return [(2, 4, 32, 1, False, False, 'None', 'ieee', 'float16', 'float32', 1, None),
(1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None)]
(1, 2, 32, 1, False, False, 'None', 'ieee', 'float8e5', 'float32', 1, None),
# N=8: TF32 K=8 (wgmma.m64n8k8, sm90+) and FP16 K=16 (wgmma.m64n8k16)
(64, 8, 8, 4, False, False, 'None', 'tf32', 'float32', 'float32', 1, None),
(64, 8, 16, 4, False, False, 'None', 'ieee', 'float16', 'float32', 1, None)]


@pytest.mark.interpreter
Expand Down Expand Up @@ -3315,7 +3318,8 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
pytest.skip(f"input_precision {input_precision} is not supported in the interpreter")
else:
if not is_hip() and K < 16:
if in_dtype != 'float64':
tf32_n8 = (in_dtype == 'float32' and N == 8 and K == 8 and input_precision == 'tf32')
if in_dtype != 'float64' and not tf32_n8:
pytest.skip("small dots are supported only on HIP at the moment")
if is_cuda():
capability = torch.cuda.get_device_capability()
Expand Down Expand Up @@ -3517,12 +3521,15 @@ def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, strid
# make sure ld/st are vectorized
ptx = pgm.asm['ptx']

if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
# XXX: skip small sizes because they are not vectorized
# XXX: skip small sizes because they are not vectorized; with runtime
# strides, v4 needs the contiguous dim >= 16 (K for loads, N for stores).
enough_work = (M * N // (num_warps * 32) >= 4) and (K > 16 or N > 16 or M > 16)
if enough_work and K >= 16:
if 'float64' in in_dtype:
assert 'ld.global.v2.b64' in ptx
else:
assert 'ld.global.v4' in ptx
if enough_work and N >= 16:
if 'float8' in in_dtype:
assert 'st.global.v2' in ptx
elif 'float64' in in_dtype:
Expand Down
34 changes: 34 additions & 0 deletions test/TritonGPU/accelerate-matmul.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,37 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
tt.return %d : tensor<128x256xf64, #blocked>
}
}

// -----

// Verify TF32 dot with N=8, K=8 (native WGMMA tile) selects MMAv3 on sm90.
// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8, 8]}>
#blocked_tf32_n8 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: tf32_n8k8_sm90
tt.func public @tf32_n8k8_sm90(
%a: tensor<64x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked_tf32_n8}>>,
%b: tensor<8x8xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked_tf32_n8}>>) -> tensor<64x8xf32, #blocked_tf32_n8> {
%cst = arith.constant dense<0.000000e+00> : tensor<64x8xf32, #blocked_tf32_n8>
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x8xf32, #[[$MMA]]>
%d = tt.dot %a, %b, %cst, inputPrecision = tf32 : tensor<64x8xf32, #ttg.dot_op<{opIdx = 0, parent = #blocked_tf32_n8}>> * tensor<8x8xf32, #ttg.dot_op<{opIdx = 1, parent = #blocked_tf32_n8}>> -> tensor<64x8xf32, #blocked_tf32_n8>
tt.return %d : tensor<64x8xf32, #blocked_tf32_n8>
}
}

// -----

// Verify FP16 dot with N=8, K=16 (native WGMMA tile) selects MMAv3 on sm90.
// CHECK: #[[$MMA:.+]] = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 8, 16]}>
#blocked_fp16_n8 = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: fp16_n8k16_sm90
tt.func public @fp16_n8k16_sm90(
%a: tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked_fp16_n8}>>,
%b: tensor<16x8xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked_fp16_n8}>>) -> tensor<64x8xf32, #blocked_fp16_n8> {
%cst = arith.constant dense<0.000000e+00> : tensor<64x8xf32, #blocked_fp16_n8>
// CHECK: ttng.warp_group_dot {{.*}} -> tensor<64x8xf32, #[[$MMA]]>
%d = tt.dot %a, %b, %cst : tensor<64x16xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked_fp16_n8}>> * tensor<16x8xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked_fp16_n8}>> -> tensor<64x8xf32, #blocked_fp16_n8>
tt.return %d : tensor<64x8xf32, #blocked_fp16_n8>
}
}
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m,
return (1, 1, 32)
elif lhs_bitwidth == 64:
return (1, 1, 4)
elif lhs_bitwidth == 32:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more general than sm90+. Likely you need more code updates to support a smaller dot shape across archs. I'll defer it to @ThomasRaoux to determine if this is right direction

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it would be good to update Blackwell as well the same way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have access to Blackwell GPUs to test it directly, but I've updated MMAv5

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about sm80 with TF32, are you able to use (1,1,8) to pass all tests?

Copy link
Copy Markdown
Contributor Author

@mwichro mwichro May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A100 is sm80, so MMAv3 is not available, so it should have no effect anyway.

I have A100, this one I can test:

python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header 
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)

Looks fine to me.

Wait, A100 has some other instructions in MMAv2 that also could be used, let me check that.
Checked it's fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, btw, on review, Claude is suggesting:

So min_dot_size is reinventing the same formula with a hardcoded if/elif chain, ignoring the target parameter it receives. The fix is to use the formula directly:

def min_dot_size(target: GPUTarget):

    def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]:
        lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
        rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
        assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
        # For small M/N we can still use tensor cores with padding.
        # The minimum K is determined by the native MMA tile: 256 / bitwidth.
        return (1, 1, 256 // lhs_bitwidth)

    return check_dot_compatibility

This is identical in behaviour to the current code for all supported types, eliminates the special-casing we added for 32-bit, and directly mirrors the hardware formula used in mmaVersionToInstrShape. It also correctly handles any future type (e.g., fp8 at 8-bit is already covered).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like it better the way it is right now

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)

We don't have very small dot tests previously and your test skipped sm80

return (1, 1, 8)
else:
return (1, 1, 16)

Expand Down
Loading