diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index c5c02d4bcc98..c5d311acf26d 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -645,6 +645,101 @@ def forward( core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) + def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: + """Warm up GDN prefill kernels during V1 profiling. + + During V1 profile runs, ``_forward_core`` returns early because + ``attn_metadata`` is ``None``, so the autotuned kernels used by + ``chunk_gated_delta_rule`` (e.g. ``solve_tril``, + ``chunk_scaled_dot_kkt``) are never invoked. After profiling, + vLLM allocates KV cache using most of the remaining GPU memory. + When the first real inference triggers the autotuner it OOMs + because there is not enough memory left for benchmarking. + + This method runs minimal forward passes through + ``chunk_gated_delta_rule`` with small dummy tensors to force + autotuning while GPU memory is still plentiful. The autotuner + results are cached globally, so only the first layer incurs + actual benchmarking cost. + + Most kernels use a fixed ``BT = chunk_size`` (64), but + ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence + length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT`` + is part of its autotune key, we run warmup passes with T = 16, + 32, and 64 to cover all possible ``BT`` values. + + The decode path uses ``fused_sigmoid_gating_delta_rule_update`` + which has fixed kernel parameters (no autotuning), so only the + prefill (chunked) path needs warming up. + """ + if hasattr(self, "_prefill_kernels_warmed_up"): + return + self._prefill_kernels_warmed_up = True + + device = mixed_qkv.device + dtype = mixed_qkv.dtype + num_k_heads = self.num_k_heads // self.tp_size + num_v_heads = self.num_v_heads // self.tp_size + _, state_dtype = self.get_state_dtype() + + # Run warmup for each possible BT value of chunk_fwd_kernel_o: + # T=16 → BT=16, T=32 → BT=32, T=64 → BT=64. + # Other kernels always use BT=chunk_size(64), so their autotune + # cache is populated on the first pass and reused thereafter. + for T in (16, 32, 64): + q = torch.randn( + 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype + ) + k = torch.randn( + 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype + ) + v = torch.randn( + 1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype + ) + g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + state = torch.zeros( + 1, + num_v_heads, + self.head_v_dim, + self.head_k_dim, + device=device, + dtype=state_dtype, + ) + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long) + + try: + self.chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=state, + output_final_state=False, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + except Exception: + logger.warning( + "GDN prefill kernel warmup (T=%d) failed for " + "layer %s. First inference may OOM due to " + "autotuner.", + T, + self.prefix, + exc_info=True, + ) + else: + logger.debug( + "GDN prefill kernel warmup (T=%d) completed for layer %s", + T, + self.prefix, + ) + finally: + del q, k, v, g, beta, state, cu_seqlens + + torch.accelerator.empty_cache() + def _forward_core( self, mixed_qkv: torch.Tensor, @@ -659,7 +754,9 @@ def _forward_core( attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # V1 profile run + # V1 profile run — warm up prefill kernels so that + # autotuning completes before KV cache allocation. + self._warmup_prefill_kernels(mixed_qkv) return assert isinstance(attn_metadata, dict)