Skip to content

[Ascend] perf: optimize rope embedding with triton kernel for huge performance gain#5918

Merged
wangxiyuan merged 4 commits intovllm-project:mainfrom
ZCG12345:main
Jan 21, 2026
Merged

[Ascend] perf: optimize rope embedding with triton kernel for huge performance gain#5918
wangxiyuan merged 4 commits intovllm-project:mainfrom
ZCG12345:main

Conversation

@ZCG12345
Copy link
Copy Markdown
Contributor

@ZCG12345 ZCG12345 commented Jan 15, 2026

What this PR does / why we need it?

  1. Implement a high-performance Triton custom kernel for the rotary position embedding (RoPE) operator on Ascend NPU platform
  2. Fix critical bugs in the Triton RoPE kernel registration and invocation process: including incorrect fake impl function name matching, wrong torch ops namespace for kernel call, missing self parameter in cos/sin slice fetching, and syntax errors in function type annotations.
  3. Achieve extreme performance optimization for the core RoPE operator: the single inference latency is reduced from 57.1 μs to 9 μs, with 6.34x performance improvement and 84.24% latency reduction.
  4. The RoPE operator is a hot path that is executed in every transformer layer during LLM inference, the optimization will directly reduce the overall inference latency and improve the throughput of LLM serving on Ascend NPU.
  5. Keep full backward compatibility: the Triton kernel is enabled only when HAS_TRITON=True, and automatically fall back to the original Ascend NPU native implementation if Triton is not available, no functional regression.

Does this PR introduce any user-facing change?

NO

  • No changes to any public APIs, interfaces or inference behaviors of vLLM.
  • No impact on the text generation quality and correctness of the large model.
  • The optimization is transparent to end users, only the inference speed (latency/throughput) is improved without any functional change.

How was this patch tested?

  1. Environment Validation: Tested on Ascend NPU platform with vLLM-Ascend framework, Triton library installed and enabled (HAS_TRITON=True).
  2. Kernel Registration Test: Verified the Triton RoPE kernel (rope_forward_triton) is successfully registered to torch.ops._C_ascend namespace without any ValueError/NameError/SyntaxError.
  3. Functional Correctness Test: Run large model (GLM4/MoE) inference on the Ascend NPU platform, the generated text content is completely correct (no garbled text, no logical errors), consistent with the original implementation.
  4. Performance Benchmark Test: Measure the single execution latency of the RoPE operator before/after optimization, confirm the latency is stably reduced from 57.1 μs to 9 μs, the performance gain is valid and stable.
  5. Fallback Mechanism Test: Manually disable Triton (HAS_TRITON=False), verify the code correctly falls back to the original Ascend NPU native RoPE implementation, no service crash and normal inference.
  6. Compatibility Test: Test with different tensor shapes/sizes of query/key, all cases work correctly with the Triton kernel, no shape mismatch error.

Signed-off-by: ZCG12345 <2097562023@qq.com>
@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a high-performance Triton kernel for Rotary Position Embedding (RoPE) on Ascend NPUs, which yields a significant performance improvement. The changes include the Triton kernel implementation, its registration as a custom PyTorch op, and integration into the existing RoPE logic. My review focuses on ensuring the correctness and maintainability of the new code. I've identified a potential critical issue regarding an unresolved reference that could lead to a NameError, and some maintainability issues with comments in Chinese that should be translated to English. Overall, the optimization is a great addition, and addressing these points will improve the quality of the contribution.

Comment on lines +380 to +384
direct_register_custom_op(op_name="rope_forward_triton",
op_func=rope_forward_triton,
fake_impl=_rope_forward_triton_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function rope_forward_triton is used here as the op_func for the custom operator, but it is not defined or imported within this file. This will likely result in a NameError when this module is imported. To fix this, you should import it from vllm_ascend.ops.triton.rope at the top of the file.

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
Comment on lines +193 to +194
cos = cos.view(-1, self.rotary_dim)#就是rope_dim 数据转换为2维
sin = sin.view(-1, self.rotary_dim)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment on line 193 is in Chinese. To ensure the codebase is maintainable and accessible to all contributors, please provide comments in English. If the code is self-explanatory, you can remove the comment.

Suggested change
cos = cos.view(-1, self.rotary_dim)#就是rope_dim 数据转换为2维
sin = sin.view(-1, self.rotary_dim)
cos = cos.view(-1, self.rotary_dim)
sin = sin.view(-1, self.rotary_dim)

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
Comment on lines +195 to +197
q = query.contiguous().view(query.shape[0], -1,
self.head_size)#B,N,H,D 批次 tokens数量,头数,头的维度
k = key.contiguous().view(key.shape[0], -1, self.head_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This comment is in Chinese. Please translate it to English for better code maintainability. Additionally, the .view() call appears to be redundant as the query and key tensors should already have the correct 3D shape (num_tokens, num_heads, head_size) at this point. A simple .contiguous() call should be sufficient to ensure the tensors are contiguous in memory before passing them to the Triton kernel.

Suggested change
q = query.contiguous().view(query.shape[0], -1,
self.head_size)#B,N,H,D 批次 tokens数量,头数,头的维度
k = key.contiguous().view(key.shape[0], -1, self.head_size)
q = query.contiguous()
k = key.contiguous()

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
cos, sin = get_cos_and_sin_slice()
if HAS_TRITON:

cos = cos.view(-1, self.rotary_dim)#就是rope_dim 数据转换为2维
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove chinese note

Comment thread .pre-commit-config.yaml Outdated
rev: v1.7.7
hooks:
- id: actionlint
# - repo: https://github.com/rhysd/actionlint
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I Run lint this place need annotation. sorry I fogot release it after the local testing finished

Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
if HAS_TRITON:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will go into triton rope whenever triton is intalled. We need more performance data to support this change.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine, I will change it to torch_npu branch in next commit ,you can check again at that time

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: ZCG12345 <2097562023@qq.com>
self.cos_sin_cache = self.cos_sin_cache.to(query.device)
if self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.dtype)
cos, sin = get_cos_and_sin_slice()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be problems to put this index_select here. Please contact @Angazenn to make sure of this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_cos_and_sin_slice does not take index_select, I think its ok to move it here

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine

@wangxiyuan wangxiyuan added ready read for review ready-for-test start test by label for PR labels Jan 19, 2026
@wangxiyuan wangxiyuan enabled auto-merge (squash) January 19, 2026 08:48
@ZCG12345 ZCG12345 closed this Jan 20, 2026
auto-merge was automatically disabled January 20, 2026 01:22

Pull request was closed

@ZCG12345 ZCG12345 reopened this Jan 20, 2026
@ZCG12345 ZCG12345 closed this Jan 20, 2026
@ZCG12345 ZCG12345 reopened this Jan 20, 2026
@ZCG12345 ZCG12345 marked this pull request as draft January 21, 2026 01:32
@ZCG12345 ZCG12345 marked this pull request as ready for review January 21, 2026 01:32
@wangxiyuan wangxiyuan merged commit 8900e33 into vllm-project:main Jan 21, 2026
20 checks passed
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Jan 22, 2026
…to FIA_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend:
  [CI] Upgrade CANN to 8.5.0 (vllm-project#6070)
  Default enable MLAPO (vllm-project#5952)
  [Doc] Supplement PD separation parameters of DeepSeek V3.1 (vllm-project#6053)
  [Ascend] perf: optimize rope embedding with triton kernel for huge performance gain (vllm-project#5918)
  [Ops] update causal_conv1d_update (vllm-project#5984)
  [CI]Update triton ascend version in 3.2.0 (vllm-project#6067)
  [bugfix] fix the complex and potentially problematic generate_kv_idx. (vllm-project#5957)
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang 
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
starmountain1997 pushed a commit to starmountain1997/vllm-ascend that referenced this pull request Jan 31, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang 
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
@wangxiyuan wangxiyuan mentioned this pull request Feb 24, 2026
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang 
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…rformance gain (vllm-project#5918)

### What this PR does / why we need it?
1. Implement a **high-performance Triton custom kernel** for the rotary
position embedding (RoPE) operator on **Ascend NPU** platform
2. Fix critical bugs in the Triton RoPE kernel registration and
invocation process: including incorrect fake impl function name
matching, wrong torch ops namespace for kernel call, missing self
parameter in cos/sin slice fetching, and syntax errors in function type
annotations.
3. Achieve **extreme performance optimization** for the core RoPE
operator: the single inference latency is reduced from **57.1 μs** to
**9 μs**, with **6.34x performance improvement** and **84.24% latency
reduction**.
4. The RoPE operator is a **hot path** that is executed in every
transformer layer during LLM inference, the optimization will directly
reduce the overall inference latency and improve the throughput of LLM
serving on Ascend NPU.
5. Keep full backward compatibility: the Triton kernel is enabled only
when `HAS_TRITON=True`, and automatically fall back to the original
Ascend NPU native implementation if Triton is not available, no
functional regression.

### Does this PR introduce _any_ user-facing change?
**NO**
- No changes to any public APIs, interfaces or inference behaviors of
vLLM.
- No impact on the text generation quality and correctness of the large
model.
- The optimization is transparent to end users, only the inference speed
(latency/throughput) is improved without any functional change.

### How was this patch tested?
1. **Environment Validation**: Tested on Ascend NPU platform with
vLLM-Ascend framework, Triton library installed and enabled
(`HAS_TRITON=True`).
2. **Kernel Registration Test**: Verified the Triton RoPE kernel
(`rope_forward_triton`) is successfully registered to
`torch.ops._C_ascend` namespace without any
`ValueError/NameError/SyntaxError`.
3. **Functional Correctness Test**: Run large model (GLM4/MoE) inference
on the Ascend NPU platform, the generated text content is **completely
correct** (no garbled text, no logical errors), consistent with the
original implementation.
4. **Performance Benchmark Test**: Measure the single execution latency
of the RoPE operator before/after optimization, confirm the latency is
stably reduced from 57.1 μs to 9 μs, the performance gain is valid and
stable.
5. **Fallback Mechanism Test**: Manually disable Triton
(`HAS_TRITON=False`), verify the code correctly falls back to the
original Ascend NPU native RoPE implementation, no service crash and
normal inference.
6. **Compatibility Test**: Test with different tensor shapes/sizes of
query/key, all cases work correctly with the Triton kernel, no shape
mismatch error.
- operator supply by Hexiang Wang 
- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@11b6af5

---------

Signed-off-by: ZCG12345 <2097562023@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants