Skip to content

[Fusion] [Graph] Add qknorm rope fusion operator#4711

Merged
wangxiyuan merged 40 commits intovllm-project:mainfrom
wxsIcey:qknorm_rope_fusion
Dec 17, 2025
Merged

[Fusion] [Graph] Add qknorm rope fusion operator#4711
wangxiyuan merged 40 commits intovllm-project:mainfrom
wxsIcey:qknorm_rope_fusion

Conversation

@wxsIcey
Copy link
Copy Markdown
Collaborator

@wxsIcey wxsIcey commented Dec 4, 2025

What this PR does / why we need it?

This PR add qkv_rmsnorm_rope operator and introduces a graph fusion pass for qknorm_rope operations. The implementation includes a new configuration flag, a pattern matching pass using torch._inductor.pattern_matcher, and a custom Triton kernel for the fused operation.

Co-authored-by: Angazenn supperccell@163.com

Does this PR introduce any user-facing change?

Yes, add new additional_config

How was this patch tested?

local test

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a graph fusion pass for qknorm_rope operations on Ascend hardware, which is a great step for performance optimization. The implementation includes a new configuration flag, a pattern matching pass using torch._inductor.pattern_matcher, and a custom Triton kernel for the fused operation. The code is well-structured, but I've identified several areas for improvement regarding code quality, robustness, and maintainability. My review comments focus on removing debug artifacts, improving code clarity and consistency, enhancing robustness by avoiding hardcoded values and unsafe module-level initializations, and addressing significant code duplication.

Comment thread vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py Outdated
Comment thread vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Comment thread vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Comment thread vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Comment thread vllm_ascend/compilation/passes/qknorm_rope_fusion_pass.py
Comment thread vllm_ascend/compilation/passes/qkvnorm_rope_fusion_pass.py Outdated
Comment thread vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Dec 4, 2025

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

@wxsIcey wxsIcey changed the title [Fusion] [Graph] Add qknorm rope fusion [Fusion] [Graph] Add qknorm rope fusion operator Dec 5, 2025
Comment thread vllm_ascend/worker/model_runner_v1.py Outdated
return q_output, k_output, v_output


direct_register_custom_op(op_name="qkv_rmsnorm_rope",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use import torch_npu._inductor

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The pattern_matcher method of inductor does not support the triton operator. It does support torch.ops.aten (aten operator), torch.ops.npu (custom operator), and torch.add (PyTorch API). Therefore, it is wrapped as a custom op.

return driver.active.utils.get_device_properties(device)


num_vectorcore = get_npu_properties()["num_vectorcore"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this parameter has already been defined in triton/utils.py

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have modified it.

@wxsIcey wxsIcey marked this pull request as ready for review December 9, 2025 02:04
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@wxsIcey wxsIcey added ready read for review ready-for-test start test by label for PR labels Dec 11, 2025
@wxsIcey wxsIcey requested a review from whx-sjtu December 11, 2025 03:02
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.


return q_rope, k_rope, v

def replacement(qkv: torch.Tensor, q_weight: torch.Tensor,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pattern in 'if xxx else: torch.ops.vllm.qkv_rmsnorm_rope ’ need support in future releases

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't perform any special checks in the pattern. You can add a new pattern match.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

@github-actions
Copy link
Copy Markdown
Contributor

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
@wangxiyuan wangxiyuan merged commit cadfa5d into vllm-project:main Dec 17, 2025
25 checks passed
@wangxiyuan
Copy link
Copy Markdown
Collaborator

Additional config doc should be updated. Please submit a follow-up PR. @wxsIcey

chenaoxuan pushed a commit to chenaoxuan/vllm-ascend that referenced this pull request Dec 20, 2025
### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
### What this PR does / why we need it?
This PR add `qkv_rmsnorm_rope` operator and introduces a graph fusion
pass for `qknorm_rope` operations. The implementation includes a new
configuration flag, a pattern matching pass using
`torch._inductor.pattern_matcher`, and a custom Triton kernel for the
fused operation.

Co-authored-by: Angazenn
[supperccell@163.com](mailto:supperccell@163.com)

### Does this PR introduce _any_ user-facing change?
Yes, add new additional_config

- vLLM version: v0.12.0
- vLLM main:
vllm-project/vllm@ad32e3e

---------

Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:core module:ops ready read for review ready-for-test start test by label for PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants