Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions vllm_ascend/ops/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,18 @@ def _record_cos_and_sin_cache(cos_cache, sin_cache):
_sin_cache = sin_cache


def _record_cos_and_sin_cache_interleaved(cos_sin_cache):
global _cos_cache
global _sin_cache
if _cos_cache is not None or _sin_cache is not None:
return
hidden_dim = cos_sin_cache.shape[-1] // 2
cos_cache, sin_cache = cos_sin_cache.view(-1, 2, hidden_dim).repeat(
1, 1, 2).chunk(2, dim=1)
_cos_cache = cos_cache.squeeze(1)
_sin_cache = sin_cache.squeeze(1)
Comment on lines +132 to +141
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of this function is overly complex and hard to follow. It can be significantly simplified for better readability and maintainability. The current logic is equivalent to splitting the cos_sin_cache and repeating each part, which can be expressed more directly using chunk and repeat.

Additionally, the function name _record_cos_and_sin_cache_interleaved is misleading because the cos_sin_cache tensor is concatenated ([cos, sin]), not interleaved ([c0, s0, c1, s1, ...]). Consider renaming it to something like _record_cos_and_sin_cache_from_combined to more accurately reflect its purpose.

Suggested change
def _record_cos_and_sin_cache_interleaved(cos_sin_cache):
global _cos_cache
global _sin_cache
if _cos_cache is not None or _sin_cache is not None:
return
hidden_dim = cos_sin_cache.shape[-1] // 2
cos_cache, sin_cache = cos_sin_cache.view(-1, 2, hidden_dim).repeat(
1, 1, 2).chunk(2, dim=1)
_cos_cache = cos_cache.squeeze(1)
_sin_cache = sin_cache.squeeze(1)
def _record_cos_and_sin_cache_interleaved(cos_sin_cache):
global _cos_cache
global _sin_cache
if _cos_cache is not None or _sin_cache is not None:
return
# cos_sin_cache is concatenated from cos and sin, each of size rotary_dim/2.
cos_part, sin_part = cos_sin_cache.chunk(2, dim=-1)
# For neox style, cos and sin are duplicated to match rotary_dim.
_cos_cache = cos_part.repeat(1, 2)
_sin_cache = sin_part.repeat(1, 2)



def update_cos_sin(positions):
global _cos
global _sin
Expand Down Expand Up @@ -252,6 +264,7 @@ def __init__(
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
_record_cos_sin_cache(self.cos_sin_cache)
_record_cos_and_sin_cache_interleaved(self.cos_sin_cache)

def forward_oot(
self,
Expand Down
Loading