Skip to content

[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic#29873

Merged
vllm-bot merged 15 commits intovllm-project:mainfrom
shen-shanshan:rope
Dec 16, 2025
Merged

[CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic#29873
vllm-bot merged 15 commits intovllm-project:mainfrom
shen-shanshan:rope

Conversation

@shen-shanshan
Copy link
Copy Markdown
Contributor

@shen-shanshan shen-shanshan commented Dec 2, 2025

Purpose

  1. In some modeling files, there are direct calling of apply_rotary_emb function by using pre-computed cos/sin cache, like: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_5_vl.py#L383-L385. This is just a function, not an operation, and cannot be overwritten by some plugins (e.g., vllm-ascend). By extracting it as an CustomOp, we can just extend this class and implement our forward_oot() function, after which we register it to easily replace this op with other op of OOT device.
  2. There are several dispatch function for apply_rotary_emb which may make users or developers confused, such as https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py#L56-L70 and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py#L73-L97. This PR has unified these separate dispatch logic into one CustomOp to make it clearer.
  3. There are also some redundant definitions of some functions, such as rotate_half() and apply_rotary_emb_torch(). This PR has removed these replicated definitions in different modeling files and just remained one replica in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py.

Recent updates:

  1. Move find_spec("flash_attn") to __init__() of ApplyRotaryEmb.
  2. Add enable_fp32_compute param to ApplyRotaryEmb to determine whether convert input to float and recover it after computation inside the OP.
  3. Correct forward dispatch of ApplyRotaryEmb in XxxRope. Make sure forward_native() -> forward_native(), forward_cuda() / forward_hip() -> forward().
  4. Support object-level enable for ApplyRotaryEmb in ViT part to avoid extra configs when lauching the server. Find more details at [CustomOp] Support object-level enable for CustomOp #30547.

Test Plan

  • Functional test on Ascend A2 NPU.
  • Functional test on NVIDIA A100 GPU.
  • Benchmark on NVIDIA A100 GPU.
  • Add UT to make sure the dispatching logic is correct.

Test Result

✅ Functional test on Ascend A2 NPU

I have tested this PR on Ascend A2 NPU together with vllm-project/vllm-ascend#4667.

Run:

vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384 \
--max-num-batched-tokens 16384 \
--tensor-parallel-size 2 \
--enforce-eager

Output:

{"id":"chatcmpl-9ab4de23690c85aa","object":"chat.completion","created":1764748509,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the image reads \"TONGYI Qwen.\" The word \"TONGYI\" is written in blue, and \"Qwen\" is written in gray. The font appears to be modern and clean, with \"TONGYI\" being slightly larger than \"Qwen.\" The design includes a geometric, abstract shape on the left side of the logo, which complements the text.","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":162,"completion_tokens":84,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}

✅ Functional test on NVIDIA A100 GPU

Run:

python examples/offline_inference/vision_language.py -m dots_ocr

Output:

--------------------------------------------------
 <doc> Tokyo Tower is seen through the pink cherry blossoms in central Tokyo. The city is in full bloom in April.</doc>
--------------------------------------------------
 <doc> Tokyo Tower is seen through the pink cherry blossoms in central Tokyo. The city is in full bloom in March.</doc>
--------------------------------------------------
 <doc> Tokyo Tower is seen through the pink cherry blossoms in central Tokyo. The city is in full bloom in March.</doc>
--------------------------------------------------
 <doc> Tokyo Skytree is seen through aarray of pink cherry blossoms in this April 2016 photo. (AP Photo/Kyodo, File)AP Photo/Kyodo, File) </doc>
--------------------------------------------------

✅ Benchmark on NVIDIA A100 GPU

Thanks for @gcanlin doing this benchmark on NVIDIA A100 GPU and adding UT for testing dispatching logic.

Run:

vllm serve rednote-hilab/dots.ocr \
--trust-remote-code

vllm bench serve \
--backend openai-chat \
--model rednote-hilab/dots.ocr \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path lmarena-ai/VisionArena-Chat \
--hf-split train \
--num-prompts 1000

Before this PR:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  102.22
Total input tokens:                      94327
Total generated tokens:                  48876
Request throughput (req/s):              9.78
Output token throughput (tok/s):         478.14
Peak output token throughput (tok/s):    4560.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          1400.92
---------------Time to First Token----------------
Mean TTFT (ms):                          47767.79
Median TTFT (ms):                        42990.88
P99 TTFT (ms):                           100079.91
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          221.51
Median TPOT (ms):                        231.44
P99 TPOT (ms):                           427.04
---------------Inter-token Latency----------------
Mean ITL (ms):                           229.47
Median ITL (ms):                         223.69
P99 ITL (ms):                            1545.02
==================================================

After this PR:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  101.26
Total input tokens:                      94327
Total generated tokens:                  48775
Request throughput (req/s):              9.88
Output token throughput (tok/s):         481.70
Peak output token throughput (tok/s):    4436.00
Peak concurrent requests:                1000.00
Total Token throughput (tok/s):          1413.27
---------------Time to First Token----------------
Mean TTFT (ms):                          47802.28
Median TTFT (ms):                        42908.37
P99 TTFT (ms):                           99437.50
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          219.32
Median TPOT (ms):                        230.00
P99 TPOT (ms):                           430.33
---------------Inter-token Latency----------------
Mean ITL (ms):                           225.50
Median ITL (ms):                         219.57
P99 ITL (ms):                            1602.67
==================================================

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: shen-shanshan 467638484@qq.com
Co-authored-by: gcanlin canlinguosdu@gmail.com

@mergify mergify bot added the qwen Related to Qwen models label Dec 2, 2025
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm/model_executor/layers/rotary_embedding/common.py
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully refactors the apply_rotary_emb functionality into a CustomOp class, unifying dispatch logic and removing redundant definitions across several files. This is a positive step towards better modularity and extensibility. However, there are a few critical issues that need to be addressed to ensure correctness and prevent runtime errors.

Comment thread vllm/model_executor/layers/rotary_embedding/common.py
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this necessary - there's already a RotaryEmbedding custom op class?

@shen-shanshan
Copy link
Copy Markdown
Contributor Author

Why is this necessary - there's already a RotaryEmbedding custom op class?

@ProExpertProg

  1. In some modeling files, there are direct calling of apply_rotary_emb function by using pre-computed cos/sin cache, like: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_5_vl.py#L383-L385. This is just a function, not an operation, and cannot be overwritten by some plugins (e.g., vllm-ascend). By extracting it as an CustomOp, we can just extend this class and implement our forward_oot() function, after which we register it to easily replace this op with other op of OOT device.
  2. There are several dispatch function for apply_rotary_emb which may make users or developers confused, such as https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py#L56-L70 and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py#L73-L97. This PR has unified these separate dispatch logic into one CustomOp to make it clearer.
  3. There are also some redundant definitions of some functions, such as rotate_half() and apply_rotary_emb_torch(). This PR has removed these replicated definitions in different modeling files and just remained one replica in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding/common.py.

@shen-shanshan
Copy link
Copy Markdown
Contributor Author

shen-shanshan commented Dec 3, 2025

Why is this necessary - there's already a RotaryEmbedding custom op class?

I have also tested this PR together with vllm-project/vllm-ascend#4667 on Ascend NPU.
Maybe we need to add a ready label to see whether this PR will break other backend, such as Cuda, Rocm, ...

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

@ProExpertProg - mind following up on this?

Copy link
Copy Markdown
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice dispatch cleanup, approving so that it's not blocked while I'm gone for the next 2 weeks but please address comments!!

Comment thread vllm/model_executor/layers/rotary_embedding/base.py
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 5, 2025
@shen-shanshan
Copy link
Copy Markdown
Contributor Author

shen-shanshan commented Dec 6, 2025

Nice dispatch cleanup, approving so that it's not blocked while I'm gone for the next 2 weeks but please address comments!!

Really thanks for your review. I will address the comments and fix CI errors recently.

@shen-shanshan shen-shanshan force-pushed the rope branch 2 times, most recently from c18f8a6 to 15797ae Compare December 9, 2025 07:19
@shen-shanshan shen-shanshan changed the title [CustomOp] Extract apply_rotary_emb as CustomOp and unify the dispatch logic [CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic Dec 9, 2025
Comment thread vllm/model_executor/layers/rotary_embedding/common.py Outdated
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 10, 2025

@shen-shanshan Can you check if there are performance regression. When you change to use customops, it will by default run the forward_native. This is called in the ViT which is not traced in torch.compile and cudagraph. Last test it is faster to use the forward_cuda and forward_hip function where we use the triton kernels.

CC @DarkLight1337

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 10, 2025

@DarkLight1337 @ProExpertProg
I am not quite sure if we want this small function to be a custom op. It makes user experience really clunky. Now we need to set

--compilation_config.custom_ops '["+apply_rotary_emb"]'

Comment thread vllm/model_executor/models/dots_ocr.py
Comment thread vllm/model_executor/layers/rotary_embedding/mrope.py Outdated
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 10, 2025

Last time I optimized the ViT rotary embedding before so the best config that we want is

  • RotaryEmbedding class (from LLM) uses forward_native + torch compile
  • ViT part we use forward_cuda/forward_hip (with triton kernels) as we don't have torch compile.

I tested the the new apply_rotary_emb custom ops. It is affecting the RotaryEmbedding as well.

vllm serve Qwen/Qwen2.5-VL-7B-Instruct \
 -tp 1 \
 --trust_remote_code \
 --compilation_config.custom_ops '["+apply_rotary_emb"]' \
> server_Qwen_Qwen2.5-VL-7B-Instruct-aiter-v1-customops-2.log 2>&1

I am getting error on ROCm.

�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]    File "/app/reviewrope/rope/vllm/model_executor/models/qwen2.py", line 413, in forward
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     hidden_states, residual = layer(positions, hidden_states, residual)
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/models/qwen2.py", line 267, in forward
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     hidden_states = self.self_attn(
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/models/qwen2.py", line 201, in forward
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     q, k = self.rotary_emb(positions, q, k)
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/custom_op.py", line 46, in forward
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     return self._forward_method(*args, **kwargs)
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/layers/rotary_embedding/mrope.py", line 303, in forward_native
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     query_rot = self.apply_rotary_emb(
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/custom_op.py", line 46, in forward
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     return self._forward_method(*args, **kwargs)
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]   File "/app/reviewrope/rope/vllm/model_executor/layers/rotary_embedding/common.py", line 203, in forward_hip
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867]     if find_spec("flash_attn") is not None:
�[0;36m(EngineCore_DP0 pid=9298)�[0;0m ERROR 12-10 09:11:24 [core.py:867] 

In the mrope despite calling its own forward_native, mrope is using apply_rotary_emb.hip_forward. Not the behaviour we want.

Let's make sure that the RotaryEmbedding classes in vllm/model_executor/layers/rotary_embedding/

  • RotaryEmbedding.forward_native() uses ApplyRotaryEmb.forward_native
  • RotaryEmbedding.forward_cuda() uses ApplyRotaryEmb.forward() (not ApplyRotaryEmb.forward_cuda()) NOTE: there are cases where we don't have forward_hip and forward_hip uses forward_cuda as fallback, so we let ApplyRotaryEmb.forward() to automatically pick the platform's forward() function.
  • RotaryEmbedding.forward_hip() uses ApplyRotaryEmb.forward() (not ApplyRotaryEmb.forward_hip())

Related comments https://github.com/vllm-project/vllm/pull/29873/files#r2605818486

CC @DarkLight1337

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is breaking on ROCm, and I think it also breaks CUDA when we enable custom_ops +apply_rotary_emb

The bug is not captured because we don't have tests for this new custom op, even on CUDA.

@shen-shanshan can you add unit tests? Especially a small model's unit tests.

@shen-shanshan
Copy link
Copy Markdown
Contributor Author

@DarkLight1337 @ProExpertProg I am not quite sure if we want this small function to be a custom op. It makes user experience really clunky. Now we need to set

--compilation_config.custom_ops '["+apply_rotary_emb"]'

Because we want to replace some plugin-device ops here, such as https://github.com/vllm-project/vllm-ascend/pull/4667/changes#diff-2d5d0eeaa29da1fddb16a845a92692c9277676ed31cf60604eccbc3bb0ed764aR441-R480.

@shen-shanshan
Copy link
Copy Markdown
Contributor Author

It is breaking on ROCm, and I think it also breaks CUDA when we enable custom_ops +apply_rotary_emb

The bug is not captured because we don't have tests for this new custom op, even on CUDA.

@shen-shanshan can you add unit tests? Especially a small model's unit tests.

Ok, I will add related UT later.

@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 10, 2025

@shen-shanshan Thanks for going through my long reviews. Ping me when you are ready.

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Dec 11, 2025

Documentation preview: https://vllm--29873.org.readthedocs.build/en/29873/

@mergify mergify bot added the documentation Improvements or additions to documentation label Dec 11, 2025
@shen-shanshan shen-shanshan requested a review from mgoin as a code owner December 12, 2025 02:37
shen-shanshan and others added 3 commits December 15, 2025 03:29
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@tjtanaa
Copy link
Copy Markdown
Collaborator

tjtanaa commented Dec 15, 2025

Performance benchmark on A100

Command

vllm serve rednote-hilab/dots.ocr  --trust-remote-code
vllm bench serve --backend openai-chat --model rednote-hilab/dots.ocr --endpoint /v1/chat/completions --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --hf-split train --num-prompts 1000

Result

Before:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  102.22
Total input tokens:                      94327
Total generated tokens:                  48876
Request throughput (req/s):              9.78
Output token throughput (tok/s):         478.14
Peak output token throughput (tok/s):    4560.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          1400.92
---------------Time to First Token----------------
Mean TTFT (ms):                          47767.79
Median TTFT (ms):                        42990.88
P99 TTFT (ms):                           100079.91
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          221.51
Median TPOT (ms):                        231.44
P99 TPOT (ms):                           427.04
---------------Inter-token Latency----------------
Mean ITL (ms):                           229.47
Median ITL (ms):                         223.69
P99 ITL (ms):                            1545.02
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  101.26
Total input tokens:                      94327
Total generated tokens:                  48775
Request throughput (req/s):              9.88
Output token throughput (tok/s):         481.70
Peak output token throughput (tok/s):    4436.00
Peak concurrent requests:                1000.00
Total Token throughput (tok/s):          1413.27
---------------Time to First Token----------------
Mean TTFT (ms):                          47802.28
Median TTFT (ms):                        42908.37
P99 TTFT (ms):                           99437.50
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          219.32
Median TPOT (ms):                        230.00
P99 TPOT (ms):                           430.33
---------------Inter-token Latency----------------
Mean ITL (ms):                           225.50
Median ITL (ms):                         219.57
P99 ITL (ms):                            1602.67
==================================================

Does this improvement comes from using enforce custom VIT? or it is using forward_native?

@tjtanaa tjtanaa added the rocm Related to AMD ROCm label Dec 15, 2025
@Isotr0py Isotr0py added this to the v0.13.0 milestone Dec 15, 2025
@gcanlin
Copy link
Copy Markdown
Contributor

gcanlin commented Dec 15, 2025

Does this improvement comes from using enforce custom VIT? or it is using forward_native?

I think that it's just because of the random deviation? Let me test again based on the newest commit.

@gcanlin
Copy link
Copy Markdown
Contributor

gcanlin commented Dec 15, 2025

@tjtanaa I tested again with the default parameters and the same command. Not sure what makes the performance better. Now the behavior of forward should be in our expectation. Maybe the performance difference is within the noise of the benchmark.

vllm serve rednote-hilab/dots.ocr  --trust-remote-code
vllm bench serve --backend openai-chat --model rednote-hilab/dots.ocr --endpoint /v1/chat/completions --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --hf-split train --num-prompts 1000

Before:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  114.76
Total input tokens:                      94327
Total generated tokens:                  48697
Request throughput (req/s):              8.71
Output token throughput (tok/s):         424.34
Peak output token throughput (tok/s):    4733.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          1246.29
---------------Time to First Token----------------
Mean TTFT (ms):                          60138.18
Median TTFT (ms):                        57816.07
P99 TTFT (ms):                           112421.83
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          237.48
Median TPOT (ms):                        226.24
P99 TPOT (ms):                           648.82
---------------Inter-token Latency----------------
Mean ITL (ms):                           222.26
Median ITL (ms):                         235.03
P99 ITL (ms):                            1282.04
==================================================

After:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  109.97
Total input tokens:                      94327
Total generated tokens:                  48493
Request throughput (req/s):              9.09
Output token throughput (tok/s):         440.98
Peak output token throughput (tok/s):    4878.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          1298.75
---------------Time to First Token----------------
Mean TTFT (ms):                          57628.96
Median TTFT (ms):                        53848.18
P99 TTFT (ms):                           107828.06
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          230.14
Median TPOT (ms):                        221.27
P99 TPOT (ms):                           576.06
---------------Inter-token Latency----------------
Mean ITL (ms):                           215.87
Median ITL (ms):                         214.18
P99 ITL (ms):                            1367.58
==================================================

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@shen-shanshan
Copy link
Copy Markdown
Contributor Author

The amd CI has failed seems something unrelated to this PR.

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice work!

@vllm-bot vllm-bot merged commit 3bd9c49 into vllm-project:main Dec 16, 2025
56 of 57 checks passed
khluu pushed a commit that referenced this pull request Dec 17, 2025
…logic (#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
(cherry picked from commit 3bd9c49)
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Dec 23, 2025
…patch (#4667)

### What this PR does / why we need it?

Following vllm-project/vllm#29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

#### ✅ Test Qwen2.5-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ  ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio":   null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok    en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```

#### ✅ Test Qwen3-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Majid-Taheri pushed a commit to Majid-Taheri/vllm that referenced this pull request Dec 23, 2025
…logic (vllm-project#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
Signed-off-by: Ubuntu <mjtaheri68@gmail.com>
@DarkLight1337 DarkLight1337 mentioned this pull request Jan 6, 2026
5 tasks
fort726 pushed a commit to fort726/vllm that referenced this pull request Jan 6, 2026
…logic (vllm-project#29873)

Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…patch (vllm-project#4667)

### What this PR does / why we need it?

Following vllm-project/vllm#29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

#### ✅ Test Qwen2.5-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ  ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio":   null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok    en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```

#### ✅ Test Qwen3-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…patch (vllm-project#4667)

### What this PR does / why we need it?

Following vllm-project/vllm#29873, register
`AscendApplyRotaryEmb` CustomOp and remove related patch.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

#### ✅ Test Qwen2.5-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-b02c1ff3415d2462","object":"chat.completion","created":1766129265,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen2.5-VL-7B-In struct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is \"TONGYI Qwen.\" The word \"TONGYI\" is writ  ten in blue, and \"Qwen\" is written in gray. The text appears to be part of a logo or branding design.","refusal":null,"annotations":null,"audio":   null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"stop","stop_reason":null,"tok    en_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":78,"total_tokens":129,"completion_tokens":51,"prompt_tokens_d
```

#### ✅ Test Qwen3-VL

Run:

```bash
vllm serve /root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct \
--max_model_len 16384
```

Output:

```
{"id":"chatcmpl-a3a7de5a900a9321","object":"chat.completion","created":1766129586,"model":"/root/.cache/modelscope/hub/models/Qwen/Qwen3-VL-8B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"The text in the illustration is **“TONGYI Qwen”**.\n\n### How it looks:\n- **“TONGYI”** is written in **uppercase letters** in a **bold, modern sans-serif font**, colored **blue**.\n- **“Qwen”** is written in **lowercase letters** in a **slightly thinner, elegant sans-serif font**, colored **dark gray**.\n- The two lines of text are stacked vertically, with “TONG","refusal":null,"annotations":null,"audio":null,"function_call":null,"tool_calls":[],"reasoning":null,"reasoning_content":null},"logprobs":null,"finish_reason":"length","stop_reason":null,"token_ids":null}],"service_tier":null,"system_fingerprint":null,"usage":{"prompt_tokens":112,"total_tokens":212,"completion_tokens":100,"prompt_tokens_details":null},"prompt_logprobs":null,"prompt_token_ids":null,"kv_transfer_params":null}
```

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants