Skip to content

rocm/ck: switch FlashAttention to the CK vendored via aiter and fix current CK API integration#2350

Open
eliasmagn wants to merge 4 commits intoDao-AILab:mainfrom
eliasmagn:aiter_ck_fashattn_mi300x_validated
Open

rocm/ck: switch FlashAttention to the CK vendored via aiter and fix current CK API integration#2350
eliasmagn wants to merge 4 commits intoDao-AILab:mainfrom
eliasmagn:aiter_ck_fashattn_mi300x_validated

Conversation

@eliasmagn
Copy link
Copy Markdown

@eliasmagn eliasmagn commented Mar 14, 2026

This PR updates FlashAttention’s ROCm CK backend to build against the Composable Kernel vendored via aiter instead of the old direct csrc/composable_kernel submodule.

It also updates FlashAttention’s CK wrapper layer to match the newer CK FMHA API used by the vendored CK. The primary validation target for this change is MI300X.

I realize this is a fairly large integration change. I tried to keep the scope limited to FlashAttention’s own CK integration layer and not modify aiter itself. If maintainers would prefer this to be split into smaller PRs, I’m very happy to do that.

Summary

  • remove the direct FlashAttention CK submodule dependency
  • build the ROCm CK backend against third_party/aiter/3rdparty/composable_kernel
  • adapt FlashAttention’s CK wrapper layer to the newer CK FMHA API
  • keep temporary dual CK FMHA API support in the FlashAttention wrapper layer during the migration
  • detect the current CK FMHA API from the actual vendored CK headers
  • fix splitKV / appendKV argument initialization for the current CK API
  • keep the FlashAttention-side gfx11 build/codegen support carried in this branch

Why

When switching from the old in-tree CK submodule to the CK vendored via aiter, several previously working MI300X paths regressed, mainly:

  • kvcache
  • paged varlen_causal
  • deterministic / varlen deterministic paths

The root cause was in FlashAttention’s own CK integration layer, not in aiter itself:

  • incomplete adaptation to the current CK FMHA API
  • stale or partially initialized forward argument structs
  • incorrect current-vs-legacy CK API selection

This PR fixes those issues in FlashAttention’s CK wrapper layer without modifying aiter.

Transitional dual CK API support

To reduce migration risk during the port, the FlashAttention CK wrapper layer currently keeps build-time support for both legacy and current CK FMHA APIs. This is intended as transitional migration support rather than a permanent design requirement.

This is not present in upstream today and is not the primary goal of the PR. It was added as a transition aid so the wrapper layer could be ported and validated more safely while moving from the old in-tree CK integration to the newer CK vendored via aiter.

In practice, this helped reduce migration risk by allowing the wrapper layer to continue compiling against older and newer CK API layouts during the port, and by making it easier to distinguish wrapper-layer migration issues from issues in the vendored integration path itself.

If maintainers prefer, this transitional compatibility layer can be removed or simplified in a follow-up once only the newer CK FMHA API path needs to be supported. I started the port to the new CK before aiter got included. Now this feature is possibly not needed anymore.

Validation

I did not run the full test_flash_attn_ck.py, i ran a total set of 46 tests.

Primary validation target: MI300X

Legend:

  • PASS: validated and passing
  • FAIL: test completed but failed
  • CRASH: runtime failure / crash
  • not tested: not part of this comparison set
  • no support: baseline path not supported on that target
case MI300X main MI300X aiter gfx1102 main gfx1102 aiter
qkvpacked fp16 PASS PASS no support PASS
qkvpacked bf16 causal PASS PASS no support PASS
qkvpacked dropout PASS PASS no support PASS
qkvpacked local PASS PASS no support PASS
qkvpacked alibi PASS PASS no support PASS
varlen_qkvpacked fp16 PASS PASS no support PASS
varlen_qkvpacked bf16 causal PASS PASS no support PASS
varlen_qkvpacked dropout PASS PASS no support PASS
varlen_qkvpacked stress combo not tested PASS no support CRASH
output mha PASS PASS no support PASS
output mqa PASS PASS no support PASS
output gqa kvpacked PASS PASS no support PASS
output stress combo not tested PASS no support CRASH
varlen_output PASS PASS no support PASS
varlen_output stress combo not tested PASS no support CRASH
causal PASS PASS no support PASS
causal local bf16 not tested PASS no support PASS
varlen_causal None PASS PASS no support PASS
varlen_causal 256 PASS PASS no support PASS
varlen_causal 512 PASS PASS no support PASS
kvcache PASS PASS no support PASS
kvcache paged not tested PASS no support PASS
kvcache has_batch_idx not tested PASS no support PASS
kvcache rotary/new_kv not tested PASS no support PASS
kvcache mqa/gqa paged not tested PASS no support PASS
kvcache stress combo not tested PASS no support PASS
bwd_transpose PASS PASS no support PASS
bwd_varlen_overflow not tested PASS no support PASS
deterministic PASS PASS no support FAIL
varlen_deterministic PASS PASS no support FAIL
race_condition not tested PASS no support PASS
bwd_overflow PASS PASS no support PASS

Representative MI300X validation commands and environment:

excerpt:
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_kvcache[1-128-64-False-False-None-0.0-False-True-True-False-False-False-mha-0-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_causal[256-128-217-False-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_causal[512-128-217-False-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_deterministic[128-217-False-64-False-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic[128-217-False-64-False-False-dtype1]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_race_condition[0.0-128-128-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_kvcache[1-128-256-False-False-256-1.0-True-True-True-True-True-True-mha-1-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_qkvpacked[0.17-1025-128-True-True-True-True-dtype1]' -x

Validation environment (relevant parts):

  • tested on an AMD droplet
  • Python 3.12
  • ROCm 6.4.1
  • torch 2.7.0+gitf717b2a
  • triton 3.2.0+gite5be006a
  • aiter 0.1.5.dev65+g4822e675
  • flash_attn 2.7.4.post1
  • tested on MI300X and gfx1102

System tools:

  • AMDSMI Tool: 25.4.2+aca1101
  • AMDSMI Library: 25.4.0
  • ROCM-SMI: 3.0.0+e68c0d1
  • ROCM-SMI-LIB: 7.5.0
  • amdgpu: 6.12.12

Validation environment on gfx1102 differed substantially from the MI300X setup and used a theRock / ROCm 7.12 dev stack rather than the ROCm 6.4.1 environment used for MI300X validation.

Relevant gfx1102 environment details:

  • Python 3.12
  • theRock / ROCm 7.12.0a20260304
  • torch 2.12.0a0+rocm7.12.0a20260304
  • triton 3.6.0+gitadd9159a.rocm7.12.0a20260304
  • flash_attn 2.8.4+aiter.ck.fashattn.586a96f.cutlass.71275920.aiter.428e8e761

Because this is a newer dev-stack environment with mixed local/dev/pip installations, gfx1102 results should be interpreted as bring-up / compatibility status rather than as a direct comparison to the MI300X validation environment.

Notes on gfx1102 / gfx11

I have a gfx1102 workstation with 8 rx7600xt. The test where made against theRock 7.12.

Current status on gfx1102:

  • several core paths already build and run
  • some gfx11-specific failures are still open:
    • kvcache
    • paged varlen_causal
    • deterministic
    • varlen_deterministic

Additional gfx1102 validation is still pending, and further work on gfx11-specific issues is ongoing.

For this PR, i expected the primary acceptance bar is MI300X / ROCm CK correctness.

Scope

This PR intentionally does not modify aiter. It only updates FlashAttention’s own CK integration layer so that FlashAttention can correctly build and run against the CK vendored via aiter.

If this is too much to review in one PR, I’d be happy to split it into smaller pieces.

@rocking5566
Copy link
Copy Markdown
Contributor

Hi @eliasmagn, thanks for this comprehensive work!

I noticed this PR covers both the CK API alignment and the migration to aiter-vendored CK. To make things easier to review and merge incrementally, I've opened #2363 which focuses specifically on updating the CK submodule and aligning the FlashAttention CK wrapper layer to the latest CK FMHA API — without switching to aiter yet.

Splitting this into smaller PRs also helps with bisecting if regressions come up later — by landing the CK API update first and the aiter migration separately, it's much easier to pinpoint which change introduced an issue.

Once #2363 lands, it should reduce the diff here significantly and let this PR focus on the aiter migration and gfx11 support parts.

@eliasmagn
Copy link
Copy Markdown
Author

@tridao, I need a clear direction here.

This PR predates the later PR sequence, and the migration/correctness work that now overlaps with #2364 is already written and tested in this branch
I’ve already offered to split this further if that would make review easier, but I do not want to keep reshaping already validated work without knowing whether this PR is still intended to be reviewed as the vehicle for it.
If it is, I’m happy to continue. If not, please say so directly.

@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 22, 2026

I'll defer to @rocking5566 on this since he's more familiar with the AMD part of the codebase. @rocking5566 what do you think?

@rocking5566
Copy link
Copy Markdown
Contributor

rocking5566 commented Mar 24, 2026

Hi @eliasmagn, thanks for raising this — let me clarify the situation.

I think there may be some confusion between this PR (#2350) and #2364. They are doing different things:

As for #2363 (now merged) — that focused on aligning our CK wrapper to the latest CK FMHA API changes using the existing submodule. Your work here on switching the build to use aiter-vendored CK is a logical next step.

One more note on the long-term picture: we do want to keep the direct CK API path maintained even after #2364 eventually lands. The aiter API dispatches to either the CK backend or the asm backend, and the asm backend can produce incorrect results in some cases. Having the direct CK codepath as a fallback/reference is important.

I'm happy to review this PR. If you could rebase on top of the current main (since #2363 has been merged and may overlap with some of your CK API alignment changes), that would help move things forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants