Skip to content

Support mnnvl all2allv from Flashinfer#21003

Merged
mgoin merged 9 commits intovllm-project:mainfrom
wenscarl:dev_a2a
Sep 24, 2025
Merged

Support mnnvl all2allv from Flashinfer#21003
mgoin merged 9 commits intovllm-project:mainfrom
wenscarl:dev_a2a

Conversation

@wenscarl
Copy link
Contributor

@wenscarl wenscarl commented Jul 15, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Needs flashinfer-ai/flashinfer#1245

Purpose

Test Plan

VLLM_ALL2ALL_BACKEND="flashinfer_all2allv" \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="throughput" \
  /home/shuw/.local/bin/vllm serve nvidia/DeepSeek-R1-FP4 \
    --quantization="modelopt_fp4" \
    --trust-remote-code \
    --max-model-len=2048 \
    --block-size=128 \
    --max-num-seqs=256 \
    --enable-expert-parallel \
    --gpu_memory_utilization=0.8 \
    --tensor-parallel-size 1 \
    --data-parallel-size 4
    
python benchmarks/benchmark_serving.py \
  --model nvidia/DeepSeek-R1-FP4 \
  --dataset-name random \
  --ignore-eos \
  --num-prompts 256 \
  --max-concurrency 256 \
  --random-input-len 128 \
  --random-output-len 1024 

vs.
VLLM_ALL2ALL_BACKEND="naive" \

...

Test Result

accuracy:

VLLM_USE_FLASHINFER_MOE_FP4=1 VLLM_FLASHINFER_MOE_BACKEND="throughput" VLLM_WORKER_MULTIPROC_METHOD=spawn VLLM_ALL2ALL_BACKEND="flashinfer_all2allv" \
lm_eval --model vllm --model_args pretrained=nvidia/DeepSeek-R1-FP4,quantization=modelopt_fp4,data_parallel_size=4,enable_expert_parallel=True,tensor_parallel_size=1,enforce_eager=True,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9371|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.9340|±  |0.0068|

perf:
Alltoallv(this PR):

Successful requests:                     256       
Benchmark duration (s):                  111.43    
Total input tokens:                      32512     
Total generated tokens:                  262144    
Request throughput (req/s):              2.30      
Output token throughput (tok/s):         2352.63   
Total Token throughput (tok/s):          2644.42   
---------------Time to First Token----------------
Mean TTFT (ms):                          9734.68   
Median TTFT (ms):                        9769.39   
P99 TTFT (ms):                           9794.07   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          99.37     
Median TPOT (ms):                        99.34     
P99 TPOT (ms):                           99.35     
---------------Inter-token Latency----------------
Mean ITL (ms):                           99.37     
Median ITL (ms):                         90.51     
P99 ITL (ms):                            143.30    

allgather-reducescatter

Successful requests:                     256       
Benchmark duration (s):                  134.83    
Total input tokens:                      32512     
Total generated tokens:                  262144    
Request throughput (req/s):              1.90      
Output token throughput (tok/s):         1944.28   
Total Token throughput (tok/s):          2185.42   
---------------Time to First Token----------------
Mean TTFT (ms):                          9614.91   
Median TTFT (ms):                        9648.26   
P99 TTFT (ms):                           9673.86   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          122.36    
Median TPOT (ms):                        122.33    
P99 TPOT (ms):                           122.34    
---------------Inter-token Latency----------------
Mean ITL (ms):                           122.36    
Median ITL (ms):                         122.26    
P99 ITL (ms):                            129.50 

(Optional) Documentation Update

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify
Copy link

mergify bot commented Jul 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 15, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Flashinfer's mnnvl all2allv for Mixture-of-Experts (MoE) layers, which is a significant performance enhancement for distributed inference. The changes are comprehensive, touching custom ops, distributed communicators, the MoE layer implementation, and quantization methods.

The core of the change is the new FlashInferAllToAllManager and its integration into the MoE forward pass. The review focuses on potential issues like hardcoded values, code duplication, and correctness of the communication logic to ensure the new feature is robust and maintainable.

Comment on lines +300 to +301
gpus_per_node: int = 4, #TODO(shuw): remove hardcode
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The gpus_per_node parameter is hardcoded. This limits the flexibility of the implementation for different hardware configurations. Consider making this configurable.

Comment on lines +117 to +119
print("xxxx"*100)
print(all2all_manager)
print(f"ep_size:{self.ep_size}, {self.ep_rank}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements appear to be for debugging purposes and should be removed before merging.

Comment on lines +117 to +133
print("xxxx"*100)
print(all2all_manager)
print(f"ep_size:{self.ep_size}, {self.ep_rank}")
assert all2all_manager is not None
# TODO(shuw): need to consider chunking for global_num_tokens_cpu
x1, topk_ids1, topk_weights1, alltoall_info = all2all_manager.dispatch(
get_dp_group().device_communicator,
global_num_tokens_cpu,
a1,
topk_ids,
topk_weights,
top_k,
num_experts,
self.ep_rank,
self.ep_size,
)
self.alltoall_info = alltoall_info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

It appears that the all2all_info variable is assigned a value but never used. If this is the case, the assignment should be removed to avoid confusion.

Comment on lines +144 to +149
if enable_flashinfer_fp4_allgather:
topk_weights, topk_ids, a1q, a1q_scale = \
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale],
dim=0,
sizes=get_local_sizes(local_tokens))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code block if enable_flashinfer_fp4_allgather: seems to perform a redundant communication. The all2all_manager.dispatch call on line 122 already performs a gather operation. This is then followed by another get_dp_group().all_gatherv here. This appears to be redundant and could impact performance. Please verify if both are necessary. If not, the redundant call should be removed.

Comment on lines +150 to +157
# if enable_flashinfer_alltoall:
# print("all2allcalling"*100)
# a1q = MnnvlMoe.mnnvl_moe_alltoallv(a1q, self.alltoall_info,
# self.alltoall_workspace,
# self.ep_rank, self.ep_size)
# a1q_scale = MnnvlMoe.mnnvl_moe_alltoallv(
# a1q_scale, alltoall_info, self.alltoall_workspace,
# self.ep_rank, self.ep_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are large blocks of commented-out code related to mnnvl_moe_alltoallv. This code should be either implemented or removed to avoid confusion and keep the codebase clean.

@mergify
Copy link

mergify bot commented Aug 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify
Copy link

mergify bot commented Sep 4, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 4, 2025
@tlrmchlsmth tlrmchlsmth self-assigned this Sep 9, 2025
@mergify mergify bot removed the needs-rebase label Sep 9, 2025
Signed-off-by: Shu Wang <shuw@nvidia.com>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 23, 2025
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) September 23, 2025 16:00
@mergify
Copy link

mergify bot commented Sep 23, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 23, 2025
@mgoin
Copy link
Member

mgoin commented Sep 23, 2025

@wenscarl can you fix the latest merge conflict?

Signed-off-by: Shu Wang <shuw@nvidia.com>
auto-merge was automatically disabled September 24, 2025 16:05

Head branch was pushed to by a user without write access

@mergify mergify bot removed the needs-rebase label Sep 24, 2025
@mgoin mgoin merged commit 54e42b7 into vllm-project:main Sep 24, 2025
56 checks passed
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
rtourgeman pushed a commit to rtourgeman/vllm that referenced this pull request Nov 10, 2025
Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants