diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 7db074bf9205..0b39a4000126 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -92,6 +92,72 @@ def _int4_gemm_w4a16_fake( return torch.empty((M, N), dtype=input.dtype, device=input.device) +def _gdn_attention_core_xpu_impl( + core_attn_out: torch.Tensor, + z: torch.Tensor, + projected_states_qkvz: torch.Tensor, + projected_states_ba: torch.Tensor, + layer_name: str, +) -> None: + """Custom op wrapping the XPU SYCL GDN kernel for torch.compile.""" + from vllm.forward_context import get_forward_context + from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata + + forward_context = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + attn_metadata_raw = forward_context.attn_metadata + + if attn_metadata_raw is None: + return + + assert isinstance(attn_metadata_raw, dict) + attn_metadata = attn_metadata_raw[self.prefix] + assert isinstance(attn_metadata, GDNAttentionMetadata) + + # TODO: xpu does not support speculative decoding yet + assert attn_metadata.spec_sequence_masks is None # type: ignore[attr-defined] + + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + torch.ops._xpu_C.gdn_attention( + core_attn_out, + z, + projected_states_qkvz, + projected_states_ba, + self.num_k_heads, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + conv_state=self.kv_cache[0], + ssm_state=self.kv_cache[1], + conv_weights=conv_weights, + conv_bias=self.conv1d.bias, + activation=self.activation, + A_log=self.A_log, + dt_bias=self.dt_bias, + num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined] + num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined] + has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined] + non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined] + non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined] + num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined] + tp_size=self.tp_size, + reorder_input=not self.gqa_interleaved_layout, + ) + + +def _gdn_attention_core_xpu_fake( + core_attn_out: torch.Tensor, + z: torch.Tensor, + projected_states_qkvz: torch.Tensor, + projected_states_ba: torch.Tensor, + layer_name: str, +) -> None: + return + + def _xpu_ops_deepseek_scaling_rope_impl( positions: torch.Tensor, query: torch.Tensor, @@ -618,6 +684,13 @@ def register_ops_once() -> None: fake_impl=_xpu_mxfp4_quantize_fake, ) + direct_register_custom_op( + op_name="gdn_attention_core_xpu", + op_func=_gdn_attention_core_xpu_impl, + mutates_args=["core_attn_out", "z"], + fake_impl=_gdn_attention_core_xpu_fake, + ) + _OPS_REGISTERED = True diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index f7483db52a42..5b726899c2f5 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -744,6 +744,7 @@ class CompilationConfig: "vllm::linear_attention", "vllm::plamo2_mamba_mixer", "vllm::gdn_attention_core", + "vllm::gdn_attention_core_xpu", "vllm::olmo_hybrid_gdn_full_forward", "vllm::kda_attention", "vllm::sparse_attn_indexer", diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 7a0b54335baa..a621ab962f0a 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -618,54 +618,20 @@ def forward_xpu( # ============================================================ # Part 2: Core Attention # ============================================================ - forward_context = get_forward_context() - attn_metadata_raw = forward_context.attn_metadata core_attn_out = torch.zeros( (num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim), dtype=hidden_states.dtype, device=hidden_states.device, ) z = torch.empty_like(core_attn_out) - if attn_metadata_raw is not None: - assert isinstance(attn_metadata_raw, dict) - attn_metadata = attn_metadata_raw[self.prefix] - - # TODO: xpu does not support this param yet - spec_sequence_masks = attn_metadata.spec_sequence_masks # type: ignore[attr-defined] - assert spec_sequence_masks is None - conv_weights = self.conv1d.weight.view( - self.conv1d.weight.size(0), self.conv1d.weight.size(2) - ) - - conv_state = self.kv_cache[0] - ssm_state = self.kv_cache[1] - - torch.ops._xpu_C.gdn_attention( - core_attn_out, - z, - projected_states_qkvz, - projected_states_ba, - self.num_k_heads, - self.num_v_heads, - self.head_k_dim, - self.head_v_dim, - conv_state=conv_state, - ssm_state=ssm_state, - conv_weights=conv_weights, - conv_bias=self.conv1d.bias, - activation=self.activation, - A_log=self.A_log, - dt_bias=self.dt_bias, - num_prefills=attn_metadata.num_prefills, # type: ignore[attr-defined] - num_decodes=attn_metadata.num_decodes, # type: ignore[attr-defined] - has_initial_state=attn_metadata.has_initial_state, # type: ignore[attr-defined] - non_spec_query_start_loc=attn_metadata.non_spec_query_start_loc, # type: ignore[attr-defined] - non_spec_state_indices_tensor=attn_metadata.non_spec_state_indices_tensor, # type: ignore[attr-defined] - num_actual_tokens=attn_metadata.num_actual_tokens, # type: ignore[attr-defined] - tp_size=self.tp_size, - reorder_input=not self.gqa_interleaved_layout, - ) + torch.ops.vllm.gdn_attention_core_xpu( + core_attn_out, + z, + projected_states_qkvz, + projected_states_ba, + self.prefix, + ) # ============================================================ # Part 3: Output Projection