diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 8678aaef364c..9bf82b2d30eb 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -113,18 +113,6 @@ def __init__( if not _is_cuda: cache = cache.to(dtype) - if ( - (not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]) - and not (_is_cpu and _is_cpu_amx_available) - and not (_is_xpu) - ): - from vllm._custom_ops import rotary_embedding - - self.use_fallback_kernel = True - self.fallback_rotary_embedding = rotary_embedding - else: - self.use_fallback_kernel = False - self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -276,34 +264,20 @@ def forward_cuda( offsets: Optional[torch.Tensor] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if not self.use_fallback_kernel: - apply_rope_with_cos_sin_cache_inplace( - positions=positions, - query=query, - key=key, - head_size=self.head_size, - cos_sin_cache=self.cos_sin_cache, - is_neox=self.is_neox_style, - # Compatible with old sgl-kernel - **( - dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg) - if fused_set_kv_buffer_arg is not None - else {} - ), - ) - else: - assert ( - fused_set_kv_buffer_arg is None - ), "save kv cache is not supported for vllm_rotary_embedding." - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - self.fallback_rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + # Compatible with old sgl-kernel + **( + dict(fused_set_kv_buffer_arg=fused_set_kv_buffer_arg) + if fused_set_kv_buffer_arg is not None + else {} + ), + ) return query, key def extra_repr(self) -> str: