Skip to content

[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE #30014

Merged
pavanimajety merged 6 commits intovllm-project:mainfrom
jiahanc:postQuantComm
Dec 16, 2025
Merged

[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE #30014
pavanimajety merged 6 commits intovllm-project:mainfrom
jiahanc:postQuantComm

Conversation

@jiahanc
Copy link
Contributor

@jiahanc jiahanc commented Dec 4, 2025

Purpose

Move the FP4 quant before All Gather when Flashinfer TRTLLMGEN MOE is used on DP + EP enabled. This reduce the message size in All gather thus speed up All gather kernel time.
Blocked by #29804, will add same opt to flashinfer_trtllm_fp4_routed_moe after merged.
Original size

hidden_states (num_tokens, hidden_size), dtype bf16
routing_logits (num_tokens, num_experts), dtype fp32

After change

hidden_state (num_tokens, hidden_size/2), dtype unit8
hidden_state_sf (num_tokens , hidden_size/16), dtype fp8
routing_logits (num_tokens, num_experts), dtype fp32

On DeepSeek-R1 case, it is a 2.4x less message size

Test Plan

python3 -m vllm.entrypoints.openai.api_server --model nvidia/DeepSeek-R1-0528-FP4-v2 --tokenizer nvidia/DeepSeek-R1-0528-FP4-v2 --dtype auto --kv-cache-dtype fp8 --tensor-parallel-size 1 --pipeline-parallel-size 1 --data-parallel-size 4 --enable-expert-parallel --swap-space 16 --max-num-seqs 1024 --trust-remote-code --max-model-len 4096 --gpu-memory-utilization 0.8 --max-num-batched-tokens 16384 --no-enable-prefix-caching --compilation_config.pass_config.enable_fi_allreduce_fusion true --compilation_config.pass_config.enable_attn_fusion true --compilation_config.pass_config.enable_noop true --compilation_config.custom_ops+=+quant_fp8,+rms_norm --compilation_config.max_cudagraph_capture_size 2048 &


lm_eval --model local-completions --tasks gsm8k --model_args model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

Test Result

local-completions (model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192,trust_remote_code=True), gen_kwargs: (None), limit: 0.5, num_fewshot: None, batch_size: 2048
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9712|±  |0.0065|
|     |       |strict-match    |     5|exact_match|↑  |0.9712|±  |0.0065|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by moving FP4 quantization before the All-Gather operation in Flashinfer TRTLLMGEN MOE. This reduces the communication overhead. The changes correctly plumb the necessary extra_tensors through the distributed communication layers. I've identified one critical issue regarding an incorrect return type annotation that should be addressed.

@jiahanc
Copy link
Contributor Author

jiahanc commented Dec 4, 2025

Perf test on 4xGB200, pure prefill test (ISL 2048, OSL 1)
original perf

============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Maximum request concurrency:             128       
Benchmark duration (s):                  15.15     
Total input tokens:                      1048064   
Total generated tokens:                  512       
Request throughput (req/s):              33.80     
Output token throughput (tok/s):         33.80     
Peak output token throughput (tok/s):    55.00     
Peak concurrent requests:                183.00    
Total Token throughput (tok/s):          69219.00  

quant before all gather optimization

============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Maximum request concurrency:             128       
Benchmark duration (s):                  14.22     
Total input tokens:                      1048064   
Total generated tokens:                  512       
Request throughput (req/s):              35.99     
Output token throughput (tok/s):         35.99     
Peak output token throughput (tok/s):    64.00     
Peak concurrent requests:                192.00    
Total Token throughput (tok/s):          73714.89  

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +332 to +336
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch passes extra tensors to backends that cannot accept them

The new pre-quantized path now calls all2all_manager.dispatch(...) with an extra_tensors argument (see call below), but only NaiveAll2AllManager and AgRsAll2AllManager were updated to accept that parameter. Other supported backends (e.g., PPLXAll2AllManager.dispatch at vllm/distributed/device_communicators/all2all.py:229-235 and DeepEPAll2AllManagerBase.dispatch at lines 276-282) still take only (hidden_states, router_logits, is_sequence_parallel). When ModelOpt FP4 MoE runs with those backends and post_quant_allgather is enabled, this call will raise TypeError: dispatch() takes 4 positional arguments but 5 were given, crashing inference for those configurations.

Useful? React with 👍 / 👎.

@jiahanc
Copy link
Contributor Author

jiahanc commented Dec 4, 2025

@mgoin @pavanimajety may you help review?

@mgoin
Copy link
Member

mgoin commented Dec 4, 2025

cc @bnellnm @varun-sundar-rabindranath for the dispatch change

@mergify
Copy link

mergify bot commented Dec 8, 2025

Hi @jiahanc, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@jiahanc jiahanc marked this pull request as draft December 8, 2025 23:22
@jiahanc jiahanc marked this pull request as ready for review December 11, 2025 23:46
@jiahanc
Copy link
Contributor Author

jiahanc commented Dec 11, 2025

@bnellnm @pavanimajety may you help re-review the PR? Thanks~

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +1952 to +1954
dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch uses undefined hidden_states_to_dispatch

In the DP dispatch path the call to get_ep_group().dispatch(...) uses hidden_states_to_dispatch, but that variable is only assigned inside the if post_quant_allgather branch above; when the optimization is disabled (e.g., any dp_size>1 run that is not ModelOpt TRTLLM), this block is skipped and hidden_states_to_dispatch is undefined, so the forward will crash with an UnboundLocalError before any dispatch occurs.

Useful? React with 👍 / 👎.

@mgoin mgoin added performance Performance-related issues moe ready ONLY add when PR is ready to merge/full CI is needed labels Dec 12, 2025
@jiahanc jiahanc force-pushed the postQuantComm branch 2 times, most recently from 6c9de8d to fdfd942 Compare December 12, 2025 21:18
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@mgoin mgoin moved this to In review in NVIDIA Dec 16, 2025
@github-project-automation github-project-automation bot moved this from In review to Ready in NVIDIA Dec 16, 2025
@pavanimajety pavanimajety merged commit 254a7f8 into vllm-project:main Dec 16, 2025
62 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Dec 16, 2025
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Dec 17, 2025
…lm-project#30014)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…lm-project#30014)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
…lm-project#30014)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…lm-project#30014)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants