Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions vllm_ascend/patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,14 @@
# https://github.com/vllm-project/vllm/pull/36225
# Future Plan:
# Remove this patch when vLLM merges the PR.
#
# ** 17. File: worker/patch_qwen3_5.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core`
# Why:
# The class Qwen3_5GatedDeltaNet reuse the `_forward_core` method of Qwen3NextGatedDeltaNet,
# but the ascendC ops of Qwen3NextGatedDeltaNet do not support ssm_state with float32 format.
# How:
# patch Qwen3_5GatedDeltaNet._forward_core to use triton ops like `fused_recurrent_gated_delta_rule`.
# Future Plan:
# Remove this patch when all ops in _forward_core support both Qwen3_5 and Qwen3Next.
5 changes: 5 additions & 0 deletions vllm_ascend/patch/worker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

from vllm.triton_utils import HAS_TRITON

from vllm_ascend.utils import vllm_version_is

if HAS_TRITON:
import vllm_ascend.patch.worker.patch_triton

if not vllm_version_is("v0.16.0"):
import vllm_ascend.patch.worker.patch_qwen3_5 # noqa

# isort: off
import vllm_ascend.patch.platform.patch_sched_yield # noqa
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
Expand Down
257 changes: 257 additions & 0 deletions vllm_ascend/patch/worker/patch_qwen3_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# from collections.abc import Iterable
# mypy: ignore-errors


import torch
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fla.ops import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
from vllm.model_executor.layers.mamba.ops.causal_conv1d import causal_conv1d_update
from vllm.model_executor.models.qwen3_5 import Qwen3_5GatedDeltaNet
from vllm.v1.attention.backend import AttentionMetadata # type: ignore
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.utils import PAD_SLOT_ID

from vllm_ascend.attention.utils import maybe_save_kv_layer_to_connector
from vllm_ascend.ops.triton.fla.sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from vllm_ascend.ops.triton.fused_gdn_gating import fused_gdn_gating_patch
from vllm_ascend.utils import enable_sp


class AscendQwen3_5GatedDeltaNet(Qwen3_5GatedDeltaNet):
def _forward_core(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
# Core attention computation (called by custom op).

# NOTE: The processing logic of Qwen3_5GatedDeltaNet is the same as Qwen3NextGatedDeltaNet.
# However, because the ops `torch_npu.npu_recurrent_gated_delta_rule`
# currently does not support `ssm_state` inputs in float32 format,
# we temporarily retain the current _forward_core implementation.
# Once the ops supports float32 `ssm_state`, this patch should be removed.

forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata

if attn_metadata is None:
# V1 profile run
return

assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
spec_sequence_masks = attn_metadata.spec_sequence_masks
spec_token_indx = attn_metadata.spec_token_indx
non_spec_token_indx = attn_metadata.non_spec_token_indx
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
num_accepted_tokens = attn_metadata.num_accepted_tokens

if not enable_sp():
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]

# 1. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2))
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
mixed_qkv_spec = mixed_qkv
mixed_qkv_non_spec = None
else:
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
else:
mixed_qkv_spec = None
mixed_qkv_non_spec = mixed_qkv

# 1.1: Process the multi-query part
if spec_sequence_masks is not None:
mixed_qkv_spec = causal_conv1d_update(
mixed_qkv_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=spec_state_indices_tensor[:, 0][: attn_metadata.num_spec_decodes],
num_accepted_tokens=num_accepted_tokens,
query_start_loc=spec_query_start_loc,
max_query_len=spec_state_indices_tensor.size(-1),
validate_data=False,
)

# 1.2: Process the remaining part
if attn_metadata.num_prefills > 0:
if mixed_qkv_non_spec is not None:
conv_weights_T = conv_weights.transpose(0, 1)
mixed_qkv_non_spec = torch.ops._C_ascend.causal_conv1d_fn(
mixed_qkv_non_spec,
conv_weights_T,
self.conv1d.bias,
activation=self.activation,
conv_state=self_kv_cache[0],
has_initial_state=has_initial_state,
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
non_spec_query_start_loc=non_spec_query_start_loc,
pad_slot_id=PAD_SLOT_ID,
)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv_non_spec,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[: attn_metadata.num_actual_tokens],
validate_data=True,
)
else:
mixed_qkv_non_spec = None
query_spec, key_spec, value_spec = self.rearrange_mixed_qkv(mixed_qkv_spec)
query_non_spec, key_non_spec, value_non_spec = self.rearrange_mixed_qkv(mixed_qkv_non_spec)

if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
if spec_sequence_masks is not None:
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
g_spec = g
beta_spec = beta
g_non_spec = None
beta_non_spec = None
else:
g_spec = g.index_select(1, spec_token_indx)
beta_spec = beta.index_select(1, spec_token_indx)
g_non_spec = g.index_select(1, non_spec_token_indx)
beta_non_spec = beta.index_select(1, non_spec_token_indx)
else:
g_spec = None
beta_spec = None
g_non_spec = g
beta_non_spec = beta

# 2. Recurrent attention

# 2.1: Process the multi-query part
if spec_sequence_masks is not None:
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_spec,
k=key_spec,
v=value_spec,
g=g_spec,
beta=beta_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
ssm_state_indices=spec_state_indices_tensor,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_spec, last_recurrent_state = None, None

# 2.2: Process the remaining part
if attn_metadata.num_prefills > 0:
initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
initial_state[~has_initial_state, ...] = 0
(
core_attn_out_non_spec,
last_recurrent_state,
) = chunk_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
head_first=False,
use_qk_l2norm_in_kernel=True,
)
# Init cache
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
q=query_non_spec,
k=key_non_spec,
v=value_non_spec,
g=g_non_spec,
beta=beta_non_spec,
initial_state=ssm_state,
inplace_final_state=True,
cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
ssm_state_indices=non_spec_state_indices_tensor,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out_non_spec, last_recurrent_state = None, None

elif attn_metadata.num_decodes > 0:
core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
A_log=self.A_log.contiguous(),
dt_bias=self.dt_bias.contiguous(),
q=query_non_spec.contiguous(),
k=key_non_spec.contiguous(),
v=value_non_spec.contiguous(),
a=a.contiguous(),
b=b.contiguous(),
initial_state_source=ssm_state,
initial_state_indices=non_spec_state_indices_tensor,
cu_seqlens=non_spec_query_start_loc,
use_qk_l2norm_in_kernel=True,
softplus_beta=1.0,
softplus_threshold=20.0,
)

# 3. Merge core attention output
if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
merged_out = torch.empty(
(1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
dtype=core_attn_out_non_spec.dtype,
device=core_attn_out_non_spec.device,
)
merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
if not enable_sp():
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens]
elif spec_sequence_masks is not None:
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens]
else:
if not enable_sp():
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens]
Comment on lines +138 to +253
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable core_attn_out_non_spec is not initialized on all code paths, which can lead to an UnboundLocalError. Specifically, if num_prefills == 0, spec_sequence_masks is None, and num_decodes == 0, the variable is never assigned a value, but it is accessed in the final else block of the merge logic. This is a critical issue that could cause a runtime crash.

        core_attn_out_non_spec = None
        core_attn_out_spec = None
        if attn_metadata.num_prefills > 0 or spec_sequence_masks is not None:
            g, beta = fused_gdn_gating_patch(self.A_log, a, b, self.dt_bias)
            if spec_sequence_masks is not None:
                if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
                    g_spec = g
                    beta_spec = beta
                    g_non_spec = None
                    beta_non_spec = None
                else:
                    g_spec = g.index_select(1, spec_token_indx)
                    beta_spec = beta.index_select(1, spec_token_indx)
                    g_non_spec = g.index_select(1, non_spec_token_indx)
                    beta_non_spec = beta.index_select(1, non_spec_token_indx)
            else:
                g_spec = None
                beta_spec = None
                g_non_spec = g
                beta_non_spec = beta

            # 2. Recurrent attention

            # 2.1: Process the multi-query part
            if spec_sequence_masks is not None:
                core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
                    q=query_spec,
                    k=key_spec,
                    v=value_spec,
                    g=g_spec,
                    beta=beta_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
                    ssm_state_indices=spec_state_indices_tensor,
                    num_accepted_tokens=num_accepted_tokens,
                    use_qk_l2norm_in_kernel=True,
                )
            else:
                core_attn_out_spec, last_recurrent_state = None, None

            # 2.2: Process the remaining part
            if attn_metadata.num_prefills > 0:
                initial_state = ssm_state[non_spec_state_indices_tensor].contiguous()
                initial_state[~has_initial_state, ...] = 0
                (
                    core_attn_out_non_spec,
                    last_recurrent_state,
                ) = chunk_gated_delta_rule(
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    g=g_non_spec,
                    beta=beta_non_spec,
                    initial_state=initial_state,
                    output_final_state=True,
                    cu_seqlens=non_spec_query_start_loc,
                    head_first=False,
                    use_qk_l2norm_in_kernel=True,
                )
                # Init cache
                ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(ssm_state.dtype)
            elif attn_metadata.num_decodes > 0:
                core_attn_out_non_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
                    q=query_non_spec,
                    k=key_non_spec,
                    v=value_non_spec,
                    g=g_non_spec,
                    beta=beta_non_spec,
                    initial_state=ssm_state,
                    inplace_final_state=True,
                    cu_seqlens=non_spec_query_start_loc[: attn_metadata.num_decodes + 1],
                    ssm_state_indices=non_spec_state_indices_tensor,
                    use_qk_l2norm_in_kernel=True,
                )
            else:
                core_attn_out_non_spec, last_recurrent_state = None, None

        elif attn_metadata.num_decodes > 0:
            core_attn_out_non_spec = fused_sigmoid_gating_delta_rule_update(
                A_log=self.A_log.contiguous(),
                dt_bias=self.dt_bias.contiguous(),
                q=query_non_spec.contiguous(),
                k=key_non_spec.contiguous(),
                v=value_non_spec.contiguous(),
                a=a.contiguous(),
                b=b.contiguous(),
                initial_state_source=ssm_state,
                initial_state_indices=non_spec_state_indices_tensor,
                cu_seqlens=non_spec_query_start_loc,
                use_qk_l2norm_in_kernel=True,
                softplus_beta=1.0,
                softplus_threshold=20.0,
            )

        # 3. Merge core attention output
        if spec_sequence_masks is not None and core_attn_out_non_spec is not None:
            merged_out = torch.empty(
                (1, num_actual_tokens, *core_attn_out_spec.shape[2:]),
                dtype=core_attn_out_non_spec.dtype,
                device=core_attn_out_non_spec.device,
            )
            merged_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
            merged_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
            if not enable_sp():
                core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)
            else:
                core_attn_out[:num_actual_tokens] = merged_out.squeeze(0)[:num_actual_tokens]
        elif spec_sequence_masks is not None:
            if not enable_sp():
                core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)
            else:
                core_attn_out[:num_actual_tokens] = core_attn_out_spec.squeeze(0)[:num_actual_tokens]
        elif core_attn_out_non_spec is not None:
            if not enable_sp():
                core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
            else:
                core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)[:num_actual_tokens]

maybe_save_kv_layer_to_connector("", [])


Qwen3_5GatedDeltaNet._forward_core = AscendQwen3_5GatedDeltaNet._forward_core
2 changes: 2 additions & 0 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def load_model(self, model: nn.Module) -> None:
"Qwen2_5_VLForConditionalGeneration",
"Qwen3VLForConditionalGeneration",
"Qwen3VLMoeForConditionalGeneration",
"Qwen3_5ForConditionalGeneration",
"Qwen3_5MoeForConditionalGeneration",
]:
self.model.config.image_token_index = model.config.image_token_id
elif self.get_model_name(model) == "PixtralForConditionalGeneration":
Expand Down
Loading