For BS=1, the gen phase flash_attn_vec_ext_f32 kernel is launched with a constant parallel_blocks value of 4. Check code.
However, parallel_blocks = 4 causes poor occupancy on GPU.
Consider following models. The current occupancy is far below what is achievable if parallel_blocks value is increased.
| Model |
num_heads |
head_dim |
occupancy with PB=4 on RTX 4090 |
achievable occupancy with optimal PB value on RTX 4090 |
| Llama 3B |
24 |
128 |
0.06 |
0.25 |
| Llama 8B |
32 |
128 |
0.08 |
0.25 |
| Qwen 1.5B |
12 |
128 |
0.03 |
0.25 |
| Qwen 7B |
28 |
128 |
0.07 |
0.25 |
I have a change that addresses this issue and it shows improvement in gen phase performance by up to 14%.