Skip to content

tune FP8 tile sizes with two-level accumulation#125

Merged
LucasWilkinson merged 3 commits intovllm-project:mainfrom
jmkuebler:tune-fp8-tile-sizes
Mar 16, 2026
Merged

tune FP8 tile sizes with two-level accumulation#125
LucasWilkinson merged 3 commits intovllm-project:mainfrom
jmkuebler:tune-fp8-tile-sizes

Conversation

@jmkuebler
Copy link

@jmkuebler jmkuebler commented Mar 6, 2026

We added two-level accumulation for FP8 attention in #104. However, back then we missed that this severely degrades prefill performance at the existing tile sizes due to heavy spill-overs of registers.

#122 proposes to mitigate this by only calling the extra accumulator every N steps. I am considering N=4 here. However, this is a drastic change to the codebase and also decreases accuracy a little bit.

This PR takes a different approach and reduces the tile sizes such that the register pressure is minimized and thus spills are less sever. We achieve very good speedups for headdim 64 and 128, even better than the 4-steps approach. However, for headdim 256, we did not find a tiling config that meaningfully improves over the default.

Since we reduce the kv_len dimension of the tiling config, this means that we actually use the 2level call more often so, if anything, this PR would improve the accuracy over the existing mainline.

I am trying to summarize my consideration here:

Dimension This PR (tile tuning) N-steps (#122)
Code change ~15 lines in tile_size.h only Larger change across mainloop HPP, utils.h, divergence from upstream
Accuracy Same or better than main — smaller kBlockN means two-level fires more often Slightly worse — accumulator only fires every N steps
Prefill hd64 faster slower
Prefill hd128 faster slower
Prefill hd256 slower / unusable slower than bf16 but usable

As a side I am also adding the decode specialized tiling for headdim 256, but that's not the main discussion here.

experiments

In the experiments I am consider a fixed 8:1 GQA ratio and am using a fixed dim of 2048 to determine the number of Q-heads. I am also adding FP8_main and FP8_no2L as references here.

[Edit:FP8_nsteps_tile combines both this PR and https://github.com//pull/122]

headdim batch_size seq_len BF16 FP8_branch FP8_main FP8_no2L FP8_nsteps FP8_nsteps_tile branch_vs_BF16 branch_vs_main branch_vs_no2L branch_vs_nsteps nsteps_tile_vs_BF16 nsteps_tile_vs_main nsteps_tile_vs_branch
64 1 16384 2186 1966.9 2305.4 1715.6 2222.7 1888.3 1.111 1.172 0.872 1.13 1.158 1.221 1.042
64 1 8192 591.3 563.1 636.5 486.6 640.1 529.6 1.05 1.13 0.864 1.137 1.117 1.202 1.063
64 1 4096 195.5 194.2 211.8 175.8 216.9 184.8 1.007 1.091 0.905 1.117 1.058 1.146 1.051
64 2 8192 1156 1066.7 1236.8 931.4 1216.7 1015.1 1.084 1.159 0.873 1.141 1.139 1.218 1.051
64 2 4096 333.9 333.2 371.6 298.8 380.4 317.1 1.002 1.115 0.897 1.142 1.053 1.172 1.051
64 4 4096 631.3 612.6 695.8 543.3 704.1 580.7 1.031 1.136 0.887 1.149 1.087 1.198 1.055
64 4 2048 218.6 219.5 247.1 206.5 250.8 211.1 0.996 1.126 0.941 1.143 1.036 1.171 1.04
64 8 2048 385.3 386.8 440.8 359.3 446.2 370.3 0.996 1.14 0.929 1.154 1.041 1.19 1.045
64 8 1024 160.8 163.8 184.9 158.6 186.2 157.9 0.982 1.129 0.968 1.137 1.018 1.171 1.037
64 16 1024 265.9 270.7 314.9 261.6 312.9 261.6 0.982 1.163 0.966 1.156 1.016 1.204 1.035
64 16 512 129.6 133.7 152.6 133.4 153.4 131.2 0.969 1.141 0.998 1.147 0.988 1.163 1.019
64 32 512 206.2 212.5 252.2 213.3 250.6 208.5 0.97 1.187 1.004 1.179 0.989 1.21 1.019
96 1 16384 1407.2 1366.8 1534 1077.1 1316.5 1290.9 1.03 1.122 0.788 0.963 1.09 1.188 1.059
96 1 8192 393.4 398.6 431.8 312.4 376.1 368.6 0.987 1.083 0.784 0.944 1.067 1.171 1.081
96 1 4096 147.9 145.3 157.7 127.3 151.4 135.4 1.018 1.085 0.876 1.042 1.092 1.165 1.073
96 2 8192 748.3 740.1 821.2 585.4 714.4 692.1 1.011 1.11 0.791 0.965 1.081 1.187 1.069
96 2 4096 230.8 240 259.9 196.6 235.5 226 0.962 1.083 0.819 0.981 1.021 1.15 1.062
96 4 4096 420.6 429.1 463.3 337.7 414.2 398.4 0.98 1.08 0.787 0.965 1.056 1.163 1.077
96 4 2048 158.2 161 169.7 136.7 164.8 151.9 0.983 1.054 0.849 1.024 1.041 1.117 1.06
96 8 2048 263 270.9 281.9 224.2 271.8 255.7 0.971 1.041 0.828 1.003 1.029 1.102 1.059
96 8 1024 142.3 121.1 145.7 130.6 151 115.5 1.175 1.203 1.078 1.247 1.232 1.261 1.048
96 16 1024 226.5 191.1 235.3 204 240.7 183.1 1.185 1.231 1.068 1.26 1.237 1.285 1.044
96 16 512 106.3 103.1 105.8 95.8 114.9 99.4 1.031 1.026 0.929 1.114 1.069 1.064 1.037
96 32 512 159.8 152.3 157.5 137.9 171.6 147.3 1.049 1.034 0.905 1.127 1.085 1.069 1.034
128 1 16384 1720.2 1539.1 2048.6 1400.5 1665.5 1464.7 1.118 1.331 0.91 1.082 1.174 1.399 1.051
128 1 8192 472.4 433.1 571.8 400.2 470.7 413.7 1.091 1.32 0.924 1.087 1.142 1.382 1.047
128 1 4096 157.1 158.4 194.3 147.5 173.3 152.7 0.992 1.227 0.931 1.094 1.029 1.272 1.037
128 2 8192 910.6 825.5 1101.1 760.3 899.6 795.3 1.103 1.334 0.921 1.09 1.145 1.385 1.038
128 2 4096 276.8 266.5 342.1 248.3 293.7 256.7 1.039 1.284 0.932 1.102 1.078 1.333 1.038
128 4 4096 501.6 475.8 627.2 440.3 527.7 459 1.054 1.318 0.925 1.109 1.093 1.366 1.037
128 4 2048 175.2 181.5 222.3 170.4 203.6 176 0.965 1.225 0.939 1.122 0.995 1.263 1.031
128 8 2048 303.9 310.4 396.2 288.9 350.8 301.1 0.979 1.276 0.931 1.13 1.009 1.316 1.031
128 8 1024 130.9 140.9 166.6 132.8 159.3 138 0.929 1.182 0.943 1.131 0.949 1.207 1.021
128 16 1024 208.6 227.1 277.2 215.3 261.5 223.8 0.919 1.221 0.948 1.151 0.932 1.239 1.015
128 16 512 106.9 120.3 137.1 115.9 135.6 117.2 0.889 1.14 0.963 1.127 0.912 1.17 1.026
128 32 512 164 187.7 218.3 179.7 215.4 184.3 0.874 1.163 0.957 1.148 0.89 1.184 1.018
192 1 16384 1200 1376.9 1737 847.5 1145.7 1041.4 0.872 1.262 0.616 0.832 1.152 1.668 1.322
192 1 8192 333 393.4 485.5 252 332.1 296.7 0.846 1.234 0.641 0.844 1.122 1.636 1.326
192 1 4096 119.6 140 161.7 100.8 129.5 113.3 0.854 1.155 0.72 0.925 1.056 1.427 1.236
192 2 8192 630 732.5 913.5 465.9 628.1 563.8 0.86 1.247 0.636 0.857 1.117 1.62 1.299
192 2 4096 195.8 234.1 280.6 159.8 208.8 181.8 0.836 1.199 0.683 0.892 1.077 1.543 1.288
192 4 4096 354.1 413.5 507.9 273.3 362.5 316.3 0.856 1.228 0.661 0.877 1.12 1.606 1.307
192 4 2048 132.2 152.1 181.9 115.3 147.7 125.4 0.869 1.196 0.758 0.971 1.054 1.451 1.213
192 8 2048 216.5 253 313.4 183.8 242.7 204.2 0.856 1.239 0.726 0.959 1.06 1.535 1.239
192 8 1024 126.5 141.7 163.9 118.9 138.8 121.8 0.893 1.157 0.839 0.98 1.039 1.346 1.163
192 16 1024 194.3 219.4 258.7 180.1 214.7 186.2 0.886 1.179 0.821 0.979 1.044 1.389 1.178
192 16 512 110.2 118 111.8 86.8 104.1 106.4 0.934 0.947 0.736 0.882 1.036 1.051 1.109
192 32 512 162.2 174.8 164.1 119.6 150.4 156 0.928 0.939 0.684 0.86 1.04 1.052 1.121
256 1 16384 1578.6 3010.9 3025.1 1084.1 1658.7 1658.4 0.524 1.005 0.36 0.551 0.952 1.824 1.816
256 1 8192 428.1 821.2 819.2 311.9 461.8 461.8 0.521 0.998 0.38 0.562 0.927 1.774 1.778
256 1 4096 140.4 254.7 251 117.3 166.5 165.2 0.551 0.985 0.461 0.654 0.85 1.519 1.542
256 2 8192 822.1 1575.3 1578.9 590.4 892.2 893.4 0.522 1.002 0.375 0.566 0.92 1.767 1.763
256 2 4096 245.2 452.3 450 191.6 280.9 280.9 0.542 0.995 0.424 0.621 0.873 1.602 1.61
256 4 4096 448.2 844.3 844 339 517.9 519.3 0.531 1 0.402 0.613 0.863 1.625 1.626
256 4 2048 155.4 274.3 270.2 133.7 192.7 192.5 0.567 0.985 0.487 0.703 0.807 1.404 1.425
256 8 2048 266.7 489 484.5 218.9 335 333.8 0.545 0.991 0.448 0.685 0.799 1.451 1.465
256 8 1024 142.1 233.7 230.3 127.9 179.8 178.8 0.608 0.985 0.547 0.769 0.795 1.288 1.307
256 16 1024 220.9 387.7 383.2 198.6 293.9 292.7 0.57 0.988 0.512 0.758 0.755 1.309 1.325
256 16 512 119 144 139.5 95.4 130.2 129.3 0.826 0.969 0.663 0.904 0.92 1.079 1.114
256 32 512 179.8 220.5 215 135.5 202.1 200.6 0.815 0.975 0.615 0.917 0.896 1.072 1.099

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Author

cc
@PatrykSaffer

@PatrykSaffer
Copy link

This looks really promising.
Could you also try smaller tiles + accumulation every 4 steps?

@jmkuebler
Copy link
Author

jmkuebler commented Mar 6, 2026

This looks really promising. Could you also try smaller tiles + accumulation every 4 steps?

@PatrykSaffer good idea to test that too! I added that as a last option by merging our two branches. It is consistently the fast option (see last column of the table all numbers are >=1.

Whether it is worth to take the code-depth + accuracy risk for that is to be discussed I'd say.

@PatrykSaffer
Copy link

Whether it is worth to take the code-depth + accuracy risk for that is to be discussed I'd say.

It looks like gains from accumulating every 4 steps for head sizes 64 and 128 are 0-5%. I think there aren't that many models supporting head size 256. Maybe we could replace 256 with 192 used by deepseek.
If doing accumulation every 4 tiles doesn't yield any meaningful gains for 192 I would abandon it and go with your better tiling strategy.
WDYT?

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for doing this @jmkuebler !, I do very much appreciate this more surgical approach. Can you please open a vLLM side PR so we can run it through CI while benchmarking is ongoing?

@jmkuebler
Copy link
Author

Thanks for doing this @jmkuebler !, I do very much appreciate this more surgical approach. Can you please open a vLLM side PR so we can run it through CI while benchmarking is ongoing?

@LucasWilkinson I setup the vLLM PRhere vllm-project/vllm#36265

@jmkuebler
Copy link
Author

It looks like gains from accumulating every 4 steps for head sizes 64 and 128 are 0-5%. I think there aren't that many models supporting head size 256. Maybe we could replace 256 with 192 used by deepseek. If doing accumulation every 4 tiles doesn't yield any meaningful gains for 192 I would abandon it and go with your better tiling strategy. WDYT?

@PatrykSaffer you mean to add another separate config for headdim 192?

@PatrykSaffer
Copy link

PatrykSaffer commented Mar 6, 2026

@PatrykSaffer you mean to add another separate config for headdim 192?

I meant to benchmark 192 and focus on this shape too as I think it's much often used than 256.

Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
@jmkuebler
Copy link
Author

@PatrykSaffer for headdim 192 it is also not possible to find tile sizes that prevent spills. I still added tuned configs for it and especially one for decoding that should give large speedups.

@jmkuebler
Copy link
Author

@LucasWilkinson @MatthewBonanni @PatrykSaffer

The summary:
For head_dim up to 128, tile tuning can achieve that FP8 prefill perf with two-level accumulation after each tile is at least at parity with BF16.

For headdim 192 and 256, we cannot find tiles that make the fp8 prefill perf match bf16 prefill perf. In those cases the n-step accumulation might still be needed.

My recommendation is:
a/ merge this PR as it helps either way and I have implemented it to fall back to the original sizes when two-level is disabled.

b/ We can discuss whether to add the n-step accumulation. Here the trade-off is prefill Speed for headdim 192 and 256 vs accuracy and code maintainability.

WDYT?

@PatrykSaffer
Copy link

+1 on merging this one

IMO still worth merging accumulation every 4 steps to make fp8 dim 192 and 256 usable.

@MatthewBonanni
Copy link

@jmkuebler that plan seems reasonable to me. Let's try to get CI green on vllm-project/vllm#36265 and then land this

@LucasWilkinson LucasWilkinson merged commit 2921022 into vllm-project:main Mar 16, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants