[AMD ROCm] Update ROCm/CK backend to align with latest ComposableKernel API changes#2363
[AMD ROCm] Update ROCm/CK backend to align with latest ComposableKernel API changes#2363tridao merged 15 commits intoDao-AILab:mainfrom
Conversation
|
Hi @rocking5566, nice to see we converged on a very similar direction here. When I saw the recent aiter-related merges, I decided to base this on the CK copy already pulled in through aiter, rather than keeping two overlapping CK copies in the project. The main motivation on my side was to remove that duplication and rely on the CK path that is already exercised as part of the aiter project. In any case, I understand the preference for the split you described. I’m happy to split my PR (#2350) accordingly when the time comes, if that helps with review and integration. |
Thanks @eliasmagn! That makes total sense — consolidating on the Once this PR gets merged, you should be able to rebase #2350 on top of it, which should significantly reduce the diff and let your PR focus on the @tridao would you mind taking a look and merging this one when you get a chance? It should also unblock #2350 for a cleaner integration. Thanks! |
…el API changes (Dao-AILab#2363) * update ck * update ck * before gpt-oss sink * gpt-oss sink * Add missing parameter * Fix typo * Update to ROCm/composable_kernel@b09112b * add -Wno-unknown-warning-option * Update to ROCm/rocm-libraries#4368 (ROCm/rocm-libraries@17f7dfc) * Update to ROCm/rocm-libraries@a358a21 --------- Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: Yi DING <andy-ding@outlook.com>
Summary
Update the AMD ROCm ComposableKernel (CK) backend to be compatible with the latest CK FMHA API changes, including new fields for MX (microscaling) FP8 support, attention sink, and improved backward pass
dq_accummemory layout.Changes
CK backend API alignment (
csrc/flash_attn_ck/)fmha_fwd_args,fmha_fwd_splitkv_traitsandfmha_bwd_traitsto align with upstream CKdq_accumtensor layoutinmha_bwd.cppandmha_varlen_bwd.cppto aligns with upstream CKnsplitsfromfmha_bwd_launcherinstead of hardcoding split count logic on the host sideComposableKernel submodule update
csrc/composable_kernelfrom13f6d635to574c1c12Build system (
setup.py)-Wno-unknown-warning-optionand-fbracket-depth=1024compiler flags for ROCm CK buildsTesting
pytest tests/test_flash_attn_ck.pyon MI300 and MI350