Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions vllm/_xpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op

logger = init_logger(__name__)

Expand Down Expand Up @@ -54,6 +55,37 @@ def _int4_gemm_w4a16_fake(
return torch.empty((M, N), dtype=input.dtype, device=input.device)


def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
assert key is not None
return torch.ops._xpu_C.deepseek_scaling_rope(
positions, query, key, offsets, cos_sin_cache, rotary_dim, is_neox_style
)


def _xpu_ops_deepseek_scaling_rope_fake(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None,
offsets: torch.Tensor | None,
cos_sin_cache: torch.Tensor | None,
rotary_dim: int,
is_neox_style: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
return query, key


# Global flag to ensure ops are registered only once
_OPS_REGISTERED = False


class xpu_ops:
@staticmethod
def flash_attn_varlen_func(
Expand Down Expand Up @@ -402,3 +434,21 @@ def top_k_per_row_decode(
raw_topk_indices[: topk_indices.shape[0], : topk_indices.shape[1]] = (
topk_indices
)

@staticmethod
def register_ops_once() -> None:
global _OPS_REGISTERED
if not _OPS_REGISTERED:
# register all the custom ops here
direct_register_custom_op(
op_name="xpu_ops_deepseek_scaling_rope",
op_func=_xpu_ops_deepseek_scaling_rope_impl,
mutates_args=[],
fake_impl=_xpu_ops_deepseek_scaling_rope_fake,
dispatch_key=current_platform.dispatch_key,
)

_OPS_REGISTERED = True


Comment thread
yitingw1 marked this conversation as resolved.
xpu_ops.register_ops_once()
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,23 @@ def forward_native(
key = key_rot
return query, key

def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return torch.ops.vllm.xpu_ops_deepseek_scaling_rope(
positions,
query,
key,
offsets,
self._match_cos_sin_cache_dtype(query),
self.rotary_dim,
self.is_neox_style,
)

def forward_hip(
self,
positions: torch.Tensor,
Expand Down
Loading