Skip to content

[kernel][perf] support uncontiguous input for rms_norm kernel#28103

Merged
vllm-bot merged 14 commits intovllm-project:mainfrom
izhuhaoran:rmsnorm-noncontiguous-input
Nov 21, 2025
Merged

[kernel][perf] support uncontiguous input for rms_norm kernel#28103
vllm-bot merged 14 commits intovllm-project:mainfrom
izhuhaoran:rmsnorm-noncontiguous-input

Conversation

@izhuhaoran
Copy link
Copy Markdown
Contributor

@izhuhaoran izhuhaoran commented Nov 5, 2025

Purpose

currently, main branch has a todo:

vllm/vllm/_custom_ops.py

Lines 331 to 332 in 14a125a

# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
input_contiguous = input.contiguous()

As titiled, this pr support uncontiguous input for norm kernel and solve this todo, which is introduced by #17735 .

Previously, the RMS norm kernel required .contiguous() because q/k tensors used in qk-norm are sliced from qkv via .split() and then reshaped to [num_tokens, num_heads, head_dim]. This results in non-contiguous tensors where the first dimension has qkv's original stride. The original kernel used input.view({-1, hidden_size}), which fails or produces incorrect results for such tensors.

This PR extends the kernel to accept explicit stride information and supports both 2D and 3D non-contiguous inputs (with the last dimension required to be contiguous).

BTW, this PR should be merged after #27165

Test Result

timeline profile trace

  • Main
image
  • This PR
image

bench serve

setting: qwen3-0.6b, tp1, num_requests=32, max_concurrency=8, in_len=out_len=1024
result: TTFT from 86.76ms to 85.73ms, TPOT from 3.46ms to 3.31ms

  • Main
============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  14.49     
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              2.21      
Output token throughput (tok/s):         2261.23   
Peak output token throughput (tok/s):    2472.00   
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          4522.47   
---------------Time to First Token----------------
Mean TTFT (ms):                          86.76     
Median TTFT (ms):                        95.76     
P99 TTFT (ms):                           100.81    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          3.46      
Median TPOT (ms):                        3.45      
P99 TPOT (ms):                           3.52      
---------------Inter-token Latency----------------
Mean ITL (ms):                           3.46      
Median ITL (ms):                         3.48      
P99 ITL (ms):                            4.00      
----------------End-to-end Latency----------------
Mean E2EL (ms):                          3621.38   
Median E2EL (ms):                        3618.82   
P99 E2EL (ms):                           3634.57   
==================================================
  • This PR
============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  13.91     
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              2.30      
Output token throughput (tok/s):         2356.32   
Peak output token throughput (tok/s):    2528.00   
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          4712.63   
---------------Time to First Token----------------
Mean TTFT (ms):                          85.73     
Median TTFT (ms):                        93.82     
P99 TTFT (ms):                           97.15     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          3.31      
Median TPOT (ms):                        3.31      
P99 TPOT (ms):                           3.38      
---------------Inter-token Latency----------------
Mean ITL (ms):                           3.31      
Median ITL (ms):                         3.33      
P99 ITL (ms):                            3.87      
----------------End-to-end Latency----------------
Mean E2EL (ms):                          3475.61   
Median E2EL (ms):                        3476.51   
P99 E2EL (ms):                           3485.74   
==================================================

lm_eval

lm_eval --model local-completions --tasks gsm8k --batch_size 128 --model_args model=/mnt/data/nas/zhr/models/Qwen3-0.6B,base_url=http://localhost:8000/v1/completions,max_retries=3

  • Main
local-completions (model=/mnt/data/nas/zhr/models/Qwen3-0.6B,base_url=http://localhost:8000/v1/completions,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4071|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.4094|±  |0.0135|
  • This PR
local-completions (model=/mnt/data/nas/zhr/models/Qwen3-0.6B,base_url=http://localhost:8000/v1/completions,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4086|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.4124|±  |0.0136|

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for non-contiguous inputs to the rms_norm kernel, which resolves a TODO and improves performance by avoiding an explicit .contiguous() call. The changes in the CUDA kernel to handle 2D and 3D tensors with explicit strides are well-implemented. My review includes one suggestion to refactor duplicated code in the C++ dispatcher function to improve maintainability.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for non-contiguous inputs to the RMS norm kernel, which removes a .contiguous() call and provides a performance improvement. The changes in the CUDA kernel to handle 2D and 3D non-contiguous tensors using explicit strides are well-implemented.

However, I've identified a critical issue where the output tensor out is not guaranteed to be contiguous, which will cause a runtime failure in the C++ kernel. I've left a specific comment with details on how to address this. Once that is fixed, this PR should be in a great shape.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully adds support for non-contiguous inputs to the RMS norm kernel, which removes a .contiguous() call and provides a performance improvement as shown in the benchmarks. The changes in the CUDA kernel correctly handle both 2D and 3D non-contiguous tensors by using explicit stride information. The Python and C++ wrapper code is updated accordingly. My feedback includes one suggestion to refactor the C++ code to improve maintainability by reducing code duplication.

@izhuhaoran izhuhaoran changed the title [kernel][perf] support uncontiguous input for norm kernel [kernel][perf] support uncontiguous input for rms_norm kernel Nov 5, 2025
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@izhuhaoran
Copy link
Copy Markdown
Contributor Author

@ProExpertProg , would you please take a look when you have time ?

…-input

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@izhuhaoran
Copy link
Copy Markdown
Contributor Author

@ProExpertProg I think this PR is ready for review. Would you please have a look?

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one nit vis-a-vis dispatching

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Nov 19, 2025
@ProExpertProg
Copy link
Copy Markdown
Collaborator

cc @yewentao256

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also test for a larger model, eg. R1 using lm_eval and vllm bench as well?

@izhuhaoran
Copy link
Copy Markdown
Contributor Author

izhuhaoran commented Nov 20, 2025

Could you also test for a larger model, eg. R1 using lm_eval and vllm bench as well?

Actually, could we test Qwen3-235b-fp8 instead? R1-fp8 is too large for my current hardware and would result in an OOM.
Here are the test results of Qwen3-235b-fp8:

bench serve

  • Main
============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  327.34    
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              0.10      
Output token throughput (tok/s):         100.10    
Peak output token throughput (tok/s):    104.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          200.21    
---------------Time to First Token----------------
Mean TTFT (ms):                          670.19    
Median TTFT (ms):                        730.73    
P99 TTFT (ms):                           746.87    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          79.34     
Median TPOT (ms):                        79.32     
P99 TPOT (ms):                           79.81     
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.34     
Median ITL (ms):                         79.25     
P99 ITL (ms):                            80.00     
----------------End-to-end Latency----------------
Mean E2EL (ms):                          81831.64  
Median E2EL (ms):                        81859.35  
P99 E2EL (ms):                           81888.10  
==================================================
  • this PR
============ Serving Benchmark Result ============
Successful requests:                     32        
Failed requests:                         0         
Maximum request concurrency:             8         
Benchmark duration (s):                  325.45    
Total input tokens:                      32768     
Total generated tokens:                  32768     
Request throughput (req/s):              0.10      
Output token throughput (tok/s):         100.68    
Peak output token throughput (tok/s):    104.00    
Peak concurrent requests:                16.00     
Total Token throughput (tok/s):          201.37    
---------------Time to First Token----------------
Mean TTFT (ms):                          663.54    
Median TTFT (ms):                        723.32    
P99 TTFT (ms):                           739.82    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          78.88     
Median TPOT (ms):                        78.86     
P99 TPOT (ms):                           79.35     
---------------Inter-token Latency----------------
Mean ITL (ms):                           78.88     
Median ITL (ms):                         78.80     
P99 ITL (ms):                            79.56     
----------------End-to-end Latency----------------
Mean E2EL (ms):                          81360.23  
Median E2EL (ms):                        81384.90  
P99 E2EL (ms):                           81407.92  
==================================================

lm_eval

  • Main
local-completions (model=/mnt/data/nas/models/Qwen3-235B-A22B-Thinking-2507-FP8,base_url=http://localhost:8000/v1/completions,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6732|±  |0.0129|
|     |       |strict-match    |     5|exact_match|↑  |0.6209|±  |0.0134|
  • This PR
local-completions (model=/mnt/data/nas/models/Qwen3-235B-A22B-Thinking-2507-FP8,base_url=http://localhost:8000/v1/completions,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 128
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.6823|±  |0.0128|
|     |       |strict-match    |     5|exact_match|↑  |0.6217|±  |0.0134|

izhuhaoran and others added 4 commits November 20, 2025 11:07
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
… for transformers model_impl

Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
@izhuhaoran izhuhaoran force-pushed the rmsnorm-noncontiguous-input branch from 270c8cc to 5f39dcd Compare November 20, 2025 10:55
@izhuhaoran
Copy link
Copy Markdown
Contributor Author

@yewentao256 I've updated the test results for the larger models and fixed the CI issues—please take a look when you have time. Also cc @ProExpertProg

BTW, there's currently a CI error: "ValueError: No available memory for the cache blocks. Try increasing gpu_memory_utilization when initializing the engine." This appears unrelated to this PR.

Copy link
Copy Markdown
Member

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the work!

@yewentao256 yewentao256 enabled auto-merge (squash) November 20, 2025 16:54
@izhuhaoran
Copy link
Copy Markdown
Contributor Author

@yewentao256 The CI failure is unrelated to this PR. The failing Plamo3 test is also failing on main and should be fixed by #29092

@vllm-bot vllm-bot merged commit a982f5b into vllm-project:main Nov 21, 2025
87 of 89 checks passed
ywang96 pushed a commit to ywang96/vllm that referenced this pull request Nov 23, 2025
…roject#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
RunkaiTao pushed a commit to RunkaiTao/vllm that referenced this pull request Nov 24, 2025
…roject#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
Signed-off-by: Runkai Tao <rt572@physics.rutgers.edu>
devpatelio pushed a commit to SumanthRH/vllm that referenced this pull request Nov 29, 2025
…roject#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
kitaekatt pushed a commit to kitaekatt/vllm that referenced this pull request Dec 1, 2025
…roject#28103)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
jikunshang added a commit to jikunshang/vllm that referenced this pull request Dec 12, 2025
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
yma11 pushed a commit to yma11/vllm that referenced this pull request Jan 6, 2026
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants