[v0.13.0][Ops] Make triton rope support index_selecting from cos_sin_cache#6602
Conversation
Signed-off-by: Angazenn <supperccell@163.com>
Summary of ChangesHello @Angazenn, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant refactoring of how Rotary Positional Embeddings (RoPE) are handled within the system. The core change involves transitioning from using separate Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request refactors the Rotary Position Embedding (RoPE) implementation to use a unified cos_sin_cache and positions tensor, moving away from global sin and cos caches. This is a good change for code clarity and maintainability. The PR introduces a new Triton kernel and updates fusion passes accordingly. I've found a couple of critical syntax errors that will prevent the code from running. Please see the detailed comments. Additionally, the PR title and description should be updated to follow the repository's style guide.
Suggested PR Title:
[main][Attention][Feature] Refactor RoPE to use cos_sin_cacheSuggested PR Summary:
### What this PR does / why we need it?
This PR refactors the Rotary Position Embedding (RoPE) implementation to use a unified `cos_sin_cache` and `positions` tensor. This change simplifies the RoPE logic by removing global `sin` and `cos` caches and making data dependencies explicit. It introduces a new Triton kernel `rope_forward_triton_with_positions` and updates fusion passes to leverage this new implementation.
This refactoring improves code clarity, maintainability, and is a prerequisite for further performance optimizations by enabling more fusion opportunities.
### Does this PR introduce _any_ user-facing change?
No, this is a refactoring of the internal implementation and does not introduce any user-facing changes.
### How was this patch tested?
CI should pass with existing and any new unit tests. It is recommended to add specific unit tests for the new Triton kernels and to verify the correctness of the `qknorm_rope_fusion_pass`.Signed-off-by: Angazenn <supperccell@163.com>
Signed-off-by: Angazenn <supperccell@163.com>
…s not supported (#6749) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? With #6602 , `npu_rotary_embedding` unifies all rope implementation in AscendRotaryEmbedding, but allows a wider range of application of fusion op `split_qkv_rmsnorm_rope`. This PR restricts the fusion of `split_qkv_rmsnorm_rope` to only cases where `head_size` == 128 && `rotary_dim` == `head_size`. Further enhancement and generalization of this op will be accomplished by @whx-sjtu . ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: Angazenn <supperccell@163.com>
What this PR does / why we need it?
This PR adapts #5450, #6523 to v0.13.0, to fix #6612 .
This PR extends original
rope_triton_forwardandsplit_qkv_rmsnorm_ropeto supportcos_sin_cache&&positionsas inputs. This fully aligns to vLLM RoPE api interface. Compared with earlier implementation for RoPE, the benefits are:cossinbefore model execution, which helps to remove redundant codes.In addition, this kernel change only introduces very small performance degradation. Those
index_selectorchunkoperations are now changed into simple memory access in triton kernelHighlights
rope_forward_ootas a new custom operation, allowing its use in fused compilation passes and providing a dedicated entry point for the new RoPE implementation.Does this PR introduce any user-facing change?
How was this patch tested?