[NVIDIA] Support swizzle 0 TMA + MMA for Hopper and Blackwell#9931
Conversation
|
wait, why did you revert to the previous code? Generalising the code I proposed to just push back the layout is trivial and much cleaner than the current implementation. you just need to accumulate the layout when traversing back the views and then use it rather than the descriptor layout. I removed that because it was not necessary in that pass. |
|
@lezcano Restored to the code structure you proposed |
ThomasRaoux
left a comment
There was a problem hiding this comment.
LGTM but let's wait for @lezcano to finish reviewing
lezcano
left a comment
There was a problem hiding this comment.
The current get compatible shared encoding will only work if the ops that we are doing these on are memdesc_{trans,reshape} because there are the ones that take equivalent layouts. Are there any other cases where this could just break?
Although layout equivalence does not mean that two layouts are interchangeable, I cannot come up with an example where doing so in this context of this PR would lead to functional or performance bugs. This is because the premise of this rewrite is that when there is an immediate For example, taking this contrived lit test, since That said, I'll try limiting the scope of the new behavior of |
lezcano
left a comment
There was a problem hiding this comment.
I think it's alright as a patch.
In a perfect world, we'd generalise the lowering as to take any LinearLayout that TMA can lower into, but that's a completely different PR, a trickier one.
| if (auto tensorTy = dyn_cast<RankedTensorType>(type)) { | ||
| return getCompatibleSharedEncoding(tensorTy.getEncoding(), | ||
| tensorTy.getShape(), | ||
| tensorTy.getElementType()); | ||
| } |
There was a problem hiding this comment.
I think so, this is a type of the user (local_alloc), not a tensordesc
|
@lezcano thanks, I also pushed, to a separate branch, an alternative solution which carries the "desired encoding" from ODO to ODE, which eliminates the need for the layout-equivalence based propagation. masahi@dc1af07 This is obviously safe, but it looks a bit ad hoc and introduces a weird coupling between ODO and ODE. I'm proceeding with the current, general solution but let me know if you prefer this alternative |
…-lang#9931) A `nvmma_shared` layout with swizzle=0 represents a flat, contiguous layout. This is valid for TMA but it is never the correct layout for Hopper WGMMA and Blackwell tcgen05 MMA instructions, since operand SMEMs with swizzle=0 are required to be in the special layout ("core matrices" format). Being able to use swizzle=0 for MMA is useful in case other swizzling modes cannot be applied for some reason. In principle, if the operand in gmem is already in the right format that MMA expects, using TMA with swizzle 0 and directly feeding the result into MMA can work. But currently I'm not aware of a way to express that in Triton. This PR adds a key rewrite pass that enables that. The idea is the same as how we enable blocked-scale load via TMA and `tmem_copy`. See the attached test case for a fully worked-out example. * Users must prepare the operand in gmem in the correct, swizzle-0 core-matrices format. Conceptually, an operand would have a shape like `(num_blocks_m, num_blocks_k, num_cm_m, num_cm_k, 16, 8)` for fp8. * In a kernel, use `tt.trans` and `tt.reshape` to "undo" the core-matrices transformation. The op sequence looks like `desc_load<swizzle=0> -> tt.reshape / tt.trans -> local_alloc -> mma` * The new rewrite pass bubbles up `local_alloc` so that it immediately follows `desc_load`. It also lifts those reshape / trans into transformations on memdesc. The op sequence now looks like `desc_load<swizzle=0> -> local_alloc<swizzle=0> -> memdesc reshape / trans -> mma<swizzle=0>`. * Thanks to the LinearLayout inference machinery on memdesc reshape / trans, the MMA now gets an operand smem with a special `#shared_linear` layout. If the `tt.trans` and `tt.reshape` transformations the user specified are correct, the operand linear layout is correctly identified as representing the swizzle-0 core-matrices format and lowered to the corresponding SMEM descriptor.
triton-lang#10032) Reverts triton-lang#9931 as it is causing functional regressions
…#10148) Compared to #9931: * When looking for an equivalent LL in the loop https://github.com/masahi/triton/blob/62b3a08c1b205522991148b4d0b6d761e0ecb369/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp#L69-L79, we are now guarding against an LL creation failure. Previously, I was using `ttg::areLayoutsEquivalent(shape, sharedLinear, candidate)`, but this test can fail with a non-recoverable error due to an incompatible shape and `nvmma_shared` attributes. I added `tryNvmmaSharedToLinearLayout` as a safe way to create an LL and test layout equivalence if the former succeeds. An alternative would be to decide if an LL creation is guaranteed to succeed before using `ttg::areLayoutsEquivalent`. I didn't investigate the feasibility and the completeness of this path deeply * The new rewrite in `OptimizeDotOperands` always updates the operand encoding with `#shared_linear` whenever view operations are present. The premise of this rewrite is supposed to be that it preserves the operand encoding and propagates `#shared_linear` encoding upward. This rewriting should not fire when the operand encoding is `#nvmma_shared`, in which case it is replaced with an equivalent `#shared_linear` encoding. Although this rewrite is benign in principle, I decided to keep the scope of this rewrite to those MMA ops whose operand encoding is already `#shared_linear`, since this was the original use case this rewriting was intended for. This change is unrelated to the regression but I added it for an additional safety.
A
nvmma_sharedlayout with swizzle=0 represents a flat, contiguous layout. This is valid for TMA but it is never the correct layout for Hopper WGMMA and Blackwell tcgen05 MMA instructions, since operand SMEMs with swizzle=0 are required to be in the special layout ("core matrices" format). Being able to use swizzle=0 for MMA is useful in case other swizzling modes cannot be applied for some reason.In principle, if the operand in gmem is already in the right format that MMA expects, using TMA with swizzle 0 and directly feeding the result into MMA can work. But currently I'm not aware of a way to express that in Triton. This PR adds a key rewrite pass that enables that. The idea is the same as how we enable blocked-scale load via TMA and
tmem_copy. See the attached test case for a fully worked-out example.(num_blocks_m, num_blocks_k, num_cm_m, num_cm_k, 16, 8)for fp8.tt.transandtt.reshapeto "undo" the core-matrices transformation. The op sequence looks likedesc_load<swizzle=0> -> tt.reshape / tt.trans -> local_alloc -> mmalocal_allocso that it immediately followsdesc_load. It also lifts those reshape / trans into transformations on memdesc. The op sequence now looks likedesc_load<swizzle=0> -> local_alloc<swizzle=0> -> memdesc reshape / trans -> mma<swizzle=0>.#shared_linearlayout. If thett.transandtt.reshapetransformations the user specified are correct, the operand linear layout is correctly identified as representing the swizzle-0 core-matrices format and lowered to the corresponding SMEM descriptor.