tune FP8 tile sizes with two-level accumulation#125
tune FP8 tile sizes with two-level accumulation#125LucasWilkinson merged 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
|
This looks really promising. |
@PatrykSaffer good idea to test that too! I added that as a last option by merging our two branches. It is consistently the fast option (see last column of the table all numbers are >=1. Whether it is worth to take the code-depth + accuracy risk for that is to be discussed I'd say. |
It looks like gains from accumulating every 4 steps for head sizes 64 and 128 are 0-5%. I think there aren't that many models supporting head size 256. Maybe we could replace 256 with 192 used by deepseek. |
There was a problem hiding this comment.
Thanks for doing this @jmkuebler !, I do very much appreciate this more surgical approach. Can you please open a vLLM side PR so we can run it through CI while benchmarking is ongoing?
@LucasWilkinson I setup the vLLM PRhere vllm-project/vllm#36265 |
@PatrykSaffer you mean to add another separate config for headdim 192? |
I meant to benchmark 192 and focus on this shape too as I think it's much often used than 256. |
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
|
@PatrykSaffer for headdim 192 it is also not possible to find tile sizes that prevent spills. I still added tuned configs for it and especially one for decoding that should give large speedups. |
|
@LucasWilkinson @MatthewBonanni @PatrykSaffer The summary: For headdim 192 and 256, we cannot find tiles that make the fp8 prefill perf match bf16 prefill perf. In those cases the n-step accumulation might still be needed. My recommendation is: b/ We can discuss whether to add the n-step accumulation. Here the trade-off is prefill Speed for headdim 192 and 256 vs accuracy and code maintainability. WDYT? |
|
+1 on merging this one IMO still worth merging accumulation every 4 steps to make fp8 dim 192 and 256 usable. |
|
@jmkuebler that plan seems reasonable to me. Let's try to get CI green on vllm-project/vllm#36265 and then land this |
We added two-level accumulation for FP8 attention in #104. However, back then we missed that this severely degrades prefill performance at the existing tile sizes due to heavy spill-overs of registers.
#122 proposes to mitigate this by only calling the extra accumulator every N steps. I am considering N=4 here. However, this is a drastic change to the codebase and also decreases accuracy a little bit.
This PR takes a different approach and reduces the tile sizes such that the register pressure is minimized and thus spills are less sever. We achieve very good speedups for headdim 64 and 128, even better than the 4-steps approach. However, for headdim 256, we did not find a tiling config that meaningfully improves over the default.
Since we reduce the kv_len dimension of the tiling config, this means that we actually use the 2level call more often so, if anything, this PR would improve the accuracy over the existing mainline.
I am trying to summarize my consideration here:
As a side I am also adding the decode specialized tiling for headdim 256, but that's not the main discussion here.
experiments
In the experiments I am consider a fixed 8:1 GQA ratio and am using a fixed dim of 2048 to determine the number of Q-heads. I am also adding FP8_main and FP8_no2L as references here.
[Edit:
FP8_nsteps_tilecombines both this PR and https://github.com//pull/122]