[BACKEND] Extend support for small MMAv2 FP64: single 8x8x4 instructions.#10060
Conversation
Jokeren
left a comment
There was a problem hiding this comment.
- You can update the title to MMAv2 fp64 instead of generic MMA, which might be confusing.
- I guess the previous fp64 instruction size would yield a higher throughput? But since we don't care about fp64 performance, I think it's acceptable to make it more general without considering multiple instruction candidates. @lezcano what do you think
|
For these tiny matrices can't we ask the user to just pad the matrix with zeros? What are we getting from using a slightly lower tile? |
I was trying to find the code from the benchmark that generated the data presented in #7310 , but I was not able to do it, and I had to write it on my own. Below are the results on a A100 GPU
I really care about FP64. I tried to write it so that performance for larger matrices will be unaffected, so I would be really grateful for feedback on where I missed that. matmul-performance-fp64:
M N K cuBLAS Triton
--------------------------------------------------------
256 256 256 2.730667 2.730667
384 384 384 5.820632 5.529600
512 512 512 8.456258 8.192000
640 640 640 10.666666 12.190476
768 768 768 12.639085 11.796479
896 896 896 12.323930 13.910179
1024 1024 1024 14.873419 13.797053
1152 1152 1152 15.471420 16.140454
1280 1280 1280 14.840579 15.170371
1408 1408 1408 16.129514 15.444124
1536 1536 1536 15.941189 15.254069
1664 1664 1664 15.983858 15.252393
1792 1792 1792 16.241943 15.438769
1920 1920 1920 16.330774 15.691260
2048 2048 2048 16.304388 16.100975
2176 2176 2176 16.294452 16.549053
2304 2304 2304 17.260024 17.026281
2432 2432 2432 16.526155 16.296093
2560 2560 2560 16.692817 16.864642
2688 2688 2688 16.673870 16.428349
2816 2816 2816 17.293500 17.043458
2944 2944 2944 17.032137 16.779809
3072 3072 3072 16.877230 16.619637
3200 3200 3200 16.797901 16.541743
3328 3328 3328 16.785101 16.523135
3456 3456 3456 17.431690 17.157175
3584 3584 3584 16.888691 16.617149
3712 3712 3712 16.986456 16.716423
3840 3840 3840 17.231536 16.840567
3968 3968 3968 17.266724 16.987879
4096 4096 4096 17.433137 17.152425Here is the benchmark for future use: """
FP64 Matrix Multiplication Benchmark
=====================================
Benchmarks fp64 matrix-matrix multiplication performance of a Triton kernel
against cuBLAS (via torch.matmul), reproducing and verifying results from
https://github.com/triton-lang/triton/pull/7310.
"""
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def get_fp64_autotune_config():
# Smaller blocks than fp16 because fp64 elements are 8 bytes,
# so shared memory limits force smaller tiles.
return [
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=3, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
num_stages=3, num_warps=8),
]
@triton.autotune(
configs=get_fp64_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_fp64(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the fp64 matmul C = A x B."""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
tl.assume(stride_am > 0)
tl.assume(stride_ak > 0)
tl.assume(stride_bn > 0)
tl.assume(stride_bk > 0)
tl.assume(stride_cm > 0)
tl.assume(stride_cn > 0)
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# Accumulate in fp64.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float64)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def matmul_fp64(a, b):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert a.dtype == torch.float64 and b.dtype == torch.float64, "Both inputs must be fp64"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float64)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_fp64[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
# ── Unit test ──────────────────────────────────────────────────────────────
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float64)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float64)
triton_output = matmul_fp64(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp64_inputs={triton_output}")
print(f"torch_output_with_fp64_inputs={torch_output}")
if torch.allclose(triton_output, torch_output, atol=1e-8, rtol=1e-5):
print("✅ Triton and Torch match")
else:
max_diff = (triton_output - torch_output).abs().max().item()
print(f"❌ Triton and Torch differ (max diff: {max_diff})")
# ── Benchmark ──────────────────────────────────────────────────────────────
ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'
# Benchmark loop
sizes = [128 * i for i in range(2, 33)]
results = []
print(f"matmul-performance-fp64:")
print(f"{'M':>10} {'N':>10} {'K':>10} {ref_lib:>12} {'Triton':>12}")
print("-" * 56)
for size in sizes:
M = N = K = size
# Benchmark reference library
a = torch.randn((M, K), device=DEVICE, dtype=torch.float64)
b = torch.randn((K, N), device=DEVICE, dtype=torch.float64)
quantiles = [0.5, 0.2, 0.8]
cublas_ms, _, _ = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
cublas_tflops = 2 * M * N * K * 1e-12 / (cublas_ms * 1e-3)
# Benchmark Triton
triton_ms, _, _ = triton.testing.do_bench(lambda: matmul_fp64(a, b), quantiles=quantiles)
triton_tflops = 2 * M * N * K * 1e-12 / (triton_ms * 1e-3)
results.append({
'M': float(M),
'N': float(N),
'K': float(K),
ref_lib: cublas_tflops,
'Triton': triton_tflops,
})
print(f"{M:>10.0f} {N:>10.0f} {K:>10.0f} {cublas_tflops:>12.6f} {triton_tflops:>12.6f}") |
|
For the context, I am working on Finite Elements. I want to evaluate
@lezcano |
Do you have some numbers? |
|
As I can see, even for your matrices this shape is too big, so you are going to have to pad them regardless? I'd assume that for these ops you are heavily BW bound or depending on the kernel, kernel launch bound, so I'd expect that padding it a bit more should give you exactly the same perf. |
Not necessarily, if I play with the layout, they might be just the right size. I will need to do 4x
Neither BW nor kernel launch is the limit here:
|
Can you compare the original instruction size and 8x8x4 on these shapes? If you do care about performance, I think the code shouldn't be optimized just for particular cases |
|
Do you mean to benchmark small MMA? I have no idea how reliable those metrics are, but here are the numbers: Or do you want me to compare performance before/after changes? I can provide that too. Before: After: |
Yeah, seems benign to me. @lezcano are you worried about introducing code complexity? |
|
@mwichro can you fix the tests first? |
|
fwiw, after discussing it, we think we'd accept the patch as the diff is small and it doesn't regress |
|
Thanks! How should I proceed? As far as I can see, the test is failing because the output code is slightly different Or I must fix the code? |
|
yep, ask codex (or claude, pick your poison :P) to edit the lit tests a bit so that they pass with the new codegen |
|
All tests have passed. I squished the commit into a single one. Status: I resolved merge conflicts, fixed the tests again, and made it a single commit. |
|
All test passed, commit cleaned up. |
…ons. (triton-lang#10060) The fp64 MMA path now operates at native `m8n8k4` granularity, supporting any shape that is a multiple of 8×8×4, including the minimal 8×8×4 case. This is an extension of triton-lang#7310 (The implementation was based on that PR) Tests passed on A100. ## Files Changed ### `lib/Dialect/TritonGPU/IR/Dialect.cpp` - `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 * 256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to `8` for fp64. ### `lib/Dialect/TritonGPU/Transforms/Utility.cpp` - `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was always 16). ### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp` - `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for tile shape computation. K tile multiplier is 4 (not 8) when `instrM == 8`. ### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp` - `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead of `struct{f64, f64, f64, f64}` (4 elements). - `callMmaAmpereFp64`: Extended: Now it is able to also emit a single `m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1), cArgs(2)). - `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`). - `numMmaRets`: 2 for fp64 (was 4). - `numCPackedElem`: 1 for fp64 (was incorrectly computed). - fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of hardcoded `4`. ### `third_party/nvidia/backend/compiler.py` - `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to allow K=4 for fp64. ### `python/test/unit/language/test_core.py` - Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`, `(8,8,16)` with `num_warps=1`. ### `test/Conversion/tritongpu_to_llvm.mlir` - Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the new fp64 encoding. ---- # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [X] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
… K=8 (#10234) Enable `tl.dot` with TF32 precision on tiles with **N=8** and **K=8** (e.g. `wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`) via the standard `tt.dot` → `AccelerateMatmul` path on sm90+. Related to #10060 (comment) I am trying Triton for Finite Elements, and it does wonders! The matrices used in those computations are usually quite small. With some management, it is possible to pack several operations into MMA cores, but the tile sizes implemented were too big. I ran the lit test, and they are passing, so I guess the resulting IR is the same. Addes test for the new functionality. --- ## Changes ### `lib/Analysis/Utility.cpp` In `supportMMA` (version 3), relaxed the N-dimension divisibility check from `% 16` to `% 8`: ```cpp - retShapePerCTA[rank - 1] % 16 == 0 + retShapePerCTA[rank - 1] % 8 == 0 ``` The WGMMAv3 op verifier already required only `N % 8 == 0`, and `mmaVersionToInstrShape` already listed `n=8` as valid. This was the sole gatekeeper preventing N=8 tiles from using WGMMA, causing a silent fallback to MMAv2. ### `third_party/nvidia/backend/compiler.py` In `min_dot_size`, added an explicit case for 32-bit types (TF32/FP32): ```python + elif lhs_bitwidth == 32: + return (1, 1, 8) ``` The TF32 hardware instruction has K=8. The previous fallthrough to the `else` branch returned `K >= 16`, blocking compilation of K=8 TF32 kernels. ### `python/test/unit/language/test_core.py` Added `test_dot_wgmma_tf32_n8k8` parametrized over `M ∈ {64, 128}`, verifying both numerical correctness and that the emitted PTX contains `wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`. # New contributor declaration - [X] I am not making a trivial change, such as fixing a typo in a comment. - [X] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [X] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
The fp64 MMA path now operates at native
m8n8k4granularity, supporting any shape that is a multiple of 8×8×4, including the minimal 8×8×4 case.This is an extension of #7310
(The implementation was based on that PR)
Tests passed on A100.
Files Changed
lib/Dialect/TritonGPU/IR/Dialect.cppgetRepForOperand: ChangedtileBitWidthKfrom2 * 256to1 * 256for fp64 (K-tile = 4). ChangedtileSize[M]from hardcoded16to8for fp64.lib/Dialect/TritonGPU/Transforms/Utility.cppmmaVersionToInstrShape: ReturnsinstrShape[M] = 8for fp64 (was always 16).lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cppnvidiaDotToLinearLayout: UsesinstrShapefrom the MMA encoding for tile shape computation. K tile multiplier is 4 (not 8) wheninstrM == 8.third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cppgetMmaRetType: fp64 returnsstruct{f64, f64}(2 elements) instead ofstruct{f64, f64, f64, f64}(4 elements).callMmaAmpereFp64: Extended: Now it is able to also emit a singlem8n8k4instruction per call (single retArgs(2), aArgs(1), bArgs(1), cArgs(2)).numRegisters:{1, 1, 1}for fp64 (was effectively{2, 1, 2}).numMmaRets: 2 for fp64 (was 4).numCPackedElem: 1 for fp64 (was incorrectly computed).numMmaRets * numCPackedEleminstead of hardcoded4.third_party/nvidia/backend/compiler.pymin_dot_size: Addedelif lhs_bitwidth == 64: return (1, 1, 4)to allow K=4 for fp64.python/test/unit/language/test_core.py(8,8,4),(8,8,8),(16,8,4),(8,8,16)withnum_warps=1.test/Conversion/tritongpu_to_llvm.mlirf64_mma_cvttest to useinstrShape = [8, 8]matching the new fp64 encoding.New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)