diff --git a/tests/compile/passes/test_fusion.py b/tests/compile/passes/test_fusion.py index a13ee94f61b8..92d1902b2c2f 100644 --- a/tests/compile/passes/test_fusion.py +++ b/tests/compile/passes/test_fusion.py @@ -521,7 +521,7 @@ def __init__(self, num_v_heads: int, head_v_dim: int, tp_size: int = 1): self.head_v_dim = head_v_dim self.tp_size = tp_size - from vllm.model_executor.layers.mamba.gdn_linear_attn import ( + from vllm.model_executor.layers.mamba.gdn.base import ( GatedDeltaNetAttention, ) diff --git a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py index 0ee60d01a815..e7ba3385725b 100644 --- a/vllm/compilation/passes/fusion/rocm_aiter_fusion.py +++ b/vllm/compilation/passes/fusion/rocm_aiter_fusion.py @@ -560,11 +560,14 @@ def __init__(self, config: VllmConfig) -> None: # Discover (num_heads, head_dim) pairs for gated RMSNorm patterns # from GatedDeltaNetAttention layers in static_forward_context. - from vllm.model_executor.layers.mamba.gdn_linear_attn import ( + from vllm.model_executor.layers.mamba.gdn.base import ( GatedDeltaNetAttention, ) - gdn_layers = get_layers_from_vllm_config(config, GatedDeltaNetAttention) + gdn_layers = get_layers_from_vllm_config( + config, + GatedDeltaNetAttention, # type: ignore[type-abstract] + ) gated_norm_shapes: set[tuple[int, int]] = set() for layer in gdn_layers.values(): gated_norm_shapes.add(