Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions vllm/model_executor/layers/mamba/gdn_linear_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,19 +702,33 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()

# All kernels use BT = chunk_size (FLA_CHUNK_SIZE4), so a single pass with
# T = chunk_size is sufficient to populate every autotuner cache.
# All kernels use BT = chunk_size, so a single pass with T = chunk_size
# is sufficient to populate every autotuner cache. Mirror the real
# prefill path here: build q/k/v/g/beta via fused_post_conv_prep and
# then run chunk_gated_delta_rule with in-kernel L2 norm disabled.
T = FLA_CHUNK_SIZE
q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype)
v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype)
# NOTE: g and beta must have the same dtypes as during
# inference, so we construct them with the same function
# (fused_gdn_gating). dummy_a and dummy_b are throwaway
# inputs required by that function.
dummy_mixed_qkv = torch.randn(
T, mixed_qkv.shape[-1], device=device, dtype=dtype
)
dummy_a = torch.randn(T, num_v_heads, device=device, dtype=dtype)
dummy_b = torch.randn(T, num_v_heads, device=device, dtype=dtype)
g, beta = fused_gdn_gating(self.A_log, dummy_a, dummy_b, self.dt_bias)
q, k, v, g, beta = fused_post_conv_prep(
conv_output=dummy_mixed_qkv,
a=dummy_a,
b=dummy_b,
A_log=self.A_log,
dt_bias=self.dt_bias,
num_k_heads=num_k_heads,
head_k_dim=self.head_k_dim,
head_v_dim=self.head_v_dim,
apply_l2norm=True,
output_g_exp=False,
)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
g = g.unsqueeze(0)
beta = beta.unsqueeze(0)
state = torch.zeros(
1,
num_v_heads,
Expand All @@ -735,7 +749,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
initial_state=state,
output_final_state=True,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
use_qk_l2norm_in_kernel=False,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because warmup is for the prefill path, and the real prefill call here uses use_qk_l2norm_in_kernel=False. Leaving it as True means warmup does not match the actual inference path we are trying to prepare.

)
except Exception:
logger.warning(
Expand All @@ -753,7 +767,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
self.prefix,
)
finally:
del q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens
del dummy_mixed_qkv, q, k, v, dummy_a, dummy_b, g, beta, state, cu_seqlens

torch.accelerator.empty_cache()

Expand Down
Loading