Skip to content

[NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5#20394

Open
trevor-m wants to merge 1 commit intosgl-project:mainfrom
trevor-m:trtllm-moe-routed
Open

[NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5#20394
trevor-m wants to merge 1 commit intosgl-project:mainfrom
trevor-m:trtllm-moe-routed

Conversation

@trevor-m
Copy link
Collaborator

@trevor-m trevor-m commented Mar 12, 2026

Modifications

  • The kernel will always output in bf16, so we might have to cast it back to the hidden states dtype.
  • Enable align_fp8_moe_weights_for_flashinfer_trtllm for the routed version
  • Use same fused_func as non-routed and just use topk output checker to determine which one to run.
  • Fix issue where getattr doesn't use the default value when the attr is set to none.
  • Add comments to enable autotune, remove copy when flashinfer bugs are fixed.

Issues

Accuracy Tests

TP4

Accuracy: 0.946
Invalid: 0.000
Latency: 39.622 s
Output throughput: 3120.237 token/s

TEP4

Accuracy: 0.952
Invalid: 0.000
Latency: 30.512 s
Output throughput: 4069.051 token/s

Benchmarking and Profiling

4xGB200 Results - TP4: 9.04% speedup over default (triton) moe, TEP4: 5.48% speedup

TP4 + flashinfer_trtllm moe

python -m sglang.launch_server   --model-path MiniMaxAI/MiniMax-M2.5  --reasoning-parser minimax   --tool-call-parser minimax-m2   --trust-remote-code  --tp 4 --mem-fraction-static 0.9   --kv-cache-dtype fp8_e4m3    --attention-backend flashinfer --moe-runner-backend flashinfer_trtllm_routed --quantization fp8 --enforce-piecewise-cuda-graph
python3 /trevor/bench_serving/benchmark_serving.py --backend openai --host 0.0.0.0 --port 30000 --model MiniMaxAI/MiniMax-M2.5 --num-prompts 1280 --trust-remote-code --ignore-eos --max-concurrency 128 --random-input-len 1024 --random-output-len 1024 --random-range-ratio 0.8 --use-chat-template --dataset-name random
============ Serving Benchmark Result ============
Successful requests:                     1280      
Benchmark duration (s):                  368.20    
Total input tokens:                      1185656   
Total generated tokens:                  1177625   
Request throughput (req/s):              3.48      
Output token throughput (tok/s):         3198.31   
Total Token throughput (tok/s):          6418.44   
---------------Time to First Token----------------
Mean TTFT (ms):                          320.41    
Median TTFT (ms):                        140.95    
P99 TTFT (ms):                           2911.07   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          38.75     
Median TPOT (ms):                        39.49     
P99 TPOT (ms):                           41.10     
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.78     
Median ITL (ms):                         31.50     
P99 ITL (ms):                            139.07    
==================================================

TP4 + default moe

python -m sglang.launch_server   --model-path MiniMaxAI/MiniMax-M2.5  --reasoning-parser minimax   --tool-call-parser minimax-m2   --trust-remote-code  --tp 4 --mem-fraction-static 0.9   --kv-cache-dtype fp8_e4m3    --attention-backend flashinfer 
python3 /trevor/bench_serving/benchmark_serving.py --backend openai --host 0.0.0.0 --port 30000 --model MiniMaxAI/MiniMax-M2.5 --num-prompts 1280 --trust-remote-code --ignore-eos --max-concurrency 128 --random-input-len 1024 --random-output-len 1024 --random-range-ratio 0.8 --use-chat-template --dataset-name random
============ Serving Benchmark Result ============
Successful requests:                     1280      
Benchmark duration (s):                  401.48    
Total input tokens:                      1185656   
Total generated tokens:                  1177625   
Request throughput (req/s):              3.19      
Output token throughput (tok/s):         2933.23   
Total Token throughput (tok/s):          5886.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          335.99    
Median TTFT (ms):                        144.27    
P99 TTFT (ms):                           3050.40   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          42.27     
Median TPOT (ms):                        42.98     
P99 TPOT (ms):                           44.62     
---------------Inter-token Latency----------------
Mean ITL (ms):                           42.30     
Median ITL (ms):                         35.23     
P99 ITL (ms):                            138.13    
==================================================

TEP4 + flashinfer_trtllm_routed moe

============ Serving Benchmark Result ============
Successful requests:                     1280      
Benchmark duration (s):                  366.20    
Total input tokens:                      1185656   
Total generated tokens:                  1177625   
Request throughput (req/s):              3.50      
Output token throughput (tok/s):         3215.76   
Total Token throughput (tok/s):          6453.45   
---------------Time to First Token----------------
Mean TTFT (ms):                          347.28    
Median TTFT (ms):                        134.17    
P99 TTFT (ms):                           3416.14   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          38.50     
Median TPOT (ms):                        39.17     
P99 TPOT (ms):                           40.71     
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.53     
Median ITL (ms):                         31.76     
P99 ITL (ms):                            130.63    
==================================================

TEP4 + default moe

============ Serving Benchmark Result ============
Successful requests:                     1280      
Benchmark duration (s):                  386.27    
Total input tokens:                      1185656   
Total generated tokens:                  1177625   
Request throughput (req/s):              3.31      
Output token throughput (tok/s):         3048.70   
Total Token throughput (tok/s):          6118.18   
---------------Time to First Token----------------
Mean TTFT (ms):                          350.67    
Median TTFT (ms):                        141.24    
P99 TTFT (ms):                           3391.56   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          40.61     
Median TPOT (ms):                        41.29     
P99 TPOT (ms):                           42.93     
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.64     
Median ITL (ms):                         33.47     
P99 ITL (ms):                            135.83    
==================================================

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added the quant LLM Quantization label Mar 12, 2026
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch from d49c17a to f8b9aae Compare March 13, 2026 01:28
@trevor-m trevor-m changed the title [NVIDIA] Add modular flashinfer trtllm fp8 moe and enable for MiniMax-M2.5 [NVIDIA] Modular flashinfer trtllm fp8 moe improvements, +fp4 support, +enable for MiniMax-M2.5 Mar 13, 2026
@trevor-m trevor-m changed the title [NVIDIA] Modular flashinfer trtllm fp8 moe improvements, +fp4 support, +enable for MiniMax-M2.5 [NVIDIA] Modular flashinfer trtllm fp8 moe improvements, +enable for MiniMax-M2.5 Mar 16, 2026
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch 3 times, most recently from 4f0c9c8 to 1563b11 Compare March 16, 2026 18:41
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 16, 2026
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch 2 times, most recently from 2231a2c to b652fca Compare March 16, 2026 19:00
):
backend = "flashinfer_trtllm_routed"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep routed unit tests since routed and fused are 2 separate code paths?

@zianglih
Copy link
Contributor

zianglih commented Mar 17, 2026

Hi @trevor-m , the key motivation to the original explicit --moe-runner-backend=flashinfer_trtllm_routed flag is to support MoE expert rollout routing replay for RL use cases. In RL use case, even if the topk + MoE routing method is supported by fused kernel, we still want to explicitly use the routed kernel which triggers the SGLang standard topk code path for expert selection capturing.

I believe it is better to have 2 explict backends as in my previous design:

  • flashinfer_trtllm:
    • For general purpose inference and serving.
    • FlashInfer handles topk and routing.
    • SGLang dispatches to FlashInfer trtllm_*_block_scale_moe api.
  • flashinfer_trtllm_routed
    • Specifically for RL use case which requires expert routing replay.
    • SGLang handles topk and routing and expert selection capturing if --enable-return-routed-experts.
    • SGLang dispatches to FlashInfer trtllm_*_block_scale_routed_moe api.

@trevor-m trevor-m force-pushed the trtllm-moe-routed branch from 0c705ef to 2ba0522 Compare March 17, 2026 17:44
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch 4 times, most recently from 75eab54 to 619fcb8 Compare March 20, 2026 20:11
@trevor-m trevor-m changed the title [NVIDIA] Modular flashinfer trtllm fp8 moe improvements, +enable for MiniMax-M2.5 [NVIDIA] Enable flashinfer_trtllm_routed MoE for MiniMax-M2.5 Mar 20, 2026
@trevor-m trevor-m changed the title [NVIDIA] Enable flashinfer_trtllm_routed MoE for MiniMax-M2.5 [NVIDIA] Enable fp8 flashinfer_trtllm_routed MoE for MiniMax-M2.5 Mar 20, 2026
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch from 619fcb8 to 3daf5fe Compare March 23, 2026 23:49
Fix circular import

Fix

skip autotune

fix routign method

cast back to fp16

fixes

fixes

comments

restore routed flag

fixes

fixes
@trevor-m trevor-m force-pushed the trtllm-moe-routed branch from 3daf5fe to ef30d42 Compare March 23, 2026 23:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants