Skip to content

[NVIDIA] Support swizzle 0 TMA + MMA for Hopper and Blackwell#9931

Merged
masahi merged 58 commits into
triton-lang:mainfrom
masahi:tma-mma-swizzle-0
Apr 10, 2026
Merged

[NVIDIA] Support swizzle 0 TMA + MMA for Hopper and Blackwell#9931
masahi merged 58 commits into
triton-lang:mainfrom
masahi:tma-mma-swizzle-0

Conversation

@masahi
Copy link
Copy Markdown
Collaborator

@masahi masahi commented Apr 6, 2026

A nvmma_shared layout with swizzle=0 represents a flat, contiguous layout. This is valid for TMA but it is never the correct layout for Hopper WGMMA and Blackwell tcgen05 MMA instructions, since operand SMEMs with swizzle=0 are required to be in the special layout ("core matrices" format). Being able to use swizzle=0 for MMA is useful in case other swizzling modes cannot be applied for some reason.

In principle, if the operand in gmem is already in the right format that MMA expects, using TMA with swizzle 0 and directly feeding the result into MMA can work. But currently I'm not aware of a way to express that in Triton. This PR adds a key rewrite pass that enables that. The idea is the same as how we enable blocked-scale load via TMA and tmem_copy. See the attached test case for a fully worked-out example.

  • Users must prepare the operand in gmem in the correct, swizzle-0 core-matrices format. Conceptually, an operand would have a shape like (num_blocks_m, num_blocks_k, num_cm_m, num_cm_k, 16, 8) for fp8.
  • In a kernel, use tt.trans and tt.reshape to "undo" the core-matrices transformation. The op sequence looks like desc_load<swizzle=0> -> tt.reshape / tt.trans -> local_alloc -> mma
  • The new rewrite pass bubbles up local_alloc so that it immediately follows desc_load. It also lifts those reshape / trans into transformations on memdesc. The op sequence now looks like desc_load<swizzle=0> -> local_alloc<swizzle=0> -> memdesc reshape / trans -> mma<swizzle=0>.
  • Thanks to the LinearLayout inference machinery on memdesc reshape / trans, the MMA now gets an operand smem with a special #shared_linear layout. If the tt.trans and tt.reshape transformations the user specified are correct, the operand linear layout is correctly identified as representing the swizzle-0 core-matrices format and lowered to the corresponding SMEM descriptor.

@lezcano
Copy link
Copy Markdown
Contributor

lezcano commented Apr 9, 2026

wait, why did you revert to the previous code? Generalising the code I proposed to just push back the layout is trivial and much cleaner than the current implementation. you just need to accumulate the layout when traversing back the views and then use it rather than the descriptor layout. I removed that because it was not necessary in that pass.

@masahi
Copy link
Copy Markdown
Collaborator Author

masahi commented Apr 9, 2026

@lezcano Restored to the code structure you proposed

@masahi masahi requested review from ThomasRaoux and lezcano and removed request for peterbell10 April 10, 2026 05:56
Copy link
Copy Markdown
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but let's wait for @lezcano to finish reviewing

Comment thread lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp Outdated
Comment thread lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp Outdated
Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current get compatible shared encoding will only work if the ops that we are doing these on are memdesc_{trans,reshape} because there are the ones that take equivalent layouts. Are there any other cases where this could just break?

@masahi
Copy link
Copy Markdown
Collaborator Author

masahi commented Apr 10, 2026

The current get compatible shared encoding will only work if the ops that we are doing these on are memdesc_{trans,reshape} because there are the ones that take equivalent layouts. Are there any other cases where this could just break?

Although layout equivalence does not mean that two layouts are interchangeable, I cannot come up with an example where doing so in this context of this PR would lead to functional or performance bugs. This is because the premise of this rewrite is that when there is an immediate local_alloc or local_store user, it must be attached with #shared_linear layout: https://github.com/masahi/triton/blob/3130b82d07e4d1e6bf69759ff599d9fb901cf90d/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp#L40-L42. This is a very strong premise - I don't see how choosing an equivalent nvmma_shared encoding for the producing descriptor_load would be incorrect.

For example, taking this contrived lit test,
https://github.com/masahi/triton/blob/3130b82d07e4d1e6bf69759ff599d9fb901cf90d/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir#L164-L165

since replaceUsesWithLocalLoad tests memdesc equivalence via exact encoding equiality, TMALowering would inject local_load / local_alloc round trip even in the absence of the new change. So using the #shared_linear encoding attached to local_alloc to decide the descriptor encoding doesn't hurt anything.

That said, I'll try limiting the scope of the new behavior ofgetCompatibleSharedEncoding to those ops with the tt.desired_encoding attribute - I'll set this during ODO to signal the cases where propagating local_alloc encoding to descriptor_load is obviously safe.

Copy link
Copy Markdown
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's alright as a patch.

In a perfect world, we'd generalise the lowering as to take any LinearLayout that TMA can lower into, but that's a completely different PR, a trickier one.

Comment on lines +259 to +263
if (auto tensorTy = dyn_cast<RankedTensorType>(type)) {
return getCompatibleSharedEncoding(tensorTy.getEncoding(),
tensorTy.getShape(),
tensorTy.getElementType());
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is thsi necessary after #9851?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, this is a type of the user (local_alloc), not a tensordesc

@masahi
Copy link
Copy Markdown
Collaborator Author

masahi commented Apr 10, 2026

@lezcano thanks, I also pushed, to a separate branch, an alternative solution which carries the "desired encoding" from ODO to ODE, which eliminates the need for the layout-equivalence based propagation. masahi@dc1af07

This is obviously safe, but it looks a bit ad hoc and introduces a weird coupling between ODO and ODE. I'm proceeding with the current, general solution but let me know if you prefer this alternative

@masahi masahi merged commit d2d63f2 into triton-lang:main Apr 10, 2026
9 checks passed
@masahi masahi deleted the tma-mma-swizzle-0 branch April 10, 2026 09:01
plognjen pushed a commit to plognjen/triton that referenced this pull request Apr 14, 2026
…-lang#9931)

A `nvmma_shared` layout with swizzle=0 represents a flat, contiguous
layout. This is valid for TMA but it is never the correct layout for
Hopper WGMMA and Blackwell tcgen05 MMA instructions, since operand SMEMs
with swizzle=0 are required to be in the special layout ("core matrices"
format). Being able to use swizzle=0 for MMA is useful in case other
swizzling modes cannot be applied for some reason.

In principle, if the operand in gmem is already in the right format that
MMA expects, using TMA with swizzle 0 and directly feeding the result
into MMA can work. But currently I'm not aware of a way to express that
in Triton. This PR adds a key rewrite pass that enables that. The idea
is the same as how we enable blocked-scale load via TMA and `tmem_copy`.
See the attached test case for a fully worked-out example.
* Users must prepare the operand in gmem in the correct, swizzle-0
core-matrices format. Conceptually, an operand would have a shape like
`(num_blocks_m, num_blocks_k, num_cm_m, num_cm_k, 16, 8)` for fp8.
* In a kernel, use `tt.trans` and `tt.reshape` to "undo" the
core-matrices transformation. The op sequence looks like
`desc_load<swizzle=0> -> tt.reshape / tt.trans -> local_alloc -> mma`
* The new rewrite pass bubbles up `local_alloc` so that it immediately
follows `desc_load`. It also lifts those reshape / trans into
transformations on memdesc. The op sequence now looks like
`desc_load<swizzle=0> -> local_alloc<swizzle=0> -> memdesc reshape /
trans -> mma<swizzle=0>`.
* Thanks to the LinearLayout inference machinery on memdesc reshape /
trans, the MMA now gets an operand smem with a special `#shared_linear`
layout. If the `tt.trans` and `tt.reshape` transformations the user
specified are correct, the operand linear layout is correctly identified
as representing the swizzle-0 core-matrices format and lowered to the
corresponding SMEM descriptor.
ThomasRaoux added a commit that referenced this pull request Apr 14, 2026
raymondtay pushed a commit to raymondtay/triton that referenced this pull request Apr 18, 2026
masahi added a commit that referenced this pull request Apr 30, 2026
…#10148)

Compared to #9931:
* When looking for an equivalent LL in the loop
https://github.com/masahi/triton/blob/62b3a08c1b205522991148b4d0b6d761e0ecb369/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp#L69-L79,
we are now guarding against an LL creation failure. Previously, I was
using `ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)`, but
this test can fail with a non-recoverable error due to an incompatible
shape and `nvmma_shared` attributes. I added
`tryNvmmaSharedToLinearLayout` as a safe way to create an LL and test
layout equivalence if the former succeeds. An alternative would be to
decide if an LL creation is guaranteed to succeed before using
`ttg::areLayoutsEquivalent`. I didn't investigate the feasibility and
the completeness of this path deeply

* The new rewrite in `OptimizeDotOperands` always updates the operand
encoding with `#shared_linear` whenever view operations are present. The
premise of this rewrite is supposed to be that it preserves the operand
encoding and propagates `#shared_linear` encoding upward. This rewriting
should not fire when the operand encoding is `#nvmma_shared`, in which
case it is replaced with an equivalent `#shared_linear` encoding.
Although this rewrite is benign in principle, I decided to keep the
scope of this rewrite to those MMA ops whose operand encoding is already
`#shared_linear`, since this was the original use case this rewriting
was intended for. This change is unrelated to the regression but I added
it for an additional safety.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants