Skip to content

[Main][Ops] Make triton rope support index_selecting from cos_sin_cache#5450

Merged
whx-sjtu merged 7 commits intovllm-project:mainfrom
Angazenn:cos_sin
Feb 11, 2026
Merged

[Main][Ops] Make triton rope support index_selecting from cos_sin_cache#5450
whx-sjtu merged 7 commits intovllm-project:mainfrom
Angazenn:cos_sin

Conversation

@Angazenn
Copy link
Copy Markdown
Collaborator

@Angazenn Angazenn commented Dec 28, 2025

What this PR does / why we need it?

This PR extends original rope_triton_forward and split_qkv_rmsnorm_rope to support cos_sin_cache && positions as inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are:

  1. avoiding pre-computation of cos sin before model execution, which helps to remove redundant codes.
  2. allowing eagle3 draft model to have different rope parameters with main model (see [Bug]: Potential accuracy && accept rate degradation if rope parameters in eagle3 draft model is different from main model. #6612 ). This help to recover accept rate && accuracy in that case.

In addition, this kernel change only introduces very small performance degradation. Those index_select or chunk operations are now changed into simple memory access in triton kernel (For example, https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

Highlights

  • RoPE Cache Unification: Replaced separate _sin and _cos global tensors with a unified cos_sin_cache and explicit positions tensor for Rotary Positional Embeddings (RoPE), streamlining data handling.
  • Triton Kernel Integration: Updated Triton kernels (split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the cos_sin_cache and positions for more efficient and integrated RoPE calculations.
  • Custom Operation Registration: Registered rope_forward_oot as a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation.
  • Refactored RoPE Forward Pass: Modified the rope_forward_oot function to accept the new cos_sin_cache and positions arguments, enabling a more flexible and integrated RoPE application within the system.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

Additional test on Qwen3-235b accuracy:

Aime2024 GSM8K Livecodebench
83.33 96.26 70.23

@github-actions
Copy link
Copy Markdown
Contributor

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Rotary Positional Embedding (RoPE) implementation to use a combined cos_sin_cache, which is a good simplification. However, I've found a few critical issues. The new Triton kernel for RoPE does not correctly use the token positions, leading to incorrect calculations. Additionally, an old RoPE function (rope_forward_triton) is now broken due to using undefined variables. Finally, the shape of the tensors returned by the new RoPE implementation is incorrect, which will likely cause errors in downstream attention operations. These issues need to be addressed before this PR can be merged.

Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/rotary_embedding.py
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/rotary_embedding.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jan 5, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@wangxiyuan
Copy link
Copy Markdown
Collaborator

Any progress? If this PR is still alive, please rebase to main and make CI happy, otherwise you can close it. Thanks

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 2, 2026

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@Angazenn Angazenn force-pushed the cos_sin branch 2 times, most recently from 6b0ec47 to df33b34 Compare February 5, 2026 11:23
@Angazenn Angazenn force-pushed the cos_sin branch 2 times, most recently from e6a1b57 to b3f5744 Compare February 5, 2026 11:27
@Angazenn Angazenn marked this pull request as ready for review February 5, 2026 11:28
@Angazenn Angazenn marked this pull request as draft February 9, 2026 07:45
@Angazenn Angazenn marked this pull request as ready for review February 9, 2026 07:46
@whx-sjtu whx-sjtu changed the title [draft]rope support cos_sin_cache [Ops] Make triton rope support index_selecting from cos_sin_cache Feb 9, 2026
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
@Angazenn Angazenn changed the title [Ops] Make triton rope support index_selecting from cos_sin_cache [Main][Ops] Make triton rope support index_selecting from cos_sin_cache Feb 10, 2026
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Comment thread vllm_ascend/ops/triton/rope.py Outdated
Signed-off-by: Angazenn <supperccell@163.com>
Comment thread vllm_ascend/ops/register_custom_ops.py Outdated
Signed-off-by: Angazenn <supperccell@163.com>
@whx-sjtu whx-sjtu merged commit c0c2eb6 into vllm-project:main Feb 11, 2026
27 checks passed
wangxiyuan pushed a commit that referenced this pull request Feb 11, 2026
…cache (#6602)

### What this PR does / why we need it?
This PR adapts #5450, #6523 to v0.13.0, to fix #6612 .

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see #6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

---------

Signed-off-by: Angazenn <supperccell@163.com>
845473182 pushed a commit to 845473182/vllm-ascend that referenced this pull request Feb 12, 2026
…to qwen3next_rebase

* 'main' of https://github.com/vllm-project/vllm-ascend:
  [Docs] Fix GLM-5 deploy command (vllm-project#6711)
  [npugraph_ex]enable npugraph_ex by default (vllm-project#6664)
  [doc]add GLM5.md (vllm-project#6709)
  [Model] GLM5 adaptation (vllm-project#6642)
  [Bugfix] Update target probs to target logits in rejection sample (vllm-project#6685)
  [Main][Ops] Make triton rope support index_selecting from cos_sin_cache (vllm-project#5450)
  [CI]fix nightly multi node test error for wait for pod ready (vllm-project#6675)
  [main  to main] upgrade main 0210 (vllm-project#6673)
  [main][Quant] Remove unused rotation functions and parameters from W4A4 LAOS quantization (vllm-project#6648)
  [Test][BugFix] Fix torch.rand usage in triton penalty test (vllm-project#6680)
  Add Worker Interface:check_health (vllm-project#6681)
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
banxiaduhuo pushed a commit to banxiaduhuo/vllm-ascend that referenced this pull request Feb 26, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…he (vllm-project#5450)

### What this PR does / why we need it?

This PR extends original `rope_triton_forward` and
`split_qkv_rmsnorm_rope` to support `cos_sin_cache` && `positions` as
inputs. This fully aligns to vLLM RoPE api interface. Compared with
earlier implementation for RoPE, the benefits are:

1. avoiding pre-computation of `cos` `sin` before model execution, which
helps to remove redundant codes.
2. allowing eagle3 draft model to have different rope parameters with
main model (see vllm-project#6612 ). This help to recover accept rate && accuracy in
that case.

In addition, this kernel change only introduces very small performance
degradation. Those `index_select` or `chunk` operations are now changed
into simple memory access in triton kernel (For example,
https://github.com/vllm-project/vllm-ascend/pull/5450/changes#diff-a4c2d3071530df193b98f9bf38553874bc4d47571336711f116c26d019cfbb6aR77-R81).

**Highlights**

- **RoPE Cache Unification**: Replaced separate _sin and _cos global
tensors with a unified cos_sin_cache and explicit positions tensor for
Rotary Positional Embeddings (RoPE), streamlining data handling.
- **Triton Kernel Integration**: Updated Triton kernels
(split_qkv_rmsnorm_rope_kernel, _triton_rope) to directly consume the
cos_sin_cache and positions for more efficient and integrated RoPE
calculations.
- **Custom Operation Registration**: Registered `rope_forward_oot` as a
new custom operation, allowing its use in fused compilation passes and
providing a dedicated entry point for the new RoPE implementation.
- **Refactored RoPE Forward Pass**: Modified the rope_forward_oot
function to accept the new cos_sin_cache and positions arguments,
enabling a more flexible and integrated RoPE application within the
system.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
vllm-project/vllm@5326c89

Additional test on Qwen3-235b accuracy:

| Aime2024 | GSM8K | Livecodebench |
| -------- | -------- | -------- |
| 83.33 | 96.26 | 70.23 |

---------

Signed-off-by: Angazenn <supperccell@163.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants