From 0a3df409f98ac694ddf5599f71f9c037815dc96f Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Thu, 3 Jul 2025 07:07:48 +0000 Subject: [PATCH 1/2] add two stream norm for qwen3 Co-authored-by: ispobock --- python/sglang/srt/models/qwen3.py | 25 ++++++++++++++++++++----- python/sglang/srt/models/qwen3_moe.py | 24 +++++++++++++++++++----- 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index ae7bbfd4cae..51b50b832bb 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -23,15 +23,17 @@ from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP from sglang.srt.models.qwen2 import Qwen2Model -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_cuda Qwen3Config = None logger = logging.getLogger(__name__) +_is_cuda = is_cuda() class Qwen3Attention(nn.Module): @@ -110,14 +112,27 @@ def __init__( prefix=add_prefix("attn", prefix), ) + self.alt_stream = torch.cuda.Stream() if _is_cuda else None + def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) q = q_by_head.view(q.shape) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) return q, k diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c76326ec01b..35f564409fe 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -67,6 +67,7 @@ VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, @@ -76,11 +77,12 @@ from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher -from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty Qwen3MoeConfig = None logger = logging.getLogger(__name__) +_is_cuda = is_cuda() class Qwen3MoeSparseMoeBlock(nn.Module): @@ -421,15 +423,27 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.alt_stream = torch.cuda.Stream() if _is_cuda else None def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: - q_by_head = q.reshape(-1, self.head_dim) - q_by_head = self.q_norm(q_by_head) + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) q = q_by_head.view(q.shape) - k_by_head = k.reshape(-1, self.head_dim) - k_by_head = self.k_norm(k_by_head) k = k_by_head.view(k.shape) return q, k From fab74719ef48c7ed2b340cb4ebd0937a98a869a3 Mon Sep 17 00:00:00 2001 From: yizhang2077 <1109276519@qq.com> Date: Thu, 3 Jul 2025 09:51:12 +0000 Subject: [PATCH 2/2] move stream to model --- python/sglang/srt/models/qwen2.py | 3 +++ python/sglang/srt/models/qwen2_moe.py | 3 +++ python/sglang/srt/models/qwen3.py | 8 ++++++-- python/sglang/srt/models/qwen3_moe.py | 7 ++++++- 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 10ac84eccc9..ab0569bc2e3 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -185,6 +185,7 @@ def __init__( layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -246,6 +247,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -272,6 +274,7 @@ def __init__( config=config, quant_config=quant_config, prefix=prefix, + alt_stream=alt_stream, ), pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 0968ba0f437..95f0fcb70be 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -291,6 +291,7 @@ def __init__( layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -393,6 +394,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer, + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -418,6 +420,7 @@ def __init__( config=config, quant_config=quant_config, prefix=prefix, + alt_stream=alt_stream, ), pp_rank=self.pp_group.rank_in_group, pp_size=self.pp_group.world_size, diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 51b50b832bb..9423a1379bb 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -51,6 +51,7 @@ def __init__( rms_norm_eps: float = None, attention_bias: bool = False, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -111,8 +112,7 @@ def __init__( layer_id=layer_id, prefix=add_prefix("attn", prefix), ) - - self.alt_stream = torch.cuda.Stream() if _is_cuda else None + self.alt_stream = alt_stream def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor @@ -158,6 +158,7 @@ def __init__( layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -178,6 +179,7 @@ def __init__( rms_norm_eps=config.rms_norm_eps, attention_bias=config.attention_bias, prefix=add_prefix("self_attn", prefix), + alt_stream=alt_stream, ) self.mlp = Qwen3MLP( hidden_size=self.hidden_size, @@ -223,11 +225,13 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: + alt_stream = torch.cuda.Stream() if _is_cuda else None super().__init__( config=config, quant_config=quant_config, prefix=prefix, decoder_layer_type=Qwen3DecoderLayer, + alt_stream=alt_stream, ) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 35f564409fe..5a28444387a 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -354,6 +354,7 @@ def __init__( attention_bias: bool = False, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -423,7 +424,7 @@ def __init__( self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) - self.alt_stream = torch.cuda.Stream() if _is_cuda else None + self.alt_stream = alt_stream def _apply_qk_norm( self, q: torch.Tensor, k: torch.Tensor @@ -503,6 +504,7 @@ def __init__( layer_id: int, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.config = config @@ -528,6 +530,7 @@ def __init__( attention_bias=attention_bias, quant_config=quant_config, prefix=add_prefix("self_attn", prefix), + alt_stream=alt_stream, ) self.layer_id = layer_id @@ -671,11 +674,13 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: + alt_stream = torch.cuda.Stream() if _is_cuda else None super().__init__( config=config, quant_config=quant_config, prefix=prefix, decoder_layer_type=Qwen3MoeDecoderLayer, + alt_stream=alt_stream, )