Skip to content

[ BACKEND ] Enable tl.dot with TF32 precision on tiles with N=8 and K=8#10234

Merged
ThomasRaoux merged 1 commit into
triton-lang:mainfrom
mwichro:wgmma.m64n8k8
May 13, 2026
Merged

[ BACKEND ] Enable tl.dot with TF32 precision on tiles with N=8 and K=8#10234
ThomasRaoux merged 1 commit into
triton-lang:mainfrom
mwichro:wgmma.m64n8k8

Conversation

@mwichro
Copy link
Copy Markdown
Contributor

@mwichro mwichro commented May 5, 2026

Enable tl.dot with TF32 precision on tiles with N=8 and K=8 (e.g. wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32) via the standard tt.dotAccelerateMatmul path on sm90+.

Related to #10060 (comment)
I am trying Triton for Finite Elements, and it does wonders!
The matrices used in those computations are usually quite small. With some management, it is possible to pack several operations into MMA cores, but the tile sizes implemented were too big.

I ran the lit test, and they are passing, so I guess the resulting IR is the same. Addes test for the new functionality.


Changes

lib/Analysis/Utility.cpp

In supportMMA (version 3), relaxed the N-dimension divisibility check from % 16 to % 8:

- retShapePerCTA[rank - 1] % 16 == 0
+ retShapePerCTA[rank - 1] % 8 == 0

The WGMMAv3 op verifier already required only N % 8 == 0, and mmaVersionToInstrShape already listed n=8 as valid. This was the sole gatekeeper preventing N=8 tiles from using WGMMA, causing a silent fallback to MMAv2.

third_party/nvidia/backend/compiler.py

In min_dot_size, added an explicit case for 32-bit types (TF32/FP32):

+ elif lhs_bitwidth == 32:
+     return (1, 1, 8)

The TF32 hardware instruction has K=8. The previous fallthrough to the else branch returned K >= 16, blocking compilation of K=8 TF32 kernels.

python/test/unit/language/test_core.py

Added test_dot_wgmma_tf32_n8k8 parametrized over M ∈ {64, 128}, verifying both numerical correctness and that the emitted PTX contains wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32.

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@mwichro mwichro requested review from Jokeren and ptillet as code owners May 5, 2026 19:44
return (1, 1, 32)
elif lhs_bitwidth == 64:
return (1, 1, 4)
elif lhs_bitwidth == 32:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more general than sm90+. Likely you need more code updates to support a smaller dot shape across archs. I'll defer it to @ThomasRaoux to determine if this is right direction

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it would be good to update Blackwell as well the same way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have access to Blackwell GPUs to test it directly, but I've updated MMAv5

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about sm80 with TF32, are you able to use (1,1,8) to pass all tests?

Copy link
Copy Markdown
Contributor Author

@mwichro mwichro May 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A100 is sm80, so MMAv3 is not available, so it should have no effect anyway.

I have A100, this one I can test:

python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header 
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)

Looks fine to me.

Wait, A100 has some other instructions in MMAv2 that also could be used, let me check that.
Checked it's fine.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, btw, on review, Claude is suggesting:

So min_dot_size is reinventing the same formula with a hardcoded if/elif chain, ignoring the target parameter it receives. The fix is to use the formula directly:

def min_dot_size(target: GPUTarget):

    def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]:
        lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
        rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
        assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
        # For small M/N we can still use tensor cores with padding.
        # The minimum K is determined by the native MMA tile: 256 / bitwidth.
        return (1, 1, 256 // lhs_bitwidth)

    return check_dot_compatibility

This is identical in behaviour to the current code for all supported types, eliminates the special-casing we added for 32-bit, and directly mirrors the hardware formula used in mmaVersionToInstrShape. It also correctly handles any future type (e.g., fp8 at 8-bit is already covered).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like it better the way it is right now

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python -m pytest python/test/unit/language/test_core.py -k "dot" -q --no-header
736 passed, 1882 skipped, 6728 deselected in 286.06s (0:04:46)

We don't have very small dot tests previously and your test skipped sm80

@mwichro mwichro changed the title [ BACKEND ] Enable tl.dot with TF32 precision on tiles with **N=8** and **K=8** via the standard tt.dotAccelerateMatmul path on sm90+. [ BACKEND ] Enable tl.dot with TF32 precision on tiles with N=8 and K=8 May 5, 2026
Comment thread python/test/unit/language/test_core.py Outdated


@pytest.mark.parametrize("M, num_warps", [(64, 4), (128, 8)])
def test_dot_wgmma_tf32_n8k8(M, num_warps, device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add to current test_dot instead there are already get_test_small_dots_cases

Copy link
Copy Markdown
Contributor Author

@mwichro mwichro May 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried that: it turns out K has to be constexpr for this to properly emit wgmma instruction.
So the less invasive way is probably adding just another test.

Comment thread python/test/unit/language/test_core.py Outdated
if not is_cuda():
pytest.skip("WGMMA is NVIDIA-only")
capability = torch.cuda.get_device_capability()
if capability[0] < 9:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think there's something missing. What if you run this test on ampere by removing this constraint? My point is, the current code may not work on sm80 when the tile size is as small as 8x8x8. Before we at least emit an error on the frontend

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? mma,.sync k dim is 8 for tf32 so why wouldn't it work?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that being said we should definitely run the test on all target and not do it for sm_90 only. See my comment here: #10234 (comment)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? mma,.sync k dim is 8 for tf32 so why wouldn't it work?

Yeah, in theory it should work, but I'd like to confirm by enabling tests.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, sm80 is covered.

Comment thread lib/Analysis/Utility.cpp
Comment thread python/test/unit/language/test_core.py Outdated
for M, nw in [(64, 4), (128, 8)]
for dtype, K, prec, sm80, sm90 in _dot_n8_cases],
)
def test_dot_n8(M, num_warps, dtype, K, input_precision, sm80_ptx, sm90_ptx, device):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we add cases to the exist test_dot?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K must be constexpr, otherwise trition is not able to emit wgmma for K=8 (I tried). I think assuming that K is constexpr is quite reasonable for such a small MMA.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what K are we talking about here? the block size is always constexpr

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 3520 in python/test/unit/language/test_core.py:

    if (K > 16 or N > 16 or M > 16) and (M * N // (num_warps * 32) >= 4):
        # XXX: skip small sizes because they are not vectorized

With runtime strides, Triton can only prove contiguity along the inner dim when it is at least 16 elements wide, so K >= 16 is required to vectorize loads to v4.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you talking about failures in any of the assert operations? Feel free to update the assert conditions

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is vectorization related to this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood how this test works. I updated the tests, and I moved the MMA Mx8x8 tests inside test_dot.

Comment thread python/test/unit/language/test_core.py Outdated

is_tcgen5 = (capability[0] == 10) and (num_warps % 4) == 0 and (M % 64) == 0 and (N % 8) == 0

n_pat = '8' if N == 8 else r'\d+'
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this. I don't think it helps coverage much and makes the test even more complex

Comment thread python/test/unit/language/test_core.py Outdated
Comment on lines +3297 to +3298
(16, 'ieee', 'float16', 'float32'),
(32, 'ieee', 'float16', 'float32')]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are those cases changing at all?

Comment thread python/test/unit/language/test_core.py Outdated
Comment on lines +3290 to +3298
def get_test_dot_n8_cases():
if not is_cuda():
return []
return [(M, 8, K, nw, False, False, 'none', prec, in_dtype, out_dtype, 1, None)
for M, nw in [(64, 4), (128, 8)]
for K, prec, in_dtype, out_dtype in [(8, 'tf32', 'float32', 'float32'),
(16, 'tf32', 'float32', 'float32'),
(16, 'ieee', 'float16', 'float32'),
(32, 'ieee', 'float16', 'float32')]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like you could just add a couple well chosen tests in get_test_small_dots_cases to test the right corner cases

…wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32`) via the standard `tt.dot` → `AccelerateMatmul` path on sm90+.
@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented May 11, 2026

Thanks for your patience and approval!

The tests failure do not look related to the changes

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

They seem to originate from torch...

@mwichro
Copy link
Copy Markdown
Contributor Author

mwichro commented May 13, 2026

Test passed

@ThomasRaoux ThomasRaoux merged commit 3de9d04 into triton-lang:main May 13, 2026
23 of 27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants