Vulkan Flash Attention Coopmat1 Refactor#19075
Conversation
|
I haven't had a chance to look at the shader code in detail yet, but I'm surprised it's the token gen perf that decreases. I think those should be using the FA_SCALAR path and you didn't change that shader, so how did it get slower? |
|
It seems N gets set to gqa_ratio = 4 in this case, so the N <= 1 condition to set the path to scalar does not apply. Is that intentional or should the path be chosen before the gqa conditional? Edit: This is the case that gets worse on Nvidia: Using scalar here seems to make it worse, so I guess the choice was intentional. |
|
I enabled large K tile shmem loading on Nvidia again, that fixed the issue. Performance looks pretty good now: RTX 3090
|
|
Very good results on AMD RX 9060 XT (RDNA4) as well:
|
|
Are you sure that it's getting correctly recompiled? It would be rather odd that I can reproduce your issue and I find a fix, but you just consistently keep getting the same behaviour. |
|
@0cc4m |
|
Tested this with 3 x A770's and CPU offloading with GPT-OSS-120B. Got an extra ~20tps on PP, so definitely having an impact on performance across the board. Nothing crazy, but still very much worthwhile. |
I can reproduce it on Windows, so you are right, thank you. Very odd that I fixed it on Linux, but somehow not on Windows. I'll try to figure it out.
Intel Alchemist does not use coopmat, because we have not found a way to make it perform well that way. You can see this from |
|
5060ti: Before
After
With mod
CPU offload case:
After
|
|
I see a speedup for prompt processing, roughly noise for tg, with this change on 5090: |
|
Oh, the coopmat2 path has gotten slower. It's due to defining BLOCK_SIZE unconditionally, there are a couple |
|
@characharm This time I think I found it, now it's working for me on Windows. Please try again. @jeffbolznv You're right, sorry about that. I missed the other #ifs. Can you check if I got them now? |
|
@0cc4m much better now :D this is on linux btw
|
|
@maxious Did you disable coopmat2 for your 5080 checks? Otherwise you just caught the issue that Jeff mentioned, but you won't see any of the coopmat1 improvements that are the actual purpose of this PR. |
|
@0cc4m I was more worried about warning that if the PR was merged the impact to the status quo. But yes, there's a performance boost for coopmat1:
|
|
Makes sense, thank you for checking. |
|
Confirmed, gptoss20 is now working! 🚀 |
|
That's odd, that run passed in this PR, and now it has a single test going just barely above the error threshold. Did something else change in master in the meantime? I'll look into it. |
|
@maxious thanks for testing Intel B60. Can you please share system configurations, the numbers seem very low. |
|
I borrowed a Turing system and am able to reproduce this new failure. It has similar characteristics to an internal compiler bug we're working on and I'm fairly confident it's the same bug. As far as we know, the bug only affects Turing, but I don't have any suggestions for a workaround at the moment. The coopmat2 path doesn't appear to be affected, though I think that's just by luck. (Or really, the coopmat1 path is having bad luck to hit this bug because it's obscure enough that we've only encountered it recently). I would suggest to disable the coopmat1 path for Turing (users should be on the coopmat2 path anyway), but I don't know if we have an alternative to run CI on. |
|
If you gate it with an env variable, we can update the coopmap1 CI to set a specific value that force-enables the path. And by default, it will be disabled so that users don't go through that path. |
|
But then CI would still fail. Do we have any other coopmat1-capable hardware we can run on in CI? |
Hm, not sure - it would likely have to be self-hosted by someone in the community. I can easily add more Tesla T4 runners in the Azure cloud if this can help, but they are coopmat2. |
* vulkan: use coopmat for flash attention p*v matrix multiplication * fix P loading issue * fix barrier position * remove reduction that is no longer needed * move max thread reduction into loop * remove osh padding * add bounds checks and padding * remove unused code * fix shmem sizes, loop duration and accesses * don't overwrite Qf, add new shared psh buffer instead * add missing bounds checks * use subgroup reductions * optimize * move bounds check, reduce barriers * support other Bc values and other subgroup sizes * remove D_split * replace Of register array with shared memory Ofsh array * parallelize HSV across the rowgroups * go back to Of in registers, not shmem * vectorize sfsh * don't store entire K tile in shmem * fixes * load large k tiles to shmem on Nvidia * adapt shared memory host check function to shader changes * remove Bc 32 case * remove unused variable * fix missing mask reduction tmspsh barrier * fix mask bounds check * fix rowmax f16 under/overflow to inf * fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
* vulkan: use coopmat for flash attention p*v matrix multiplication * fix P loading issue * fix barrier position * remove reduction that is no longer needed * move max thread reduction into loop * remove osh padding * add bounds checks and padding * remove unused code * fix shmem sizes, loop duration and accesses * don't overwrite Qf, add new shared psh buffer instead * add missing bounds checks * use subgroup reductions * optimize * move bounds check, reduce barriers * support other Bc values and other subgroup sizes * remove D_split * replace Of register array with shared memory Ofsh array * parallelize HSV across the rowgroups * go back to Of in registers, not shmem * vectorize sfsh * don't store entire K tile in shmem * fixes * load large k tiles to shmem on Nvidia * adapt shared memory host check function to shader changes * remove Bc 32 case * remove unused variable * fix missing mask reduction tmspsh barrier * fix mask bounds check * fix rowmax f16 under/overflow to inf * fix flash_attn_cm2 BLOCK_SIZE preprocessor directives
I finally had the time to go through Jeff's Flash Attention shaders in detail and used the chance to refactor the Coopmat1 for AMD. It started out as an attempt to use Coopmats for the Softmax * V matrix multiplication as well and then escalated into a refactor of the whole shader structure.
It now uses coopmats for the Softmax result * V matrix multiplication, and I vectorized some variables, changed how shared memory is used, load K and V directly from global memory if possible, otherwise streamed through a shared memory cache.
Tests are passing. Performance is up significantly on AMD RX 8060S (Strix Halo). Draft because there is a regression on Nvidia. Let me know if you see anything obvious @jeffbolznv. More tuning is likely required.
AMD 8060S:
Nvidia 3090:
Claude Code was used for debugging and code analysis, but I wrote the code.