Skip to content

feat: MLA prefill enable FA4 fp8 output#43050

Open
carlyou wants to merge 9 commits into
vllm-project:mainfrom
carlyou:feat--mla-fa4-native-fp8-output
Open

feat: MLA prefill enable FA4 fp8 output#43050
carlyou wants to merge 9 commits into
vllm-project:mainfrom
carlyou:feat--mla-fa4-native-fp8-output

Conversation

@carlyou
Copy link
Copy Markdown
Contributor

@carlyou carlyou commented May 19, 2026

Completes FlashAttn x Static FP8 in #35792

Purpose

Test Plan

  • unit tests
  • eval FP8 on H100 and B200
  • benchmark FP8 on H100 and B200

Test Result

Eval

============================================
  EVAL SUMMARY
============================================

Config:      mla_fa4_fp8_output/b200_dscoder_v2_lite_fp8_eval
Model:       RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8
Repo:        https://github.com/carlyou/vllm.git

Hardware/Runtime:
  GPU:       NVIDIA B200
  GPU mem:   183359 MiB
  CUDA:      13.0
  PyTorch:   2.11.0+cu130
  Python:    3.12.3
  Platform:  x86_64

Runs:
  - main_gsm8k:
      branch:     main
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 4096 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      eval:       $ python tests/evals/gsm8k/gsm8k_eval.py --num-shots 5 --num-questions 1319 --max-tokens 256 --temperature 0.0 --seed 42
  - feat_gsm8k:
      branch:     test--mla-fa4-native-fp8-output
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 4096 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      eval:       $ python tests/evals/gsm8k/gsm8k_eval.py --num-shots 5 --num-questions 1319 --max-tokens 256 --temperature 0.0 --seed 42

| Run        | Accuracy | Invalid | Questions | Latency (s) | Q/s  | Tokens | Tok/s |
| ---------- | -------- | ------- | --------- | ----------- | ---- | ------ | ----- |
| main_gsm8k |    0.785 |   0.000 |      1319 |      587.03 | 2.25 | 168785 | 287.5 |
| feat_gsm8k |    0.769 |   0.011 |      1319 |      600.43 | 2.20 | 167201 | 278.5 |

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@mergify mergify Bot added ci/build nvidia rocm Related to AMD ROCm v1 labels May 19, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD May 19, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates support for fused FP8 output in Multi-Head Latent Attention (MLA) using FlashAttention-4 (FA4) on Blackwell GPUs. It updates the vllm-flash-attn dependency to a specific commit, modifies the attention backend interfaces to accept an output_scale parameter, and implements logic in mla_attention.py to bypass post-quantization when the attention kernel writes directly to the quantized output. I have no feedback to provide.

@carlyou
Copy link
Copy Markdown
Contributor Author

carlyou commented May 19, 2026

Benchmark, numbers seem unstable, will rerun later.

============================================
  BENCHMARK SUMMARY
============================================

Config:      mla_fa4_fp8_output/b200_dscoder_v2_lite_fp8_bench
Model:       RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8
Repo:        https://github.com/carlyou/vllm.git
Work dir:    /tmp/vllm-bench

Hardware/Runtime:
  GPU:       NVIDIA B200
  GPU mem:   183359 MiB
  CUDA:      13.0
  PyTorch:   2.11.0+cu130
  Python:    3.12.3
  Platform:  x86_64

Runs:
  - feat_warmup:
      branch:     test--mla-fa4-native-fp8-output
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 20 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 5 --ignore-eos
  - feat_baseline:
      branch:     test--mla-fa4-native-fp8-output
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 1000 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 50 --ignore-eos
  - main_warmup:
      branch:     main
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 20 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 5 --ignore-eos
  - main_baseline:
      branch:     main
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 1000 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 50 --ignore-eos

| Metric                     | feat_warmup | feat_baseline | main_warmup | main_baseline |
| -------------------------- | -----------: | -------------: | -----------: | -------------: |
| Successful requests        |          20 |          1000 |          20 |          1000 |
| Failed requests            |           0 |             0 |           0 |             0 |
| Request rate (RPS)         |         N/A |           N/A |         N/A |           N/A |
| Benchmark duration (s)     |       18.11 |        180.72 |       48.15 |        234.52 |
| Total input tokens         |       81920 |       4096000 |       81920 |       4096000 |
| Total generated tokens     |        2560 |        128000 |        2560 |        128000 |
| Request throughput (req/s) |        1.10 |          5.53 |        0.42 |          4.26 |
| Output token tput (tok/s)  |      141.39 |        708.27 |       53.17 |        545.79 |
| Peak output tput (tok/s)   |      260.00 |       2542.00 |      260.00 |       2553.00 |
| Total token tput (tok/s)   |     4665.88 |      23372.88 |     1754.66 |      18011.17 |
| Peak concurrent requests   |       20.00 |       1000.00 |       20.00 |       1000.00 |
| **Time to First Token** |             |               |             |               |
| Mean TTFT (ms)             |     1030.18 |      40027.97 |    11673.50 |      38276.86 |
| Median TTFT (ms)           |      818.93 |      40016.72 |    11191.26 |      38265.54 |
| P99 TTFT (ms)              |     3087.98 |      67248.10 |    29057.46 |      64787.02 |
| **Time per Output Token** |             |               |             |               |
| Mean TPOT (ms)             |      123.87 |        214.17 |      260.21 |        296.99 |
| Median TPOT (ms)           |      127.86 |        109.88 |      273.99 |        106.34 |
| P99 TPOT (ms)              |      134.13 |        871.12 |      315.00 |       1314.57 |
| **Inter-token Latency** |             |               |             |               |
| Mean ITL (ms)              |      123.87 |        214.17 |      260.21 |        296.99 |
| Median ITL (ms)            |       78.29 |        108.16 |       80.01 |        104.80 |
| P99 ITL (ms)               |     1582.71 |       1046.98 |     1588.46 |       1989.05 |

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e708e3996c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread vllm/v1/attention/backends/mla/prefill/flash_attn.py Outdated
@MatthewBonanni MatthewBonanni added the ready ONLY add when PR is ready to merge/full CI is needed label May 19, 2026
Comment thread cmake/external_projects/vllm_flash_attn.cmake
Comment thread cmake/external_projects/vllm_flash_attn.cmake Outdated
@carlyou carlyou force-pushed the feat--mla-fa4-native-fp8-output branch from 5a3122b to 722d066 Compare May 27, 2026 21:02
Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the contribution!

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA May 28, 2026
@MatthewBonanni MatthewBonanni enabled auto-merge (squash) May 28, 2026 15:41
@MatthewBonanni MatthewBonanni disabled auto-merge May 28, 2026 15:44
Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, are these numbers up to date? The accuracy regression / 0.011 invalid is concerning. Also these perf numbers contradict your later comment (which shows a speedup that does seem unrealistic?)

| Run        | Accuracy | Invalid | Questions | Latency (s) | Q/s  | Tokens | Tok/s |
| ---------- | -------- | ------- | --------- | ----------- | ---- | ------ | ----- |
| main_gsm8k |    0.785 |   0.000 |      1319 |      587.03 | 2.25 | 168785 | 287.5 |
| feat_gsm8k |    0.769 |   0.011 |      1319 |      600.43 | 2.20 | 167201 | 278.5 |

@github-project-automation github-project-automation Bot moved this from Ready to In review in NVIDIA May 28, 2026
@carlyou
Copy link
Copy Markdown
Contributor Author

carlyou commented May 28, 2026

Actually, are these numbers up to date? The accuracy regression / 0.011 invalid is concerning. Also these perf numbers contradict your later comment (which shows a speedup that does seem unrealistic?)

| Run        | Accuracy | Invalid | Questions | Latency (s) | Q/s  | Tokens | Tok/s |
| ---------- | -------- | ------- | --------- | ----------- | ---- | ------ | ----- |
| main_gsm8k |    0.785 |   0.000 |      1319 |      587.03 | 2.25 | 168785 | 287.5 |
| feat_gsm8k |    0.769 |   0.011 |      1319 |      600.43 | 2.20 | 167201 | 278.5 |

let me rerun both to confirm.

@carlyou
Copy link
Copy Markdown
Contributor Author

carlyou commented May 29, 2026

@MatthewBonanni here's the new result:
Eval:

| Run        | Accuracy | Invalid | Questions | Latency (s) | Q/s     | Tokens | Tok/s |
| ---------- | -------- | ------- | --------- | ----------- | ------- | ------ | ----- |
| main_gsm8k |    0.787 |   0.000 |      1319 |      528.66 |    2.49 | 168370 | 318.5 |
| feat_gsm8k |    0.778 |   0.000 |      1319 |      525.19 |    2.51 | 168267 | 320.4 |

Benchmark (using vllm bench) has large variance in this case. I did more runs, and each showed different result:

  ┌────────────────────┬─────────────────┬─────────────────┬─────────────────┬─────────────────┐
  │ Metric (baseline)  │  A feat / main  │  B feat / main  │  C feat / main  │  D feat / main  │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ feat base          │ old c6c6f4f3a   │ 0173ea046       │ 0173ea046       │ 0173ea046       │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ Duration (s)       │ 184.3 / 205.5   │ 228.6 / 186.1   │ 223.9 / 228.8   │ 202.1 / 180.8   │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ Throughput (req/s) │ 5.43 / 4.87     │ 4.38 / 5.37     │ 4.47 / 4.37     │ 4.95 / 5.53     │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ Output tok/s       │ 694.5 / 622.8   │ 560.1 / 687.8   │ 571.6 / 559.5   │ 633.3 / 707.8   │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ Mean TTFT (ms)     │ 37,846 / 35,882 │ 39,297 / 36,724 │ 37,389 / 36,829 │ 28,498 / 28,147 │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ Mean TPOT (ms)     │ 219.2 / 228.5   │ 280.6 / 214.8   │ 276.2 / 276.9   │ 250.7 / 232.2   │
  ├────────────────────┼─────────────────┼─────────────────┼─────────────────┼─────────────────┤
  │ within-run verdict │ feat +11%       │ main +23%       │ feat +2%        │ main +12%       │
  └────────────────────┴─────────────────┴─────────────────┴─────────────────┴─────────────────┘

@MatthewBonanni
Copy link
Copy Markdown
Member

MatthewBonanni commented May 29, 2026

@carlyou thanks! What are A, B, C, and D here? Different runs of the same benchmark?

@carlyou
Copy link
Copy Markdown
Contributor Author

carlyou commented May 29, 2026

@carlyou thanks! What are A, B, C, and D here? Different runs of the same benchmark?

@MatthewBonanni sorry, they are different runs of the same bench setup.
Each run compares feat vs main branch

Runs:
  - feat_baseline:
      branch:     test--mla-fa4-native-fp8-output
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 1000 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 50 --ignore-eos
  - main_baseline:
      branch:     main
      server:     $ vllm serve RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --tensor-parallel-size 1 --max-model-len 16384 --trust-remote-code --port 8000 --gpu-memory-utilization 0.85 -cc {"cudagraph_mode": "NONE", "custom_ops": ["+quant_fp8"], "pass_config": {"fuse_attn_quant": true}}
      bench:      $ vllm bench serve --model RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8 --num-prompts 1000 --request-rate inf --random-input-len 4096 --random-output-len 128 --num-warmups 50 --ignore-eos

@MatthewBonanni
Copy link
Copy Markdown
Member

@carlyou thanks for providing the details. Maybe the results would be clearer with a batch size 1 / concurrency 1 benchmark where we focus on TTFT?

Alternatively, an even better option would be the microbenchmark in benchmarks/attention_benchmarks/benchmark.py (which you'll need to update to pass output_scale, and add a post-quant op to mla_runner.py for the baseline run)

@carlyou carlyou force-pushed the feat--mla-fa4-native-fp8-output branch from 722d066 to 50fd06a Compare June 2, 2026 07:44
@carlyou carlyou requested a review from AndreasKaratzas as a code owner June 2, 2026 07:44
@mergify mergify Bot added the performance Performance-related issues label Jun 2, 2026
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG bce29425653ec0fbc579d329883030e832d15ada
GIT_TAG d0a0e2bf2113fcfd0336e5dd201a5fd89b297a8f
Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni Jun 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we've landed vllm-project/flash-attention#141 and #44065, this change is no longer necessary.

carlyou and others added 9 commits June 2, 2026 14:31
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Carl Y <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
@carlyou carlyou force-pushed the feat--mla-fa4-native-fp8-output branch from 1477e2f to 1548033 Compare June 2, 2026 21:34
@carlyou
Copy link
Copy Markdown
Contributor Author

carlyou commented Jun 2, 2026

@MatthewBonanni added mla benchmark per suggestion. and rebased from latest main.

FP8 Output Results:
                       Attention Benchmark Results                        
┏━━━━━━━┳━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━┓
┃ Batch ┃         ┃ Batch ┃ post_quant ┃ post_quant ┃    fused ┃   fused ┃
┃ Spec  ┃ Type    ┃  Size ┃   Time (s) ┃    vs Best ┃ Time (s) ┃ vs Best ┃
┡━━━━━━━╇━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━┩
│ q512  │ prefill │     1 │   0.000037 │     111.2% │ 0.000033 │  100.0% │
│ q1k   │ prefill │     1 │   0.000052 │     108.6% │ 0.000047 │  100.0% │
│ q2k   │ prefill │     1 │   0.000080 │     110.6% │ 0.000073 │  100.0% │
│ q4k   │ prefill │     1 │   0.000150 │     113.3% │ 0.000133 │  100.0% │
│ q8k   │ prefill │     1 │   0.000396 │     107.7% │ 0.000368 │  100.0% │
│ 2q4k  │ prefill │     2 │   0.000283 │     112.4% │ 0.000252 │  100.0% │
│ 4q4k  │ prefill │     4 │   0.000569 │     112.5% │ 0.000506 │  100.0% │
│ 8q4k  │ prefill │     8 │   0.001132 │     113.7% │ 0.000996 │  100.0% │
└───────┴─────────┴───────┴────────────┴────────────┴──────────┴─────────┘

please check the benchmark result above^
ps. the vllm bench run still has variance with bs=1 and output_len=1 (whichever runs last is faster...).

Copy link
Copy Markdown
Member

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! just a few more tweaks

Comment on lines +2342 to 2346
elif output_scale is not None:
# FA4 already wrote results in-place into `output`.
assert isinstance(output_prefill, torch.Tensor)
else:
assert isinstance(output_prefill, torch.Tensor)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are identical branches

Comment on lines +95 to +100
# FA4 can write native fused FP8 (e4m3fn) output on Blackwell
# SM100/SM110 only (see flash-attention#135); FA4 natively handles
# MLA's mismatched qk/v head dims so no V padding is involved.
# Only static per-tensor FP8 is wired today; per-group FP8 / NVFP4
# still go through the post-quant path. get_device_capability() is
# @cache'd, so this stays cheap on the hot path.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove unnecessary comment

Suggested change
# FA4 can write native fused FP8 (e4m3fn) output on Blackwell
# SM100/SM110 only (see flash-attention#135); FA4 natively handles
# MLA's mismatched qk/v head dims so no V padding is involved.
# Only static per-tensor FP8 is wired today; per-group FP8 / NVFP4
# still go through the post-quant path. get_device_capability() is
# @cache'd, so this stays cheap on the hot path.

Comment on lines +169 to +184
if self._is_vllm_fa:
return self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=self._prefill_metadata.query_start_loc,
cu_seqlens_k=self._prefill_metadata.query_start_loc,
max_seqlen_q=self._prefill_metadata.max_query_len,
max_seqlen_k=self._prefill_metadata.max_query_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=return_softmax_lse,
out=out,
output_scale=output_scale,
)
assert out is None and output_scale is None
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this conditional, please modify self._flash_attn_varlen_diff_headdims to add out and output_scale to kwargs (see line 106)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Todo
Status: In review

Development

Successfully merging this pull request may close these issues.

2 participants