[Reland][NVIDIA] Support swizzle 0 TMA + MMA for Hopper and Blackwell#10148
Conversation
|
I asked codex to put together the diff: |
| // Restrict this rewrite to an operand which already uses a shared-linear | ||
| // encoding. Backward propagation through tensor reshape/trans is not | ||
| // encoding-stable for NVMMAShared. | ||
| if (!isa<SharedLinearEncodingAttr>(operandTy.getEncoding())) | ||
| return failure(); |
There was a problem hiding this comment.
Mmas accept both sharedlinearlayouts as well as nvmma so this change is really benign. Do you have any particular concerns?
There was a problem hiding this comment.
I don't have an actual example that would be broken if an mma operand with #nvmma_shared gets replaced by #shared_linear. I added this skip as an extra safety measure, since I wouldn't be surprised if some pass depends on having #nvmma_shared in an mma operand.
I think this is also more in line with the original purpose of this rewriting - some earlier pass like AccelerateMatmul fixes the operand encoding to be #shared_linear when it identified a need to do so, and this rewriting is supposed to propagate that over a preceding view ops chain. So it doesn't make sense that the operand encoding would change before / after this pass.
| // This condition can fail if a layout is speculatively constructed for | ||
| // equivalence checking. | ||
| if (layout.getTotalOutDimSize() != product(maybeTransposedTmaShape)) | ||
| return failure(); |
There was a problem hiding this comment.
where can this happen exactly? it feels like quite a big issue.
There was a problem hiding this comment.
A concrete test case that fails this condition is this one: https://github.com/masahi/triton/blob/58c3b956958f572e1f6bfa3ddbd865c9cac40763/test/TritonNvidiaGPU/optimize_descriptor_encoding.mlir#L180-L188
We call buildNvmmaSharedLinearLayout with shape [1, 16, 1, 16] and various candidates nvmma_shared encodings. For some candidates, it seems ensureLayoutNotSmallerThan can return a layout that covers more than [1, 16, 1, 16] output elements.
There was a problem hiding this comment.
I still get the feeling that there's a better place to catch this one than this late, but sure.
| if (failed(layout)) | ||
| llvm::report_fatal_error("Illegal shared layout"); |
There was a problem hiding this comment.
I think it is fine to keep that along with the emitError
There was a problem hiding this comment.
I hope the current code after 58c3b95 has addressed this comment
…aSharedToLinearLayout
|
Also removed the |
lezcano
left a comment
There was a problem hiding this comment.
New diff for review: 3130b82...compare/pr10148-vs-pr9931
| // This condition can fail if a layout is speculatively constructed for | ||
| // equivalence checking. | ||
| if (layout.getTotalOutDimSize() != product(maybeTransposedTmaShape)) | ||
| return failure(); |
There was a problem hiding this comment.
I still get the feeling that there's a better place to catch this one than this late, but sure.
Compared to #9931:
When looking for an equivalent LL in the loop https://github.com/masahi/triton/blob/62b3a08c1b205522991148b4d0b6d761e0ecb369/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp#L69-L79, we are now guarding against an LL creation failure. Previously, I was using
ttg::areLayoutsEquivalent(shape, sharedLinear, candidate), but this test can fail with a non-recoverable error due to an incompatible shape andnvmma_sharedattributes. I addedtryNvmmaSharedToLinearLayoutas a safe way to create an LL and test layout equivalence if the former succeeds. An alternative would be to decide if an LL creation is guaranteed to succeed before usingttg::areLayoutsEquivalent. I didn't investigate the feasibility and the completeness of this path deeplyThe new rewrite in
OptimizeDotOperandsalways updates the operand encoding with#shared_linearwhenever view operations are present. The premise of this rewrite is supposed to be that it preserves the operand encoding and propagates#shared_linearencoding upward. This rewriting should not fire when the operand encoding is#nvmma_shared, in which case it is replaced with an equivalent#shared_linearencoding. Although this rewrite is benign in principle, I decided to keep the scope of this rewrite to those MMA ops whose operand encoding is already#shared_linear, since this was the original use case this rewriting was intended for. This change is unrelated to the regression but I added it for an additional safety.