Skip to content

[Kernel] Add fused grouped_topk kernel for MoE#23274

Merged
simon-mo merged 10 commits intovllm-project:mainfrom
xyang16:kernel
Aug 25, 2025
Merged

[Kernel] Add fused grouped_topk kernel for MoE#23274
simon-mo merged 10 commits intovllm-project:mainfrom
xyang16:kernel

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Aug 20, 2025

Purpose

This PR add fused grouped_topk kernel for MoE.

  • grouped_topk_kernels.cu: grouped topk kernel
    • Added renormalize arg, instead of always renormalize
  • Introduce routed_scaling_factor parameter to both grouped_topk and fused_grouped_topk, which is needed for grouped topk compute.

Test Plan

Added unit tests.

pytest -s -v tests/kernels/moe/test_grouped_topk.py

Test Result

Unit tests passed.

Accuracy Testing

lm_eval --model local-completions \
  --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=8 \
  --tasks gsm8k

Baseline:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

This PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.956|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.956|±  |0.0056|

Benchmarking

Serve on a p5e.48xlarge (8 H200) instance:

vllm serve deepseek-ai/DeepSeek-R1 \
    --tensor-parallel-size 8 \
    --pipeline-parallel-size 1 \
    --max-model-len 16384 \
    --max-num-seqs 8 \
    --trust-remote-code \
    --compilation_config '{"compile_sizes": [1, 2, 4, 8]}'

Benchmark cmd:

python3 benchmarks/benchmark_serving.py \
  --model deepseek-ai/DeepSeek-R1 \
  --dataset-name sharegpt \
  --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
  --max-concurrency 8

Result shows around ~9.7% improvement on output token throughput.

Baseline:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  610.08    
Total input tokens:                      219171    
Total generated tokens:                  173884    
Request throughput (req/s):              1.64      
Output token throughput (tok/s):         285.02    
Total Token throughput (tok/s):          644.27    
---------------Time to First Token----------------
Mean TTFT (ms):                          147.14    
Median TTFT (ms):                        151.12    
P99 TTFT (ms):                           256.96    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          26.93     
Median TPOT (ms):                        26.85     
P99 TPOT (ms):                           36.60     
---------------Inter-token Latency----------------
Mean ITL (ms):                           27.22     
Median ITL (ms):                         23.49     
P99 ITL (ms):                            129.81    
==================================================

This PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  540.57    
Total input tokens:                      219171    
Total generated tokens:                  169096    
Request throughput (req/s):              1.85      
Output token throughput (tok/s):         312.81    
Total Token throughput (tok/s):          718.26    
---------------Time to First Token----------------
Mean TTFT (ms):                          281563.90 
Median TTFT (ms):                        291295.61 
P99 TTFT (ms):                           529872.79 
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.52     
Median TPOT (ms):                        24.07     
P99 TPOT (ms):                           40.78     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.77     
Median ITL (ms):                         20.35     
P99 ITL (ms):                            131.90    
==================================================

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added ci/build rocm Related to AMD ROCm labels Aug 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused grouped_topk kernel for MoE, which shows a nice performance improvement based on the benchmarks. The implementation looks solid and the integration across the codebase is well-handled. I've found a critical issue related to const correctness in the new CUDA kernel that should be addressed. Otherwise, great work!

@xyang16 xyang16 force-pushed the kernel branch 6 times, most recently from 48cae44 to 8d77bdd Compare August 20, 2025 19:02
Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work!
Could you also test the acc for R1?
lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=deepseek-ai/DeepSeek-R1-0528,num_concurrent=256" --tasks gsm8k

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this hardcoded value?

Copy link
Contributor Author

@xyang16 xyang16 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review!

In my other PR #23123, I have added the routed_scaling_factor to grouped_topk, which is coming from config.routed_scaling_factor. This PR is dependent on that PR.

If that PR (#23123) can get merged first, I will change this hardcoded value to the routed_scaling_factor passed in.

Comment on lines 954 to 958
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What decides this - is it a heuristic or what is implemented?

Copy link
Contributor Author

@xyang16 xyang16 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review!

This is based on what is implemented, see in grouped_topk_kernels.cu:

  TORCH_CHECK(n_group <= 32,
              "n_group should be smaller than or equal to 32 for now");
  TORCH_CHECK(topk <= 32, "topk should be smaller than or equal to 32 for now");

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 22, 2025
@xyang16 xyang16 force-pushed the kernel branch 2 times, most recently from 0333a29 to 10895ce Compare August 22, 2025 06:49
@xyang16
Copy link
Contributor Author

xyang16 commented Aug 22, 2025

Thanks for the work! Could you also test the acc for R1? lm_eval --model local-completions --model_args "base_url=http://127.0.0.1:9256/v1/completions,model=deepseek-ai/DeepSeek-R1-0528,num_concurrent=256" --tasks gsm8k

@yewentao256 Thanks for the review!

I have run the accuracy testing. There's no accuracy drop:

lm_eval --model local-completions \
  --model_args model=deepseek-ai/DeepSeek-R1,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=8 \
  --tasks gsm8k

Baseline:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9538|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

This PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.956|±  |0.0056|
|     |       |strict-match    |     5|exact_match|↑  |0.956|±  |0.0056|

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @xyang16 I think everything checks out to me, thanks for working on this!

One thing we must fix before landing is I think this kernel does not work on AMD and broke that build https://buildkite.com/vllm/ci/builds/28032/steps/canvas?jid=0198d277-447a-409d-857d-047cd38164d2#0198d277-447a-409d-857d-047cd38164d2/30-2862

#12 85.24 FAILED: [code=1] CMakeFiles/_moe_C.dir/csrc/moe/grouped_topk_kernels.hip.o
#12 85.24 /opt/rocm/lib/llvm/bin/clang++  -DPy_LIMITED_API=3 -DTORCH_EXTENSION_NAME=_moe_C -DUSE_C10D_GLOO -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_PROF_API=1 -DUSE_RPC -DUSE_TENSORPIPE -D__HIP_PLATFORM_AMD__ -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1 -D_moe_C_EXPORTS -I/app/vllm/build/temp.linux-x86_64-cpython-312/csrc -isystem /usr/include/python3.12 -isystem /usr/local/lib/python3.12/dist-packages/torch/include -isystem /usr/local/lib/python3.12/dist-packages/torch/include/torch/csrc/api/include -isystem /opt/rocm-6.4.1/include/hiprand -isystem /opt/rocm-6.4.1/include/rocrand -Wno-unused-result -O2 -g -DNDEBUG --offload-arch=gfx90a --offload-arch=gfx942 -fPIC -fPIC -D__HIP_PLATFORM_AMD__=1 -DUSE_ROCM=1 -DHIPBLAS_V2 -DCUDA_HAS_FP16=1 -D__HIP_NO_HALF_OPERATORS__=1 -D__HIP_NO_HALF_CONVERSIONS__=1 -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 -DUSE_ROCM -DENABLE_FP8 -U__HIP_NO_HALF_CONVERSIONS__ -U__HIP_NO_HALF_OPERATORS__ -Werror=unused-variable -fno-gpu-rdc -D_GLIBCXX_USE_CXX11_ABI=1 -DTORCH_HIP_VERSION=604 -Wno-shift-count-negative -Wno-shift-count-overflow -Wno-duplicate-decl-specifier -DCAFFE2_USE_MIOPEN -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP -std=c++17 -DHIPBLASLT_VEC_EXT -DHIP_ENABLE_WARP_SYNC_BUILTINS -MD -MT CMakeFiles/_moe_C.dir/csrc/moe/grouped_topk_kernels.hip.o -MF CMakeFiles/_moe_C.dir/csrc/moe/grouped_topk_kernels.hip.o.d -o CMakeFiles/_moe_C.dir/csrc/moe/grouped_topk_kernels.hip.o -x hip -c /app/vllm/build/temp.linux-x86_64-cpython-312/csrc/moe/grouped_topk_kernels.hip
#12 85.24 /app/vllm/build/temp.linux-x86_64-cpython-312/csrc/moe/grouped_topk_kernels.hip:27:10: fatal error: 'cooperative_groups/reduce.h' file not found
#12 85.24    27 | #include <cooperative_groups/reduce.h>
#12 85.24       |          ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#12 85.24 1 error generated when compiling for gfx90a.

CMakeLists.txt Outdated
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this kernel to the CUDA section below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I have moved this. Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit/future work: maybe we should default routed_scaling_factor to None so we can skip the multiply if it isn't set

Comment on lines 956 to 957
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should include current_platform.is_cuda()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have included current_platform.is_cuda(). Thanks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is enable_fused set? I don't see any changes other than the test. Do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's for test.

Copy link
Contributor Author

@xyang16 xyang16 Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I was wanting to make grouped_topk and fused_grouped_topk separate, but rocm is only calling single grouped_topk (https://github.com/vllm-project/vllm/blob/v0.10.1.1/vllm/model_executor/layers/fused_moe/layer.py#L55-L61), so I put fused_grouped_topk inside grouped_topk to make the change minimal.

Thus I need a flag to distinguish grouped_topk and fused_grouped_topk in test.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is just for testing, I would personally just use an environment variable. We have quite a few TEST env vars

"VLLM_TEST_FORCE_FP8_MARLIN":

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added VLLM_USE_FUSED_MOE_GROUPED_TOPK env. Thanks.

Comment on lines 9 to 12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the #ifndef USE_ROCM section

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved. Thanks.

Comment on lines 11 to 16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move to the #ifndef USE_ROCM section

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved. Thanks.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Some additional thoughts

@mgoin
Copy link
Member

mgoin commented Aug 22, 2025

@yewentao256 I don't think it is worthwhile to review/change the implementation copied from trtllm. Usually when we copy vendor code, we tend to leave it as-is so if we need to resync changes it is easy. From that perspective, I basically ignored that section of code

Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16
Copy link
Contributor Author

xyang16 commented Aug 23, 2025

@yewentao256 I don't think it is worthwhile to review/change the implementation copied from trtllm. Usually when we copy vendor code, we tend to leave it as-is so if we need to resync changes it is easy. From that perspective, I basically ignored that section of code

Make sense, I thought this was implemented from scratch.

@xyang16 Could you add description in the PR saying this was sync from TRTLLM? And also add the comments in the .cu file as well.

@yewentao256 I have added the description in the PR. Thanks.

@xyang16
Copy link
Contributor Author

xyang16 commented Aug 25, 2025

FYI the CI failed, but I think it's unrelated to this PR.

Copy link
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the iterations here. Will carefully look over CI to see if all failures are known.

@simon-mo simon-mo merged commit 8a3cd90 into vllm-project:main Aug 25, 2025
65 of 71 checks passed
@xyang16
Copy link
Contributor Author

xyang16 commented Aug 25, 2025

@mgoin @yewentao256 Thanks a lot! Could you please also review #23123? This is to read routed_scaling_factor from model config and pass to grouped_topk.

@huydhn huydhn mentioned this pull request Aug 26, 2025
10 tasks
tc-mb pushed a commit to tc-mb/vllm that referenced this pull request Aug 27, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: tc-mb <caitianchi@modelbest.cn>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
@xyang16 xyang16 deleted the kernel branch August 31, 2025 01:54
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Xin Yang <xyangx@amazon.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants