Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,17 @@
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix
from sglang.srt.utils import add_prefix, is_cuda

Qwen3Config = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()


class Qwen3Attention(nn.Module):
Expand Down Expand Up @@ -110,14 +112,27 @@ def __init__(
prefix=add_prefix("attn", prefix),
)

self.alt_stream = torch.cuda.Stream() if _is_cuda else None

def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, @zhyncs
why get_is_capture_mode is used as a condition to run in 2 streams?

current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k

Expand Down
24 changes: 19 additions & 5 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
Expand All @@ -76,11 +77,12 @@
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty

Qwen3MoeConfig = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()


class Qwen3MoeSparseMoeBlock(nn.Module):
Expand Down Expand Up @@ -421,15 +423,27 @@ def __init__(

self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.alt_stream = torch.cuda.Stream() if _is_cuda else None

def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.q_norm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
q = q_by_head.view(q.shape)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
return q, k

Expand Down
Loading