Skip to content

[Bug] Fix fp8 trtllm MoE modular kernel supported routing methods#37346

Merged
mgoin merged 2 commits intovllm-project:mainfrom
wzhao18:wzhao/fix-fp8-trtllm-moe
Mar 19, 2026
Merged

[Bug] Fix fp8 trtllm MoE modular kernel supported routing methods#37346
mgoin merged 2 commits intovllm-project:mainfrom
wzhao18:wzhao/fix-fp8-trtllm-moe

Conversation

@wzhao18
Copy link
Copy Markdown
Contributor

@wzhao18 wzhao18 commented Mar 17, 2026

Purpose

#35448 introduces some regression for adding _supports_routing_method to TrtLlmFp8ExpertsBase. However, this function is only needed for the monolithic kernel not the modular kernel, which should support any routing method as routing is done external to the kernel. This PR fixes this regression.

Test Plan

Test Result

# Using TRTLLM FP8 Modular MoE
vllm serve MiniMaxAI/MiniMax-M2.5 \
    --tensor-parallel-size=2 \
    --enable-expert-parallel  \
    --trust-remote-code \
    --tool-call-parser minimax_m2 \
    --reasoning-parser minimax_m2_append_think \
    --enable-auto-tool-choice

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9287|±  |0.0071|
|     |       |strict-match    |     5|exact_match|↑  |0.9234|±  |0.0073|
# Using TRTLLM FP8 Monolithic MoE
vllm serve deepseek-ai/DeepSeek-R1 --trust-remote-code --tensor-parallel-size 8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9575|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.9560|±  |0.0056|
# mxfp8
VLLM_USE_DEEP_GEMvllm serve Qwen/Qwen3-30B-A3B --quantization mxfp8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8976|±  |0.0083|
|     |       |strict-match    |     5|exact_match|↑  |0.8931|±  |0.0085|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a regression by refactoring the _supports_routing_method out of the TrtLlmFp8ExpertsBase class and into the TrtLlmFp8ExpertsMonolithic subclass. This change allows the TrtLlmFp8ExpertsModular kernel to correctly support all routing methods by inheriting the default permissive behavior. The _supports_quant_scheme method has also been moved into the subclasses, tailoring the supported quantization schemes for both the monolithic and modular implementations. The logic appears sound and I have not identified any issues with the implementation.

@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed quantization labels Mar 17, 2026
@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 17, 2026

cc @EdalatiAli

@wzhao18
Copy link
Copy Markdown
Contributor Author

wzhao18 commented Mar 17, 2026

Testing the command in #35448 throws error (on main as well):

vllm serve Qwen/Qwen3-30B-A3B --quantization mxfp8 

 [core.py:1099]   File "/vllm/vllm/v1/engine/core.py", line 1073, in run_engine_core
 [core.py:1099]     engine_core = EngineCoreProc(*args, engine_index=dp_rank, **kwargs)
 [core.py:1099]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1099]     return func(*args, **kwargs)
 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/v1/engine/core.py", line 839, in __init__
 [core.py:1099]     super().__init__(
 [core.py:1099]   File "/vllm/vllm/v1/engine/core.py", line 122, in __init__
 [core.py:1099]     kv_cache_config = self._initialize_kv_caches(vllm_config)
 [core.py:1099]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1099]     return func(*args, **kwargs)
 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/v1/engine/core.py", line 278, in _initialize_kv_caches
 [core.py:1099]     self.model_executor.initialize_from_config(kv_cache_configs)
 [core.py:1099]   File "/vllm/vllm/v1/executor/abstract.py", line 118, in initialize_from_config
 [core.py:1099]     compilation_times: list[float] = self.collective_rpc("compile_or_warm_up_model")
 [core.py:1099]                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/v1/executor/uniproc_executor.py", line 78, in collective_rpc
 [core.py:1099]     result = run_method(self.driver_worker, method, args, kwargs)
 [core.py:1099]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/v1/serial_utils.py", line 459, in run_method
 [core.py:1099]     return func(*args, **kwargs)
 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1099]     return func(*args, **kwargs)
 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/v1/worker/gpu_worker.py", line 601, in compile_or_warm_up_model
 [core.py:1099]     kernel_warmup(self)
 [core.py:1099]   File "/vllm/vllm/model_executor/warmup/kernel_warmup.py", line 37, in kernel_warmup
 [core.py:1099]     deep_gemm_warmup(model, max_tokens)
 [core.py:1099]   File "/vllm/vllm/tracing/otel.py", line 178, in sync_wrapper
 [core.py:1099]     return func(*args, **kwargs)
 [core.py:1099]            ^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/model_executor/warmup/deep_gemm_warmup.py", line 364, in deep_gemm_warmup
 [core.py:1099]     total = _count_warmup_iterations(model, max_tokens)
 [core.py:1099]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/model_executor/warmup/deep_gemm_warmup.py", line 342, in _count_warmup_iterations
 [core.py:1099]     if _fp8_linear_may_use_deep_gemm(m):
 [core.py:1099]        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 [core.py:1099]   File "/vllm/vllm/model_executor/warmup/deep_gemm_warmup.py", line 139, in _fp8_linear_may_use_deep_gemm
 [core.py:1099]     and module.quant_method.block_quant
 [core.py:1099]   AttributeError: 'Mxfp8OnlineLinearMethod' object has no attribute 'block_quant'

@EdalatiAli Can you check if I am doing anything wrong?

@wzhao18
Copy link
Copy Markdown
Contributor Author

wzhao18 commented Mar 17, 2026

@mgoin Actually just realize that _supports_router_logits_dtype should be only applicable for monolithic kernel too. For some reason I was under the impression that it matters for modular kernel as well. Can you confirm?

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@EdalatiAli
Copy link
Copy Markdown
Contributor

@wzhao18 Thanks for catching and fixing this. #35448 introduced the regression by adding _supports_routing_method and _supports_quant_scheme to TrtLlmFp8ExpertsBase while they should exist in the subclasses. The fix looks correct to me.
TrtLlmFp8ExpertsModular receives pre-computed topk_ids/topk_weights so _supports_router_logits_dtype could be moved to TrtLlmFp8ExpertsMonolithic IMO.

@EdalatiAli
Copy link
Copy Markdown
Contributor

EdalatiAli commented Mar 17, 2026

@EdalatiAli Can you check if I am doing anything wrong?

I'm currently investigating the issue.

@EdalatiAli
Copy link
Copy Markdown
Contributor

@wzhao18 This error is raised because Mxfp8OnlineLinearMethod is incompatible with DeepGeMM and when DeepGeMM is installed, this logic does not handle skipping MXFP8 layers. I rasied #37358 to fix the issue. Meanwhile, you can work around it by setting VLLM_USE_DEEP_GEMM=0 with --quantization mxfp8 .

@wzhao18
Copy link
Copy Markdown
Contributor Author

wzhao18 commented Mar 18, 2026

@EdalatiAli Got it. Thanks! Adding VLLM_USE_DEEP_GEMM=0 makes it work now and gsm8k looks good.

@wzhao18
Copy link
Copy Markdown
Contributor Author

wzhao18 commented Mar 18, 2026

@mgoin could you help review this one?

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, makes sense!

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 19, 2026
@mgoin mgoin merged commit e27b8ba into vllm-project:main Mar 19, 2026
64 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 19, 2026
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
SouthWest7 pushed a commit to SouthWest7/vllm that referenced this pull request Mar 27, 2026
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
…lm-project#37346)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
JiantaoXu pushed a commit to JiantaoXu/vllm that referenced this pull request Mar 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants