Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions python/sglang/srt/hardware_backend/npu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,67 @@ def get_indexer_weight_stream():
if indexer_weight_stream is None:
indexer_weight_stream = torch.npu.Stream()
return indexer_weight_stream


share_stream = None
routed_stream = None


def get_share_stream():
global share_stream
return share_stream


def set_share_stream(stream):
global share_stream
share_stream = stream
# TODO LKL: set stream limit has impact on precision
# torch.npu.set_stream_limit(share_stream, 8, 16)


def get_routed_stream():
global routed_stream
return routed_stream


def set_routed_stream(stream):
global routed_stream
routed_stream = stream
# TODO LKL: set stream limit has impact on precision
# torch.npu.set_stream_limit(routed_stream, 16, 32)


def wait_share_stream():
stream = get_share_stream()
if stream is not None:
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(stream)


def wait_routed_stream():
stream = get_routed_stream()
if stream is not None:
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(stream)


def process_shared_expert(hidden_states, forward_func):
stream = get_share_stream()
if stream is None:
stream = torch.get_device_module().Stream()
set_share_stream(stream)
stream.wait_stream(torch.get_device_module().current_stream())
with torch.get_device_module().stream(stream):
shared_output = forward_func(hidden_states)
return shared_output


def process_routed_expert(hidden_states, topk_output, forward_func):
stream = get_routed_stream()
if stream is None:
stream = torch.get_device_module().Stream()
set_routed_stream(stream)
stream.wait_stream(torch.get_device_module().current_stream())
with torch.get_device_module().stream(stream):
shared_output = forward_func(hidden_states, topk_output)
return shared_output
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def _rmsnorm_forward_oot(
residual: Optional[torch.Tensor] = None,
post_residual_addition: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The post_residual_addition parameter is removed from the function signature. While this simplifies the signature, ensure that all call sites of _rmsnorm_forward_oot no longer pass this argument or handle its absence appropriately. If this parameter was conditionally used, its removal should be justified by the new NPU kernel's behavior.

from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias

if not x.is_contiguous():
x = x.contiguous()
if residual is not None:
if post_residual_addition is not None:
residual = residual + post_residual_addition
from sgl_kernel_npu.norm.add_rmsnorm_bias import add_rmsnorm_bias

out, residual_out = add_rmsnorm_bias(
x,
residual,
Expand Down
72 changes: 61 additions & 11 deletions python/sglang/srt/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
Expand Down Expand Up @@ -91,6 +92,7 @@
is_cuda,
is_hip,
is_non_idle_and_non_empty,
is_npu,
log_info_on_rank0,
make_layers,
)
Expand All @@ -102,10 +104,19 @@
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_npu = is_npu()
_device_sm = get_device_sm()

logger = logging.getLogger(__name__)

if _is_npu:
from sgl_kernel_npu.norm.split_qkv_rmsnorm_rope import split_qkv_rmsnorm_rope

from sglang.srt.hardware_backend.npu.utils import (
process_shared_expert,
wait_share_stream,
)


class Glm4MoeMLP(nn.Module):
def __init__(
Expand Down Expand Up @@ -278,17 +289,39 @@ def forward_prepare(
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = apply_qk_norm(
q=q,
k=k,
q_norm=self.q_norm,
k_norm=self.k_norm,
head_dim=self.head_dim,
alt_stream=self.alt_stream,

if (
not _is_npu
or forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed()
):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = apply_qk_norm(
q=q,
k=k,
q_norm=self.q_norm,
k_norm=self.k_norm,
head_dim=self.head_dim,
alt_stream=self.alt_stream,
)
q, k = self.rotary_emb(positions, q, k)
else:
if self.attn.layer_id == forward_batch.token_to_kv_pool.start_layer:
self.rotary_emb.get_cos_sin_with_position(positions)
q, k, v = split_qkv_rmsnorm_rope(
qkv,
self.rotary_emb.position_sin,
self.rotary_emb.position_cos,
self.q_size,
self.kv_size,
self.head_dim,
eps=self.q_norm.variance_epsilon,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_bias=getattr(self.q_norm, "bias", None),
k_bias=getattr(self.k_norm, "bias", None),
)
q, k = self.rotary_emb(positions, q, k)

inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state

Expand Down Expand Up @@ -562,10 +595,24 @@ def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
shared_output = None
enable_npu_dual_stream = (
_is_npu
and (
forward_batch.forward_mode.is_extend()
or forward_batch.forward_mode.is_target_verify()
)
and envs.SGLANG_NPU_USE_MULTI_STREAM.get()
)

if hidden_states.shape[0] > 0:
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
shared_output = self._forward_shared_experts(hidden_states)
if enable_npu_dual_stream:
shared_output = process_shared_expert(
hidden_states, self._forward_shared_experts
)
else:
shared_output = self._forward_shared_experts(hidden_states)
topk_output = self.topk(
hidden_states,
router_logits,
Expand All @@ -576,10 +623,13 @@ def forward_deepep(
)
else:
topk_output = self.topk.empty_topk_output(hidden_states.device)

final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_output=topk_output,
)
if enable_npu_dual_stream:
wait_share_stream()

if shared_output is not None:
x = shared_output
Expand Down
21 changes: 19 additions & 2 deletions python/sglang/srt/models/glm4_moe_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""

import contextlib
import logging
from typing import Iterable, Optional, Tuple

Expand All @@ -22,6 +23,7 @@
from transformers import PretrainedConfig

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.environ import temp_set_env
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
from sglang.srt.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -126,7 +128,10 @@ def __init__(
nn.Module.__init__(self)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.needs_quant_draft = (
get_global_server_args().speculative_draft_model_quantization
)
quant_config = quant_config if self.needs_quant_draft else None
self.model = Glm4MoeModelNextN(
config, quant_config, prefix=add_prefix("model", prefix)
)
Expand All @@ -150,7 +155,19 @@ def forward(
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
# Support unquant speculative draft model
if self.needs_quant_draft:
cxt = contextlib.nullcontext()
else:
unquant_patch = {
"SGLANG_DEEPEP_BF16_DISPATCH": "1",
"DEEP_NORMAL_MODE_USE_INT8_QUANT": "0",
}
cxt = temp_set_env(allow_sglang=True, **unquant_patch)

with cxt:
hidden_states = self.model(input_ids, positions, forward_batch)

return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
Expand Down
Loading