Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,101 @@ def forward(
core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)")
output[:num_tokens], _ = self.out_proj(core_attn_out)

def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None:
"""Warm up GDN prefill kernels during V1 profiling.

During V1 profile runs, ``_forward_core`` returns early because
``attn_metadata`` is ``None``, so the autotuned kernels used by
``chunk_gated_delta_rule`` (e.g. ``solve_tril``,
``chunk_scaled_dot_kkt``) are never invoked. After profiling,
vLLM allocates KV cache using most of the remaining GPU memory.
When the first real inference triggers the autotuner it OOMs
because there is not enough memory left for benchmarking.

This method runs minimal forward passes through
``chunk_gated_delta_rule`` with small dummy tensors to force
autotuning while GPU memory is still plentiful. The autotuner
results are cached globally, so only the first layer incurs
actual benchmarking cost.

Most kernels use a fixed ``BT = chunk_size`` (64), but
``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence
length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT``
is part of its autotune key, we run warmup passes with T = 16,
32, and 64 to cover all possible ``BT`` values.

The decode path uses ``fused_sigmoid_gating_delta_rule_update``
which has fixed kernel parameters (no autotuning), so only the
prefill (chunked) path needs warming up.
"""
if hasattr(self, "_prefill_kernels_warmed_up"):
return
self._prefill_kernels_warmed_up = True

device = mixed_qkv.device
dtype = mixed_qkv.dtype
num_k_heads = self.num_k_heads // self.tp_size
num_v_heads = self.num_v_heads // self.tp_size
_, state_dtype = self.get_state_dtype()

# Run warmup for each possible BT value of chunk_fwd_kernel_o:
# T=16 → BT=16, T=32 → BT=32, T=64 → BT=64.
# Other kernels always use BT=chunk_size(64), so their autotune
# cache is populated on the first pass and reused thereafter.
for T in (16, 32, 64):
Comment on lines +685 to +689
Copy link
Copy Markdown
Contributor

@lgeiger lgeiger Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upstream removed the different BT values in chunk_fwd_kernel_o in fla-org/flash-linear-attention#619 which now always sets BT = chunk_size.

Should we do the same in our kernels which would simplify this, or was this change deliberately not pulled in? This would be equivalent to the current behaviour with FLA_GDN_FIX_BT=1. @AuYang261 @ywang96 @ZJY0516 What do you think?

BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I based the _warmup_prefill_kernels on the existing chunk_fwd_kernel_o logic where BT varies with chunk_size, and wasn't aware of the upstream simplification. Happy to defer to maintainers on whether to pull that in.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it improves performance, user experience, or stability, we should keep it in sync with upstream.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
k = torch.randn(
1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype
)
v = torch.randn(
1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype
)
g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype)
state = torch.zeros(
1,
num_v_heads,
self.head_v_dim,
self.head_k_dim,
device=device,
dtype=state_dtype,
)
cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)

try:
self.chunk_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
initial_state=state,
output_final_state=False,
cu_seqlens=cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
except Exception:
logger.warning(
"GDN prefill kernel warmup (T=%d) failed for "
"layer %s. First inference may OOM due to "
"autotuner.",
T,
self.prefix,
exc_info=True,
)
else:
logger.debug(
"GDN prefill kernel warmup (T=%d) completed for layer %s",
T,
self.prefix,
)
finally:
del q, k, v, g, beta, state, cu_seqlens

torch.accelerator.empty_cache()

def _forward_core(
self,
mixed_qkv: torch.Tensor,
Expand All @@ -659,7 +754,9 @@ def _forward_core(
attn_metadata: AttentionMetadata = forward_context.attn_metadata

if attn_metadata is None:
# V1 profile run
# V1 profile run — warm up prefill kernels so that
# autotuning completes before KV cache allocation.
self._warmup_prefill_kernels(mixed_qkv)
return

assert isinstance(attn_metadata, dict)
Expand Down
Loading