From 0507815b794fe59f0663249693f5f63bfc963d0f Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Tue, 10 Mar 2026 22:18:38 -0700 Subject: [PATCH 1/2] [Perf] Enable dual stream execution of input projection for Qwen3 Next Signed-off-by: Xin Yang --- vllm/model_executor/models/qwen3_next.py | 64 ++++++++++++++++++++++-- vllm/utils/multi_stream_utils.py | 51 +++++++++++++++++++ 2 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 vllm/utils/multi_stream_utils.py diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index cfd4c7a56b43..788442b7fd5f 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -82,7 +82,11 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton -from vllm.utils.torch_utils import direct_register_custom_op +from vllm.utils.torch_utils import ( + aux_stream, + direct_register_custom_op, +) +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata @@ -387,6 +391,12 @@ def __init__( self.act = ACT2FN[config.hidden_act] self.layer_norm_epsilon = config.rms_norm_eps self.prefix = prefix + self.aux_stream = aux_stream() + self.events = ( + [torch.cuda.Event(), torch.cuda.Event()] + if current_platform.is_cuda() + else [None, None] + ) self.config = config self.model_config = model_config @@ -615,8 +625,12 @@ def forward( # ============================================================ # Part 1: Input Projection # ============================================================ - projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) - projected_states_ba, _ = self.in_proj_ba(hidden_states) + projected_states_qkvz, projected_states_ba = torch.ops.vllm.gdn_in_proj( + hidden_states, + self.in_proj_qkvz.weight.shape[0], + self.in_proj_ba.weight.shape[0], + self.prefix, + ) query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba ) @@ -751,6 +765,18 @@ def _warmup_prefill_kernels(self, mixed_qkv: torch.Tensor) -> None: torch.accelerator.empty_cache() + def _forward_in_proj( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + projected_states_qkvz, projected_states_ba = maybe_execute_in_parallel( + lambda: self.in_proj_qkvz(hidden_states)[0], + lambda: self.in_proj_ba(hidden_states)[0], + self.events[0], + self.events[1], + self.aux_stream, + ) + return projected_states_qkvz, projected_states_ba + def _forward_core( self, mixed_qkv: torch.Tensor, @@ -1638,6 +1664,32 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() +def gdn_in_proj( + hidden_states: torch.Tensor, + qkvz_output_size: int, + ba_output_size: int, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Custom op for the input projection. + """ + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + return self._forward_in_proj(hidden_states) + + +def gdn_in_proj_fake( + hidden_states: torch.Tensor, + qkvz_output_size: int, + ba_output_size: int, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for torch.compile.""" + return hidden_states.new_empty( + hidden_states.shape[0], qkvz_output_size + ), hidden_states.new_empty(hidden_states.shape[0], ba_output_size) + + def gdn_attention_core( mixed_qkv: torch.Tensor, b: torch.Tensor, @@ -1671,6 +1723,12 @@ def gdn_attention_core_fake( return +direct_register_custom_op( + op_name="gdn_in_proj", + op_func=gdn_in_proj, + fake_impl=gdn_in_proj_fake, +) + direct_register_custom_op( op_name="gdn_attention_core", op_func=gdn_attention_core, diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py new file mode 100644 index 000000000000..e30bf670d6d2 --- /dev/null +++ b/vllm/utils/multi_stream_utils.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Callable +from typing import Any + +import torch + +from vllm.forward_context import get_forward_context +from vllm.utils.torch_utils import aux_stream, direct_register_custom_op + + +def maybe_execute_in_parallel( + fn0: Callable[[], Any], + fn1: Callable[[], Any], + event0: torch.cuda.Event, + event1: torch.cuda.Event, + aux_stream: torch.cuda.Stream | None = None, +) -> tuple[Any, Any]: + """Run two functions potentially in parallel on separate CUDA streams. + + When aux_stream is provided, fn0 runs on the current (default) stream and + fn1 runs on aux_stream, synchronized via CUDA events. When aux_stream is + None, both functions execute sequentially on the current stream. + + This design follows TensorRT-LLM's maybe_execute_in_parallel pattern + (tensorrt_llm/_torch/modules/multi_stream_utils.py). + + Args: + fn0: Callable for the default stream. + fn1: Callable for the auxiliary stream. + event0: CUDA event recorded before fn0 so aux_stream can wait. + event1: CUDA event recorded after fn1 so default stream can wait. + aux_stream: The second CUDA stream for fn1. + Multi-stream is disabled when aux_stream is None. + + Returns: + Tuple of (fn0_result, fn1_result). + """ + if aux_stream is not None: + event0.record() + result0 = fn0() + with torch.cuda.stream(aux_stream): + event0.wait() + result1 = fn1() + event1.record() + event1.wait() + else: + result0 = fn0() + result1 = fn1() + return (result0, result1) From 62b7d747125cb2501ace08f6c2a5674d67ab1ee4 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 12 Mar 2026 20:41:38 -0700 Subject: [PATCH 2/2] Qwen3.5 Signed-off-by: Xin Yang --- vllm/model_executor/models/qwen3_5.py | 8 ++++++-- vllm/model_executor/models/qwen3_next.py | 2 +- vllm/utils/multi_stream_utils.py | 3 --- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index 9b1dc7468fb6..e5967c122945 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -180,12 +180,16 @@ def forward( # ============================================================ # Part 1: Input Projection # ============================================================ - mixed_qkvz, _ = self.in_proj_qkvz(hidden_states) + mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( + hidden_states, + self.in_proj_qkvz.weight.shape[0], + self.in_proj_ba.weight.shape[0], + self.prefix, + ) qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size z_size = self.value_dim // self.tp_size mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) z = z.reshape(z.size(0), -1, self.head_v_dim) - ba, _ = self.in_proj_ba(hidden_states) b, a = ba.chunk(2, dim=-1) b = b.contiguous() diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 788442b7fd5f..1095b11d6f3a 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -82,11 +82,11 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.triton_utils import tl, triton +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.utils.torch_utils import ( aux_stream, direct_register_custom_op, ) -from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata diff --git a/vllm/utils/multi_stream_utils.py b/vllm/utils/multi_stream_utils.py index e30bf670d6d2..3ade910bf99c 100644 --- a/vllm/utils/multi_stream_utils.py +++ b/vllm/utils/multi_stream_utils.py @@ -6,9 +6,6 @@ import torch -from vllm.forward_context import get_forward_context -from vllm.utils.torch_utils import aux_stream, direct_register_custom_op - def maybe_execute_in_parallel( fn0: Callable[[], Any],