Skip to content

[BACKEND] Extend support for small MMAv2 FP64: single 8x8x4 instructions.#10060

Merged
Jokeren merged 1 commit into
triton-lang:mainfrom
mwichro:smallMMA_PR
Apr 20, 2026
Merged

[BACKEND] Extend support for small MMAv2 FP64: single 8x8x4 instructions.#10060
Jokeren merged 1 commit into
triton-lang:mainfrom
mwichro:smallMMA_PR

Conversation

@mwichro
Copy link
Copy Markdown
Contributor

@mwichro mwichro commented Apr 16, 2026

The fp64 MMA path now operates at native m8n8k4 granularity, supporting any shape that is a multiple of 8×8×4, including the minimal 8×8×4 case.

This is an extension of #7310
(The implementation was based on that PR)

Tests passed on A100.

Files Changed

lib/Dialect/TritonGPU/IR/Dialect.cpp

  • getRepForOperand: Changed tileBitWidthK from 2 * 256 to 1 * 256 for fp64 (K-tile = 4). Changed tileSize[M] from hardcoded 16 to 8 for fp64.

lib/Dialect/TritonGPU/Transforms/Utility.cpp

  • mmaVersionToInstrShape: Returns instrShape[M] = 8 for fp64 (was always 16).

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

  • nvidiaDotToLinearLayout: Uses instrShape from the MMA encoding for tile shape computation. K tile multiplier is 4 (not 8) when instrM == 8.

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp

  • getMmaRetType: fp64 returns struct{f64, f64} (2 elements) instead of struct{f64, f64, f64, f64} (4 elements).
  • callMmaAmpereFp64: Extended: Now it is able to also emit a single m8n8k4 instruction per call (single retArgs(2), aArgs(1), bArgs(1), cArgs(2)).
  • numRegisters: {1, 1, 1} for fp64 (was effectively {2, 1, 2}).
  • numMmaRets: 2 for fp64 (was 4).
  • numCPackedElem: 1 for fp64 (was incorrectly computed).
  • fc indexing formula: Uses numMmaRets * numCPackedElem instead of hardcoded 4.

third_party/nvidia/backend/compiler.py

  • min_dot_size: Added elif lhs_bitwidth == 64: return (1, 1, 4) to allow K=4 for fp64.

python/test/unit/language/test_core.py

  • Added small fp64 test cases: (8,8,4), (8,8,8), (16,8,4), (8,8,16) with num_warps=1.

test/Conversion/tritongpu_to_llvm.mlir

  • Updated f64_mma_cvt test to use instrShape = [8, 8] matching the new fp64 encoding.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@mwichro mwichro requested review from lezcano and ptillet as code owners April 16, 2026 22:50
@ThomasRaoux ThomasRaoux requested a review from Jokeren April 16, 2026 22:54
Copy link
Copy Markdown
Contributor

@Jokeren Jokeren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. You can update the title to MMAv2 fp64 instead of generic MMA, which might be confusing.
  2. I guess the previous fp64 instruction size would yield a higher throughput? But since we don't care about fp64 performance, I think it's acceptable to make it more general without considering multiple instruction candidates. @lezcano what do you think

@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 17, 2026

For these tiny matrices can't we ask the user to just pad the matrix with zeros? What are we getting from using a slightly lower tile?

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

I guess the previous fp64 instruction size would yield a higher throughput?

I was trying to find the code from the benchmark that generated the data presented in #7310 , but I was not able to do it, and I had to write it on my own. Below are the results on a A100 GPU

But since we don't care about fp64 performance

I really care about FP64. I tried to write it so that performance for larger matrices will be unaffected, so I would be really grateful for feedback on where I missed that.

matmul-performance-fp64:
         M          N          K       cuBLAS       Triton
--------------------------------------------------------
       256        256        256     2.730667     2.730667
       384        384        384     5.820632     5.529600
       512        512        512     8.456258     8.192000
       640        640        640    10.666666    12.190476
       768        768        768    12.639085    11.796479
       896        896        896    12.323930    13.910179
      1024       1024       1024    14.873419    13.797053
      1152       1152       1152    15.471420    16.140454
      1280       1280       1280    14.840579    15.170371
      1408       1408       1408    16.129514    15.444124
      1536       1536       1536    15.941189    15.254069
      1664       1664       1664    15.983858    15.252393
      1792       1792       1792    16.241943    15.438769
      1920       1920       1920    16.330774    15.691260
      2048       2048       2048    16.304388    16.100975
      2176       2176       2176    16.294452    16.549053
      2304       2304       2304    17.260024    17.026281
      2432       2432       2432    16.526155    16.296093
      2560       2560       2560    16.692817    16.864642
      2688       2688       2688    16.673870    16.428349
      2816       2816       2816    17.293500    17.043458
      2944       2944       2944    17.032137    16.779809
      3072       3072       3072    16.877230    16.619637
      3200       3200       3200    16.797901    16.541743
      3328       3328       3328    16.785101    16.523135
      3456       3456       3456    17.431690    17.157175
      3584       3584       3584    16.888691    16.617149
      3712       3712       3712    16.986456    16.716423
      3840       3840       3840    17.231536    16.840567
      3968       3968       3968    17.266724    16.987879
      4096       4096       4096    17.433137    17.152425

Here is the benchmark for future use:

"""
FP64 Matrix Multiplication Benchmark
=====================================
Benchmarks fp64 matrix-matrix multiplication performance of a Triton kernel
against cuBLAS (via torch.matmul), reproducing and verifying results from
https://github.com/triton-lang/triton/pull/7310.
"""

import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


def get_fp64_autotune_config():
    # Smaller blocks than fp16 because fp64 elements are 8 bytes,
    # so shared memory limits force smaller tiles.
    return [
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=3, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
                      num_stages=3, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8},
                      num_stages=4, num_warps=4),
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=3, num_warps=8),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 8},
                      num_stages=3, num_warps=8),
    ]


@triton.autotune(
    configs=get_fp64_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_fp64(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn,
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
        GROUP_SIZE_M: tl.constexpr,
):
    """Kernel for computing the fp64 matmul C = A x B."""
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)
    tl.assume(stride_am > 0)
    tl.assume(stride_ak > 0)
    tl.assume(stride_bn > 0)
    tl.assume(stride_bk > 0)
    tl.assume(stride_cm > 0)
    tl.assume(stride_cn > 0)

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    # Accumulate in fp64.
    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
    for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
        accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float64)
        a_ptrs += BLOCK_SIZE_K * stride_ak
        b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul_fp64(a, b):
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    assert a.dtype == torch.float64 and b.dtype == torch.float64, "Both inputs must be fp64"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float64)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel_fp64[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c


# ── Unit test ──────────────────────────────────────────────────────────────
torch.manual_seed(0)
a = torch.randn((512, 512), device=DEVICE, dtype=torch.float64)
b = torch.randn((512, 512), device=DEVICE, dtype=torch.float64)
triton_output = matmul_fp64(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp64_inputs={triton_output}")
print(f"torch_output_with_fp64_inputs={torch_output}")

if torch.allclose(triton_output, torch_output, atol=1e-8, rtol=1e-5):
    print("✅ Triton and Torch match")
else:
    max_diff = (triton_output - torch_output).abs().max().item()
    print(f"❌ Triton and Torch differ (max diff: {max_diff})")

# ── Benchmark ──────────────────────────────────────────────────────────────

ref_lib = 'cuBLAS' if is_cuda() else 'rocBLAS'

# Benchmark loop
sizes = [128 * i for i in range(2, 33)]
results = []

print(f"matmul-performance-fp64:")
print(f"{'M':>10} {'N':>10} {'K':>10} {ref_lib:>12} {'Triton':>12}")
print("-" * 56)

for size in sizes:
    M = N = K = size

    # Benchmark reference library
    a = torch.randn((M, K), device=DEVICE, dtype=torch.float64)
    b = torch.randn((K, N), device=DEVICE, dtype=torch.float64)
    quantiles = [0.5, 0.2, 0.8]
    cublas_ms, _, _ = triton.testing.do_bench(lambda: torch.matmul(a, b), quantiles=quantiles)
    cublas_tflops = 2 * M * N * K * 1e-12 / (cublas_ms * 1e-3)

    # Benchmark Triton
    triton_ms, _, _ = triton.testing.do_bench(lambda: matmul_fp64(a, b), quantiles=quantiles)
    triton_tflops = 2 * M * N * K * 1e-12 / (triton_ms * 1e-3)

    results.append({
        'M': float(M),
        'N': float(N),
        'K': float(K),
        ref_lib: cublas_tflops,
        'Triton': triton_tflops,
    })

    print(f"{M:>10.0f} {N:>10.0f} {K:>10.0f} {cublas_tflops:>12.6f} {triton_tflops:>12.6f}")

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

For the context, I am working on Finite Elements. I want to evaluate
$$(I \otimes I \otimes Q) u$$
where matrix $Q$ is $4 \times 4$, $I$ is the identity $4\times 4$ and u is $4\times 4\times 4 $. m8n8k4 is the smallest thing I can use for that.
This is also why FP64 is really important for me.

For these tiny matrices, can't we ask the user to just pad the matrix with zeros? What are we getting from using a slightly lower tile?

@lezcano
Padding with zeroes will result in a severe performance hit, right? This is not an optimal solution, especially when avoiding it looks doable (?). Let me know what you think.

@mwichro mwichro changed the title [BACKEND] Extend support for small MMA: single 8x8x4 instructions. [BACKEND] Extend support for small MMAv2 FP64: single 8x8x4 instructions. Apr 17, 2026
@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 17, 2026

@lezcano
Padding with zeroes will result in a severe performance hit, right?

Do you have some numbers?

@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 17, 2026

As I can see, even for your matrices this shape is too big, so you are going to have to pad them regardless? I'd assume that for these ops you are heavily BW bound or depending on the kernel, kernel launch bound, so I'd expect that padding it a bit more should give you exactly the same perf.

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

As I can see, even for your matrices, this shape is too big, so you are going to have to pad them regardless?

Not necessarily, if I play with the layout, they might be just the right size. I will need to do 4x 8x16x4 MMA (which was not supported), and discard half of one of those multiplications. So, 87.5% of the results are valid.

I'd assume that for these ops you are heavily BW-bound or depending on the kernel, kernel launch-bound, so I'd expect that padding it a bit more should give you exactly the same perf.

Neither BW nor kernel launch is the limit here:

  • $u$ of size $4 \times 4 \times 4$ is just for a single cell; there are millions of cells processed in the same kernel
  • They are not BW bound, it was just an example of a single contraction, for once loaded cell data $u$ I perform multiple contractions that look like $A \otimes B \otimes C$ (matrices $A,B, C$ are the same for all of the cells). There are some other operations on the cell, but the arithmetic intensity from those tensor contractions alone is around 5 (absolutely the simplest case) to 15 or more (typical case).
  • However, those kernels, implemented with CUDA cores, are usually shared-memory bound; using tensor cores helps a lot here.

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 17, 2026

I really care about FP64. I tried to write it so that performance for larger matrices will be unaffected, so I would be really grateful for feedback on where I missed that.

Can you compare the original instruction size and 8x8x4 on these shapes? If you do care about performance, I think the code shouldn't be optimized just for particular cases

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

Do you mean to benchmark small MMA? I have no idea how reliable those metrics are, but here are the numbers:

matmul-performance-fp64:
         M          N          K       cuBLAS       Triton
--------------------------------------------------------
         8          8          8     0.000125     0.000143
        16         16         16     0.001000     0.001143
        24         24         24     0.003375     0.003375
        32         32         32     0.007111     0.008000
        40         40         40     0.015625     0.015625
        48         48         48     0.024000     0.027000
        56         56         56     0.038111     0.038111
        64         64         64     0.042667     0.056889

Or do you want me to compare performance before/after changes? I can provide that too.
This on slightly different hardware: A100 40GB (the ones above are on A100 80GB). No idea why those are slightly faster.

Before:

 Triton and Torch match
matmul-performance-fp64:
         M          N          K       cuBLAS       Triton
--------------------------------------------------------
       256        256        256     2.730667     2.730667
       384        384        384     6.144000     5.820632
       512        512        512     8.738134     8.738134
       640        640        640    10.666666    12.800000
       768        768        768    12.822261    12.822261
       896        896        896    12.544000    14.946042
      1024       1024       1024    14.463117    14.266340

After:

matmul-performance-fp64:
         M          N          K       cuBLAS       Triton
--------------------------------------------------------
       256        256        256     2.978909     2.978909
       384        384        384     6.144000     5.820632
       512        512        512     8.738134     8.738134
       640        640        640    10.666666    13.128206
       768        768        768    13.010823    12.461070
       896        896        896    12.544000    14.788715
      1024       1024       1024    14.364054    14.074846

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 17, 2026

Or do you want me to compare performance before/after changes? I can provide that too.

Yeah, seems benign to me. @lezcano are you worried about introducing code complexity?

@Jokeren
Copy link
Copy Markdown
Contributor

Jokeren commented Apr 17, 2026

@mwichro can you fix the tests first?

@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 17, 2026

fwiw, after discussing it, we think we'd accept the patch as the diff is small and it doesn't regress

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

Thanks!

How should I proceed? As far as I can see, the test is failing because the output code is slightly different
Shall we accept that the output is different and update the check?

Or I must fix the code?

@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 17, 2026

yep, ask codex (or claude, pick your poison :P) to edit the lit tests a bit so that they pass with the new codegen

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 17, 2026

All tests have passed. I squished the commit into a single one.

Status: I resolved merge conflicts, fixed the tests again, and made it a single commit.

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented Apr 20, 2026

All test passed, commit cleaned up.

@Jokeren Jokeren merged commit 9e59ac0 into triton-lang:main Apr 20, 2026
9 checks passed
bingyizh233 pushed a commit to bingyizh233/triton that referenced this pull request Apr 20, 2026
…ons. (triton-lang#10060)

The fp64 MMA path now operates at native `m8n8k4` granularity,
supporting any shape that is a multiple of 8×8×4, including the minimal
8×8×4 case.

This is an extension of triton-lang#7310
(The implementation was based on that PR)

Tests passed on A100.

## Files Changed

### `lib/Dialect/TritonGPU/IR/Dialect.cpp`
- `getRepForOperand`: Changed `tileBitWidthK` from `2 * 256` to `1 *
256` for fp64 (K-tile = 4). Changed `tileSize[M]` from hardcoded `16` to
`8` for fp64.

### `lib/Dialect/TritonGPU/Transforms/Utility.cpp`
- `mmaVersionToInstrShape`: Returns `instrShape[M] = 8` for fp64 (was
always 16).

### `lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp`
- `nvidiaDotToLinearLayout`: Uses `instrShape` from the MMA encoding for
tile shape computation. K tile multiplier is 4 (not 8) when `instrM ==
8`.

### `third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv2.cpp`
- `getMmaRetType`: fp64 returns `struct{f64, f64}` (2 elements) instead
of `struct{f64, f64, f64, f64}` (4 elements).
- `callMmaAmpereFp64`: Extended: Now it is able to also emit a single
`m8n8k4` instruction per call (single retArgs(2), aArgs(1), bArgs(1),
cArgs(2)).
- `numRegisters`: `{1, 1, 1}` for fp64 (was effectively `{2, 1, 2}`).
- `numMmaRets`: 2 for fp64 (was 4).
- `numCPackedElem`: 1 for fp64 (was incorrectly computed).
- fc indexing formula: Uses `numMmaRets * numCPackedElem` instead of
hardcoded `4`.

### `third_party/nvidia/backend/compiler.py`
- `min_dot_size`: Added `elif lhs_bitwidth == 64: return (1, 1, 4)` to
allow K=4 for fp64.

### `python/test/unit/language/test_core.py`
- Added small fp64 test cases: `(8,8,4)`, `(8,8,8)`, `(16,8,4)`,
`(8,8,16)` with `num_warps=1`.

### `test/Conversion/tritongpu_to_llvm.mlir`
- Updated `f64_mma_cvt` test to use `instrShape = [8, 8]` matching the
new fp64 encoding.

----


# New contributor declaration
- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [X] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [X] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [X] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
ThomasRaoux pushed a commit that referenced this pull request May 13, 2026
… K=8 (#10234)

Enable `tl.dot` with TF32 precision on tiles with **N=8** and **K=8**
(e.g. `wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`) via the
standard `tt.dot` → `AccelerateMatmul` path on sm90+.

Related to
#10060 (comment)
I am trying Triton for Finite Elements, and it does wonders! 
The matrices used in those computations are usually quite small. With
some management, it is possible to pack several operations into MMA
cores, but the tile sizes implemented were too big.

I ran the lit test, and they are passing, so I guess the resulting IR is
the same. Addes test for the new functionality.

--- 

## Changes

### `lib/Analysis/Utility.cpp`
In `supportMMA` (version 3), relaxed the N-dimension divisibility check
from `% 16` to `% 8`:
```cpp
- retShapePerCTA[rank - 1] % 16 == 0
+ retShapePerCTA[rank - 1] % 8 == 0
```
The WGMMAv3 op verifier already required only `N % 8 == 0`, and
`mmaVersionToInstrShape` already listed `n=8` as valid. This was the
sole gatekeeper preventing N=8 tiles from using WGMMA, causing a silent
fallback to MMAv2.

### `third_party/nvidia/backend/compiler.py`
In `min_dot_size`, added an explicit case for 32-bit types (TF32/FP32):
```python
+ elif lhs_bitwidth == 32:
+     return (1, 1, 8)
```
The TF32 hardware instruction has K=8. The previous fallthrough to the
`else` branch returned `K >= 16`, blocking compilation of K=8 TF32
kernels.

### `python/test/unit/language/test_core.py`
Added `test_dot_wgmma_tf32_n8k8` parametrized over `M ∈ {64, 128}`,
verifying both numerical correctness and that the emitted PTX contains
`wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`.


# New contributor declaration
- [X] I am not making a trivial change, such as fixing a typo in a
comment.

- [X] I have written a PR description following these
  [rules](https://cbea.ms/git-commit/#why-not-how).

- [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`.

- Select one of the following.
  - [x] I have added tests.
    - `/test` for `lit` tests
    - `/unittest` for C++ tests
    - `/python/test` for end-to-end tests
  - [ ] This PR does not need a test because `FILL THIS IN`.

- Select one of the following.
  - [X] I have not added any `lit` tests.
- [ ] The `lit` tests I have added follow these [best
practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python
code
    and using the instructions it generates is not minimal.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants