diff --git a/vllm/model_executor/layers/mamba/gdn_linear_attn.py b/vllm/model_executor/layers/mamba/gdn_linear_attn.py index a621ab962f0a..4f3af330c5eb 100644 --- a/vllm/model_executor/layers/mamba/gdn_linear_attn.py +++ b/vllm/model_executor/layers/mamba/gdn_linear_attn.py @@ -56,10 +56,12 @@ from vllm.platforms import current_platform from vllm.transformers_utils.configs.qwen3_next import Qwen3NextConfig from vllm.triton_utils import tl, triton +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.utils.torch_utils import ( LayerNameType, _encode_layer_name, _resolve_layer_name, + aux_stream, direct_register_custom_op, ) from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata @@ -275,6 +277,19 @@ def __init__( self.forward_xpu if current_platform.is_xpu() else self.forward_cuda ) + # Dual-stream dispatch for the two input projections (in_proj_qkvz and + # in_proj_ba). Re-enabled via the LayerName opaque type (PR #38123) so + # torch.compile no longer regresses cold compile times. + self.aux_stream = aux_stream() + self.events = ( + [ + torch.cuda.Event(enable_timing=False), + torch.cuda.Event(enable_timing=False), + ] + if self.aux_stream is not None + else [None, None] + ) + # QKV self.conv_dim = self.key_dim * 2 + self.value_dim self.conv1d = ColumnParallelLinear( @@ -541,8 +556,12 @@ def forward_cuda( b = b.contiguous() a = a.contiguous() else: - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) - ba, _ = self.in_proj_ba(hidden_states) + mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( + hidden_states, + self.in_proj_qkvz.weight.shape[0], + self.in_proj_ba.weight.shape[0], + _encode_layer_name(self.prefix), + ) if self.gqa_interleaved_layout: # Qwen3-Next: unpack the interleaved GQA layout @@ -594,6 +613,23 @@ def forward_cuda( core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") output[:num_tokens], _ = self.out_proj(core_attn_out) + def _forward_in_proj( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run in_proj_qkvz and in_proj_ba in parallel on two CUDA streams. + + Falls back to sequential execution when aux_stream is unavailable + (e.g. non-CUDA platforms). + """ + projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel( + lambda: self.in_proj_qkvz(hidden_states)[0], + lambda: self.in_proj_ba(hidden_states)[0], + self.events[0], + self.events[1], + self.aux_stream, + ) + return projected_states_qkvz, projected_states_ba + def forward_xpu( self, hidden_states: torch.Tensor, @@ -1063,6 +1099,39 @@ def _forward_core_decode_non_spec( return +def gdn_in_proj( + hidden_states: torch.Tensor, + qkvz_output_size: int, + ba_output_size: int, + layer_name: LayerNameType, +) -> tuple[torch.Tensor, torch.Tensor]: + """Custom op wrapping the dual-stream input projection.""" + layer_name = _resolve_layer_name(layer_name) + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + return self._forward_in_proj(hidden_states) + + +def gdn_in_proj_fake( + hidden_states: torch.Tensor, + qkvz_output_size: int, + ba_output_size: int, + layer_name: LayerNameType, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile.""" + return ( + hidden_states.new_empty(hidden_states.shape[0], qkvz_output_size), + hidden_states.new_empty(hidden_states.shape[0], ba_output_size), + ) + + +direct_register_custom_op( + op_name="gdn_in_proj", + op_func=gdn_in_proj, + fake_impl=gdn_in_proj_fake, +) + + def gdn_attention_core( mixed_qkv: torch.Tensor, b: torch.Tensor,