Skip to content

Fix RoutingMethodType logic#33919

Merged
vllm-bot merged 5 commits intovllm-project:mainfrom
dbari:dbariamis/fix-non-dsv3-routing-with-1-group
Feb 6, 2026
Merged

Fix RoutingMethodType logic#33919
vllm-bot merged 5 commits intovllm-project:mainfrom
dbari:dbariamis/fix-non-dsv3-routing-with-1-group

Conversation

@dbari
Copy link
Contributor

@dbari dbari commented Feb 5, 2026

Purpose

This PR contains two fixes for #33792:

  • Fix the selection logic for RoutingMethodType in fused_topk_bias_router.py and fused_topk_router.py
  • When use_grouped_topk=True, the GroupedTopKRouter did not find any valid routing methods and there is only one group, fall back to the non-grouped routers

The latter point covers Mistral Large 3, which has n_group=1 and topk_group=1 but uses Renormalize instead of DeepSeekV3 routing, handled by the FusedTopKRouter.

After #33858 was merged, Mistral Large 3 is broken. This PR fixes the vLLM part, however Flashinfer PR#2502 is also needed [edit: has been merged and added to 0.6.3].

Test Plan

Will test the accuracy of the models affected by recent relevant PRs: Mistral Large 3, Kimi-K2 and Deepseek V2-Lite.

Test Result

Results for RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8, using the default MoE backend (FLASHINFER_CUTLASS):

GSM8K Results for RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8:
  Measured metric: 0.7832
  Expected metric: 0.7200
  Tolerance: 0.0800
  Questions: 1319
  Invalid rate: 0.001
  Latency: 95.6s
  QPS: 13.8
✅ GSM8K test passed for RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8

Results for mistralai/Mistral-Large-3-675B-Instruct-2512, using the default MoE backend (FLASHINFER_TRTLLM):

local-completions (base_url=http://localhost:3825,model=mistralai/Mistral-Large-3-675B-Instruct-2512,tokenized_requests=False,,tokenizer_backend=None,num_concurrent=8,timeout=240,max_retries=3,stream=False), gen_kwargs: (temperature=0.01,top_p=0.9999999,max_gen_toks=8192), limit: 1000.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.942|±  |0.0074|
|     |       |strict-match    |     5|exact_match|↑  |0.909|±  |0.0091|

Results for moonshotai/Kimi-K2-Instruct, using the default MoE backend (FLASHINFER_TRTLLM):

local-completions (base_url=http://localhost:3825,model=/models/Kimi-K2-Instruct,tokenized_requests=False,,tokenizer_backend=None,num_concurrent=8,timeout=240,max_retries=3,stream=False), gen_kwargs: (temperature=0.01,top_p=0.9999999,max_gen_toks=8192), limit: 1000.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.947|±  |0.0071|
|     |       |strict-match    |     5|exact_match|↑  |0.946|±  |0.0072|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly fixes the logic for determining RoutingMethodType and adds a necessary fallback for grouped top-k routing. The changes appear to be correct and well-implemented. My main feedback is regarding code duplication for the new routing_method_type property, which has been added to two separate files with identical implementations. I've suggested refactoring this into a shared utility to improve long-term maintainability.

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari
Copy link
Contributor Author

dbari commented Feb 5, 2026

@robertgshaw2-redhat feel free to add yourself as reviewer if interested, I don't seem to be able to do that.

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari dbari marked this pull request as ready for review February 6, 2026 12:10
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari
Copy link
Contributor Author

dbari commented Feb 6, 2026

@mgoin I'm confident that the code is now correct, however it would be good if we could run the tests on all tested models before merging to prevent any side effects. Would it be possible to start such a CI run?

@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs labels Feb 6, 2026
@mgoin
Copy link
Member

mgoin commented Feb 6, 2026

@dbari this error seems related in Fusion E2E TP2 (B200), Kernels (B200), and other B200 tests

(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]   File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py", line 307, in fi_trtllm_fp8_per_tensor_moe
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]     return flashinfer_trtllm_fp8_per_tensor_scale_moe(
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]   File "/usr/local/lib/python3.12/dist-packages/vllm/utils/flashinfer.py", line 102, in wrapper
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]     return impl(*args, **kwargs)
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]   File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 2258, in trtllm_fp8_per_tensor_scale_moe
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]     return get_trtllm_moe_sm100_module().trtllm_fp8_per_tensor_scale_moe(
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]   File "/usr/local/lib/python3.12/dist-packages/flashinfer/fused_moe/core.py", line 1485, in trtllm_fp8_per_tensor_scale_moe_op
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]     activation_type=activation_type.value,
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852]                     ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=9432) ERROR 02-06 15:03:45 [multiproc_executor.py:852] AttributeError: 'int' object has no attribute 'value'

@dbari
Copy link
Contributor Author

dbari commented Feb 6, 2026

That's a bug that snuck in the 0.6.3 release of Flashinfer unrelated to this PR. There's an open issue and PR for it. I'll let you know when there's a fix we can use.

By the way, the previous version 0.6.2 was ok, so it did not affect vLLM. However this PR will not run with 0.6.2, so we have to postpone this PR until a post-release or a new release with the fix is out.

@mgoin
Copy link
Member

mgoin commented Feb 6, 2026

It seems maybe we can get past this if we manually specify the argument from the vLLM side, let me try that

@dbari
Copy link
Contributor Author

dbari commented Feb 6, 2026

You could call it with activation_type=ActivationType.Swiglu, but when the fix comes, it will need to be an integer like the documentation and type annotation say. If a temporary fix is ok for you, we could add it now and remove it on the next Flashinfer update.

FYI it only affects the FP8 per-tensor code path, the blockwise and FP4 are not affected.

Signed-off-by: mgoin <mgoin64@gmail.com>
@vllm-bot vllm-bot merged commit 207c3a0 into vllm-project:main Feb 6, 2026
139 of 143 checks passed
@github-project-automation github-project-automation bot moved this to Done in NVIDIA Feb 6, 2026
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build nvidia ready ONLY add when PR is ready to merge/full CI is needed ready-run-all-tests Trigger CI with all tests for wide-ranging PRs

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants