feat: MLA prefill enable FA4 fp8 output#43050
Conversation
There was a problem hiding this comment.
Code Review
This pull request integrates support for fused FP8 output in Multi-Head Latent Attention (MLA) using FlashAttention-4 (FA4) on Blackwell GPUs. It updates the vllm-flash-attn dependency to a specific commit, modifies the attention backend interfaces to accept an output_scale parameter, and implements logic in mla_attention.py to bypass post-quantization when the attention kernel writes directly to the quantized output. I have no feedback to provide.
|
Benchmark, numbers seem unstable, will rerun later. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e708e3996c
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
5a3122b to
722d066
Compare
MatthewBonanni
left a comment
There was a problem hiding this comment.
LGTM, thanks for the contribution!
There was a problem hiding this comment.
Actually, are these numbers up to date? The accuracy regression / 0.011 invalid is concerning. Also these perf numbers contradict your later comment (which shows a speedup that does seem unrealistic?)
| Run | Accuracy | Invalid | Questions | Latency (s) | Q/s | Tokens | Tok/s |
| ---------- | -------- | ------- | --------- | ----------- | ---- | ------ | ----- |
| main_gsm8k | 0.785 | 0.000 | 1319 | 587.03 | 2.25 | 168785 | 287.5 |
| feat_gsm8k | 0.769 | 0.011 | 1319 | 600.43 | 2.20 | 167201 | 278.5 |
let me rerun both to confirm. |
|
@MatthewBonanni here's the new result: Benchmark (using vllm bench) has large variance in this case. I did more runs, and each showed different result: |
|
@carlyou thanks! What are A, B, C, and D here? Different runs of the same benchmark? |
@MatthewBonanni sorry, they are different runs of the same bench setup. |
|
@carlyou thanks for providing the details. Maybe the results would be clearer with a batch size 1 / concurrency 1 benchmark where we focus on TTFT? Alternatively, an even better option would be the microbenchmark in |
722d066 to
50fd06a
Compare
| vllm-flash-attn | ||
| GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git | ||
| GIT_TAG bce29425653ec0fbc579d329883030e832d15ada | ||
| GIT_TAG d0a0e2bf2113fcfd0336e5dd201a5fd89b297a8f |
There was a problem hiding this comment.
Now that we've landed vllm-project/flash-attention#141 and #44065, this change is no longer necessary.
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
1477e2f to
1548033
Compare
|
@MatthewBonanni added mla benchmark per suggestion. and rebased from latest main. please check the benchmark result above^ |
MatthewBonanni
left a comment
There was a problem hiding this comment.
Great! just a few more tweaks
| elif output_scale is not None: | ||
| # FA4 already wrote results in-place into `output`. | ||
| assert isinstance(output_prefill, torch.Tensor) | ||
| else: | ||
| assert isinstance(output_prefill, torch.Tensor) |
There was a problem hiding this comment.
These are identical branches
| # FA4 can write native fused FP8 (e4m3fn) output on Blackwell | ||
| # SM100/SM110 only (see flash-attention#135); FA4 natively handles | ||
| # MLA's mismatched qk/v head dims so no V padding is involved. | ||
| # Only static per-tensor FP8 is wired today; per-group FP8 / NVFP4 | ||
| # still go through the post-quant path. get_device_capability() is | ||
| # @cache'd, so this stays cheap on the hot path. |
There was a problem hiding this comment.
nit: remove unnecessary comment
| # FA4 can write native fused FP8 (e4m3fn) output on Blackwell | |
| # SM100/SM110 only (see flash-attention#135); FA4 natively handles | |
| # MLA's mismatched qk/v head dims so no V padding is involved. | |
| # Only static per-tensor FP8 is wired today; per-group FP8 / NVFP4 | |
| # still go through the post-quant path. get_device_capability() is | |
| # @cache'd, so this stays cheap on the hot path. |
| if self._is_vllm_fa: | ||
| return self._flash_attn_varlen_diff_headdims( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| cu_seqlens_q=self._prefill_metadata.query_start_loc, | ||
| cu_seqlens_k=self._prefill_metadata.query_start_loc, | ||
| max_seqlen_q=self._prefill_metadata.max_query_len, | ||
| max_seqlen_k=self._prefill_metadata.max_query_len, | ||
| softmax_scale=self.scale, | ||
| causal=True, | ||
| return_softmax_lse=return_softmax_lse, | ||
| out=out, | ||
| output_scale=output_scale, | ||
| ) | ||
| assert out is None and output_scale is None |
There was a problem hiding this comment.
Instead of adding this conditional, please modify self._flash_attn_varlen_diff_headdims to add out and output_scale to kwargs (see line 106)
Completes FlashAttn x Static FP8 in #35792
Purpose
Test Plan
Test Result
Eval
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.