Skip to content

[BugFix] Fix TRT-LLM NVFP4 DP/EP#32349

Merged
robertgshaw2-redhat merged 7 commits intovllm-project:mainfrom
jiahanc:fix-trtllm-moe
Jan 19, 2026
Merged

[BugFix] Fix TRT-LLM NVFP4 DP/EP#32349
robertgshaw2-redhat merged 7 commits intovllm-project:mainfrom
jiahanc:fix-trtllm-moe

Conversation

@jiahanc
Copy link
Copy Markdown
Contributor

@jiahanc jiahanc commented Jan 14, 2026

Purpose

Flashinfer trtllm nvfp4 was broke by #31692 .

(EngineCore_DP2 pid=545)   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 841, in __call__
(EngineCore_DP2 pid=545)     return self._op(*args, **kwargs)
(EngineCore_DP2 pid=545)            ^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP2 pid=545)   File "/scratch/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 2146, in moe_forward_shared
(EngineCore_DP2 pid=545)     return self.forward_impl(hidden_states, router_logits)
(EngineCore_DP2 pid=545)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP2 pid=545)   File "/scratch/vllm/vllm/model_executor/layers/fused_moe/layer.py", line 1950, in forward_impl
(EngineCore_DP2 pid=545)     self.quant_method.prepare_dp_allgather_tensor(
(EngineCore_DP2 pid=545)   File "/scratch/vllm/vllm/model_executor/layers/quantization/modelopt.py", line 1576, in prepare_dp_allgather_tensor
(EngineCore_DP2 pid=545)     assert self.moe_quant_config is not None
(EngineCore_DP2 pid=545)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Test Plan

VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 \
VLLM_ATTENTION_BACKEND=FLASHINFER_MLA \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL=1 \
VLLM_FLASHINFER_MOE_BACKEND=latency \
python3 -m vllm.entrypoints.openai.api_server \
  --model nvidia/DeepSeek-R1-0528-FP4-v2 \
  --tokenizer nvidia/DeepSeek-R1-0528-FP4-v2 \
  --kv-cache-dtype fp8 \
  --tensor-parallel-size 1 \
  --pipeline-parallel-size 1 \
  --data-parallel-size 4 \
  --enable-expert-parallel \
  --max-model-len 34816 \
  --disable-uvicorn-access-log \
  --no-enable-prefix-caching \
  --async-scheduling \
  --all2all-backend allgather_reducescatter \
  --compilation-config.cudagraph_mode FULL_DECODE_ONLY \
  --compilation_config.custom_ops+=+rms_norm,+rotary_embedding \
  --gpu-memory-utilization 0.9 \
  --stream-interval 50 \
  --max-num-seqs 1024 \
  --max-num-batched-tokens 8192 \
  --max-cudagraph-capture-size 1024 &
lm_eval --model local-completions --tasks gsm8k --model_args model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192 --batch_size 2048 --trust_remote_code --limit 0.5

Test Result

local-completions (model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://0.0.0.0:8000/v1/completions,max_retries=3,tokenized_requests=False,timeout=1200,max_gen_toks=2048,max_length=8192,trust_remote_code=True), gen_kwargs: (None), limit: 0.5, num_fewshot: None, batch_size: 2048
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9621|±  |0.0074|
|     |       |strict-match    |     5|exact_match|↑  |0.9621|±  |0.0074|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@jiahanc
Copy link
Copy Markdown
Contributor Author

jiahanc commented Jan 14, 2026

@pavanimajety @mgoin can you help review? Thanks!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a crash that occurs when using FlashInfer with NVFP4 quantization for MoE layers, specifically with the TRT-LLM backend. The issue stems from an assertion failure where moe_quant_config was unexpectedly None. The proposed change correctly addresses the crash for the TRT-LLM backend by sourcing the a1_gscale from the layer object directly. However, my review indicates that this fix, while effective for the intended case, may introduce an AttributeError for other backends like FlashInfer-CUTLASS, where layer.a1_gscale is not set. I've provided a critical comment with a suggested code change to ensure the fix is robust across all supported backends.

hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
layer.a1_gscale,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

While this change correctly fixes the issue for the FLASHINFER_TRTLLM backend, it could introduce a bug for other backends like FLASHINFER_CUTLASS. For non-TRTLLM backends, self.moe_quant_config is created, but layer.a1_gscale is not set, which would lead to an AttributeError here.

A more robust approach would be to conditionally access a1_gscale from either self.moe_quant_config or layer.

Suggested change
layer.a1_gscale,
self.moe_quant_config.a1_gscale if self.moe_quant_config else layer.a1_gscale,

@robertgshaw2-redhat robertgshaw2-redhat changed the title [fix] fix FI NVFP4 trtllm moe break [BugFix] Fix TRT-LLM NVFP4 DP/EP Jan 14, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

thanks for the fix! sorry I did not catch this one in my testing.

@mergify mergify bot added the bug Something isn't working label Jan 14, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

Can you post your deployment?

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

add your launch command. I want to verify the result.

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

Could you please add this case to the B200 Temporary MoE Refactor job?

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment @cursor review or bugbot run to trigger another review on this PR

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

The bot is correct. This maybe_gather_dp is a hack. Can you please add an assert that this should only be called with the TRTLLM backend? and also add a TODO for Rob to refactor the need for maybe_gather_dp

@jiahanc
Copy link
Copy Markdown
Contributor Author

jiahanc commented Jan 14, 2026

add your launch command. I want to verify the result.

Added in the PR test plan

@jiahanc
Copy link
Copy Markdown
Contributor Author

jiahanc commented Jan 14, 2026

Could you please add this case to the B200 Temporary MoE Refactor job?

May you elaborate more on this? Is it a github issue I shall add to ?

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

Could you please add this case to the B200 Temporary MoE Refactor job?

May you elaborate more on this? Is it a github issue I shall add to ?

https://github.com/vllm-project/vllm/blob/main/.buildkite/test-pipeline.yaml#L1441-L1446

@jiahanc
Copy link
Copy Markdown
Contributor Author

jiahanc commented Jan 14, 2026

@robertgshaw2-redhat updated PR per comment, may you help review? Thanks~

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 14, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

I made some improvements to the PR

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

Thanks for the fix! I triggered the tests. PLease ping me on slack if you need any more help getting this landed.

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Robert Shaw and others added 4 commits January 15, 2026 20:20
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
@hjjq
Copy link
Copy Markdown
Contributor

hjjq commented Jan 15, 2026

Hi @robertgshaw2-redhat , the only failing test seems to be due to container startup errors, can you rerun it and merge if it passes? Thanks!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 19, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit 7350331 into vllm-project:main Jan 19, 2026
57 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 19, 2026
gopalsarda pushed a commit to gopalsarda/vllm that referenced this pull request Jan 20, 2026
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
lapy pushed a commit to lapy/vllm that referenced this pull request Jan 27, 2026
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants