[ROCM] Add support with Infinity Cache (LLC) awareness for performance improvement#2483
[ROCM] Add support with Infinity Cache (LLC) awareness for performance improvement#2483
Conversation
…he_aware.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
migrating from #PR#Dao-AILab/flash-attention#2217 |
|
Based on my local testing, LLC awareness actually benefits FA. CK appears to have it too (not relevant on Windows): ROCm/rocm-libraries#5018 |
Thanks @0xDELUXA ! We will always try our best to help! |
|
Minor note: this implementation is broader than the description suggests. The only gfx1100‑specific detail is the 96 MB fallback default for unrecognized RDNA architectures not listed in the cache table. |
|
@tianwyan Can you run the formatter with black==26.1.0 and ruff==0.11.11 ? This will allow the full ci to run |
formatted |
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
|
@micmelesse all checks have passed. |
micmelesse
left a comment
There was a problem hiding this comment.
I benchmarked this on gfx1100 against main across 162 configs (MHA, GQA, decode, prefill, causal True/False, SK up to 16384). I see gains at SQ=4096 (+5-9%) but significant regressions on decode (-55% to -79%) and shorter prefills (-6% to -31%). See #2542 for the bench setup. I will report the results from the CI node once it finishes. You can cherry pick the commit to reproduce.
Details
| B | HQ | HK | SQ | SK | causal | Baseline (TFLOPS) | LLC (TFLOPS) | Delta |
|---|---|---|---|---|---|---|---|---|
| 1 | 16 | 16 | 1 | 163 | ✗ | 0.0287 | 0.029 | +1.0% |
| 1 | 16 | 16 | 1 | 163 | ✓ | 0.0355 | 0.0357 | +0.3% |
| 1 | 16 | 16 | 1 | 8192 | ✗ | 0.1081 | 0.1082 | +0.1% |
| 1 | 16 | 16 | 1 | 8192 | ✓ | 0.1035 | 0.1037 | +0.2% |
| 1 | 16 | 16 | 1 | 16384 | ✗ | 0.1137 | 0.1123 | -1.3% |
| 1 | 16 | 16 | 1 | 16384 | ✓ | 0.1086 | 0.1081 | -0.5% |
| 1 | 16 | 16 | 1024 | 163 | ✗ | 11.7 | 11.5 | -1.8% |
| 1 | 16 | 16 | 1024 | 163 | ✓ | 2.3 | 2.3 | -0.4% |
| 1 | 16 | 16 | 1024 | 8192 | ✗ | 59.6 | 57.3 | -3.9% |
| 1 | 16 | 16 | 1024 | 8192 | ✓ | 55.2 | 53.5 | -3.1% |
| 1 | 16 | 16 | 1024 | 16384 | ✗ | 59.4 | 56.8 | -4.4% |
| 1 | 16 | 16 | 1024 | 16384 | ✓ | 56.1 | 54.2 | -3.3% |
| 1 | 16 | 16 | 4096 | 163 | ✗ | 19.7 | 19.8 | +0.5% |
| 1 | 16 | 16 | 4096 | 163 | ✓ | 1.2 | 1.2 | +1.0% |
| 1 | 16 | 16 | 4096 | 8192 | ✗ | 64.5 | 61.7 | -4.3% |
| 1 | 16 | 16 | 4096 | 8192 | ✓ | 58.7 | 56.5 | -3.8% |
| 1 | 16 | 16 | 4096 | 16384 | ✗ | 65.1 | 62.3 | -4.4% |
| 1 | 16 | 16 | 4096 | 16384 | ✓ | 60.6 | 58.1 | -4.1% |
| 1 | 48 | 48 | 1 | 163 | ✗ | 0.069 | 0.0695 | +0.7% |
| 1 | 48 | 48 | 1 | 163 | ✓ | 0.0929 | 0.0923 | -0.6% |
| 1 | 48 | 48 | 1 | 8192 | ✗ | 0.291 | 0.0992 | -65.9% |
| 1 | 48 | 48 | 1 | 8192 | ✓ | 0.2708 | 0.094 | -65.3% |
| 1 | 48 | 48 | 1 | 16384 | ✗ | 0.3061 | 0.0645 | -78.9% |
| 1 | 48 | 48 | 1 | 16384 | ✓ | 0.2894 | 0.0619 | -78.6% |
| 1 | 48 | 48 | 1024 | 163 | ✗ | 17.6 | 17.5 | -0.5% |
| 1 | 48 | 48 | 1024 | 163 | ✓ | 3.7 | 3.6 | -1.9% |
| 1 | 48 | 48 | 1024 | 8192 | ✗ | 63.4 | 43.5 | -31.4% |
| 1 | 48 | 48 | 1024 | 8192 | ✓ | 58.9 | 42.7 | -27.6% |
| 1 | 48 | 48 | 1024 | 16384 | ✗ | 63.8 | 47.1 | -26.3% |
| 1 | 48 | 48 | 1024 | 16384 | ✓ | 59.9 | 43.9 | -26.7% |
| 1 | 48 | 48 | 4096 | 163 | ✗ | 23.2 | 23.1 | -0.8% |
| 1 | 48 | 48 | 4096 | 163 | ✓ | 1.6 | 1.6 | -0.1% |
| 1 | 48 | 48 | 4096 | 8192 | ✗ | 62.6 | 60 | -4.2% |
| 1 | 48 | 48 | 4096 | 8192 | ✓ | 57.3 | 54.5 | -4.8% |
| 1 | 48 | 48 | 4096 | 16384 | ✗ | 57.4 | 61.1 | +6.5% |
| 1 | 48 | 48 | 4096 | 16384 | ✓ | 54.7 | 56.2 | +2.6% |
| 1 | 32 | 8 | 1 | 163 | ✗ | 0.0556 | 0.0546 | -1.8% |
| 1 | 32 | 8 | 1 | 163 | ✓ | 0.0742 | 0.0737 | -0.6% |
| 1 | 32 | 8 | 1 | 8192 | ✗ | 0.2178 | 0.2164 | -0.6% |
| 1 | 32 | 8 | 1 | 8192 | ✓ | 0.208 | 0.208 | -0.0% |
| 1 | 32 | 8 | 1 | 16384 | ✗ | 0.2233 | 0.2155 | -3.5% |
| 1 | 32 | 8 | 1 | 16384 | ✓ | 0.215 | 0.2104 | -2.2% |
| 1 | 32 | 8 | 1024 | 163 | ✗ | 15.8 | 15.8 | +0.2% |
| 1 | 32 | 8 | 1024 | 163 | ✓ | 3 | 3 | -0.1% |
| 1 | 32 | 8 | 1024 | 8192 | ✗ | 59.2 | 57.3 | -3.2% |
| 1 | 32 | 8 | 1024 | 8192 | ✓ | 55.4 | 53.7 | -3.0% |
| 1 | 32 | 8 | 1024 | 16384 | ✗ | 59.4 | 57.4 | -3.3% |
| 1 | 32 | 8 | 1024 | 16384 | ✓ | 56 | 54.1 | -3.5% |
| 1 | 32 | 8 | 4096 | 163 | ✗ | 23.9 | 23.7 | -0.8% |
| 1 | 32 | 8 | 4096 | 163 | ✓ | 1.6 | 1.6 | -0.4% |
| 1 | 32 | 8 | 4096 | 8192 | ✗ | 64 | 61.6 | -3.7% |
| 1 | 32 | 8 | 4096 | 8192 | ✓ | 59.7 | 57 | -4.4% |
| 1 | 32 | 8 | 4096 | 16384 | ✗ | 65.3 | 62.7 | -4.0% |
| 1 | 32 | 8 | 4096 | 16384 | ✓ | 61.6 | 58.3 | -5.5% |
| 4 | 16 | 16 | 1 | 163 | ✗ | 0.0836 | 0.0835 | -0.2% |
| 4 | 16 | 16 | 1 | 163 | ✓ | 0.1196 | 0.1186 | -0.9% |
| 4 | 16 | 16 | 1 | 8192 | ✗ | 0.3045 | 0.3001 | -1.4% |
| 4 | 16 | 16 | 1 | 8192 | ✓ | 0.2806 | 0.2788 | -0.6% |
| 4 | 16 | 16 | 1 | 16384 | ✗ | 0.3154 | 0.309 | -2.0% |
| 4 | 16 | 16 | 1 | 16384 | ✓ | 0.2962 | 0.2927 | -1.2% |
| 4 | 16 | 16 | 1024 | 163 | ✗ | 18.8 | 18.7 | -0.8% |
| 4 | 16 | 16 | 1024 | 163 | ✓ | 4.2 | 4.1 | -1.5% |
| 4 | 16 | 16 | 1024 | 8192 | ✗ | 60.1 | 59.4 | -1.2% |
| 4 | 16 | 16 | 1024 | 8192 | ✓ | 55.2 | 54.6 | -1.1% |
| 4 | 16 | 16 | 1024 | 16384 | ✗ | 59.9 | 59.1 | -1.3% |
| 4 | 16 | 16 | 1024 | 16384 | ✓ | 56 | 55.5 | -0.9% |
| 4 | 16 | 16 | 4096 | 163 | ✗ | 23.4 | 23.2 | -1.0% |
| 4 | 16 | 16 | 4096 | 163 | ✓ | 1.7 | 1.7 | -0.5% |
| 4 | 16 | 16 | 4096 | 8192 | ✗ | 62.5 | 63.2 | +1.2% |
| 4 | 16 | 16 | 4096 | 8192 | ✓ | 56.7 | 56 | -1.3% |
| 4 | 16 | 16 | 4096 | 16384 | ✗ | 61.3 | 59.8 | -2.5% |
| 4 | 16 | 16 | 4096 | 16384 | ✓ | 56.8 | 56.1 | -1.2% |
| 4 | 48 | 48 | 1 | 163 | ✗ | 0.109 | 0.1083 | -0.7% |
| 4 | 48 | 48 | 1 | 163 | ✓ | 0.1867 | 0.1827 | -2.2% |
| 4 | 48 | 48 | 1 | 8192 | ✗ | 0.3883 | 0.1649 | -57.5% |
| 4 | 48 | 48 | 1 | 8192 | ✓ | 0.3593 | 0.1601 | -55.4% |
| 4 | 48 | 48 | 1 | 16384 | ✗ | 0.4013 | 0.1333 | -66.8% |
| 4 | 48 | 48 | 1 | 16384 | ✓ | 0.3759 | 0.1282 | -65.9% |
| 4 | 48 | 48 | 1024 | 163 | ✗ | 21.5 | 21.4 | -0.8% |
| 4 | 48 | 48 | 1024 | 163 | ✓ | 4.7 | 4.7 | +0.1% |
| 4 | 48 | 48 | 1024 | 8192 | ✗ | 54.2 | 50.1 | -7.5% |
| 4 | 48 | 48 | 1024 | 8192 | ✓ | 51.6 | 46.6 | -9.7% |
| 4 | 48 | 48 | 1024 | 16384 | ✗ | 54.5 | 49.4 | -9.4% |
| 4 | 48 | 48 | 1024 | 16384 | ✓ | 51.6 | 45.6 | -11.7% |
| 4 | 48 | 48 | 4096 | 163 | ✗ | 22.5 | 22.2 | -1.3% |
| 4 | 48 | 48 | 4096 | 163 | ✓ | 1.3 | 1.2 | -0.5% |
| 4 | 48 | 48 | 4096 | 8192 | ✗ | 54.1 | 58.2 | +7.6% |
| 4 | 48 | 48 | 4096 | 8192 | ✓ | 50.3 | 52.7 | +4.9% |
| 4 | 48 | 48 | 4096 | 16384 | ✗ | 54.2 | 58.8 | +8.5% |
| 4 | 48 | 48 | 4096 | 16384 | ✓ | 50.9 | 53.2 | +4.6% |
| 4 | 32 | 8 | 1 | 163 | ✗ | 0.128 | 0.1265 | -1.1% |
| 4 | 32 | 8 | 1 | 163 | ✓ | 0.2233 | 0.2237 | +0.2% |
| 4 | 32 | 8 | 1 | 8192 | ✗ | 0.4497 | 0.4432 | -1.4% |
| 4 | 32 | 8 | 1 | 8192 | ✓ | 0.4242 | 0.4166 | -1.8% |
| 4 | 32 | 8 | 1 | 16384 | ✗ | 0.4515 | 0.4455 | -1.3% |
| 4 | 32 | 8 | 1 | 16384 | ✓ | 0.4346 | 0.4325 | -0.5% |
| 4 | 32 | 8 | 1024 | 163 | ✗ | 22.2 | 21.9 | -1.4% |
| 4 | 32 | 8 | 1024 | 163 | ✓ | 5 | 4.9 | -1.8% |
| 4 | 32 | 8 | 1024 | 8192 | ✗ | 62.6 | 59.8 | -4.5% |
| 4 | 32 | 8 | 1024 | 8192 | ✓ | 57.4 | 58.4 | +1.8% |
| 4 | 32 | 8 | 1024 | 16384 | ✗ | 61.9 | 61.5 | -0.7% |
| 4 | 32 | 8 | 1024 | 16384 | ✓ | 58 | 57.1 | -1.7% |
| 4 | 32 | 8 | 4096 | 163 | ✗ | 23.1 | 22.8 | -1.3% |
| 4 | 32 | 8 | 4096 | 163 | ✓ | 1.3 | 1.3 | -0.2% |
| 4 | 32 | 8 | 4096 | 8192 | ✗ | 62.9 | 63.1 | +0.4% |
| 4 | 32 | 8 | 4096 | 8192 | ✓ | 57.8 | 57.1 | -1.2% |
| 4 | 32 | 8 | 4096 | 16384 | ✗ | 63.3 | 62.3 | -1.6% |
| 4 | 32 | 8 | 4096 | 16384 | ✓ | 59.2 | 58.2 | -1.7% |
| 16 | 16 | 16 | 1 | 163 | ✗ | 0.1316 | 0.1303 | -1.0% |
| 16 | 16 | 16 | 1 | 163 | ✓ | 0.2144 | 0.2135 | -0.5% |
| 16 | 16 | 16 | 1 | 8192 | ✗ | 0.3929 | 0.3873 | -1.4% |
| 16 | 16 | 16 | 1 | 8192 | ✓ | 0.3718 | 0.3636 | -2.2% |
| 16 | 16 | 16 | 1 | 16384 | ✗ | 0.4072 | 0.3953 | -2.9% |
| 16 | 16 | 16 | 1 | 16384 | ✓ | 0.3809 | 0.3771 | -1.0% |
| 16 | 16 | 16 | 1024 | 163 | ✗ | 23.1 | 22.6 | -1.8% |
| 16 | 16 | 16 | 1024 | 163 | ✓ | 5.1 | 5.1 | -0.7% |
| 16 | 16 | 16 | 1024 | 8192 | ✗ | 60.6 | 60.3 | -0.6% |
| 16 | 16 | 16 | 1024 | 8192 | ✓ | 56 | 55.4 | -1.1% |
| 16 | 16 | 16 | 1024 | 16384 | ✗ | 59 | 58.6 | -0.6% |
| 16 | 16 | 16 | 1024 | 16384 | ✓ | 55.4 | 55.2 | -0.4% |
| 16 | 16 | 16 | 4096 | 163 | ✗ | 23.2 | 22.6 | -2.7% |
| 16 | 16 | 16 | 4096 | 163 | ✓ | 1.3 | 1.3 | -0.5% |
| 16 | 16 | 16 | 4096 | 8192 | ✗ | 61.6 | 61.1 | -0.8% |
| 16 | 16 | 16 | 4096 | 8192 | ✓ | 56.4 | 55.9 | -0.9% |
| 16 | 16 | 16 | 4096 | 16384 | ✗ | 59.8 | 59 | -1.4% |
| 16 | 16 | 16 | 4096 | 16384 | ✓ | 55.7 | 55.3 | -0.8% |
| 16 | 48 | 48 | 1 | 163 | ✗ | 0.1758 | 0.1735 | -1.3% |
| 16 | 48 | 48 | 1 | 163 | ✓ | 0.265 | 0.2634 | -0.6% |
| 16 | 48 | 48 | 1 | 8192 | ✗ | 0.4223 | 0.1782 | -57.8% |
| 16 | 48 | 48 | 1 | 8192 | ✓ | 0.4077 | 0.1729 | -57.6% |
| 16 | 48 | 48 | 1 | 16384 | ✗ | 0.4315 | 0.1522 | -64.7% |
| 16 | 48 | 48 | 1 | 16384 | ✓ | 0.4136 | 0.1484 | -64.1% |
| 16 | 48 | 48 | 1024 | 163 | ✗ | 21.9 | 21.4 | -2.3% |
| 16 | 48 | 48 | 1024 | 163 | ✓ | 4 | 4 | -0.0% |
| 16 | 48 | 48 | 1024 | 8192 | ✗ | 52.9 | 49.6 | -6.2% |
| 16 | 48 | 48 | 1024 | 8192 | ✓ | 49.5 | 46.5 | -6.0% |
| 16 | 48 | 48 | 1024 | 16384 | ✗ | 53.6 | 48.8 | -9.0% |
| 16 | 48 | 48 | 1024 | 16384 | ✓ | 50.3 | 45.5 | -9.6% |
| 16 | 48 | 48 | 4096 | 163 | ✗ | 22.7 | 21.9 | -3.8% |
| 16 | 48 | 48 | 4096 | 163 | ✓ | 1.3 | 1.3 | -0.2% |
| 16 | 48 | 48 | 4096 | 8192 | ✗ | 52.4 | 57.2 | +9.2% |
| 16 | 48 | 48 | 4096 | 8192 | ✓ | 49.1 | 52.3 | +6.6% |
| 16 | 48 | 48 | 4096 | 16384 | ✗ | 52.8 | 57.3 | +8.4% |
| 16 | 48 | 48 | 4096 | 16384 | ✓ | 49.8 | 52.7 | +5.8% |
| 16 | 32 | 8 | 1 | 163 | ✗ | 0.2147 | 0.2121 | -1.2% |
| 16 | 32 | 8 | 1 | 163 | ✓ | 0.3383 | 0.3344 | -1.2% |
| 16 | 32 | 8 | 1 | 8192 | ✗ | 0.4707 | 0.4696 | -0.2% |
| 16 | 32 | 8 | 1 | 8192 | ✓ | 0.4545 | 0.4524 | -0.5% |
| 16 | 32 | 8 | 1 | 16384 | ✗ | 0.475 | 0.4721 | -0.6% |
| 16 | 32 | 8 | 1 | 16384 | ✓ | 0.458 | 0.4567 | -0.3% |
| 16 | 32 | 8 | 1024 | 163 | ✗ | 22.6 | 22.5 | -0.3% |
| 16 | 32 | 8 | 1024 | 163 | ✓ | 4.6 | 4.5 | -0.4% |
| 16 | 32 | 8 | 1024 | 8192 | ✗ | 63.6 | 62.3 | -2.0% |
| 16 | 32 | 8 | 1024 | 8192 | ✓ | 57.2 | 57.5 | +0.4% |
| 16 | 32 | 8 | 1024 | 16384 | ✗ | 61.9 | 61.5 | -0.7% |
| 16 | 32 | 8 | 1024 | 16384 | ✓ | 58 | 57.7 | -0.5% |
| 16 | 32 | 8 | 4096 | 163 | ✗ | 22.9 | 22 | -3.9% |
| 16 | 32 | 8 | 4096 | 163 | ✓ | 1.4 | 1.4 | -0.0% |
| 16 | 32 | 8 | 4096 | 8192 | ✗ | 62.4 | 62 | -0.6% |
| 16 | 32 | 8 | 4096 | 8192 | ✓ | 57.2 | 56.6 | -1.0% |
| 16 | 32 | 8 | 4096 | 16384 | ✗ | 62.5 | 61.8 | -1.1% |
| 16 | 32 | 8 | 4096 | 16384 | ✓ | 58.2 | 58.2 | -0.1% |
No decode path changed in my PR at all. |
|
Heads grouping result from #Dao-AILab/flash-attention#2400 |
|
Hi @tianwyan, I have updated the bench step in CI to run model configs (Llama3 8B/70B/405B, Mixtral 7B/22B, DeepSeek-V3) over various sequence lengths. I have posted the baseline results in #2542. See #2542 (comment). You can rebase your branch on main and CI should bench it automatically. For local benching, you can run |
|
Hi @tianwyan, I have some results from comparing the CI benchmarks. I am going to focus on RDNA3 results. There is some variance on MI35x bench but it is probably noise since the llc feature is gated to RDNA. I compared the RDNA3 benchmark results using the data from https://github.com/ROCm/aiter/actions/runs/24348351017/job/71096344801 and the baseline data that I posted about from https://github.com/ROCm/aiter/actions/runs/24152934159/job/70485243064. I have attached a comparison figure and table below for the
RDNA3 - fwd
|

Motivation
This PR enables Flash Attention Triton support for AMD RDNA3 (Navi) GPUs, specifically targeting the gfx1100 architecture. The goal is to bring Flash Attention performance optimizations to consumer-grade AMD GPUs while leveraging the unique Infinity Cache (LLC) architecture for improved memory throughput.
Technical Details
New Architecture Support:
Performance Optimizations:
Code Cleanup:
Test Plan
Test Result