diff --git a/python/sglang/srt/hardware_backend/npu/utils.py b/python/sglang/srt/hardware_backend/npu/utils.py index 49332ccf4ca1..a0515f4f80e5 100644 --- a/python/sglang/srt/hardware_backend/npu/utils.py +++ b/python/sglang/srt/hardware_backend/npu/utils.py @@ -178,3 +178,67 @@ def get_indexer_weight_stream(): if indexer_weight_stream is None: indexer_weight_stream = torch.npu.Stream() return indexer_weight_stream + + +share_stream = None +routed_stream = None + + +def get_share_stream(): + global share_stream + return share_stream + + +def set_share_stream(stream): + global share_stream + share_stream = stream + # TODO LKL: set stream limit has impact on precision + # torch.npu.set_stream_limit(share_stream, 8, 16) + + +def get_routed_stream(): + global routed_stream + return routed_stream + + +def set_routed_stream(stream): + global routed_stream + routed_stream = stream + # TODO LKL: set stream limit has impact on precision + # torch.npu.set_stream_limit(routed_stream, 16, 32) + + +def wait_share_stream(): + stream = get_share_stream() + if stream is not None: + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(stream) + + +def wait_routed_stream(): + stream = get_routed_stream() + if stream is not None: + cur_stream = torch.get_device_module().current_stream() + cur_stream.wait_stream(stream) + + +def process_shared_expert(hidden_states, forward_func): + stream = get_share_stream() + if stream is None: + stream = torch.get_device_module().Stream() + set_share_stream(stream) + stream.wait_stream(torch.get_device_module().current_stream()) + with torch.get_device_module().stream(stream): + shared_output = forward_func(hidden_states) + return shared_output + + +def process_routed_expert(hidden_states, topk_output, forward_func): + stream = get_routed_stream() + if stream is None: + stream = torch.get_device_module().Stream() + set_routed_stream(stream) + stream.wait_stream(torch.get_device_module().current_stream()) + with torch.get_device_module().stream(stream): + shared_output = forward_func(hidden_states, topk_output) + return shared_output diff --git a/python/sglang/srt/layers/quantization/modelslim/modelslim.py b/python/sglang/srt/layers/quantization/modelslim/modelslim.py index 84acecccc415..86a7bfbebfbf 100644 --- a/python/sglang/srt/layers/quantization/modelslim/modelslim.py +++ b/python/sglang/srt/layers/quantization/modelslim/modelslim.py @@ -58,13 +58,13 @@ def _rmsnorm_forward_oot( residual: Optional[torch.Tensor] = None, post_residual_addition: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias - if not x.is_contiguous(): x = x.contiguous() if residual is not None: if post_residual_addition is not None: residual = residual + post_residual_addition + from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias + out, residual_out = add_rmsnorm_bias( x, residual, diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index fcf62efc8727..ab40fdfcd9d5 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -34,6 +34,7 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) +from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo @@ -91,6 +92,7 @@ is_cuda, is_hip, is_non_idle_and_non_empty, + is_npu, log_info_on_rank0, make_layers, ) @@ -102,10 +104,19 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() +_is_npu = is_npu() _device_sm = get_device_sm() logger = logging.getLogger(__name__) +if _is_npu: + from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope + + from sglang.srt.hardware_backend.npu.utils import ( + process_shared_expert, + wait_share_stream, + ) + class Glm4MoeMLP(nn.Module): def __init__( @@ -278,17 +289,39 @@ def forward_prepare( if hidden_states.shape[0] == 0: return hidden_states, forward_batch, None qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - if self.use_qk_norm: - q, k = apply_qk_norm( - q=q, - k=k, - q_norm=self.q_norm, - k_norm=self.k_norm, - head_dim=self.head_dim, - alt_stream=self.alt_stream, + + if ( + not _is_npu + or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed() + ): + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = apply_qk_norm( + q=q, + k=k, + q_norm=self.q_norm, + k_norm=self.k_norm, + head_dim=self.head_dim, + alt_stream=self.alt_stream, + ) + q, k = self.rotary_emb(positions, q, k) + else: + if self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer: + self.rotary_emb.get_cos_sin_with_position(positions) + q, k, v = split_qkv_rmsnorm_rope( + qkv, + self.rotary_emb.position_sin, + self.rotary_emb.position_cos, + self.q_size, + self.kv_size, + self.head_dim, + eps=self.q_norm.variance_epsilon, + q_weight=self.q_norm.weight, + k_weight=self.k_norm.weight, + q_bias=getattr(self.q_norm, "bias", None), + k_bias=getattr(self.k_norm, "bias", None), ) - q, k = self.rotary_emb(positions, q, k) + inner_state = q, k, v, forward_batch return None, forward_batch, inner_state @@ -562,10 +595,24 @@ def forward_deepep( self, hidden_states: torch.Tensor, forward_batch: ForwardBatch ) -> torch.Tensor: shared_output = None + enable_npu_dual_stream = ( + _is_npu + and ( + forward_batch.forward_mode.is_extend() + or forward_batch.forward_mode.is_target_verify() + ) + and envs.SGLANG_NPU_USE_MULTI_STREAM.get() + ) + if hidden_states.shape[0] > 0: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) - shared_output = self._forward_shared_experts(hidden_states) + if enable_npu_dual_stream: + shared_output = process_shared_expert( + hidden_states, self._forward_shared_experts + ) + else: + shared_output = self._forward_shared_experts(hidden_states) topk_output = self.topk( hidden_states, router_logits, @@ -576,10 +623,13 @@ def forward_deepep( ) else: topk_output = self.topk.empty_topk_output(hidden_states.device) + final_hidden_states = self.experts( hidden_states=hidden_states, topk_output=topk_output, ) + if enable_npu_dual_stream: + wait_share_stream() if shared_output is not None: x = shared_output diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 1f6e753646cb..6d2ef89724da 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -14,6 +14,7 @@ """Inference-only GLM-4.5, GLM-4.6 Speculative Decoding.""" +import contextlib import logging from typing import Iterable, Optional, Tuple @@ -22,6 +23,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.environ import temp_set_env from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm @@ -126,7 +128,10 @@ def __init__( nn.Module.__init__(self) self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.quant_config = quant_config + self.needs_quant_draft = ( + get_global_server_args().speculative_draft_model_quantization + ) + quant_config = quant_config if self.needs_quant_draft else None self.model = Glm4MoeModelNextN( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -150,7 +155,19 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch) + # Support unquant speculative draft model + if self.needs_quant_draft: + cxt = contextlib.nullcontext() + else: + unquant_patch = { + "SGLANG_DEEPEP_BF16_DISPATCH": "1", + "DEEP_NORMAL_MODE_USE_INT8_QUANT": "0", + } + cxt = temp_set_env(allow_sglang=True, **unquant_patch) + + with cxt: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( input_ids, hidden_states, self.lm_head, forward_batch )