From 3756e8ec40f833f6b8c43474e31bbcbcfb06161c Mon Sep 17 00:00:00 2001 From: Ibrahim Arshad <38925737+ibrahim1023@users.noreply.github.com> Date: Tue, 7 Apr 2026 13:27:58 +0400 Subject: [PATCH] fix(gdn): Align prefill warmup with real prefill path Signed-off-by: Ibrahim Arshad <38925737+ibrahim1023@users.noreply.github.com> --- .../layers/mamba/gdn_linear_attn.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index aec855d9aeb1..3ac90c942e73 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -702,19 +702,33 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: num_v_heads = self.num_v_heads // self.tp_size _, state_dtype = self.get_state_dtype() - # All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with - # T = chunk_size is sufficient to populate every autotuner cache. + # All kernels use BT = chunk_size, so a single pass with T = chunk_size + # is sufficient to populate every autotuner cache. Mirror the real + # prefill path here: build q/k/v/g/beta via fused_post_conv_prep and + # then run chunk_gated_delta_rule with in-kernel L2 norm disabled. T = FLA_CHUNK_SIZE - q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) - # NOTE: g and beta must have the same dtypes as during - # inference, so we construct them with the same function - # (fused_gdn_gating). dummy_a and dummy_b are throwaway - # inputs required by that function. + dummy_mixed_qkv = torch.randn( + T, mixed_qkv.shape[-1], device=device, dtype=dtype + ) dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype) dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype) - g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias) + q, k, v, g, beta = fused_post_conv_prep( + conv_output=dummy_mixed_qkv, + a=dummy_a, + b=dummy_b, + A_log=self.A_log, + dt_bias=self.dt_bias, + num_k_heads=num_k_heads, + head_k_dim=self.head_k_dim, + head_v_dim=self.head_v_dim, + apply_l2norm=True, + output_g_exp=False, + ) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + g = g.unsqueeze(0) + beta = beta.unsqueeze(0) state = torch.zeros( 1, num_v_heads, @@ -735,7 +749,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: initial_state=state, output_final_state=True, cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + use_qk_l2norm_in_kernel=False, ) except Exception: logger.warning( @@ -753,7 +767,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: self.prefix, ) finally: - del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens + del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens torch.accelerator.empty_cache()