Skip to content

Add support for Mistral Large 3 inference with Flashinfer MoE#33174

Merged
vllm-bot merged 10 commits intovllm-project:mainfrom
dbari:mistral_large_3_blackwell_blockwise
Jan 31, 2026
Merged

Add support for Mistral Large 3 inference with Flashinfer MoE#33174
vllm-bot merged 10 commits intovllm-project:mainfrom
dbari:mistral_large_3_blackwell_blockwise

Conversation

@dbari
Copy link
Copy Markdown
Contributor

@dbari dbari commented Jan 27, 2026

Purpose

Allow inference of Mistral Large 3 on Blackwell with Flashinfer TRTLLM (latency) backend for better performance.

This PR updates Flashinfer to 0.6.2 that includes fixed kernels for blockwise quantized FP8 MoE and makes small changes to the vLLM code to allow calling this code. It also allows calling Flashinfer for per-tensor quantized models (fp8 quantization).

Also, optimized Triton configurations are added for Mistral Large 3 FP8 TP8 with and without EP, as well as for the Eagle draft model.

This PR supersedes #29884. It also needs the changes of #33008 to be merged in order for Mistral Large 3 to be loaded.

Work on benchmark_moe.py to generate Triton configurations is in progress and any changes will follow in a next PR. completed and integrated to this PR.

Test result

Accuracy check with gsk8k:

local-completions (base_url=http://localhost:3825,model=mistralai/Mistral-Large-3-675B-Instruct-2512,tokenized_requests=False,,tokenizer_backend=None,num_concurrent=8,timeout=240,max_retries=3,stream=False), gen_kwargs: (temperature=0.01,top_p=0.9999999,max_gen_toks=8192), limit: 1000.0, num_fewshot: None, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.940|±  |0.0075|
|     |       |strict-match    |     5|exact_match|↑  |0.907|±  |0.0092|

Performance on 8xB200, using Flashinfer MLA and FP8 KV-cache, varying only the MoE backend:

MoE Backend Concurrency TTFT (ms) Output Token Throughput (t/s)
Triton 16 600.48 847.87
Flashinfer 16 515.33 978.03
Triton 64 858.36 2166.99
Flashinfer 64 700.76 2526.54

Contributors

@dbari, @DanBlanaru, @evezhier, @hypdeb

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added ci/build deepseek Related to DeepSeek models nvidia labels Jan 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the flashinfer-python dependency to version 0.6.2 and introduces new JSON configuration files for FusedMoE layers across various NVIDIA GPUs (B200, GB200, H200) and data types (fp8_w8a8), defining optimized kernel parameters. The changes also expand support for FP8 quantization in FlashInfer's _supports_routing_method to include Renormalize and RenormalizeNaive routing types, in addition to Llama4. The apply_fi_trtllm_fp8_per_tensor_moe function was refactored to dynamically pass the routing_method_type to the FlashInfer kernel and to make Llama4-specific assertions conditional, rather than universally applied. Furthermore, the DeepSeekV2 model's MoE layer initialization was updated to handle grouped_topk parameters more flexibly, setting them to None if n_group and topk_group are both 1. A minor refactoring of the flashinfer.mm_fp4 import path was also included. Review comments highlight the importance of ensuring the conditional Llama4 assertion block correctly aligns with the actual routing method and validating that the dynamically passed routing_method_type is always correctly configured to prevent unexpected kernel behavior.

@@ -147,7 +147,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe(
local_expert_offset=layer.ep_rank * layer.local_num_experts,
local_num_experts=layer.local_num_experts,
use_routing_scales_on_input=apply_router_weight_on_input,
routing_method_type=RoutingMethodType.Llama4,
routing_method_type=layer.routing_method_type,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The routing_method_type is being directly assigned from layer.routing_method_type. It's important to validate that layer.routing_method_type is correctly set and corresponds to the expected routing method for the given layer configuration. An incorrect routing_method_type could lead to unexpected behavior or errors during the FlashInfer kernel execution.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@reviewers: Do you think we should do something about this? The layer is a FusedMoE at all call sites and always has the routing_method_type, so I don't know why Gemini produced a high priority warning. I would leave it as is, but am open to suggestions.

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@dbari dbari force-pushed the mistral_large_3_blackwell_blockwise branch from a7edd1f to 487151d Compare January 28, 2026 13:58
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable to me! Could you share an eval to validate it works e2e? Also a performance result would be nice to have

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
dbari added 5 commits January 29, 2026 05:17
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Additionally:
- Fixed a serialization bug when calling Ray
- Fixed the selection of LLM architecture for some other models

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
@mergify mergify bot added the performance Performance-related issues label Jan 29, 2026
@@ -0,0 +1,147 @@
{
"triton_version": "3.4.0",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will note that these triton versions do seem out of date for modern torch+triton. It should be triton==3.5.1 for what we use on main

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, this was generated a while ago. I'll see if anything changes if I run the benchmark in the current environment.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be ok to leave these as they are? It would take quite a bit of time to regenerate. Also, please keep in mind that the older ones are for the per-tensor FP8 quantization, which is only used in the Eagle draft model. The main model uses blockwise quantization and those configurations are newer (3.5.x).

@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Jan 29, 2026

Looks reasonable to me! Could you share an eval to validate it works e2e? Also a performance result would be nice to have

I added both to the summary.

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 29, 2026
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 29, 2026
@DarkLight1337 DarkLight1337 enabled auto-merge (squash) January 30, 2026 06:26
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
auto-merge was automatically disabled January 30, 2026 09:15

Head branch was pushed to by a user without write access

@dbari dbari requested a review from WoosukKwon as a code owner January 30, 2026 09:15
@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Jan 30, 2026

The PR is ready to merge, however there are still failing tests that as far as I can tell don't have to do with it. Could someone please take a look and/or restart the failing CI steps?

FYI @mgoin

@DarkLight1337
Copy link
Copy Markdown
Member

Merged

@vllm-bot vllm-bot merged commit f0bca83 into vllm-project:main Jan 31, 2026
92 of 94 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 31, 2026
@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

  • I don't understand why this PR touches the static quantization. the mistral model uses block fp8

"quantization_config": {
--
"config_groups": {
"FP8_BLOCK": {
"format": "float-quantized",
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": 128,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "group",
"symmetric": true,
"type": "float"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": [
128,
128
],
"dynamic": false,
"group_size": null,
"num_bits": 8,
"observer": "static_minmax",
"observer_kwargs": {},
"strategy": "block",
"symmetric": true,
"type": "float"
}
}
},


@dbari
Copy link
Copy Markdown
Contributor Author

dbari commented Feb 3, 2026

@robertgshaw2-redhat it was due to the use of per-tensor FP8 by the Eagle model, see here.

PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…roject#33174)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Signed-off-by: Pai <416932041@qq.com>
Comment on lines +301 to +304
if (n_group, topk_group) == (1, 1):
n_group = None
topk_group = None
use_grouped_topk = False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change broke accuracy for Kimi-K2 models with MXINT4.. Why are we setting n_group and topk_group to None here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because having one group is equivalent to not using grouped topK. However, there was an issue #33792 that caused Kimi-K2 to select the wrong routing and resulted in wrong output. I'll submit fixes for all related issues.

ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
…roject#33174)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants