Skip to content

[AMD] Guard MFMA store layout for small N dimensions#10305

Merged
antiagainst merged 1 commit into
triton-lang:mainfrom
justinrosner:justinr-guard-mfma-store-layout
May 15, 2026
Merged

[AMD] Guard MFMA store layout for small N dimensions#10305
antiagainst merged 1 commit into
triton-lang:mainfrom
justinrosner:justinr-guard-mfma-store-layout

Conversation

@justinrosner
Copy link
Copy Markdown
Contributor

This PR fixes a bug where the wide-store epilogue path builds a derived linear layout by swapping N-dimension basis bits. For some valid shapes, the N dimension is too small for the target basis bit to exist, so layout construction indexed past the end of the basis vector list and the compiler would crash. This change makes the store-layout helper bail out in that case, allowing the existing fallback path to handle the store instead of crashing.

This issue can be reproduced on gfx950 with the following small program:

@triton.jit
def repro(A, B, C):
    m = tl.arange(0, 128)
    k = tl.arange(0, 16)
    n = tl.arange(0, 8)
    a = tl.load(A + m[:, None] * 16 + k[None, :])
    b = tl.load(B + k[:, None] * 8 + n[None, :])
    acc = tl.dot(a, b, out_dtype=tl.float32)
    out = acc.to(tl.bfloat16)
    tl.store(C + m[:, None] * 8 + n[None, :], out)

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@justinrosner justinrosner marked this pull request as ready for review May 14, 2026 19:34
@justinrosner justinrosner requested a review from lezcano as a code owner May 14, 2026 19:34
Copilot AI review requested due to automatic review settings May 14, 2026 19:34
@justinrosner justinrosner requested a review from ptillet as a code owner May 14, 2026 19:34
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes an AMD-specific crash in chooseMfmaLikeStoreLayout where building a wide-store linear layout would index past the end of dimNBases when the N dimension is too small. The helper now bails out early (returning {}) so the existing fallback store path is used instead.

Changes:

  • Compute destIdxInBases earlier and guard against N-dimension log2 sizes that are too small before constructing the swap layout.
  • Remove the now-redundant later declaration of destIdxInBases.
  • Add a lit test for mfma32 with N=8 verifying the fallback store path is taken.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp Early-exit guard when dimN log2 size is ≤ target basis bit index; relocates destIdxInBases definition.
test/TritonGPU/amd/amd-optimize-epilogue.mlir Adds lit test ensuring 32x8 MFMA32 store uses the #mma-layout fallback rather than the wide-store linear layout.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@antiagainst antiagainst merged commit 7de28f0 into triton-lang:main May 15, 2026
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants