Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions vllm/_xpu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,72 @@ def _int4_gemm_w4a16_fake(
return torch.empty((M, N), dtype=input.dtype, device=input.device)


def _gdn_attention_core_xpu_impl(
core_attn_out: torch.Tensor,
z: torch.Tensor,
projected_states_qkvz: torch.Tensor,
projected_states_ba: torch.Tensor,
layer_name: str,
) -> None:
"""Custom op wrapping the XPU SYCL GDN kernel for torch.compile."""
from vllm.forward_context import get_forward_context
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata

forward_context = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
attn_metadata_raw = forward_context.attn_metadata

if attn_metadata_raw is None:
return

assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)

# TODO: xpu does not support speculative decoding yet
assert attn_metadata.spec_sequence_masks is None # type: ignore[attr-defined]

conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)

torch.ops._xpu_C.gdn_attention(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=self.kv_cache[0],
ssm_state=self.kv_cache[1],
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)


def _gdn_attention_core_xpu_fake(
core_attn_out: torch.Tensor,
z: torch.Tensor,
projected_states_qkvz: torch.Tensor,
projected_states_ba: torch.Tensor,
layer_name: str,
) -> None:
return


def _xpu_ops_deepseek_scaling_rope_impl(
positions: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -618,6 +684,13 @@ def register_ops_once() -> None:
fake_impl=_xpu_mxfp4_quantize_fake,
)

direct_register_custom_op(
op_name="gdn_attention_core_xpu",
op_func=_gdn_attention_core_xpu_impl,
mutates_args=["core_attn_out", "z"],
fake_impl=_gdn_attention_core_xpu_fake,
)

_OPS_REGISTERED = True


Expand Down
1 change: 1 addition & 0 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ class CompilationConfig:
"vllm::linear_attention",
"vllm::plamo2_mamba_mixer",
"vllm::gdn_attention_core",
"vllm::gdn_attention_core_xpu",
"vllm::olmo_hybrid_gdn_full_forward",
"vllm::kda_attention",
"vllm::sparse_attn_indexer",
Expand Down
48 changes: 7 additions & 41 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,54 +618,20 @@ def forward_xpu(
# ============================================================
# Part 2: Core Attention
# ============================================================
forward_context = get_forward_context()
attn_metadata_raw = forward_context.attn_metadata
core_attn_out = torch.zeros(
(num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
z = torch.empty_like(core_attn_out)
if attn_metadata_raw is not None:
assert isinstance(attn_metadata_raw, dict)
attn_metadata = attn_metadata_raw[self.prefix]

# TODO: xpu does not support this param yet
spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined]
assert spec_sequence_masks is None

conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)

conv_state = self.kv_cache[0]
ssm_state = self.kv_cache[1]

torch.ops._xpu_C.gdn_attention(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.num_k_heads,
self.num_v_heads,
self.head_k_dim,
self.head_v_dim,
conv_state=conv_state,
ssm_state=ssm_state,
conv_weights=conv_weights,
conv_bias=self.conv1d.bias,
activation=self.activation,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined]
num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined]
has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined]
non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined]
non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined]
num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined]
tp_size=self.tp_size,
reorder_input=not self.gqa_interleaved_layout,
)
torch.ops.vllm.gdn_attention_core_xpu(
core_attn_out,
z,
projected_states_qkvz,
projected_states_ba,
self.prefix,
)
Comment on lines +628 to +634
Copy link

Copilot AI Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops.vllm.gdn_attention_core_xpu is registered in vllm/_xpu_ops.py, but this module is not imported anywhere in this file. If the runtime path that hits forward_xpu hasn’t already imported vllm._xpu_ops, the operator won’t exist and this call will raise at runtime. Consider importing vllm._xpu_ops (or from vllm._xpu_ops import xpu_ops to trigger registration) under the XPU path before invoking the op, or registering the op alongside gdn_attention_core in this module to make registration unconditionally happen when this layer is imported.

Copilot uses AI. Check for mistakes.

# ============================================================
# Part 3: Output Projection
Expand Down
Loading