Vulkan Scalar Flash Attention Refactor#19625
Conversation
|
I did some quick runs on my RX 470, the tests are passing and performance seems pretty similar to what it was like before. Nothing crazy at least. PR
Master
|
|
Thank very much 0cc4m. Huge improvements on 8k and 16k context. I believe for any meaning full conversation pp and tg is least important on depth 0. It requires context to communicate and this has huge improvements and making FA worthy. |
|
AMD Vega 8 APU.
|
|
I see a regression in prompt processing for GPT-OSS 20B and Qwen3MoE on RX 6800 XT, I'll try to fix it. |
|
@0cc4m can you also kindly look at qwen3next this architecture it also has PP regression atleast in my setup. This architecture is already used by Qwen team in three models. I believe going forward they will be using it and other different models as well. Previously shared benchmarks also had some other un-merged PRs related to vulkan and Qwen3Next. To isolate I just merged this PR in master removed all other PRs that I was merging. Regression is still there. Master Branch bash ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --ubatch-size 1024 --batch-size 2048 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
build: 684b361 (8057) Master only merge with this PR. bash ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --ubatch-size 1024 --batch-size 2048 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
build: 6cdddc6e0 (8089) |
|
Back to draft while I improve parameter selection to make tuning easier. |
d0cf725 to
c6ee63e
Compare
|
I fixed all regressions I could find. @jeffbolznv I refactored the way FA parameters are set, so that parameters are set only in one place. I tried to port Coopmat2 behaviour correctly and only slightly tweaked Coopmat1 parameters. Let me know if you agree with this approach or have a better idea. BenchmarksRadeon Pro VII:
AMD RX 6800 XT:
AMD RX 8060S (without coopmat):
AMD RX 8060S (with coopmat):
Intel A770:
Nvidia RTX 3090 (without coopmat):
Nvidia RTX 3090 (coopmat1):
Nvidia RTX 3090 (coopmat2):
|
|
Thank you :) regression is gone and further improvements are observed. bash ./llama-bench -m /home/tipu/AI/models/bartowski/Qwen3-Coder-Next/Qwen_Qwen3-Coder-Next-Q4_K_L.gguf -ngl 100 --mmap 0 -fa 1 -d 8096,16192 -r 3 -dio 1
build: f91542c05 (8128) |
| barrier(); | ||
|
|
||
| vec4 Of[Br][HSV_per_thread / 4]; | ||
| ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; |
There was a problem hiding this comment.
FWIW I found recently that the output can be stored in fp16 even when using GGML_PREC_F32, and his can help a lot with register usage for large head sizes (e.g. deepseek/GLM-Flash/etc).
There was a problem hiding this comment.
I should add that to scalar and coopmat1, yes.
| [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { | ||
| uint32_t d = (idx + tid) % (HSK / 4); | ||
| uint32_t c = (idx + tid) / (HSK / 4); | ||
| if (c < Bc) { |
There was a problem hiding this comment.
Maybe change to:
| if (c < Bc) { | |
| if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { |
this would allow the compiler to optimize out the branch on all except the last iteration.
|
|
||
| // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). | ||
| if ((HSK % 16) != 0) { | ||
| if ((HSK % 16) != 0 || (HSV % 16) != 0) { |
There was a problem hiding this comment.
If HSK and HSV are both not aligned and are different values, seems like the smaller one will fetch stale values from the larger one. I think that would be very uncommon, maybe just disable SHMEM_STAGING for those cases.
There was a problem hiding this comment.
My thought here was that since I stage both K and V, I should check both for whether zero-initialization is required, so that it happens also if only HSV is off. Looking at it again, it might not be necessary to zero-pad in the beginning at all since both SHMEM_STAGING loads run with HS*_pad, so should also put zero values when parts of the tile are out of bounds.
There was a problem hiding this comment.
Yeah, looks like the load for K/V will pad with zero. I think filling Q will zero is still necessary to avoid inf*0=nan, but for that just the HSK check should be enough.
There was a problem hiding this comment.
I removed the HSV check and the kvsh loop.
| coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); | ||
| { | ||
| const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); | ||
| coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); |
There was a problem hiding this comment.
Maybe worth a comment to say that we load V from shared memory even when SHMEM_STAGING is 0.
There was a problem hiding this comment.
Similarly to K, if SHMEM_STAGING is not set, we only stage through shmem if the V datatype or bounds checks don't allow direct loading from global memory. But in those cases we load a much smaller tile compared to SHMEM_STAGING. I don't see what is different about how V is handled compared to how K was already handled previously.
There was a problem hiding this comment.
I just found the logic here kind of confusing - seeing V get staged even when SHMEM_STAGING is false. I just thought it was worth a comment to clarify.
There was a problem hiding this comment.
You're right, I was just wondering if something about V was special here. I added a comment to K and to V to explain the behaviour.
|
I did perf testing for the cm2 and scalar paths. cm2 was roughly unchanged. For scalar I didn't see any real improvements, and gpt-oss got slower on my system. I don't think this is a big deal since my system isn't really the target, but here are the results anyway: |
|
I've seen that gpt-oss hit on some other tests as well. The old shader worked very well for small head size. I haven't nailed down exactly why the new version is a little bit worse for that. |
|
For some reason I do not have regression as compared to my old release. Maybe there is different behavior for different hardware.
|
| shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; | ||
|
|
||
| const uint32_t osh_stride = row_split * MatBc / 4; | ||
| shared f16vec4 pvsh[MatBc * osh_stride]; |
There was a problem hiding this comment.
I don't think this is accounted for in the shared memory calculation. Maybe it could share with kvsh?
There was a problem hiding this comment.
I forgot to update that, yes. I want them to share, but if Of is float16_t and ACC_TYPE is f32 they are different types. GLSL doesn't let me reuse shared memory with a different type.
|
Using float16 Of caused another nasty regression on AMD RDNA for smaller head sizes, while improving larger ones. I'll try to find a way to keep good speeds for both cases. |
|
The issue on RDNA was that the register reduction from Of in float16 improved occupancy to a point where enough subgroups ran at once to thrash the cache. Performance is restored when occupancy is reduced again, so I forced this with a large unused shmem buffer. This is hacky, but I didn't find a better way. Let me know if you have concerns or suggestions. |
|
I'm not sure if this CM2 CI issue is an error I introduced or if I triggered the Turing coopmat bug on CM2 as well now. |
|
I think it's likely to be the same Turing bug. I don't know what to do about it other than to disable the coopmat2 flash attention path on Turing. |
ea7dfdf to
cf28133
Compare
cf28133 to
1482e30
Compare
1482e30 to
ae849d3
Compare
|
@masamaru-san I disabled fp16 FA on GCN with proprietary driver, similar to what you did. Can you try it? |
|
@0cc4m Sorry for the slow reply. |
* vulkan: allow using fp16 in scalar flash attention shader * split rows inside of subgroups for faster synchronization * use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 * use f32 scalar FA if f16 is not supported by device * fix amd workgroup size issue * optimize masksh use * add medium rows FA shader Br size * fixes * add padding to mask shmem buffer * cache q values into registers for KQ * fuse lf accumulation, pf and v accumulation into a loop * stage K loads through shmem * stage V loads through shmem * only stage through shmem on Nvidia * default to Bc 32 * also stage V through shmem when this is done for K * dynamic subgroups for intel * use vectorized stores * use float_type for dequantize4 functions * use smaller scalar rows size for smaller rows count * relax flash attention split_k condition to allow non-gqa use * use minimal subgroup size on Intel * fix shmem support function * fix rebase issues * fixes * Bc 4 for scalar FA is not a valid configuration * Use wave32 on AMD RDNA for scalar FA * add Intel shader core count lookup-table * fix regressions * device tuning * tmpsh size fix * fix editorconfig * refactor fa tuning logic into a single place * fix gqa opt logic * fix block_rows with small n_rows * amd tuning * fix hsk=72/80 issue * tuning * allow condition skipping for column check * use float16 for Of if available * address feedback * fix bad RDNA performance on head size <= 128 by limiting occupancy * allow printing pipeline stats * cleanup and fixes * limit occupancy for GCN for small batch FA with large HSK * disable f16 FA for GCN AMD GPUs on the proprietary driver
* vulkan: allow using fp16 in scalar flash attention shader * split rows inside of subgroups for faster synchronization * use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 * use f32 scalar FA if f16 is not supported by device * fix amd workgroup size issue * optimize masksh use * add medium rows FA shader Br size * fixes * add padding to mask shmem buffer * cache q values into registers for KQ * fuse lf accumulation, pf and v accumulation into a loop * stage K loads through shmem * stage V loads through shmem * only stage through shmem on Nvidia * default to Bc 32 * also stage V through shmem when this is done for K * dynamic subgroups for intel * use vectorized stores * use float_type for dequantize4 functions * use smaller scalar rows size for smaller rows count * relax flash attention split_k condition to allow non-gqa use * use minimal subgroup size on Intel * fix shmem support function * fix rebase issues * fixes * Bc 4 for scalar FA is not a valid configuration * Use wave32 on AMD RDNA for scalar FA * add Intel shader core count lookup-table * fix regressions * device tuning * tmpsh size fix * fix editorconfig * refactor fa tuning logic into a single place * fix gqa opt logic * fix block_rows with small n_rows * amd tuning * fix hsk=72/80 issue * tuning * allow condition skipping for column check * use float16 for Of if available * address feedback * fix bad RDNA performance on head size <= 128 by limiting occupancy * allow printing pipeline stats * cleanup and fixes * limit occupancy for GCN for small batch FA with large HSK * disable f16 FA for GCN AMD GPUs on the proprietary driver
* vulkan: allow using fp16 in scalar flash attention shader * split rows inside of subgroups for faster synchronization * use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 * use f32 scalar FA if f16 is not supported by device * fix amd workgroup size issue * optimize masksh use * add medium rows FA shader Br size * fixes * add padding to mask shmem buffer * cache q values into registers for KQ * fuse lf accumulation, pf and v accumulation into a loop * stage K loads through shmem * stage V loads through shmem * only stage through shmem on Nvidia * default to Bc 32 * also stage V through shmem when this is done for K * dynamic subgroups for intel * use vectorized stores * use float_type for dequantize4 functions * use smaller scalar rows size for smaller rows count * relax flash attention split_k condition to allow non-gqa use * use minimal subgroup size on Intel * fix shmem support function * fix rebase issues * fixes * Bc 4 for scalar FA is not a valid configuration * Use wave32 on AMD RDNA for scalar FA * add Intel shader core count lookup-table * fix regressions * device tuning * tmpsh size fix * fix editorconfig * refactor fa tuning logic into a single place * fix gqa opt logic * fix block_rows with small n_rows * amd tuning * fix hsk=72/80 issue * tuning * allow condition skipping for column check * use float16 for Of if available * address feedback * fix bad RDNA performance on head size <= 128 by limiting occupancy * allow printing pipeline stats * cleanup and fixes * limit occupancy for GCN for small batch FA with large HSK * disable f16 FA for GCN AMD GPUs on the proprietary driver
This started out as an attempt to go through the scalar FA version and add proper float16 support to improve AMD and Intel performance and went quite a bit further. @jeffbolznv Sorry about the amount of changes, let me know if there's something I can do to make the review easier. Please also let me know if you have architectural concerns. Flash Attention has so many dimensions and making it work well on so much hardware and models is pretty hard. I had to spend quite a lot of time figuring out and fixing regressions on specific configurations.
AI-generated summary of changes
Scalar Flash Attention Core Optimizations
Row Size Tiering
Vendor-Specific Optimizations
split_k Enhancements
Device Compatibility
Shared Memory Management
Code Path Selection
Shader Compilation
Benchmarks
AMD Radeon Pro VII
AMD 8060S
AMD 8060S (Without Coopmat)
Intel A770
Nvidia RTX 3090 (Coopmat2)
Nvidia RTX 3090 (Coopmat1)
Nvidia RTX 3090 (Without Coopmat)