Skip to content

[CK] Fix RDNA3 FMHA tile-load paths#7016

Merged
illsilin merged 12 commits into
developfrom
users/jam/gfx12-buffer-load-fallback
May 19, 2026
Merged

[CK] Fix RDNA3 FMHA tile-load paths#7016
illsilin merged 12 commits into
developfrom
users/jam/gfx12-buffer-load-fallback

Conversation

@jammm
Copy link
Copy Markdown
Contributor

@jammm jammm commented May 2, 2026

Summary

Fix CK tile FMHA paths needed for RDNA3/RDNA4 targets.

Details

This PR addresses RDNA-specific issues hit while enabling xFormers CK FMHA on gfx11/gfx12:

  • On RDNA3, update FMHA P tile handling so the layout consumed by the second GEMM matches the WMMA path.

Testing

Validated downstream with xFormers CK/FMHA on gfx1201/gfx1151.

pytest --import-mode=importlib -q \
  tests/test_mem_eff_attention.py::test_forward \
  tests/test_mem_eff_attention.py::test_backward \
  tests/test_mem_eff_attention.py::test_dropout_ck

3844 passed, 5244 skipped, 26 warnings

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 2, 2026

Thanks @jammm for your help and guidance in enabling xformers CK on RDNA/Windows!

Comment thread projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing.hpp Outdated
Comment thread projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing.hpp Outdated
Comment thread projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing.hpp Outdated
Comment thread projects/composablekernel/include/ck_tile/core/arch/amd_buffer_addressing.hpp Outdated
@jammm jammm force-pushed the users/jam/gfx12-buffer-load-fallback branch from acb80cd to 9e55b0f Compare May 5, 2026 06:38
@jammm jammm requested a review from poyenc May 5, 2026 06:41
@hyoon1
Copy link
Copy Markdown
Contributor

hyoon1 commented May 5, 2026

I’m wondering if this fallback is actually safe.

My understanding is that on gfx12, we should not be calling async buffer load from the upper layer in the first place. If async buffer load is not supported or not intended to be used on gfx12, then silently falling back here may hide an incorrect/suboptimal code path.

Wouldn’t the proper fix be to update the kernel/code path that currently emits async buffer load, so that it calls the regular buffer load directly on gfx12 instead?

@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 5, 2026

Wouldn’t the proper fix be to update the kernel/code path that currently emits async buffer load, so that it calls the regular buffer load directly on gfx12 instead?

Tried to fix this with the following. PTAL:
a13ca21

@0xDELUXA
Copy link
Copy Markdown

0xDELUXA commented May 5, 2026

Wouldn’t the proper fix be to update the kernel/code path that currently emits async buffer load, so that it calls the regular buffer load directly on gfx12 instead?

Tried to fix this with the following. PTAL: a13ca21

I might be mistaken, but does this mean it’s safe to fall back to gfx103 and gfx11, and only gfx12 needs this change?

@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 5, 2026 via email

gfx12 falls back from async global-to-LDS loads to sync VGPR loads plus LDS stores. The async raw API relies on buffer OOB behavior instead of tensor-coordinate validity, so keep the sync fallback aligned with that raw-load contract.
@jammm jammm changed the title Fix gfx12 async buffer load fallback [CK] Fix RDNA3/RDNA4 FMHA tile-load paths May 6, 2026
@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 6, 2026

This PR now has CK fixes relevant to xformers for both RDNA3/4.

Copy link
Copy Markdown
Contributor

@poyenc poyenc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The gfx11 P-tile remap (PermuteWarpGemmCToA) and GetSmemKPackV fix are correct — WMMA lane layout differences between GEMM C and A tiles require this permutation for the P×V matmul, and the K-pack needs to match kKPerThread for WMMA. No concerns with those changes.

However, I have concerns about the gfx12/gfx103 synchronous fallbacks added to the core tile infrastructure (load_tile.hpp, tile_window.hpp, tile_window_linear.hpp, amd_buffer_addressing.hpp).

These fallbacks are dead code for all currently-dispatched paths. gfx12 FMHA only dispatches "qr" / "qr_hpad" (fully synchronous, never calls async_load_tile*). gfx11 FMHA similarly dispatches "qr" only. No async pipeline is dispatched on either architecture today.

The deeper problem is that these fallbacks sit in core tile infrastructure shared by GEMM, FlatMM, Fused MoE, Sparse Attention, and FMHA. Any future code that accidentally instantiates an async pipeline on gfx12 will silently compile and run correctly — but with all load/compute overlap removed. Without the fallbacks, the same mistake would produce a compile-time static_assert or runtime illegal instruction — immediate, obvious failure. Silent performance degradation is much harder to catch than a crash or compile error.

Also, __gfx12__ covers gfx1250 (MI450), which has dedicated TENSOR_LOAD_TO_LDS hardware. A blanket gfx12 fallback would prevent future async pipeline work on MI450 from using the correct instruction, forcing everything through the synchronous path instead.

Suggestion: Remove the gfx12 fallbacks from core infrastructure and let unsupported paths fail loudly at compile time. If a specific pipeline needs gfx12 support, the fallback should live in that pipeline — not in the shared tile load layer where it silently affects everything.

Drop the shared RDNA/gfx12 synchronous fallbacks from the core tile-load path so unsupported async pipelines continue to fail loudly instead of silently losing overlap. Keep the gfx11 FMHA-specific WMMA layout and K-pack fixes in the pipeline layer.
@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 11, 2026

@poyenc removed the fallbacks. PTAL ^^

@poyenc poyenc requested a review from a team May 12, 2026 03:08
@hyoon1
Copy link
Copy Markdown
Contributor

hyoon1 commented May 12, 2026

Async pipelines and the TRLoad/QS pipelines are neither used nor validated on gfx11 or gfx12, so we should avoid going down those paths. If execution somehow reaches them, there’s a risk that things could fail silently without any obvious warning or error. I’m not sure whether whole_k_prefetch has been properly validated either, but it seems like we should keep the code changes minimal and only touch the parts that are truly necessary.

@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 12, 2026

Async pipelines and the TRLoad/QS pipelines are neither used nor validated on gfx11 or gfx12, so we should avoid going down those paths. If execution somehow reaches them, there’s a risk that things could fail silently without any obvious warning or error. I’m not sure whether whole_k_prefetch has been properly validated either, but it seems like we should keep the code changes minimal and only touch the parts that are truly necessary.

I've only touched the parts necessary to get xformers running. Tests are passing there. Is there anything you need on this PR to be done before we can merge merge?

@hyoon1
Copy link
Copy Markdown
Contributor

hyoon1 commented May 12, 2026

Async pipelines and the TRLoad/QS pipelines are neither used nor validated on gfx11 or gfx12, so we should avoid going down those paths. If execution somehow reaches them, there’s a risk that things could fail silently without any obvious warning or error. I’m not sure whether whole_k_prefetch has been properly validated either, but it seems like we should keep the code changes minimal and only touch the parts that are truly necessary.

I've only touched the parts necessary to get xformers running. Tests are passing there. Is there anything you need on this PR to be done before we can merge merge?

The async, trload, qs pipeline files currently contain gfx11-related code, but these pipelines are not used on RDNA. If this unverified path is ever taken, it could lead to potential issues. I'm also not sure whether the performance of these changes has actually been validated, or whether the modified code paths are truly covered by existing tests.

At a minimum, we should make sure RDNA cannot enter these pipelines. To prevent this more reliably, I think it would be better to remove the RDNA-related code from these pipeline files altogether.

I’m not very familiar with xFormers, but if we want to minimize the scope, it seems that only the patches related to qr_ks_vs_whole_k_prefetch should be necessary.

@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 12, 2026

@hyoon1 makes sense. Neither gfx11 nor gfx12 supports async load/store, so the app calling CK (xformers in this case) shouldn't be going through the async pipeline. As for gfx12, it does have support for some transpose load instructions, but that's a separate topic.

I've pushed a commit that removes the gfx11 changes in the async/tr pipelines. The xformers PR has been modified to not use async pipelines for gfx11/12 ROCm/xformers@34064c9

@jammm jammm force-pushed the users/jam/gfx12-buffer-load-fallback branch from 1a7b409 to 25012a7 Compare May 12, 2026 18:37
poyenc
poyenc approved these changes May 13, 2026
@jammm jammm enabled auto-merge (squash) May 13, 2026 08:25
@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 13, 2026

@poyenc thanks! I enabled auto-merge but it's waiting on "Math CI Summary" which doesn't seem to have triggered yet for many hours.

@shumway shumway disabled auto-merge May 14, 2026 20:43
@poyenc
Copy link
Copy Markdown
Contributor

poyenc commented May 14, 2026

@jammm The Math CI failure on build #2 is unrelated to your changes — AITER's mha_bwd.cu is referencing fmha_bwd_launcher members (workspace_size, prepare_workspace) and initializer fields that were changed upstream. Build #3 is currently running; if it hits the same AITER stage failure, we can skip that test and re-trigger.

@jammm jammm changed the title [CK] Fix RDNA3/RDNA4 FMHA tile-load paths [CK] Fix RDNA3 FMHA tile-load paths May 15, 2026
@jammm jammm enabled auto-merge (squash) May 15, 2026 06:18
@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 15, 2026

Math CI is still stuck, and AITER test was still failing. Can we skip those? @poyenc

@poyenc
Copy link
Copy Markdown
Contributor

poyenc commented May 15, 2026

@jammm I have already turn AITER tests off.. now it fails on the Build CK and run Tests on gfx942 stage. And the error log indicates that CI failed to run the cat command

cat: /sys/module/amdgpu/version: No such file or directory

@poyenc
Copy link
Copy Markdown
Contributor

poyenc commented May 18, 2026

The 24 run_sink_init_tests failures on gfx1201 (Build #6) are addressed by #7530.

Root cause: When -init_sink=1 -mask=1 is passed, traits.has_sink doesn't check init_sink_value, so the *_nsink (no-sink) kernel is dispatched. The reference expects sink-initialized output (zeros for masked positions), but the GPU produces standard attention output — resulting in 100% wrong values across all head dims (d=64/128/256), precisions (fp16/bf16), modes (batch/group), and layouts (bshd/bhsd).

Fix in #7530:

  1. Includes init_sink_value != 0 in the has_sink trait check so the sink-enabled kernel is dispatched correctly.
  2. Gates run_sink_init_tests behind an opt-in -g flag in smoke_test_fwd.sh, since sink=true kernels are excluded from CI builds by the *_nsink* CMake filter.

@jammm
Copy link
Copy Markdown
Contributor Author

jammm commented May 18, 2026

@poyenc thanks for the heads up! Given those failing tests are fixed in #7530. Can we skip the check and merge this PR?

@jammm jammm disabled auto-merge May 18, 2026 10:13
@illsilin illsilin merged commit 2b73c00 into develop May 19, 2026
31 checks passed
@illsilin illsilin deleted the users/jam/gfx12-buffer-load-fallback branch May 19, 2026 13:41
assistant-librarian Bot pushed a commit to ROCm/composable_kernel that referenced this pull request May 19, 2026
[CK] Fix RDNA3 FMHA tile-load paths

## Summary

Fix CK tile FMHA paths needed for RDNA3/RDNA4 targets.

## Details

This PR addresses RDNA-specific issues hit while enabling xFormers CK
FMHA on gfx11/gfx12:

- On RDNA3, update FMHA P tile handling so the layout consumed by the
second GEMM matches the WMMA path.

## Testing

Validated downstream with xFormers CK/FMHA on gfx1201/gfx1151.

```text
pytest --import-mode=importlib -q \
  tests/test_mem_eff_attention.py::test_forward \
  tests/test_mem_eff_attention.py::test_backward \
  tests/test_mem_eff_attention.py::test_dropout_ck

3844 passed, 5244 skipped, 26 warnings
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants