Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions vllm_gaudi/models/qwen3_5.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention
from vllm.model_executor.layers.mamba.gdn.qwen_gdn_linear_attn import QwenGatedDeltaNetAttention
from vllm.forward_context import get_forward_context

from vllm_gaudi.ops.causal_conv1d_pytorch import (
Expand All @@ -26,7 +26,7 @@ def _save_ssm_state(core_attn_out, final_state, ssm_state, state_indices):
return core_attn_out


class HPUGatedDeltaNetAttention(GatedDeltaNetAttention):
class HPUGatedDeltaNetAttention(QwenGatedDeltaNetAttention):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -295,10 +295,10 @@ def forward(

# Replace the class in the upstream modules so that both Qwen3-Next and
# Qwen3.5 model definitions instantiate HPUGatedDeltaNetAttention.
import vllm.model_executor.layers.mamba.gdn_linear_attn as _gdn_module # noqa: E402
import vllm.model_executor.layers.mamba.gdn.qwen_gdn_linear_attn as _gdn_module # noqa: E402
import vllm.model_executor.models.qwen3_next as _qwen3_next_module # noqa: E402
import vllm.model_executor.models.qwen3_5 as _qwen3_5_module # noqa: E402

_gdn_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention
_qwen3_next_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention
_qwen3_5_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention
_gdn_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention
_qwen3_next_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention
_qwen3_5_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention
Loading