From 70f7d7c1a1fb275cf400de23800b1f720efc3a6d Mon Sep 17 00:00:00 2001 From: AuYang <459461160@qq.com> Date: Tue, 10 Mar 2026 13:57:07 +0800 Subject: [PATCH 1/5] [Bugfix] Warm up Triton autotuner for GDN layers during V1 profiling Signed-off-by: AuYang <459461160@qq.com> --- vllm/model_executor/models/qwen3_next.py | 79 +++++++++++++++++++++++- 1 file changed, 78 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 4c4ff0ccf365..3a9ea364aaff 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -624,6 +624,81 @@ def forward( core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) + @torch.no_grad() + def _warmup_triton_kernels(self, mixed_qkv: torch.Tensor) -> None: + """Trigger Triton autotuner warmup for GDN prefill kernels. + + During V1 profile runs, ``_forward_core`` returns early because + ``attn_metadata`` is ``None``, so the Triton-autotuned kernels + (``solve_tril``, ``chunk_scaled_dot_kkt``, etc.) are never + invoked during profiling. After profiling, vLLM allocates KV + cache using most of the remaining GPU memory. When the first + real inference triggers the Triton autotuner it OOMs because + there is not enough memory left for benchmarking. + + This method runs a minimal forward pass through + ``chunk_gated_delta_rule`` with small dummy tensors (B=1, T=64) + to force autotuning while GPU memory is still plentiful. The + autotuner results are cached globally by Triton, so only the + first layer incurs actual benchmarking cost. + """ + if hasattr(self, "_triton_kernels_warmed_up"): + return + self._triton_kernels_warmed_up = True + + device = mixed_qkv.device + dtype = mixed_qkv.dtype + # Use T=64 to match the chunk_size used by chunk_gated_delta_rule. + T = 64 + num_k_heads = self.num_k_heads // self.tp_size + num_v_heads = self.num_v_heads // self.tp_size + + # Tensor shapes mirror what _forward_core feeds into + # chunk_gated_delta_rule during prefill (IS_VARLEN=True path): + # q/k : [1, T, num_k_heads/tp, head_k_dim] + # v : [1, T, num_v_heads/tp, head_v_dim] + # g/β : [1, T, num_v_heads/tp] + # state: [N, num_v_heads/tp, head_v_dim, head_k_dim] + q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) + k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) + v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) + g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + state = torch.zeros( + 1, + num_v_heads, + self.head_v_dim, + self.head_k_dim, + device=device, + dtype=torch.float32, + ) + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long) + + try: + self.chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=state, + output_final_state=False, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + except Exception: + logger.warning( + "GDN Triton kernel warmup failed for layer %s. " + "First inference may OOM due to Triton autotuner.", + self.prefix, + exc_info=True, + ) + finally: + del q, k, v, g, beta, state, cu_seqlens + torch.accelerator.empty_cache() + + logger.info("GDN Triton kernel warmup completed for layer %s", self.prefix) + def _forward_core( self, mixed_qkv: torch.Tensor, @@ -638,7 +713,9 @@ def _forward_core( attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # V1 profile run + # V1 profile run — trigger Triton autotuner warmup so that + # autotuning completes before KV cache allocation. + self._warmup_triton_kernels(mixed_qkv) return assert isinstance(attn_metadata, dict) From 60a8085cb89d1c97de633197f1238bebe7f413ad Mon Sep 17 00:00:00 2001 From: AuYang <459461160@qq.com> Date: Tue, 10 Mar 2026 14:48:24 +0800 Subject: [PATCH 2/5] Move warmup success log inside try block Signed-off-by: AuYang <459461160@qq.com> --- vllm/model_executor/models/qwen3_next.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 3a9ea364aaff..72ef0f45e050 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -686,6 +686,7 @@ def _warmup_triton_kernels(self, mixed_qkv: torch.Tensor) -> None: cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=True, ) + logger.info("GDN Triton kernel warmup completed for layer %s", self.prefix) except Exception: logger.warning( "GDN Triton kernel warmup failed for layer %s. " @@ -697,8 +698,6 @@ def _warmup_triton_kernels(self, mixed_qkv: torch.Tensor) -> None: del q, k, v, g, beta, state, cu_seqlens torch.accelerator.empty_cache() - logger.info("GDN Triton kernel warmup completed for layer %s", self.prefix) - def _forward_core( self, mixed_qkv: torch.Tensor, From 16cf800e49ba06658c7c94ab34012f12ed3c7abd Mon Sep 17 00:00:00 2001 From: AuYang <459461160@qq.com> Date: Tue, 10 Mar 2026 15:36:14 +0800 Subject: [PATCH 3/5] Address review: rename method, remove @no_grad, clarify decode path Signed-off-by: AuYang <459461160@qq.com> --- vllm/model_executor/models/qwen3_next.py | 42 ++++++++++++++---------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 72ef0f45e050..de078e7296db 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -624,27 +624,30 @@ def forward( core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) - @torch.no_grad() - def _warmup_triton_kernels(self, mixed_qkv: torch.Tensor) -> None: - """Trigger Triton autotuner warmup for GDN prefill kernels. + def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: + """Warm up GDN prefill kernels during V1 profiling. During V1 profile runs, ``_forward_core`` returns early because - ``attn_metadata`` is ``None``, so the Triton-autotuned kernels - (``solve_tril``, ``chunk_scaled_dot_kkt``, etc.) are never - invoked during profiling. After profiling, vLLM allocates KV - cache using most of the remaining GPU memory. When the first - real inference triggers the Triton autotuner it OOMs because - there is not enough memory left for benchmarking. + ``attn_metadata`` is ``None``, so the autotuned kernels used by + ``chunk_gated_delta_rule`` (e.g. ``solve_tril``, + ``chunk_scaled_dot_kkt``) are never invoked. After profiling, + vLLM allocates KV cache using most of the remaining GPU memory. + When the first real inference triggers the autotuner it OOMs + because there is not enough memory left for benchmarking. This method runs a minimal forward pass through ``chunk_gated_delta_rule`` with small dummy tensors (B=1, T=64) to force autotuning while GPU memory is still plentiful. The - autotuner results are cached globally by Triton, so only the - first layer incurs actual benchmarking cost. + autotuner results are cached globally, so only the first layer + incurs actual benchmarking cost. + + The decode path uses ``fused_sigmoid_gating_delta_rule_update`` + which has fixed kernel parameters (no autotuning), so only the + prefill (chunked) path needs warming up. """ - if hasattr(self, "_triton_kernels_warmed_up"): + if hasattr(self, "_prefill_kernels_warmed_up"): return - self._triton_kernels_warmed_up = True + self._prefill_kernels_warmed_up = True device = mixed_qkv.device dtype = mixed_qkv.dtype @@ -686,11 +689,14 @@ def _warmup_triton_kernels(self, mixed_qkv: torch.Tensor) -> None: cu_seqlens=cu_seqlens, use_qk_l2norm_in_kernel=True, ) - logger.info("GDN Triton kernel warmup completed for layer %s", self.prefix) + logger.info( + "GDN prefill kernel warmup completed for layer %s", + self.prefix, + ) except Exception: logger.warning( - "GDN Triton kernel warmup failed for layer %s. " - "First inference may OOM due to Triton autotuner.", + "GDN prefill kernel warmup failed for layer %s. " + "First inference may OOM due to autotuner.", self.prefix, exc_info=True, ) @@ -712,9 +718,9 @@ def _forward_core( attn_metadata: AttentionMetadata = forward_context.attn_metadata if attn_metadata is None: - # V1 profile run — trigger Triton autotuner warmup so that + # V1 profile run — warm up prefill kernels so that # autotuning completes before KV cache allocation. - self._warmup_triton_kernels(mixed_qkv) + self._warmup_prefill_kernels(mixed_qkv) return assert isinstance(attn_metadata, dict) From 907ae157d22721ec900e0d586ac3c99e30904d16 Mon Sep 17 00:00:00 2001 From: AuYang <459461160@qq.com> Date: Tue, 10 Mar 2026 23:24:52 +0800 Subject: [PATCH 4/5] Warmup all BT values for chunk_fwd_kernel_o Signed-off-by: AuYang <459461160@qq.com> --- vllm/model_executor/models/qwen3_next.py | 116 +++++++++++++---------- 1 file changed, 65 insertions(+), 51 deletions(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index de078e7296db..9656512a6b68 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -635,11 +635,17 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: When the first real inference triggers the autotuner it OOMs because there is not enough memory left for benchmarking. - This method runs a minimal forward pass through - ``chunk_gated_delta_rule`` with small dummy tensors (B=1, T=64) - to force autotuning while GPU memory is still plentiful. The - autotuner results are cached globally, so only the first layer - incurs actual benchmarking cost. + This method runs minimal forward passes through + ``chunk_gated_delta_rule`` with small dummy tensors to force + autotuning while GPU memory is still plentiful. The autotuner + results are cached globally, so only the first layer incurs + actual benchmarking cost. + + Most kernels use a fixed ``BT = chunk_size`` (64), but + ``chunk_fwd_kernel_o`` recomputes ``BT`` from the sequence + length: ``min(64, max(16, next_power_of_2(T)))``. Since ``BT`` + is part of its autotune key, we run warmup passes with T = 16, + 32, and 64 to cover all possible ``BT`` values. The decode path uses ``fused_sigmoid_gating_delta_rule_update`` which has fixed kernel parameters (no autotuning), so only the @@ -651,58 +657,66 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: device = mixed_qkv.device dtype = mixed_qkv.dtype - # Use T=64 to match the chunk_size used by chunk_gated_delta_rule. - T = 64 num_k_heads = self.num_k_heads // self.tp_size num_v_heads = self.num_v_heads // self.tp_size - # Tensor shapes mirror what _forward_core feeds into - # chunk_gated_delta_rule during prefill (IS_VARLEN=True path): - # q/k : [1, T, num_k_heads/tp, head_k_dim] - # v : [1, T, num_v_heads/tp, head_v_dim] - # g/β : [1, T, num_v_heads/tp] - # state: [N, num_v_heads/tp, head_v_dim, head_k_dim] - q = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - k = torch.randn(1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype) - v = torch.randn(1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype) - g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) - beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) - state = torch.zeros( - 1, - num_v_heads, - self.head_v_dim, - self.head_k_dim, - device=device, - dtype=torch.float32, - ) - cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long) - - try: - self.chunk_gated_delta_rule( - q=q, - k=k, - v=v, - g=g, - beta=beta, - initial_state=state, - output_final_state=False, - cu_seqlens=cu_seqlens, - use_qk_l2norm_in_kernel=True, + # Run warmup for each possible BT value of chunk_fwd_kernel_o: + # T=16 → BT=16, T=32 → BT=32, T=64 → BT=64. + # Other kernels always use BT=chunk_size(64), so their autotune + # cache is populated on the first pass and reused thereafter. + for T in (16, 32, 64): + q = torch.randn( + 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype + ) + k = torch.randn( + 1, T, num_k_heads, self.head_k_dim, device=device, dtype=dtype ) - logger.info( - "GDN prefill kernel warmup completed for layer %s", - self.prefix, + v = torch.randn( + 1, T, num_v_heads, self.head_v_dim, device=device, dtype=dtype ) - except Exception: - logger.warning( - "GDN prefill kernel warmup failed for layer %s. " - "First inference may OOM due to autotuner.", - self.prefix, - exc_info=True, + g = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + beta = torch.randn(1, T, num_v_heads, device=device, dtype=dtype) + state = torch.zeros( + 1, + num_v_heads, + self.head_v_dim, + self.head_k_dim, + device=device, + dtype=torch.float32, ) - finally: - del q, k, v, g, beta, state, cu_seqlens - torch.accelerator.empty_cache() + cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long) + + try: + self.chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=state, + output_final_state=False, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=True, + ) + except Exception: + logger.warning( + "GDN prefill kernel warmup (T=%d) failed for " + "layer %s. First inference may OOM due to " + "autotuner.", + T, + self.prefix, + exc_info=True, + ) + else: + logger.debug( + "GDN prefill kernel warmup (T=%d) completed for layer %s", + T, + self.prefix, + ) + finally: + del q, k, v, g, beta, state, cu_seqlens + + torch.accelerator.empty_cache() def _forward_core( self, From 0220d643e8bd6f34b081bdec250efea444ad9d5e Mon Sep 17 00:00:00 2001 From: AuYang <459461160@qq.com> Date: Wed, 11 Mar 2026 09:25:40 +0800 Subject: [PATCH 5/5] Use get_state_dtype() for warmup state tensor Signed-off-by: AuYang <459461160@qq.com> --- vllm/model_executor/models/qwen3_next.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 9656512a6b68..27b46affe705 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -659,6 +659,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: dtype = mixed_qkv.dtype num_k_heads = self.num_k_heads // self.tp_size num_v_heads = self.num_v_heads // self.tp_size + _, state_dtype = self.get_state_dtype() # Run warmup for each possible BT value of chunk_fwd_kernel_o: # T=16 → BT=16, T=32 → BT=32, T=64 → BT=64. @@ -682,7 +683,7 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: self.head_v_dim, self.head_k_dim, device=device, - dtype=torch.float32, + dtype=state_dtype, ) cu_seqlens = torch.tensor([0, T], device=device, dtype=torch.long)