rocm/ck: switch FlashAttention to the CK vendored via aiter and fix current CK API integration#2350
Conversation
|
Hi @eliasmagn, thanks for this comprehensive work! I noticed this PR covers both the CK API alignment and the migration to Splitting this into smaller PRs also helps with bisecting if regressions come up later — by landing the CK API update first and the Once #2363 lands, it should reduce the diff here significantly and let this PR focus on the |
|
@tridao, I need a clear direction here. This PR predates the later PR sequence, and the migration/correctness work that now overlaps with #2364 is already written and tested in this branch |
|
I'll defer to @rocking5566 on this since he's more familiar with the AMD part of the codebase. @rocking5566 what do you think? |
|
Hi @eliasmagn, thanks for raising this — let me clarify the situation. I think there may be some confusion between this PR (#2350) and #2364. They are doing different things:
As for #2363 (now merged) — that focused on aligning our CK wrapper to the latest CK FMHA API changes using the existing submodule. Your work here on switching the build to use aiter-vendored CK is a logical next step. One more note on the long-term picture: we do want to keep the direct CK API path maintained even after #2364 eventually lands. The aiter API dispatches to either the CK backend or the asm backend, and the asm backend can produce incorrect results in some cases. Having the direct CK codepath as a fallback/reference is important. I'm happy to review this PR. If you could rebase on top of the current main (since #2363 has been merged and may overlap with some of your CK API alignment changes), that would help move things forward. |
This PR updates FlashAttention’s ROCm CK backend to build against the Composable Kernel vendored via
aiterinstead of the old directcsrc/composable_kernelsubmodule.It also updates FlashAttention’s CK wrapper layer to match the newer CK FMHA API used by the vendored CK. The primary validation target for this change is MI300X.
I realize this is a fairly large integration change. I tried to keep the scope limited to FlashAttention’s own CK integration layer and not modify
aiteritself. If maintainers would prefer this to be split into smaller PRs, I’m very happy to do that.Summary
third_party/aiter/3rdparty/composable_kernelWhy
When switching from the old in-tree CK submodule to the CK vendored via
aiter, several previously working MI300X paths regressed, mainly:kvcachevarlen_causalThe root cause was in FlashAttention’s own CK integration layer, not in
aiteritself:This PR fixes those issues in FlashAttention’s CK wrapper layer without modifying
aiter.Transitional dual CK API support
To reduce migration risk during the port, the FlashAttention CK wrapper layer currently keeps build-time support for both legacy and current CK FMHA APIs. This is intended as transitional migration support rather than a permanent design requirement.
This is not present in upstream today and is not the primary goal of the PR. It was added as a transition aid so the wrapper layer could be ported and validated more safely while moving from the old in-tree CK integration to the newer CK vendored via
aiter.In practice, this helped reduce migration risk by allowing the wrapper layer to continue compiling against older and newer CK API layouts during the port, and by making it easier to distinguish wrapper-layer migration issues from issues in the vendored integration path itself.
If maintainers prefer, this transitional compatibility layer can be removed or simplified in a follow-up once only the newer CK FMHA API path needs to be supported. I started the port to the new CK before aiter got included. Now this feature is possibly not needed anymore.
Validation
I did not run the full test_flash_attn_ck.py, i ran a total set of 46 tests.
Primary validation target: MI300X
Legend:
PASS: validated and passingFAIL: test completed but failedCRASH: runtime failure / crashnot tested: not part of this comparison setno support: baseline path not supported on that targetRepresentative MI300X validation commands and environment:
excerpt:
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_kvcache[1-128-64-False-False-None-0.0-False-True-True-False-False-False-mha-0-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_causal[256-128-217-False-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_causal[512-128-217-False-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_deterministic[128-217-False-64-False-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_deterministic[128-217-False-64-False-False-dtype1]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_race_condition[0.0-128-128-64-False-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_kvcache[1-128-256-False-False-256-1.0-True-True-True-True-True-True-mha-1-dtype0]' -x
pytest -q -s 'tests/test_flash_attn_ck.py::test_flash_attn_varlen_qkvpacked[0.17-1025-128-True-True-True-True-dtype1]' -x
Validation environment (relevant parts):
System tools:
Validation environment on gfx1102 differed substantially from the MI300X setup and used a theRock / ROCm 7.12 dev stack rather than the ROCm 6.4.1 environment used for MI300X validation.
Relevant gfx1102 environment details:
torch 2.12.0a0+rocm7.12.0a20260304triton 3.6.0+gitadd9159a.rocm7.12.0a20260304flash_attn 2.8.4+aiter.ck.fashattn.586a96f.cutlass.71275920.aiter.428e8e761Because this is a newer dev-stack environment with mixed local/dev/pip installations, gfx1102 results should be interpreted as bring-up / compatibility status rather than as a direct comparison to the MI300X validation environment.
Notes on gfx1102 / gfx11
I have a gfx1102 workstation with 8 rx7600xt. The test where made against theRock 7.12.
Current status on
gfx1102:kvcachevarlen_causaldeterministicvarlen_deterministicAdditional gfx1102 validation is still pending, and further work on gfx11-specific issues is ongoing.
For this PR, i expected the primary acceptance bar is MI300X / ROCm CK correctness.
Scope
This PR intentionally does not modify
aiter. It only updates FlashAttention’s own CK integration layer so that FlashAttention can correctly build and run against the CK vendored viaaiter.If this is too much to review in one PR, I’d be happy to split it into smaller pieces.