Skip to content

[NVIDIA] Add flashinfer all-to-all MOE dispatcher#14668

Merged
Fridge003 merged 15 commits intosgl-project:mainfrom
trevor-m:a2a
Jan 24, 2026
Merged

[NVIDIA] Add flashinfer all-to-all MOE dispatcher#14668
Fridge003 merged 15 commits intosgl-project:mainfrom
trevor-m:a2a

Conversation

@trevor-m
Copy link
Collaborator

@trevor-m trevor-m commented Dec 8, 2025

Draft PR since flashinfer-ai/flashinfer#2102 is not yet merged into flashinfer.

flashinfer-ai/flashinfer#2102 is now merged.

Motivation

This PR integrates the latest TRT-LLM moe all-to-all kernels into sglang (AKA nvlink one sided allotall or mnnvlthroughput alltoall):

NVLINK one-sided comm AllToAll strategy for throughput scenarios.

This implementation utilizes symmetric memory to enable peer-to-peer access between GPUs over NVLink.
The kernels only take the role as one side of the communication: the dispatch kernel puts the data
into peer ranks' symmetric memory from local buffer, while the combine kernel gets the data from peer
ranks' symmetric memory and reduces the data into local buffer. It is the most efficient implementation
by now, but requires symmetric memory size proportional to max_num_tokens * n_ranks, which may not
scale well for very large-scale parallelization.

Currently I have tested it with the flashinfer_cutlass moe runner backend and use fp4 quantize before communication. It also allows flashinfer_cutlass moe to write directly to the workspace buffer.

Remaining issues:

  1. Need to improve max_num_tokens - should multiply the per_rank value by ep_size. Or calculate this automatically based on server args.
  2. Currently flashinfer a2a doesn't support when ranks have 0 tokens, so we need to create a dummy token in those cases. Doing the tensor allocation prevents us from using streams to overlap with shared experts.

Modifications

  • Add --moe-a2a-backend=flashinfer

Accuracy Tests

SGLANG_MOE_NVFP4_DISPATCH=1 python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-v2 --trust-remote-code --quantization modelopt_fp4 --tp 4 --moe-runner-backend flashinfer_cutlass --ep-size 4 --dp 4 --enable-dp-attention --mem-fraction-static 0.85 --max-running-requests 2048 --stream-interval 5 --enable-dp-lm-head --attention-backend triton --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 --disable-radix-cache --moe-a2a-backend flashinfer
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2048 --random-input 1024 --random-output 1024 --random-range-ratio 1 --max-concurrency 2048
python3 benchmark/gsm8k/bench_sglang.py --num-shots 8 --num-questions 1319 --parallel 1319 --port=30000
Accuracy: 0.955
Invalid: 0.000
Latency: 166.179 s
Output throughput: 874.407 token/s

Benchmarking and Profiling

Single node results are below. Working on multinode benchmarking

Single node 4xGB200 results and profiles

SGLANG_MOE_NVFP4_DISPATCH=1 python3 -m sglang.launch_server --model-path nvidia/DeepSeek-R1-0528-FP4-v2 --trust-remote-code --quantization modelopt_fp4 --tp 4 --moe-runner-backend flashinfer_cutlass --ep-size 4 --dp 4 --enable-dp-attention --mem-fraction-static 0.85 --max-running-requests 2048 --stream-interval 5 --enable-dp-lm-head --cuda-graph-bs 1 2 4 8 16 32 64 128 256 512 --disable-radix-cache --moe-a2a-backend flashinfer --disable-shared-experts-fusion
python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 2048 --random-input 1024 --random-output 1024 --random-range-ratio 1 --max-concurrency 2048
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 2048      
Successful requests:                     2048      
Benchmark duration (s):                  105.78    
Total input tokens:                      2097152   
Total input text tokens:                 2097152   
Total input vision tokens:               0         
Total generated tokens:                  2097152   
Total generated tokens (retokenized):    2093741   
Request throughput (req/s):              19.36     
Input token throughput (tok/s):          19826.37  
Output token throughput (tok/s):         19826.37  
Peak output token throughput (tok/s):    41273.00  
Peak concurrent requests:                2048      
Total token throughput (tok/s):          39652.75  
Concurrency:                             2033.67   
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   105036.00 
Median E2E Latency (ms):                 105036.65 
---------------Time to First Token----------------
Mean TTFT (ms):                          19192.10  
Median TTFT (ms):                        18945.35  
P99 TTFT (ms):                           35753.33  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          83.91     
Median TPOT (ms):                        84.31     
P99 TPOT (ms):                           96.69     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           83.91     
Median ITL (ms):                         64.86     
P95 ITL (ms):                            134.52    
P99 ITL (ms):                            187.70    
Max ITL (ms):                            6041.95   
==================================================

With SGLANG_MOE_NVFP4_DISPATCH=0 (BF16 dispatch)
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 2048      
Successful requests:                     2048      
Benchmark duration (s):                  113.61    
Total input tokens:                      2097152   
Total input text tokens:                 2097152   
Total input vision tokens:               0         
Total generated tokens:                  2097152   
Total generated tokens (retokenized):    2094097   
Request throughput (req/s):              18.03     
Input token throughput (tok/s):          18459.37  
Output token throughput (tok/s):         18459.37  
Peak output token throughput (tok/s):    39440.00  
Peak concurrent requests:                2048      
Total token throughput (tok/s):          36918.74  
Concurrency:                             2032.17   
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   112730.76 
Median E2E Latency (ms):                 112745.83 
---------------Time to First Token----------------
Mean TTFT (ms):                          22471.12  
Median TTFT (ms):                        22192.53  
P99 TTFT (ms):                           42519.44  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          88.23     
Median TPOT (ms):                        88.62     
P99 TPOT (ms):                           104.26    
---------------Inter-Token Latency----------------
Mean ITL (ms):                           88.23     
Median ITL (ms):                         66.35     
P95 ITL (ms):                            139.80    
P99 ITL (ms):                            192.80    
Max ITL (ms):                            7388.84   
==================================================

With --moe-a2a-backend=none
============ Serving Benchmark Result ============
Backend:                                 sglang    
Traffic request rate:                    inf       
Max request concurrency:                 2048      
Successful requests:                     2048      
Benchmark duration (s):                  107.98    
Total input tokens:                      2097152   
Total input text tokens:                 2097152   
Total input vision tokens:               0         
Total generated tokens:                  2097152   
Total generated tokens (retokenized):    2093493   
Request throughput (req/s):              18.97     
Input token throughput (tok/s):          19420.79  
Output token throughput (tok/s):         19420.79  
Peak output token throughput (tok/s):    43370.00  
Peak concurrent requests:                2048      
Total token throughput (tok/s):          38841.58  
Concurrency:                             2036.25   
----------------End-to-End Latency----------------
Mean E2E Latency (ms):                   107365.61 
Median E2E Latency (ms):                 107376.61 
---------------Time to First Token----------------
Mean TTFT (ms):                          19164.99  
Median TTFT (ms):                        18753.23  
P99 TTFT (ms):                           35803.22  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          86.22     
Median TPOT (ms):                        86.69     
P99 TPOT (ms):                           98.54     
---------------Inter-Token Latency----------------
Mean ITL (ms):                           86.22     
Median ITL (ms):                         66.88     
P95 ITL (ms):                            149.64    
P99 ITL (ms):                            224.95    
Max ITL (ms):                            5954.79   
==================================================

Dispatch (bs=512) with --moe-a2a-backend=flashinfer

Screenshot 2025-12-10 at 2 28 05 PM

Dispatch (bs=512) with --moe-a2a-backend=none (FP4 allgather)

Screenshot 2025-12-08 at 11 48 14 AM

Combine (bs=512) with --moe-a2a-backend=flashinfer

Screenshot 2025-12-08 at 11 50 12 AM

Combine (bs-512) --moe-a2a-backend=none (reduce-scatter)

Screenshot 2025-12-08 at 11 50 15 AM (bs=512) with --moe-a2a-backend=none (reduce-scatter)

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added quant LLM Quantization deepseek labels Dec 8, 2025
Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trevor-m trevor-m changed the title Draft: Add flashinfer all-to-all MOE dispatcher [NVIDIA] Add flashinfer all-to-all MOE dispatcher Dec 15, 2025
@trevor-m trevor-m force-pushed the a2a branch 2 times, most recently from eafb94e to 61a262e Compare December 17, 2025 04:34
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Dec 18, 2025

Hi, I am wondering whether this will be helpful for flashinfer_trtllm moe, and whether there will be support for it? Thanks!

@trevor-m
Copy link
Collaborator Author

Hi, I am wondering whether this will be helpful for flashinfer_trtllm moe, and whether there will be support for it? Thanks!

@fzyzcjy It looks like it should work based on NVIDIA/TensorRT-LLM@e4bf29b
I can give it a try

Copy link
Collaborator

@ch-wan ch-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Fridge003 Fridge003 mentioned this pull request Dec 21, 2025
6 tasks
@Fridge003
Copy link
Collaborator

/tag-and-rerun-ci

@trevor-m trevor-m enabled auto-merge (squash) January 6, 2026 23:25
@trevor-m trevor-m disabled auto-merge January 7, 2026 00:46
@fzyzcjy
Copy link
Collaborator

fzyzcjy commented Jan 12, 2026

@trevor-m: I can give it a try

Hi, is there any updates about that? if I understand correctly trtllm moe should be used for decode, and thus this feature is most useful when combined with that

@@ -0,0 +1,322 @@
import unittest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this test to test/srt/ep and register it at nightly test.
Can open a following PR for this

@trevor-m
Copy link
Collaborator Author

trevor-m commented Jan 22, 2026

@trevor-m: I can give it a try

Hi, is there any updates about that? if I understand correctly trtllm moe should be used for decode, and thus this feature is most useful when combined with that

@fzyzcjy I tried this out. The problem is that currently flashinfer doesn't support doing trtllm moe without fused routing. For all-to-all, we need to do the routing separately first, then do the communication, then run moe. Here is my WIP branch for the sglang changes which would allow it otherwise: https://github.com/trevor-m/sglang/tree/trtllm-a2a-wip

Edit: I found that flashinfer has a separate api for this: "trtllm_fp4_block_scale_routed_moe". I will try it out.

I looked at our WideEP configs and it looks like flashinfer_cutedsl moe is actually what we use for decode (flashinfer_cutlass is used for prefill). Let me see if that can be enabled with this all-to-all.

@Fridge003 Fridge003 merged commit 2c2c4e4 into sgl-project:main Jan 24, 2026
279 of 292 checks passed
Johnsonms pushed a commit to Johnsonms/sglang that referenced this pull request Feb 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek documentation Improvements or additions to documentation Grace Blackwell nvidia quant LLM Quantization run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants