[AMD] Guard MFMA store layout for small N dimensions#10305
Merged
antiagainst merged 1 commit intoMay 15, 2026
Conversation
There was a problem hiding this comment.
Pull request overview
Fixes an AMD-specific crash in chooseMfmaLikeStoreLayout where building a wide-store linear layout would index past the end of dimNBases when the N dimension is too small. The helper now bails out early (returning {}) so the existing fallback store path is used instead.
Changes:
- Compute
destIdxInBasesearlier and guard against N-dimension log2 sizes that are too small before constructing the swap layout. - Remove the now-redundant later declaration of
destIdxInBases. - Add a lit test for
mfma32with N=8 verifying the fallback store path is taken.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp | Early-exit guard when dimN log2 size is ≤ target basis bit index; relocates destIdxInBases definition. |
| test/TritonGPU/amd/amd-optimize-epilogue.mlir | Adds lit test ensuring 32x8 MFMA32 store uses the #mma-layout fallback rather than the wide-store linear layout. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
zhanglx13
approved these changes
May 15, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes a bug where the wide-store epilogue path builds a derived linear layout by swapping N-dimension basis bits. For some valid shapes, the N dimension is too small for the target basis bit to exist, so layout construction indexed past the end of the basis vector list and the compiler would crash. This change makes the store-layout helper bail out in that case, allowing the existing fallback path to handle the store instead of crashing.
This issue can be reproduced on gfx950 with the following small program:
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)