Skip to content

Vulkan Flash Attention Coopmat1 Refactor#19075

Merged
0cc4m merged 30 commits intomasterfrom
0cc4m/vulkan-fa-cm1-pv
Jan 28, 2026
Merged

Vulkan Flash Attention Coopmat1 Refactor#19075
0cc4m merged 30 commits intomasterfrom
0cc4m/vulkan-fa-cm1-pv

Conversation

@0cc4m
Copy link
Contributor

@0cc4m 0cc4m commented Jan 24, 2026

I finally had the time to go through Jeff's Flash Attention shaders in detail and used the chance to refactor the Coopmat1 for AMD. It started out as an attempt to use Coopmats for the Softmax * V matrix multiplication as well and then escalated into a refactor of the whole shader structure.

It now uses coopmats for the Softmax result * V matrix multiplication, and I vectorized some variables, changed how shared memory is used, load K and V directly from global memory if possible, otherwise streamed through a shared memory cache.

Tests are passing. Performance is up significantly on AMD RX 8060S (Strix Halo). Draft because there is a regression on Nvidia. Let me know if you see anything obvious @jeffbolznv. More tuning is likely required.

AMD 8060S:

model size params ngl fa test t/s (ROCm) t/s (before) t/s (after) diff
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 1087.56 ± 41.81 1004.49 ± 44.19 1020.70 ± 46.87 +1.6%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 40.38 ± 0.28 43.02 ± 0.25 43.91 ± 0.44 +2.1%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 @ d8192 604.32 ± 4.11 418.94 ± 1.64 556.71 ± 3.25 +32.9%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 @ d8192 33.96 ± 0.04 33.34 ± 0.08 35.88 ± 0.14 +7.6%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 1037.80 ± 33.87 1316.13 ± 23.00 1285.66 ± 30.56 -2.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 70.82 ± 0.06 76.86 ± 0.12 76.17 ± 1.36 -0.9%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 734.15 ± 3.78 854.11 ± 2.40 967.79 ± 4.56 +13.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 65.34 ± 0.09 67.96 ± 0.36 70.87 ± 0.30 +4.3%

Nvidia 3090:

model size params ngl fa test t/s (before) t/s (after) diff
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 3482.02 ± 7.00 3497.43 ± 10.06 +0.4%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 126.31 ± 0.36 119.71 ± 1.53 -5.2%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 @ d8192 1528.77 ± 10.40 1664.99 ± 7.69 +8.9%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 @ d8192 103.77 ± 0.83 96.24 ± 0.76 -7.3%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 pp512 3114.63 ± 62.70 3151.79 ± 22.70 +1.2%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 tg128 182.60 ± 1.21 172.98 ± 1.47 -5.3%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 pp512 @ d8192 2143.15 ± 26.05 2272.91 ± 10.81 +6.1%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 tg128 @ d8192 159.97 ± 0.65 145.73 ± 1.31 -8.9%

Claude Code was used for debugging and code analysis, but I wrote the code.

@github-actions github-actions bot added Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning labels Jan 24, 2026
@jeffbolznv
Copy link
Contributor

I haven't had a chance to look at the shader code in detail yet, but I'm surprised it's the token gen perf that decreases. I think those should be using the FA_SCALAR path and you didn't change that shader, so how did it get slower?

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 25, 2026

It seems N gets set to gqa_ratio = 4 in this case, so the N <= 1 condition to set the path to scalar does not apply. Is that intentional or should the path be chosen before the gqa conditional?

Edit: This is the case that gets worse on Nvidia: FLASH_ATTN_EXT dst(128,32,1,1), q(128,1,32,1), k(128,8448,8,1), v(128,8448,8,1), m(8448,1,1,1): 20480 x 63.947 us = 1.30965e+06 us

Using scalar here seems to make it worse, so I guess the choice was intentional.

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 25, 2026

I enabled large K tile shmem loading on Nvidia again, that fixed the issue. Performance looks pretty good now:

RTX 3090

model size params ngl fa test t/s (Coopmat2) t/s (before) t/s (after) diff
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 4670.31 ± 50.79 3488.29 ± 24.23 3525.73 ± 6.64 +1.1%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 129.48 ± 0.31 127.66 ± 0.54 129.08 ± 0.63 +1.1%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 @ d8192 2143.89 ± 6.51 1511.25 ± 4.94 1922.56 ± 11.74 +27.2%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 @ d8192 98.50 ± 0.50 103.98 ± 0.17 108.84 ± 0.09 +4.7%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 pp512 4527.83 ± 29.20 3107.29 ± 23.96 3128.15 ± 26.23 +0.7%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 tg128 182.22 ± 0.58 181.79 ± 0.67 180.50 ± 0.48 -0.7%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 pp512 @ d8192 3226.14 ± 8.27 2126.74 ± 25.03 2371.08 ± 27.36 +11.5%
gpt-oss 20B Q8_0 11.27 GiB 20.91 B 99 1 tg128 @ d8192 155.60 ± 0.86 158.68 ± 0.59 159.03 ± 0.55 +0.2%

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 25, 2026

Very good results on AMD RX 9060 XT (RDNA4) as well:

model size params ngl fa test t/s (ROCm) t/s (before) t/s (after) diff
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 2206.08 ± 18.35 1862.80 ± 5.52 1943.51 ± 4.64 +4.3%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 46.99 ± 0.22 62.34 ± 0.11 63.81 ± 0.20 +2.4%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 pp512 @ d8192 882.63 ± 1.24 612.12 ± 3.75 1286.64 ± 5.34 +110.2%
llama 8B Q4_K - Small 4.36 GiB 8.03 B 99 1 tg128 @ d8192 40.35 ± 0.09 49.37 ± 0.16 52.40 ± 0.09 +6.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 2917.11 ± 17.84 2165.96 ± 33.70 2163.10 ± 24.91 -0.1%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 78.66 ± 0.16 103.09 ± 0.53 103.35 ± 0.09 +0.3%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 pp512 @ d8192 2064.34 ± 10.68 1238.09 ± 5.88 1740.71 ± 13.51 +40.6%
gpt-oss 20B MXFP4 MoE 11.27 GiB 20.91 B 99 1 tg128 @ d8192 72.97 ± 0.15 91.78 ± 0.59 95.49 ± 0.45 +4.0%

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 27, 2026

Are you sure that it's getting correctly recompiled? It would be rather odd that I can reproduce your issue and I find a fix, but you just consistently keep getting the same behaviour.

@characharm
Copy link
Contributor

@0cc4m
I delete the build directory between builds. I’ve now tried doing this in a fresh directory by running git clone https://github.com/ggml-org/llama.cpp -b 0cc4m/vulkan-fa-cm1-pv.

@HumerousGorgon
Copy link

Tested this with 3 x A770's and CPU offloading with GPT-OSS-120B. Got an extra ~20tps on PP, so definitely having an impact on performance across the board. Nothing crazy, but still very much worthwhile.
Compiling bench so that I can post actual numbers.

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 27, 2026

@0cc4m I delete the build directory between builds. I’ve now tried doing this in a fresh directory by running git clone https://github.com/ggml-org/llama.cpp -b 0cc4m/vulkan-fa-cm1-pv.

I can reproduce it on Windows, so you are right, thank you. Very odd that I fixed it on Linux, but somehow not on Windows. I'll try to figure it out.

Tested this with 3 x A770's and CPU offloading with GPT-OSS-120B. Got an extra ~20tps on PP, so definitely having an impact on performance across the board. Nothing crazy, but still very much worthwhile. Compiling bench so that I can post actual numbers.

Intel Alchemist does not use coopmat, because we have not found a way to make it perform well that way. You can see this from matrix cores: none. This PR only makes a difference if you see matrix cores: KHR_coopmat.

@lovedheart
Copy link
Contributor

lovedheart commented Jan 27, 2026

5060ti:
PP around -13%, TG around -4.6%

Before

model size params backend ngl fa dev test t/s
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B CUDA,Vulkan 99 1 Vulkan1 pp512 @ d8192 1674.94 ± 20.28
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B CUDA,Vulkan 99 1 Vulkan1 tg128 @ d8192 63.90 ± 0.06

After

model size params backend ngl fa dev test t/s
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 pp512 @ d8192 1458.11 ± 19.23
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 tg128 @ d8192 60.94 ± 0.29

With mod const uint32_t k_load_shmem = 0;

model size params backend ngl fa dev test t/s
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 pp512 @ d8192 1458.42 ± 15.79
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 tg128 @ d8192 60.86 ± 0.13

CPU offload case:
Before

model size params backend ngl fa dev ot test t/s
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B CUDA,Vulkan 99 1 Vulkan1 (4[0-7]).ffn_.*_exps=CPU pp512 @ d131072 508.83 ± 4.05
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B CUDA,Vulkan 99 1 Vulkan1 (4[0-7]).ffn_.*_exps=CPU tg128 @ d131072 33.44 ± 0.05

After

model size params backend ngl fa dev ot test t/s
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 (4[0-7]).ffn_.*_exps=CPU pp512 @ d131072 307.27 ± 2.05
qwen3next 80B.A3B Q4_K - Medium 13.24 GiB 21.65 B Vulkan 99 1 Vulkan1 (4[0-7]).ffn_.*_exps=CPU tg128 @ d131072 24.59 ± 0.02

@jeffbolznv
Copy link
Contributor

I see a speedup for prompt processing, roughly noise for tg, with this change on 5090:

before:
Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m c:\models\gpt-oss-20b-mxfp4.gguf -fa 1 -p 512 -n 128 -d 0,8192
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |      7329.55 ± 82.12 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        360.21 ± 1.32 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   pp512 @ d8192 |      4989.28 ± 46.03 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   tg128 @ d8192 |        323.58 ± 1.18 |

build: 2b4cbd283 (7849)

after:
Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo>llama-bench.exe -m c:\models\gpt-oss-20b-mxfp4.gguf -fa 1 -p 512 -n 128 -d 0,8192
ggml_vulkan: Found 1 Vulkan devices:
ggml_vulkan: 0 = NVIDIA GeForce RTX 5090 (NVIDIA) | uma: 0 | fp16: 1 | bf16: 1 | warp size: 32 | shared memory: 49152 | int dot: 1 | matrix cores: KHR_coopmat
load_backend: loaded Vulkan backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-vulkan.dll
load_backend: loaded CPU backend from Z:\github\jeffbolznv\llama.cpp\build\bin\RelWithDebInfo\ggml-cpu.dll
| model                          |       size |     params | backend    | ngl | fa |            test |                  t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | -: | --------------: | -------------------: |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           pp512 |     7411.47 ± 123.85 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |           tg128 |        356.75 ± 2.87 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   pp512 @ d8192 |      5624.12 ± 56.26 |
| gpt-oss 20B MXFP4 MoE          |  11.27 GiB |    20.91 B | Vulkan     |  99 |  1 |   tg128 @ d8192 |        328.25 ± 1.18 |

build: 92051ec92 (7849)

@jeffbolznv
Copy link
Contributor

Oh, the coopmat2 path has gotten slower. It's due to defining BLOCK_SIZE unconditionally, there are a couple defined(BLOCK_SIZE) in flash_attn_cm2.comp that should be updated.

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 28, 2026

@characharm This time I think I found it, now it's working for me on Windows. Please try again.

@jeffbolznv You're right, sorry about that. I missed the other #ifs. Can you check if I got them now?

@maxious
Copy link

maxious commented Jan 28, 2026

@0cc4m much better now :D this is on linux btw

Device Test Master (293a156) PR (92051ec) PR (7056c66) vs Master
NVIDIA RTX 5080 pp512 6885 6379 7073 +2.7% ✅
pp512@d8192 4391 2565 4561 +3.9% ✅
tg128 142 139 154 +8.6% ✅
tg128@d8192 117 111 128 +9.8% ✅
Intel B60 pp512 441 695 695 +57.6% ✅
pp512@d8192 22 77 77 +248% ✅
tg128 46 51 51 +8.9% ✅
tg128@d8192 6.2 12.2 12.2 +97% ✅

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 28, 2026

@maxious Did you disable coopmat2 for your 5080 checks? Otherwise you just caught the issue that Jeff mentioned, but you won't see any of the coopmat1 improvements that are the actual purpose of this PR.

@maxious
Copy link

maxious commented Jan 28, 2026

@0cc4m I was more worried about warning that if the PR was merged the impact to the status quo. But yes, there's a performance boost for coopmat1:

Device Test Master (293a156) coopmat2 Master (293a156) coopmat1 PR (7056c66) coopmat2 PR (7056c66) coopmat1 PR vs Master (coopmat1) PR vs Master (coopmat2)
NVIDIA RTX 5080 pp512 6885 5881 7073 5932 +0.9% +2.7%
NVIDIA RTX 5080 pp512@d8192 4391 2339 4561 2754 +17.7% +3.9%
NVIDIA RTX 5080 tg128 142 151 154 152 +0.7% +8.5%
NVIDIA RTX 5080 tg128@d8192 117 125 128 127 +1.6% +9.4%

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 28, 2026

Makes sense, thank you for checking.

@characharm
Copy link
Contributor

@0cc4m

Confirmed, gptoss20 is now working! 🚀

@0cc4m 0cc4m merged commit f6b533d into master Jan 28, 2026
75 of 78 checks passed
@0cc4m 0cc4m deleted the 0cc4m/vulkan-fa-cm1-pv branch January 28, 2026 17:52
@CISC
Copy link
Member

CISC commented Jan 29, 2026

@0cc4m
Copy link
Contributor Author

0cc4m commented Jan 29, 2026

That's odd, that run passed in this PR, and now it has a single test going just barely above the error threshold. Did something else change in master in the meantime? I'll look into it.

@bbharti
Copy link

bbharti commented Jan 29, 2026

@maxious thanks for testing Intel B60. Can you please share system configurations, the numbers seem very low.

@jeffbolznv
Copy link
Contributor

I borrowed a Turing system and am able to reproduce this new failure. It has similar characteristics to an internal compiler bug we're working on and I'm fairly confident it's the same bug. As far as we know, the bug only affects Turing, but I don't have any suggestions for a workaround at the moment.

The coopmat2 path doesn't appear to be affected, though I think that's just by luck. (Or really, the coopmat1 path is having bad luck to hit this bug because it's obscure enough that we've only encountered it recently). I would suggest to disable the coopmat1 path for Turing (users should be on the coopmat2 path anyway), but I don't know if we have an alternative to run CI on.

@ggerganov
Copy link
Member

If you gate it with an env variable, we can update the coopmap1 CI to set a specific value that force-enables the path. And by default, it will be disabled so that users don't go through that path.

@jeffbolznv
Copy link
Contributor

But then CI would still fail. Do we have any other coopmat1-capable hardware we can run on in CI?

@ggerganov
Copy link
Member

Do we have any other coopmat1-capable hardware we can run on in CI?

Hm, not sure - it would likely have to be self-hosted by someone in the community.

I can easily add more Tesla T4 runners in the Azure cloud if this can help, but they are coopmat2.

4b1tQu4ntN3k0 pushed a commit to 4b1tQu4ntN3k0/llama.cpp that referenced this pull request Feb 2, 2026
* vulkan: use coopmat for flash attention p*v matrix multiplication

* fix P loading issue

* fix barrier position

* remove reduction that is no longer needed

* move max thread reduction into loop

* remove osh padding

* add bounds checks and padding

* remove unused code

* fix shmem sizes, loop duration and accesses

* don't overwrite Qf, add new shared psh buffer instead

* add missing bounds checks

* use subgroup reductions

* optimize

* move bounds check, reduce barriers

* support other Bc values and other subgroup sizes

* remove D_split

* replace Of register array with shared memory Ofsh array

* parallelize HSV across the rowgroups

* go back to Of in registers, not shmem

* vectorize sfsh

* don't store entire K tile in shmem

* fixes

* load large k tiles to shmem on Nvidia

* adapt shared memory host check function to shader changes

* remove Bc 32 case

* remove unused variable

* fix missing mask reduction tmspsh barrier

* fix mask bounds check

* fix rowmax f16 under/overflow to inf

* fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
shaofeiqi pushed a commit to qualcomm/llama.cpp that referenced this pull request Feb 6, 2026
* vulkan: use coopmat for flash attention p*v matrix multiplication

* fix P loading issue

* fix barrier position

* remove reduction that is no longer needed

* move max thread reduction into loop

* remove osh padding

* add bounds checks and padding

* remove unused code

* fix shmem sizes, loop duration and accesses

* don't overwrite Qf, add new shared psh buffer instead

* add missing bounds checks

* use subgroup reductions

* optimize

* move bounds check, reduce barriers

* support other Bc values and other subgroup sizes

* remove D_split

* replace Of register array with shared memory Ofsh array

* parallelize HSV across the rowgroups

* go back to Of in registers, not shmem

* vectorize sfsh

* don't store entire K tile in shmem

* fixes

* load large k tiles to shmem on Nvidia

* adapt shared memory host check function to shader changes

* remove Bc 32 case

* remove unused variable

* fix missing mask reduction tmspsh barrier

* fix mask bounds check

* fix rowmax f16 under/overflow to inf

* fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ggml changes relating to the ggml tensor library for machine learning Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants