Skip to content

[AMD ROCm] Update CK and add RDNA 3/4 support#2400

Merged
tridao merged 11 commits intoDao-AILab:mainfrom
ROCm:ck_improve_v0.1.10
Mar 26, 2026
Merged

[AMD ROCm] Update CK and add RDNA 3/4 support#2400
tridao merged 11 commits intoDao-AILab:mainfrom
ROCm:ck_improve_v0.1.10

Conversation

@rocking5566
Copy link
Copy Markdown
Contributor

Summary

  • Bump CK submodule to 859acb5 and align FMHA argument wiring
  • Add gfx11 (RDNA 3) and gfx12 (RDNA 4) build support in setup.py
  • Guard backward: disabled on gfx11, deterministic-only disabled on gfx12
  • Add shared LLC head-grouping forward helper (mha_fwd_head_grouping_utils.hpp)
  • Update tests with arch-aware skip logic and dynamic skip messages

Co-worked with @hyoon1 (author of #2054). This is a more complete version — supersedes #2054. Would like to merge before #2350 (aiter migration) to avoid conflicts.

Test Plan

  • pytest tests/test_flash_attn_ck.py on gfx11 and gfx12, no failures
  • Existing CDNA (MI300/MI350) paths unaffected

@tridao
Copy link
Copy Markdown
Member

tridao commented Mar 26, 2026

kk just let me know the order we should merge. Does this PR go first?

@rocking5566
Copy link
Copy Markdown
Contributor Author

rocking5566 commented Mar 26, 2026

Yes, this PR should go first. #2350 (aiter migration) still uses the older CK version and hasn't been updated yet. Merging this first avoids conflicts since #2350 will need to rebase on top of the updated CK submodule.

@tridao tridao merged commit 5301a35 into Dao-AILab:main Mar 26, 2026
@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Mar 28, 2026

This is great!

I thought we couldn’t use FlashAttention's CK backend on RDNA/Windows, but this PR enabled it.

I’m currently building FA from source without setting $env:FLASH_ATTENTION_TRITON_AMD_ENABLE = "TRUE", and it’s actually compiling CK for gfx1200. It’ll take a while and might fail at any point - I’ll report back later.

Here’s what I ran if anyone wants to try:

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/

cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }

.\venv\Scripts\Activate.ps1

$ROCM_ROOT = (rocm-sdk path --root).Trim()
$ROCM_BIN = (rocm-sdk path --bin).Trim()
$env:ROCM_HOME = $ROCM_ROOT
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"

$env:CC = "clang-cl"
$env:CXX = "clang-cl"
$env:DISTUTILS_USE_SDK = "1"

$env:MAX_JOBS = "8"
pip install --no-build-isolation -v .

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Mar 29, 2026

FlashAttention 2 Composable Kernel backend on Windows

With the help of @astrelsky, after some debugging for Windows support, we successfully built and ran FlashAttention 2 on Windows using the CK backend.

Verification

python -c "from flash_attn import flash_attn_func; print('OK:', flash_attn_func); import flash_attn_2_cuda; print(flash_attn_2_cuda); print(dir(flash_attn_2_cuda))"

OK: <function flash_attn_func at 0x00000237FD015760>
<module 'flash_attn_2_cuda' from 'C:\\venv\\Lib\\site-packages\\flash_attn_2_cuda.cp312-win_amd64.pyd'>
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'bwd', 'fwd', 'fwd_kvcache', 'varlen_bwd', 'varlen_fwd']

Benchmark

(Using this script, it includes SDPA Flash via AOTriton and SageAttention 1.0.6)

FA2 Triton backend: [aiter] Windows: CK and HIP ops are not available. Triton ops only.
FA2 CK backend: C:\venv\Lib\site-packages\flash_attn_2_cuda.cp312-win_amd64.pyd

SeqLen     | Method                   | Latency      | TFLOPS     | vs SDPA
------------------------------------------------------------------------------
4096       | PyTorch SDPA             |     5.09 ms |    27.02 | 1.00x (ref)
           | SageAttn V1              |     8.51 ms |    16.16 |  0.60x
           | FA2 Triton               |     4.18 ms |    32.87 |  1.22x
           | FA2 CK                   |     2.62 ms |    52.36 |  1.94x
------------------------------------------------------------------------------
8192       | PyTorch SDPA             |    18.89 ms |    29.10 | 1.00x (ref)
           | SageAttn V1              |    35.01 ms |    15.70 |  0.54x
           | FA2 Triton               |    15.79 ms |    34.82 |  1.20x
           | FA2 CK                   |    14.94 ms |    36.79 |  1.26x
------------------------------------------------------------------------------
16384      | PyTorch SDPA             |    74.44 ms |    29.54 | 1.00x (ref)
           | SageAttn V1              |   136.18 ms |    16.15 |  0.55x
           | FA2 Triton               |    61.57 ms |    35.72 |  1.21x
           | FA2 CK                   |    42.80 ms |    51.37 |  1.74x
------------------------------------------------------------------------------
32768      | PyTorch SDPA             |   299.98 ms |    29.32 | 1.00x (ref)
           | SageAttn V1              |   548.98 ms |    16.02 |  0.55x
           | FA2 Triton               |   246.00 ms |    35.76 |  1.22x
           | FA2 CK                   |   169.07 ms |    52.02 |  1.77x
------------------------------------------------------------------------------

Verification

I also verified CK Flash by running:

python -m pytest tests/test_flash_attn_ck.py::test_flash_attn_output -v --tb=short

All tests completed successfully:

===================================== 42240 passed in 1627.55s (0:27:07) =====================================

Environment

OS: Windows 11
Python: 3.12.10
ROCm: 7.13.0a20260328 (TheRock)
PyTorch: 2.10.0+rocm7.13.0a20260328
GPU: AMD Radeon RX 9060 XT (gfx1200)
Triton: 3.6.0+gitae9d5a54.post27 (triton-windows)

Current issue preventing CK Flash from building on Windows

@jnolck
Copy link
Copy Markdown

jnolck commented Mar 31, 2026

I'm running your script on Linux with a gfx1100 card and my numbers are very different. I do have this merged in with the triton backend ROCm/aiter#2483 . Maybe that's why.
Screenshot From 2026-03-31 12-27-26

@hyoon1
Copy link
Copy Markdown
Contributor

hyoon1 commented Mar 31, 2026

I'm running your script on Linux with a gfx1100 card and my numbers are very different. I do have this merged in with the triton backend ROCm/aiter#2483 . Maybe that's why. Screenshot From 2026-03-31 12-27-26

@0xDELUXA @jnolck
Above script passes the CK fwd(...) arguments in the wrong order and makes buggy results

Current code:
flash_attn_2_cuda.fwd(..., SOFTMAX_SCALE, 0.0, ...)

But CK expects:
flash_attn_2_cuda.fwd(..., 0.0, SOFTMAX_SCALE, ...)

After fixing that argument order bug, now I get above 70 TFLOPs for those cases on gfx1100.

@tianwyan
Copy link
Copy Markdown

tianwyan commented Apr 1, 2026

FlashAttention 2 Composable Kernel backend on Windows

With the help of @astrelsky, after some debugging for Windows support, we successfully built and ran FlashAttention 2 on Windows using the CK backend.

Verification

python -c "from flash_attn import flash_attn_func; print('OK:', flash_attn_func); import flash_attn_2_cuda; print(flash_attn_2_cuda); print(dir(flash_attn_2_cuda))"

OK: <function flash_attn_func at 0x00000237FD015760>
<module 'flash_attn_2_cuda' from 'C:\\venv\\Lib\\site-packages\\flash_attn_2_cuda.cp312-win_amd64.pyd'>
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'bwd', 'fwd', 'fwd_kvcache', 'varlen_bwd', 'varlen_fwd']

Benchmark

(Using this script, it includes SDPA Flash via AOTriton and SageAttention 1.0.6)

FA2 Triton backend: [aiter] Windows: CK and HIP ops are not available. Triton ops only.
FA2 CK backend: C:\venv\Lib\site-packages\flash_attn_2_cuda.cp312-win_amd64.pyd

SeqLen     | Method                   | Latency      | TFLOPS     | vs SDPA
------------------------------------------------------------------------------
4096       | PyTorch SDPA             |     5.29 ms |    25.98 | 1.00x (ref)
           | SageAttn V1              |     8.62 ms |    15.95 |  0.61x
           | FA2 Triton               |     4.33 ms |    31.71 |  1.22x
           | FA2 CK                   |     3.69 ms |    37.20 |  1.43x
------------------------------------------------------------------------------
8192       | PyTorch SDPA             |    18.96 ms |    29.00 | 1.00x (ref)
           | SageAttn V1              |    34.66 ms |    15.86 |  0.55x
           | FA2 Triton               |    15.97 ms |    34.42 |  1.19x
           | FA2 CK                   |    15.38 ms |    35.75 |  1.23x
------------------------------------------------------------------------------
16384      | PyTorch SDPA             |    76.44 ms |    28.77 | 1.00x (ref)
           | SageAttn V1              |   139.62 ms |    15.75 |  0.55x
           | FA2 Triton               |    62.69 ms |    35.08 |  1.22x
           | FA2 CK                   |    56.46 ms |    38.95 |  1.35x
------------------------------------------------------------------------------
32768      | PyTorch SDPA             |   308.65 ms |    28.50 | 1.00x (ref)
           | SageAttn V1              |   563.65 ms |    15.61 |  0.55x
           | FA2 Triton               |   251.44 ms |    34.98 |  1.23x
           | FA2 CK                   |   224.83 ms |    39.12 |  1.37x
------------------------------------------------------------------------------

Results Overview

FA2_results ### **Verification** I also verified CK Flash by running:
python -m pytest tests/test_flash_attn_ck.py::test_flash_attn_output -v --tb=short

All tests completed successfully:

===================================== 42240 passed in 1627.55s (0:27:07) =====================================

Environment

OS: Windows 11 Python: 3.12.10 ROCm: 7.13.0a20260328 (TheRock) PyTorch: 2.10.0+rocm7.13.0a20260328 GPU: AMD Radeon RX 9060 XT (gfx1200) Triton: 3.6.0+gitae9d5a54.post27 (triton-windows)

Current issues

The Triton FA is without heads grouping (#2217 ) or ROCm/aiter#2483 right?

@jnolck
Copy link
Copy Markdown

jnolck commented Apr 1, 2026

I'm running your script on Linux with a gfx1100 card and my numbers are very different. I do have this merged in with the triton backend ROCm/aiter#2483 . Maybe that's why. Screenshot From 2026-03-31 12-27-26

@0xDELUXA @jnolck Above script passes the CK fwd(...) arguments in the wrong order and makes buggy results

Current code: flash_attn_2_cuda.fwd(..., SOFTMAX_SCALE, 0.0, ...)

But CK expects: flash_attn_2_cuda.fwd(..., 0.0, SOFTMAX_SCALE, ...)

After fixing that argument order bug, now I get above 70 TFLOPs for those cases on gfx1100.
Yup! fixed it and ran it again. Tianwyan, I double checked to make sure your PR was baked in.
'''flash_attn_triton_amd.py::fwd inputs
q: torch.Size([1, 32768, 16, 128])
k: torch.Size([1, 32768, 16, 128])
v: torch.Size([1, 32768, 16, 128])
out: None
alibi_slopes: None
dropout_p: 0.0
softmax_scale: 0.08838834764831845
causal: False
window_size_left: -1
window_size_right: -1
softcap: 0.0
return_softmax: False
Using Triton implementation
[LLC Head Grouping fwd_prefill] Processing 16 heads in groups of 5
flash_attn_triton_amd.py::fwd outputs
out: torch.Size([1, 32768, 16, 128])
softmax_lse: torch.Size([1, 16, 32768])
sd_mask: None'''
reran.

Screenshot From 2026-03-31 21-25-12 and both are flying!

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 1, 2026

@0xDELUXA @jnolck
Above script passes the CK fwd(...) arguments in the wrong order and makes buggy results

Current code:
flash_attn_2_cuda.fwd(..., SOFTMAX_SCALE, 0.0, ...)

But CK expects:
flash_attn_2_cuda.fwd(..., 0.0, SOFTMAX_SCALE, ...)

After fixing that argument order bug, now I get above 70 TFLOPs for those cases on gfx1100.

Got it, thanks for the explanation! I didn’t even know what CK is since I’ve never used Linux.

I fixed the argument order, but I still can't get anywhere near 70 TFLOPs on gfx1200 (Windows). The highest is 52.36.

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 1, 2026

The Triton FA is without heads grouping (#2217 ) or ROCm/aiter#2483 right?

Yes.

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 1, 2026

Screenshot From 2026-03-31 21-25-12

Looks like the argument order makes a much bigger difference on Linux than on Windows.

@astrelsky
Copy link
Copy Markdown

Screenshot From 2026-03-31 21-25-12

Looks like the argument order makes a much bigger difference on Linux than on Windows.

Yea, I went from at best 1.74TFLOP/s (0.17x sdpa) to 3.47TFLOP/s (0.34x sdpa). I'm pretty sure there's just something wrong for gfx1150 since all the ck tests start failing at 50% (when backward starts I think).

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 1, 2026

Yea, I went from at best 1.74TFLOP/s (0.17x sdpa) to 3.47TFLOP/s (0.34x sdpa). I'm pretty sure there's just something wrong for gfx1150 since all the ck tests start failing at 50% (when backward starts I think).

What about this?

@astrelsky
Copy link
Copy Markdown

astrelsky commented Apr 1, 2026

Yea, I went from at best 1.74TFLOP/s (0.17x sdpa) to 3.47TFLOP/s (0.34x sdpa). I'm pretty sure there's just something wrong for gfx1150 since all the ck tests start failing at 50% (when backward starts I think).

What about this?

Oh yeah, that would explain why they failed. I cancelled it because it was taking forever so I never saw the failure reasons.

@msembinelli
Copy link
Copy Markdown

Adding my results, gfx1100 (7900XTX). ~70-80 TFLOPS under FA2 CK

BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1

Screenshot from 2026-04-02 00-26-50

@astrelsky
Copy link
Copy Markdown

Adding my results, gfx1100 (7900XTX). ~70-80 TFLOPS under FA2 CK

BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu22.04_py3.10_pytorch_release_2.9.1

Screenshot from 2026-04-02 00-26-50

gfx1150, this is after fixing the ck arguments. I had to do auto tuning for sage attention or it would have been just as bad as ck.

SeqLen     | Method                   | Latency      | TFLOPS     | vs SDPA
------------------------------------------------------------------------------
4096       | PyTorch SDPA             |    14.71 ms |     9.34 | 1.00x (ref)
           | SageAttn V1              |    15.35 ms |     8.95 |  0.96x
           | FA2 Triton               |    18.75 ms |     7.33 |  0.78x
           | FA2 CK                   |    40.88 ms |     3.36 |  0.36x
------------------------------------------------------------------------------
8192       | PyTorch SDPA             |    59.21 ms |     9.29 | 1.00x (ref)
           | SageAttn V1              |    47.96 ms |    11.46 |  1.23x
           | FA2 Triton               |    68.31 ms |     8.05 |  0.87x
           | FA2 CK                   |   162.09 ms |     3.39 |  0.37x
------------------------------------------------------------------------------
16384      | PyTorch SDPA             |   214.76 ms |    10.24 | 1.00x (ref)
           | SageAttn V1              |   191.69 ms |    11.47 |  1.12x
           | FA2 Triton               |   276.85 ms |     7.94 |  0.78x
           | FA2 CK                   |   650.47 ms |     3.38 |  0.33x
------------------------------------------------------------------------------
32768      | PyTorch SDPA             |   863.99 ms |    10.18 | 1.00x (ref)
           | SageAttn V1              |  1261.91 ms |     6.97 |  0.68x
           | FA2 Triton               |  1131.45 ms |     7.77 |  0.76x
           | FA2 CK                   |  2533.70 ms |     3.47 |  0.34x
------------------------------------------------------------------------------

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 2, 2026

gfx1150, this is after fixing the ck arguments.

I've fixed the script, so I assume @msembinelli used this updated version.

@0xDELUXA
Copy link
Copy Markdown
Contributor

0xDELUXA commented Apr 3, 2026

@micmelesse This PR enables us to build Flash Attention's CK backend on Windows (see #2400 (comment)). What do you think - will we be able to maintain Windows support after ROCm/aiter#2264 and #2364 are eventually merged? CK is already supported in MIOpen on Windows, so I think FA could support it as well.
Also, any thoughts on #2350?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants