Skip to content

feat: FA4 flash att supports fused fp8 output#135

Merged
MatthewBonanni merged 3 commits into
vllm-project:mainfrom
carlyou:feat--fa4-fwd-fused-fp8
May 20, 2026
Merged

feat: FA4 flash att supports fused fp8 output#135
MatthewBonanni merged 3 commits into
vllm-project:mainfrom
carlyou:feat--fa4-fwd-fused-fp8

Conversation

@carlyou
Copy link
Copy Markdown

@carlyou carlyou commented Apr 30, 2026

Purpose

Part of vLLM issue#35792 of MLA + quant fusion. Specifically,

  • FA4 is the preferred MLA prefill backend
  • PR adds support for fused static per-tensor FP8 output (SM100)
  • PR explores the pattern in external-facing interface and internal forward kernel, in preparation for per-group FP8 and NVFP4.
  • PR reuses the concepts of quant_key and quant_kwargs from vllm.
  • The quant scaling happens in one of the two places:
    • no kv split: directly in foward kernel
    • with kv split: in forward combine after kernels
  • SM90 is not included in this PR, because its forward class misses o_dtype and needs bigger changes.
    • skipped to limit the scope of this change, and will make a fast follow up PR afterwards.

Test Plan

  • new unit tests pass on B200
  • benchmark on B200
  • Note: more benchmark can be performed in vLLM once the PR is merged.

Test Results

benchmark result

shape bf16 attn bf16+quant (unfused) fused fp8 saved speedup
prefill_mla_4k (h=16/1, d=192-128) 119.2us / 1441 TF 152.3us / 1128 TF 118.0us / 1456 TF +34.4us 1.29x
prefill_mha_4k (h=32/32, d=128) 223.0us / 1233 TF 273.6us / 1005 TF 217.8us / 1262 TF +55.8us 1.26x
prefill_gqa_8k (h=32/4, d=128) 762.6us / 1442 TF 850.8us / 1292 TF 751.5us / 1463 TF +99.3us 1.13x
decode_gqa_8k (b=16, sq=1, h=16/1, d=128) 83.5us 97.8us 83.2us +14.7us 1.18x
decode_mha_8k (b=16, sq=1, h=16/16, d=128) 187.0us 201.4us 186.8us +14.6us 1.08x

note: the "fused fp8" is slighter quicker than "bf16 attn", likely due to less mem write.

pytest result

cute/test_flash_attn_fp8_output.py: 49 passed, 1117 warnings in 84.39s (0:01:24) ``` ============================= test session starts ============================== platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /root/flash-attention/.venv/bin/python cachedir: .pytest_cache rootdir: /root/flash-attention/tests configfile: pyproject.toml collecting ... collected 49 items

cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-False-dtype0] PASSED [ 2%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-False-dtype1] PASSED [ 4%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-True-dtype0] PASSED [ 6%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-64-64-True-dtype1] PASSED [ 8%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-False-dtype0] PASSED [ 10%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-False-dtype1] PASSED [ 12%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-True-dtype0] PASSED [ 14%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-128-128-True-dtype1] PASSED [ 16%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-False-dtype0] PASSED [ 18%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-False-dtype1] PASSED [ 20%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-True-dtype0] PASSED [ 22%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mha-192-128-True-dtype1] PASSED [ 24%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-False-dtype0] PASSED [ 26%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-False-dtype1] PASSED [ 28%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-True-dtype0] PASSED [ 30%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-64-64-True-dtype1] PASSED [ 32%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-False-dtype0] PASSED [ 34%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-False-dtype1] PASSED [ 36%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-True-dtype0] PASSED [ 38%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-128-128-True-dtype1] PASSED [ 40%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-False-dtype0] PASSED [ 42%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-False-dtype1] PASSED [ 44%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-True-dtype0] PASSED [ 46%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[mqa-192-128-True-dtype1] PASSED [ 48%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-False-dtype0] PASSED [ 51%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-False-dtype1] PASSED [ 53%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-True-dtype0] PASSED [ 55%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-64-64-True-dtype1] PASSED [ 57%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-False-dtype0] PASSED [ 59%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-False-dtype1] PASSED [ 61%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-True-dtype0] PASSED [ 63%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-128-128-True-dtype1] PASSED [ 65%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-False-dtype0] PASSED [ 67%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-False-dtype1] PASSED [ 69%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-True-dtype0] PASSED [ 71%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_matches_post_quant[gqa-192-128-True-dtype1] PASSED [ 73%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_varlen_deepseek_mla PASSED [ 75%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_auto_allocate PASSED [ 77%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_as_tensor PASSED [ 79%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[causal_local_left] PASSED [ 81%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[symmetric_local] PASSED [ 83%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_sliding_window[causal_full] PASSED [ 85%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_softcap[15.0] PASSED [ 87%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_softcap[30.0] PASSED [ 89%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_underuses_range] PASSED [ 91%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_matches_peak] PASSED [ 93%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_scale_extremes[scale_overuses_range] PASSED [ 95%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_split_kv PASSED [ 97%]
cute/test_flash_attn_fp8_output.py::test_fp8_output_validation_errors PASSED [100%]

=============================== warnings summary ===============================
cute/test_flash_attn_fp8_output.py: 1117 warnings
/root/flash-attention/.venv/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/_mlir_helpers/op.py:63: DeprecationWarning: make_fragment is deprecated, use make_rmem_tensor instead
res_or_list = opFunc(*args, **kwargs, loc=loc)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================= 49 passed, 1117 warnings in 84.39s (0:01:24) =================

</details>

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bd91f6e4dc

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread flash_attn/cute/interface.py Outdated
Comment thread flash_attn/cute/interface.py Outdated
@carlyou
Copy link
Copy Markdown
Author

carlyou commented May 1, 2026

Hi @LucasWilkinson @MatthewBonanni, would like your reviews on this PR.
This is an initial attempt to fuse quant output in FA4, starting with static FP8 and plan to add NVFP4, group FP8 next.
cc @@ProExpertProg

Copy link
Copy Markdown

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Two high-level questions:

  • We shouldn't pass the scale around as a float32, it will cause an unnecessary GPU-CPU sync.
  • Can we benchmark what happens if we read the scale from global memory and don't invert it ahead of time? If it's the same, that might be better as ahead-of-time inversion costs more if done in the forward pass.
  • I'm not sure adding quant_kwargs is the best idea, we're tying flash-attn source to vLLM structure. I'm actually wondering if we can just communicate the necessary parameters via scale metadata; if we pass the scale tensor output as a parameter, its dimensions and strides should be enough to infer the other params: group size from shape, column major and tma alignment from strides, which only leaves ue8m0, so just a single boolean. That way we don't need to pass quant_key around either

@carlyou carlyou force-pushed the feat--fa4-fwd-fused-fp8 branch from 61a1d0b to f96d04f Compare May 4, 2026 14:54
@carlyou
Copy link
Copy Markdown
Author

carlyou commented May 4, 2026

@ProExpertProg thanks for the comments, updated!
here's the benchmark result comparing inversion -- turns out doing it in forward pass is faster

| shape                                | preinvert | rcp_approx  | fdiv (1/output_scale)   |
| ---                                  | ---       | ---         | ---     |
| prefill_mla_4k (h=16/1, d=192-128)   | 129.4us   |   120.3us   | 120.1us |
| prefill_mha_4k (h=32/32, d=128)      | 226.2us   |   218.2us   | 218.6us |
| prefill_gqa_8k (h=32/4, d=128)       | 745.2us   |   740.7us   | 744.8us |
| decode_gqa_8k  (b=16, sq=1, h=16/1)  | 90.9us    |   83.0us    | 83.0us  |
| decode_mha_8k  (b=16, sq=1, h=16/16) | 197.6us   |   189.2us   | 189.3us |

also compared using rcp_approx, and no perf difference

Comment thread flash_attn/cute/flash_fwd_sm100.py Outdated
@carlyou carlyou force-pushed the feat--fa4-fwd-fused-fp8 branch from 6e973a9 to 631e1f2 Compare May 5, 2026 04:06
Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very clean to me, thanks! Please make the corresponding vllm-side PR so we can use vLLM CI to verify this

Comment thread tests/cute/test_flash_attn_fp8_output.py Outdated
Squashed from 10 commits of PR vllm-project#135 for rebase onto main.

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou force-pushed the feat--fa4-fwd-fused-fp8 branch from 91fb309 to 96e8957 Compare May 19, 2026 05:46
Comment on lines +866 to +869
assert qv is None, "fused FP8 output + MLA (qv) not supported yet"
assert not use_dedicated_hd256_kernel, (
"fused FP8 output + head_dim=256 kernel not supported yet"
)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are two new SM100 kernels after rebase, leave them out of scope for now -- will be fast followed once current PR nails the overall structure.

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou force-pushed the feat--fa4-fwd-fused-fp8 branch from 28b41f6 to e577a7a Compare May 19, 2026 06:07
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MatthewBonanni MatthewBonanni merged commit d0a0e2b into vllm-project:main May 20, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants