Skip to content

[Perf] Fuse stride preparation for NVFP4 cutlass_moe#31837

Merged
robertgshaw2-redhat merged 1 commit intovllm-project:mainfrom
neuralmagic:fuse-stride-prep-nvfp4-cutlass-moe
Jan 7, 2026
Merged

[Perf] Fuse stride preparation for NVFP4 cutlass_moe#31837
robertgshaw2-redhat merged 1 commit intovllm-project:mainfrom
neuralmagic:fuse-stride-prep-nvfp4-cutlass-moe

Conversation

@mgoin
Copy link
Member

@mgoin mgoin commented Jan 6, 2026

Purpose

When profiling another PR, I noticed that there were always 3 vectorized_elementwise_kernel kernels when calling the ops.cutlass_fp4_moe_mm function and found these torch::full constructors. The kernel __get_group_gemm_starts already runs once per-expert to set up pointers. We can add stride initialization there instead of launching 3 separate torch::full kernels. This improves latency by ~5% and affects throughput by a smaller amount.

Before (based on #31832 already fusing silu_and_mul):
Screenshot 2026-01-06 at 4 00 48 PM

After:
Screenshot 2026-01-06 at 4 37 21 PM

Test Plan

Test Result

Latency Benchmark

vllm serve nvidia/Qwen3-30B-A3B-NVFP4
vllm bench serve --input-len 100 --output-len 100 --num-prompts 8

# MAIN
============ Serving Benchmark Result ============
Successful requests:                     8         
Failed requests:                         0         
Benchmark duration (s):                  0.78      
Total input tokens:                      800       
Total generated tokens:                  800       
Request throughput (req/s):              10.25     
Output token throughput (tok/s):         1025.21   
Peak output token throughput (tok/s):    800.00    
Peak concurrent requests:                8.00      
Total token throughput (tok/s):          2050.42   
---------------Time to First Token----------------
Mean TTFT (ms):                          31.20     
Median TTFT (ms):                        32.42     
P99 TTFT (ms):                           34.46     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.53      
Median TPOT (ms):                        7.53      
P99 TPOT (ms):                           7.53      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.53      
Median ITL (ms):                         7.55      
P99 ITL (ms):                            8.30      
==================================================

# PR
============ Serving Benchmark Result ============
Successful requests:                     8         
Failed requests:                         0         
Benchmark duration (s):                  0.74      
Total input tokens:                      800       
Total generated tokens:                  800       
Request throughput (req/s):              10.79     
Output token throughput (tok/s):         1079.23   
Peak output token throughput (tok/s):    800.00    
Peak concurrent requests:                8.00      
Total token throughput (tok/s):          2158.45   
---------------Time to First Token----------------
Mean TTFT (ms):                          31.98     
Median TTFT (ms):                        32.33     
P99 TTFT (ms):                           33.39     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.14      
Median TPOT (ms):                        7.15      
P99 TPOT (ms):                           7.15      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         7.14      
P99 ITL (ms):                            7.85      
==================================================

Throughput Benchmark

vllm serve nvidia/Qwen3-30B-A3B-NVFP4
vllm bench serve --input-len 100 --output-len 100 --num-prompts 512

# MAIN
============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Benchmark duration (s):                  2.55      
Total input tokens:                      51200     
Total generated tokens:                  51200     
Request throughput (req/s):              200.45    
Output token throughput (tok/s):         20044.52  
Peak output token throughput (tok/s):    32996.00  
Peak concurrent requests:                512.00    
Total token throughput (tok/s):          40089.04  
---------------Time to First Token----------------
Mean TTFT (ms):                          776.44    
Median TTFT (ms):                        759.85    
P99 TTFT (ms):                           1005.40   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.20     
Median TPOT (ms):                        17.63     
P99 TPOT (ms):                           19.88     
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.24     
Median ITL (ms):                         15.43     
P99 ITL (ms):                            100.69    
==================================================

# PR
============ Serving Benchmark Result ============
Successful requests:                     512       
Failed requests:                         0         
Benchmark duration (s):                  2.54      
Total input tokens:                      51200     
Total generated tokens:                  51200     
Request throughput (req/s):              201.89    
Output token throughput (tok/s):         20188.84  
Peak output token throughput (tok/s):    33901.00  
Peak concurrent requests:                512.00    
Total token throughput (tok/s):          40377.68  
---------------Time to First Token----------------
Mean TTFT (ms):                          768.59    
Median TTFT (ms):                        806.81    
P99 TTFT (ms):                           975.17    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          17.12     
Median TPOT (ms):                        16.75     
P99 TPOT (ms):                           19.28     
---------------Inter-token Latency----------------
Mean ITL (ms):                           17.13     
Median ITL (ms):                         15.21     
P99 ITL (ms):                            99.73     
==================================================

Eval

vllm serve nvidia/Qwen3-30B-A3B-NVFP4
python tests/evals/gsm8k/gsm8k_eval.py --port 8000

# MAIN
Results:
Accuracy: 0.883
Invalid responses: 0.000
Total latency: 20.371 s
Questions per second: 64.748
Total output tokens: 153967
Output tokens per second: 7558.026

# PR
Results:
Accuracy: 0.886
Invalid responses: 0.001
Total latency: 19.811 s
Questions per second: 66.578
Total output tokens: 153574
Output tokens per second: 7751.853

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: mgoin <mgoin64@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for the nvfp4 CUTLASS MoE kernels by fusing the stride tensor initialization. Instead of using torch::full to create and initialize stride tensors, which launches separate CUDA kernels, this change allocates uninitialized tensors with torch::empty and performs the initialization inside the existing __get_group_gemm_starts kernel. This effectively reduces kernel launch overhead, leading to the performance improvements shown in the benchmarks. The changes are applied consistently for both sm100 and sm120 architectures. The implementation is correct and the optimization is a good practice for CUDA programming. I have no further comments.

@mgoin mgoin changed the title [Perf] Fuse stride preparation for nvfp4 cutlass moe [Perf] Fuse stride preparation for NVFP4 cutlass_moe Jan 6, 2026
@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Jan 6, 2026
Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! We may be able to further reduce these by initializing in post processing AOT because the values remain constant and are deterministic.

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Jan 7, 2026
@robertgshaw2-redhat robertgshaw2-redhat merged commit f347ac6 into vllm-project:main Jan 7, 2026
102 checks passed
@robertgshaw2-redhat
Copy link
Collaborator

nice work

@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Jan 7, 2026
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Jan 9, 2026
akh64bit pushed a commit to akh64bit/vllm that referenced this pull request Jan 16, 2026
dsuhinin pushed a commit to dsuhinin/vllm that referenced this pull request Jan 21, 2026
)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: dsuhinin <suhinin.dmitriy@gmail.com>
ItzDEXX pushed a commit to ItzDEXX/vllm that referenced this pull request Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants