Add support for Mistral Large 3 inference with Flashinfer MoE#33174
Add support for Mistral Large 3 inference with Flashinfer MoE#33174vllm-bot merged 10 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request updates the flashinfer-python dependency to version 0.6.2 and introduces new JSON configuration files for FusedMoE layers across various NVIDIA GPUs (B200, GB200, H200) and data types (fp8_w8a8), defining optimized kernel parameters. The changes also expand support for FP8 quantization in FlashInfer's _supports_routing_method to include Renormalize and RenormalizeNaive routing types, in addition to Llama4. The apply_fi_trtllm_fp8_per_tensor_moe function was refactored to dynamically pass the routing_method_type to the FlashInfer kernel and to make Llama4-specific assertions conditional, rather than universally applied. Furthermore, the DeepSeekV2 model's MoE layer initialization was updated to handle grouped_topk parameters more flexibly, setting them to None if n_group and topk_group are both 1. A minor refactoring of the flashinfer.mm_fp4 import path was also included. Review comments highlight the importance of ensuring the conditional Llama4 assertion block correctly aligns with the actual routing method and validating that the dynamically passed routing_method_type is always correctly configured to prevent unexpected kernel behavior.
| @@ -147,7 +147,7 @@ def apply_fi_trtllm_fp8_per_tensor_moe( | |||
| local_expert_offset=layer.ep_rank * layer.local_num_experts, | |||
| local_num_experts=layer.local_num_experts, | |||
| use_routing_scales_on_input=apply_router_weight_on_input, | |||
| routing_method_type=RoutingMethodType.Llama4, | |||
| routing_method_type=layer.routing_method_type, | |||
There was a problem hiding this comment.
The routing_method_type is being directly assigned from layer.routing_method_type. It's important to validate that layer.routing_method_type is correctly set and corresponds to the expected routing method for the given layer configuration. An incorrect routing_method_type could lead to unexpected behavior or errors during the FlashInfer kernel execution.
There was a problem hiding this comment.
@reviewers: Do you think we should do something about this? The layer is a FusedMoE at all call sites and always has the routing_method_type, so I don't know why Gemini produced a high priority warning. I would leave it as is, but am open to suggestions.
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
a7edd1f to
487151d
Compare
mgoin
left a comment
There was a problem hiding this comment.
Looks reasonable to me! Could you share an eval to validate it works e2e? Also a performance result would be nice to have
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Additionally: - Fixed a serialization bug when calling Ray - Fixed the selection of LLM architecture for some other models Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
| @@ -0,0 +1,147 @@ | |||
| { | |||
| "triton_version": "3.4.0", | |||
There was a problem hiding this comment.
I will note that these triton versions do seem out of date for modern torch+triton. It should be triton==3.5.1 for what we use on main
There was a problem hiding this comment.
Good catch, this was generated a while ago. I'll see if anything changes if I run the benchmark in the current environment.
There was a problem hiding this comment.
Would it be ok to leave these as they are? It would take quite a bit of time to regenerate. Also, please keep in mind that the older ones are for the per-tensor FP8 quantization, which is only used in the Eagle draft model. The main model uses blockwise quantization and those configurations are newer (3.5.x).
I added both to the summary. |
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Head branch was pushed to by a user without write access
|
The PR is ready to merge, however there are still failing tests that as far as I can tell don't have to do with it. Could someone please take a look and/or restart the failing CI steps? FYI @mgoin |
|
Merged |
|
|
@robertgshaw2-redhat it was due to the use of per-tensor FP8 by the Eagle model, see here. |
…roject#33174) Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Signed-off-by: Pai <416932041@qq.com>
| if (n_group, topk_group) == (1, 1): | ||
| n_group = None | ||
| topk_group = None | ||
| use_grouped_topk = False |
There was a problem hiding this comment.
This change broke accuracy for Kimi-K2 models with MXINT4.. Why are we setting n_group and topk_group to None here?
There was a problem hiding this comment.
Because having one group is equivalent to not using grouped topK. However, there was an issue #33792 that caused Kimi-K2 to select the wrong routing and resulted in wrong output. I'll submit fixes for all related issues.
…roject#33174) Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
Purpose
Allow inference of Mistral Large 3 on Blackwell with Flashinfer TRTLLM (
latency) backend for better performance.This PR updates Flashinfer to 0.6.2 that includes fixed kernels for blockwise quantized FP8 MoE and makes small changes to the vLLM code to allow calling this code. It also allows calling Flashinfer for per-tensor quantized models (
fp8quantization).Also, optimized Triton configurations are added for Mistral Large 3 FP8 TP8 with and without EP, as well as for the Eagle draft model.
This PR supersedes #29884. It also needs the changes of #33008 to be merged in order for Mistral Large 3 to be loaded.
Work on
benchmark_moe.pyto generate Triton configurations isin progress and any changes will follow in a next PR.completed and integrated to this PR.Test result
Accuracy check with
gsk8k:Performance on 8xB200, using Flashinfer MLA and FP8 KV-cache, varying only the MoE backend:
Contributors
@dbari, @DanBlanaru, @evezhier, @hypdeb