diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 63aa3e28e2c..28bde0747d7 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -129,6 +129,18 @@ def _record_cos_and_sin_cache(cos_cache, sin_cache): _sin_cache = sin_cache +def _record_cos_and_sin_cache_interleaved(cos_sin_cache): + global _cos_cache + global _sin_cache + if _cos_cache is not None or _sin_cache is not None: + return + hidden_dim = cos_sin_cache.shape[-1] // 2 + cos_cache, sin_cache = cos_sin_cache.view(-1, 2, hidden_dim).repeat( + 1, 1, 2).chunk(2, dim=1) + _cos_cache = cos_cache.squeeze(1) + _sin_cache = sin_cache.squeeze(1) + + def update_cos_sin(positions): global _cos global _sin @@ -252,6 +264,7 @@ def __init__( super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) _record_cos_sin_cache(self.cos_sin_cache) + _record_cos_and_sin_cache_interleaved(self.cos_sin_cache) def forward_oot( self,