[AMD ROCm] Update CK and add RDNA 3/4 support#2400
Conversation
[CK_TILE] Update CK and add RDNA build support
|
kk just let me know the order we should merge. Does this PR go first? |
|
This is great! I thought we couldn’t use FlashAttention's CK backend on RDNA/Windows, but this PR enabled it. I’m currently building FA from source without setting Here’s what I ran if anyone wants to try: git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention/
cmd /c '"C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Auxiliary\Build\vcvars64.bat" >nul 2>&1 && set' | ForEach-Object { if ($_ -match '^([^=]+)=(.*)$') { [System.Environment]::SetEnvironmentVariable($matches[1], $matches[2], 'Process') } }
.\venv\Scripts\Activate.ps1
$ROCM_ROOT = (rocm-sdk path --root).Trim()
$ROCM_BIN = (rocm-sdk path --bin).Trim()
$env:ROCM_HOME = $ROCM_ROOT
$env:PATH = "$ROCM_ROOT\lib\llvm\bin;$ROCM_BIN;$env:PATH"
$env:CC = "clang-cl"
$env:CXX = "clang-cl"
$env:DISTUTILS_USE_SDK = "1"
$env:MAX_JOBS = "8"
pip install --no-build-isolation -v . |
FlashAttention 2 Composable Kernel backend on WindowsWith the help of @astrelsky, after some debugging for Windows support, we successfully built and ran FlashAttention 2 on Windows using the CK backend. Verificationpython -c "from flash_attn import flash_attn_func; print('OK:', flash_attn_func); import flash_attn_2_cuda; print(flash_attn_2_cuda); print(dir(flash_attn_2_cuda))"
OK: <function flash_attn_func at 0x00000237FD015760>
<module 'flash_attn_2_cuda' from 'C:\\venv\\Lib\\site-packages\\flash_attn_2_cuda.cp312-win_amd64.pyd'>
['__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'bwd', 'fwd', 'fwd_kvcache', 'varlen_bwd', 'varlen_fwd']Benchmark(Using this script, it includes SDPA Flash via AOTriton and SageAttention 1.0.6) FA2 Triton backend: [aiter] Windows: CK and HIP ops are not available. Triton ops only.
FA2 CK backend: C:\venv\Lib\site-packages\flash_attn_2_cuda.cp312-win_amd64.pyd
SeqLen | Method | Latency | TFLOPS | vs SDPA
------------------------------------------------------------------------------
4096 | PyTorch SDPA | 5.09 ms | 27.02 | 1.00x (ref)
| SageAttn V1 | 8.51 ms | 16.16 | 0.60x
| FA2 Triton | 4.18 ms | 32.87 | 1.22x
| FA2 CK | 2.62 ms | 52.36 | 1.94x
------------------------------------------------------------------------------
8192 | PyTorch SDPA | 18.89 ms | 29.10 | 1.00x (ref)
| SageAttn V1 | 35.01 ms | 15.70 | 0.54x
| FA2 Triton | 15.79 ms | 34.82 | 1.20x
| FA2 CK | 14.94 ms | 36.79 | 1.26x
------------------------------------------------------------------------------
16384 | PyTorch SDPA | 74.44 ms | 29.54 | 1.00x (ref)
| SageAttn V1 | 136.18 ms | 16.15 | 0.55x
| FA2 Triton | 61.57 ms | 35.72 | 1.21x
| FA2 CK | 42.80 ms | 51.37 | 1.74x
------------------------------------------------------------------------------
32768 | PyTorch SDPA | 299.98 ms | 29.32 | 1.00x (ref)
| SageAttn V1 | 548.98 ms | 16.02 | 0.55x
| FA2 Triton | 246.00 ms | 35.76 | 1.22x
| FA2 CK | 169.07 ms | 52.02 | 1.77x
------------------------------------------------------------------------------VerificationI also verified CK Flash by running: python -m pytest tests/test_flash_attn_ck.py::test_flash_attn_output -v --tb=shortAll tests completed successfully: EnvironmentOS: Windows 11 Current issue preventing CK Flash from building on Windows |
|
I'm running your script on Linux with a gfx1100 card and my numbers are very different. I do have this merged in with the triton backend ROCm/aiter#2483 . Maybe that's why. |
@0xDELUXA @jnolck Current code: But CK expects: After fixing that argument order bug, now I get above 70 TFLOPs for those cases on gfx1100. |
The Triton FA is without heads grouping (#2217 ) or ROCm/aiter#2483 right? |
and both are flying!
|
Got it, thanks for the explanation! I didn’t even know what CK is since I’ve never used Linux. I fixed the argument order, but I still can't get anywhere near |
Yes. |
What about this? |
Oh yeah, that would explain why they failed. I cancelled it because it was taking forever so I never saw the failure reasons. |
I've fixed the script, so I assume @msembinelli used this updated version. |
|
@micmelesse This PR enables us to build Flash Attention's CK backend on Windows (see #2400 (comment)). What do you think - will we be able to maintain Windows support after ROCm/aiter#2264 and #2364 are eventually merged? CK is already supported in MIOpen on Windows, so I think FA could support it as well. |






Summary
859acb5and align FMHA argument wiringsetup.pymha_fwd_head_grouping_utils.hpp)Co-worked with @hyoon1 (author of #2054). This is a more complete version — supersedes #2054. Would like to merge before #2350 (aiter migration) to avoid conflicts.
Test Plan
pytest tests/test_flash_attn_ck.pyon gfx11 and gfx12, no failures