diff --git a/vllm_gaudi/models/qwen3_5.py b/vllm_gaudi/models/qwen3_5.py index d8b3dabf7f..c37b39fa92 100644 --- a/vllm_gaudi/models/qwen3_5.py +++ b/vllm_gaudi/models/qwen3_5.py @@ -1,5 +1,5 @@ import torch -from vllm.model_executor.layers.mamba.gdn_linear_attn import GatedDeltaNetAttention +from vllm.model_executor.layers.mamba.gdn.qwen_gdn_linear_attn import QwenGatedDeltaNetAttention from vllm.forward_context import get_forward_context from vllm_gaudi.ops.causal_conv1d_pytorch import ( @@ -26,7 +26,7 @@ def _save_ssm_state(core_attn_out, final_state, ssm_state, state_indices): return core_attn_out -class HPUGatedDeltaNetAttention(GatedDeltaNetAttention): +class HPUGatedDeltaNetAttention(QwenGatedDeltaNetAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -295,10 +295,10 @@ def forward( # Replace the class in the upstream modules so that both Qwen3-Next and # Qwen3.5 model definitions instantiate HPUGatedDeltaNetAttention. -import vllm.model_executor.layers.mamba.gdn_linear_attn as _gdn_module # noqa: E402 +import vllm.model_executor.layers.mamba.gdn.qwen_gdn_linear_attn as _gdn_module # noqa: E402 import vllm.model_executor.models.qwen3_next as _qwen3_next_module # noqa: E402 import vllm.model_executor.models.qwen3_5 as _qwen3_5_module # noqa: E402 -_gdn_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention -_qwen3_next_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention -_qwen3_5_module.GatedDeltaNetAttention = HPUGatedDeltaNetAttention +_gdn_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention +_qwen3_next_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention +_qwen3_5_module.QwenGatedDeltaNetAttention = HPUGatedDeltaNetAttention