From f69ed125963e6ac34eef5ef91179d992bbbf53bd Mon Sep 17 00:00:00 2001 From: yitingw1 Date: Tue, 10 Mar 2026 01:24:40 -0700 Subject: [PATCH 1/3] [XPU] Add deepseek_scaling_rope fused kernel Signed-off-by: yitingw1 --- vllm/_xpu_ops.py | 50 +++++++++++++++++++ .../rotary_embedding/deepseek_scaling_rope.py | 17 +++++++ 2 files changed, 67 insertions(+) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 1f64aacd421a..42a6b8dd2261 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -7,6 +7,8 @@ from vllm_xpu_kernels.flash_attn_interface import flash_attn_varlen_func from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -53,6 +55,36 @@ def _int4_gemm_w4a16_fake( return torch.empty((M, N), dtype=input.dtype, device=input.device) +def _xpu_ops_deepseek_scaling_rope_impl( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + offsets: torch.Tensor | None, + cos_sin_cache: torch.Tensor | None, + rotary_dim: int, + is_neox_style: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.ops._xpu_C.deepseek_scaling_rope( + positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style + ) + + +def _xpu_ops_deepseek_scaling_rope_fake( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None, + offsets: torch.Tensor | None, + cos_sin_cache: torch.Tensor | None, + rotary_dim: int, + is_neox_style: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + return query, key + + +# Global flag to ensure ops are registered only once +_OPS_REGISTERED = False + + class xpu_ops: @staticmethod def flash_attn_varlen_func( @@ -157,3 +189,21 @@ def get_scheduler_metadata( "get_scheduler_metadata is not implemented for xpu_ops, returning None." ) return None + + @staticmethod + def register_ops_once() -> None: + global _OPS_REGISTERED + if not _OPS_REGISTERED: + # register all the custom ops here + direct_register_custom_op( + op_name="xpu_ops_deepseek_scaling_rope", + op_func=_xpu_ops_deepseek_scaling_rope_impl, + mutates_args=[], + fake_impl=_xpu_ops_deepseek_scaling_rope_fake, + dispatch_key=current_platform.dispatch_key, + ) + + _OPS_REGISTERED = True + + +xpu_ops.register_ops_once() diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index c3abdc1563b1..1e4fdcdf79f9 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -152,6 +152,23 @@ def forward_native( key = key_rot return query, key + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor | None = None, + offsets: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + return torch.ops.vllm.xpu_ops_deepseek_scaling_rope( + positions, + query, + key, + offsets, + self.cos_sin_cache, + self.rotary_dim, + self.is_neox_style, + ) + def forward_hip( self, positions: torch.Tensor, From aef8de4fd27d9fe29eb9a5477c4aa6527991f620 Mon Sep 17 00:00:00 2001 From: yitingw1 Date: Tue, 10 Mar 2026 19:10:47 -0700 Subject: [PATCH 2/3] Fix Signed-off-by: yitingw1 --- vllm/_xpu_ops.py | 6 +++--- .../layers/rotary_embedding/deepseek_scaling_rope.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 42a6b8dd2261..febb8b596e2c 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -63,7 +63,7 @@ def _xpu_ops_deepseek_scaling_rope_impl( cos_sin_cache: torch.Tensor | None, rotary_dim: int, is_neox_style: bool, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor | None]: return torch.ops._xpu_C.deepseek_scaling_rope( positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style ) @@ -77,7 +77,7 @@ def _xpu_ops_deepseek_scaling_rope_fake( cos_sin_cache: torch.Tensor | None, rotary_dim: int, is_neox_style: bool, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor | None]: return query, key @@ -203,7 +203,7 @@ def register_ops_once() -> None: dispatch_key=current_platform.dispatch_key, ) - _OPS_REGISTERED = True + _OPS_REGISTERED = True xpu_ops.register_ops_once() diff --git a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py index 1e4fdcdf79f9..69c1101664d0 100644 --- a/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +++ b/vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py @@ -164,7 +164,7 @@ def forward_xpu( query, key, offsets, - self.cos_sin_cache, + self._match_cos_sin_cache_dtype(query), self.rotary_dim, self.is_neox_style, ) From 28048e94ab21b999fed768aadeec15fb5cfe9f71 Mon Sep 17 00:00:00 2001 From: yitingw1 Date: Thu, 12 Mar 2026 00:41:02 -0700 Subject: [PATCH 3/3] fix2 Signed-off-by: yitingw1 --- vllm/_xpu_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index f638d9b03108..91f5e0290583 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -63,7 +63,8 @@ def _xpu_ops_deepseek_scaling_rope_impl( cos_sin_cache: torch.Tensor | None, rotary_dim: int, is_neox_style: bool, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor]: + assert key is not None return torch.ops._xpu_C.deepseek_scaling_rope( positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style ) @@ -77,7 +78,7 @@ def _xpu_ops_deepseek_scaling_rope_fake( cos_sin_cache: torch.Tensor | None, rotary_dim: int, is_neox_style: bool, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor]: return query, key