From f09e865b9fbd457f7f50475a220019fb2409df42 Mon Sep 17 00:00:00 2001 From: haosdent Date: Sat, 9 May 2026 17:38:43 +0800 Subject: [PATCH] [Bugfix] Drop @torch.compile from GDN qkv reshape helpers The @torch.compile(fullgraph=True) decorators added in #40711 on prepare_gdn_attention_core_inputs and rearrange_mixed_qkv crash CUDA-graph capture for Qwen3.5 MTP / MoeMTP: Inductor's first-call Triton autotune runs torch.cuda.synchronize(), which is illegal during stream capture. The non-spec path is autotuned during eager warmup; the spec path's mixed_qkv_spec is None during warmup and only becomes a tensor during capture, so autotune fires inside torch.cuda.graph(...) and the engine core dies with cudaErrorStreamCaptureInvalidated. Removing the decorators fixes the crash. The cat-then-slice bodies were pessimizations without compile fusion, so simplify them to plain split/contiguous/view (and drop the fused round-trip in prepare_gdn_attention_core_inputs). Signed-off-by: haosdent Signed-off-by: haosdent --- .../layers/mamba/gdn_linear_attn.py | 77 +++---------------- 1 file changed, 10 insertions(+), 67 deletions(-) diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index 518e9d4f0cff..84d20e1e85bb 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -524,7 +524,6 @@ def fix_query_key_value_ordering( return query, key, value, z, b, a - @torch.compile(fullgraph=True) def prepare_gdn_attention_core_inputs( self, mixed_qkvz: torch.Tensor, @@ -586,7 +585,7 @@ def prepare_gdn_attention_core_inputs( (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=-1) (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=-1) - mixed_qkv_logical = torch.cat( + mixed_qkv_out = torch.cat( [ query.reshape(num_tokens, -1), key.reshape(num_tokens, -1), @@ -594,59 +593,15 @@ def prepare_gdn_attention_core_inputs( ], dim=-1, ) - - # The split above produces non-contiguous views into the interleaved - # buffer. Concatenating everything into a single flat tensor forces a - # contiguous copy, then slicing back out gives contiguous q/k/v/z/b/a - # tensors that downstream kernels require. Doing this in one cat+slice - # keeps torch.compile in a single Triton graph instead of emitting - # separate copy kernels per tensor. The original code used - # rearrange(...).contiguous() on each tensor individually. - fused = torch.cat( - [ - mixed_qkv_logical.reshape(-1), - z.reshape(-1), - b.reshape(-1), - a.reshape(-1), - ], - dim=0, - ) - - curr = 0 - qkv_numel = mixed_qkv_logical.numel() - z_numel = z.numel() - b_numel = b.numel() - a_numel = a.numel() - - mixed_qkv_out = fused[curr : curr + qkv_numel].view(num_tokens, -1) - curr += qkv_numel - - z_out = fused[curr : curr + z_numel].view( - num_tokens, self.num_v_heads // self.tp_size, self.head_v_dim - ) - curr += z_numel - - b_out = fused[curr : curr + b_numel].view( - num_tokens, self.num_v_heads // self.tp_size - ) - curr += b_numel - - a_out = fused[curr : curr + a_numel].view( - num_tokens, self.num_v_heads // self.tp_size - ) + num_v_per_tp = self.num_v_heads // self.tp_size + z_out = z.reshape(num_tokens, num_v_per_tp, self.head_v_dim) + b_out = b.reshape(num_tokens, num_v_per_tp) + a_out = a.reshape(num_tokens, num_v_per_tp) return mixed_qkv_out, z_out, b_out, a_out - @torch.compile(fullgraph=True) def rearrange_mixed_qkv(self, mixed_qkv): - """Split packed qkv into contiguous (1, seq, heads, dim) tensors. - - The original code used ``rearrange(x, "l (h d) -> 1 l h d", d=...)`` - followed by ``.contiguous()`` on each tensor. This version flattens - all three splits into a single buffer via ``torch.cat`` so that - torch.compile emits one Triton copy kernel instead of three separate - contiguous() calls. - """ + """Split packed qkv into contiguous (1, seq, heads, dim) tensors.""" if mixed_qkv is None: return None, None, None @@ -656,24 +611,12 @@ def rearrange_mixed_qkv(self, mixed_qkv): v_dim = self.value_dim // self.tp_size query, key, value = torch.split(mixed_qkv, [q_dim, k_dim, v_dim], dim=-1) - - fused = torch.cat( - [query.reshape(-1), key.reshape(-1), value.reshape(-1)], dim=0 + return ( + query.contiguous().view(1, seq_len, -1, self.head_k_dim), + key.contiguous().view(1, seq_len, -1, self.head_k_dim), + value.contiguous().view(1, seq_len, -1, self.head_v_dim), ) - q_size = seq_len * q_dim - k_size = seq_len * k_dim - - q_contig = fused[0:q_size] - k_contig = fused[q_size : q_size + k_size] - v_contig = fused[q_size + k_size :] - - query = q_contig.view(1, seq_len, -1, self.head_k_dim) - key = k_contig.view(1, seq_len, -1, self.head_k_dim) - value = v_contig.view(1, seq_len, -1, self.head_v_dim) - - return query, key, value - def forward( self, hidden_states: torch.Tensor,