diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index b873bfa7f024..91f5e0290583 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -54,6 +55,37 @@ def _int4_gemm_w4a16_fake( return torch.empty((M, N), dtype=input.dtype, device=input.device) +def _xpu_ops_deepseek_scaling_rope_impl( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + offsets: torch.Tensor | None, + cos_sin_cache: torch.Tensor | None, + rotary_dim: int, + is_neox_style: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + assert key is not None + return torch.ops._xpu_C.deepseek_scaling_rope( + positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style + ) + + +def _xpu_ops_deepseek_scaling_rope_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + offsets: torch.Tensor | None, + cos_sin_cache: torch.Tensor | None, + rotary_dim: int, + is_neox_style: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return query, key + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + class xpu_ops: @staticmethod def flash_attn_varlen_func( @@ -402,3 +434,21 @@ def top_k_per_row_decode( raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = ( topk_indices ) + + @staticmethod + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + # register all the custom ops here + direct_register_custom_op( + op_name="xpu_ops_deepseek_scaling_rope", + op_func=_xpu_ops_deepseek_scaling_rope_impl, + mutates_args=[], + fake_impl=_xpu_ops_deepseek_scaling_rope_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + +xpu_ops.register_ops_once() diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index c3abdc1563b1..69c1101664d0 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -152,6 +152,23 @@ def forward_native( key = key_rot return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return torch.ops.vllm.xpu_ops_deepseek_scaling_rope( + positions, + query, + key, + offsets, + self._match_cos_sin_cache_dtype(query), + self.rotary_dim, + self.is_neox_style, + ) + def forward_hip( self, positions: torch.Tensor,