diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index c5c285ca0fc..0b68a21919b 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -14,6 +14,8 @@ if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace +else: + from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -82,12 +84,6 @@ def __init__( # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability if not _is_cuda: cache = cache.to(dtype) - - if not _is_cuda or self.head_size not in [64, 128, 256, 512]: - from vllm._custom_ops import rotary_embedding - - self.vllm_rotary_embedding = rotary_embedding - self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -164,7 +160,7 @@ def forward_cuda( ) else: self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.vllm_rotary_embedding( + vllm_rotary_embedding( positions, query, key,