Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand Down Expand Up @@ -246,6 +247,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
Expand All @@ -272,6 +274,7 @@ def __init__(
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def __init__(
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
Expand Down Expand Up @@ -393,6 +394,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type: type[nn.Module] = Qwen2MoeDecoderLayer,
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
Expand All @@ -418,6 +420,7 @@ def __init__(
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
Expand Down
29 changes: 24 additions & 5 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cuda

Qwen3Config = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()


class Qwen3Attention(nn.Module):
Expand All @@ -49,6 +51,7 @@ def __init__(
rms_norm_eps: float = None,
attention_bias: bool = False,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -109,15 +112,27 @@ def __init__(
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
self.alt_stream = alt_stream

def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @zhyncs
why get_is_capture_mode is used as a condition to run in 2 streams?

current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k

Expand All @@ -143,6 +158,7 @@ def __init__(
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
Expand All @@ -163,6 +179,7 @@ def __init__(
rms_norm_eps=config.rms_norm_eps,
attention_bias=config.attention_bias,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
)
self.mlp = Qwen3MLP(
hidden_size=self.hidden_size,
Expand Down Expand Up @@ -208,11 +225,13 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3DecoderLayer,
alt_stream=alt_stream,
)


Expand Down
29 changes: 24 additions & 5 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
Expand All @@ -76,11 +77,12 @@
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty

Qwen3MoeConfig = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()


class Qwen3MoeSparseMoeBlock(nn.Module):
Expand Down Expand Up @@ -352,6 +354,7 @@ def __init__(
attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -421,15 +424,27 @@ def __init__(

self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.alt_stream = alt_stream

def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k

Expand Down Expand Up @@ -489,6 +504,7 @@ def __init__(
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
) -> None:
super().__init__()
self.config = config
Expand All @@ -514,6 +530,7 @@ def __init__(
attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
alt_stream=alt_stream,
)

self.layer_id = layer_id
Expand Down Expand Up @@ -657,11 +674,13 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
alt_stream = torch.cuda.Stream() if _is_cuda else None
super().__init__(
config=config,
quant_config=quant_config,
prefix=prefix,
decoder_layer_type=Qwen3MoeDecoderLayer,
alt_stream=alt_stream,
)


Expand Down
Loading