Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions python/sglang/srt/models/glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
import torch.nn as nn
from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig

from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from sglang.srt.distributed.parallel_state import get_pp_group
from sglang.srt.layers.attention import vision_utils
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
Expand All @@ -19,10 +23,11 @@
from sglang.srt.models.glm4_moe import Glm4MoeModel
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
from sglang.srt.server_args import get_global_server_args
from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils import add_prefix, get_device_sm, is_cuda, log_info_on_rank0
from sglang.srt.utils.hf_transformers_utils import get_processor

_is_cuda = is_cuda()
_device_sm = get_device_sm()

logger = logging.getLogger(__name__)

Expand All @@ -44,11 +49,8 @@ def __init__(
vision_utils.update_vit_attn_dummy_heads_config(self.config)
self.tp_size = get_tensor_model_parallel_world_size()
self.quant_config = quant_config
self.num_fused_shared_experts = (
0
if get_global_server_args().disable_shared_experts_fusion
else config.n_shared_experts
)
self.num_fused_shared_experts = 0
self.determine_num_fused_shared_experts()

self.model = Glm4MoeModel(
config,
Expand Down Expand Up @@ -84,6 +86,36 @@ def __init__(
# For EAGLE3 support
self.capture_aux_hidden_states = False

def determine_num_fused_shared_experts(self):
if get_global_server_args().disable_shared_experts_fusion:
return

disable_reason = None
if not getattr(self.config, "n_shared_experts", None):
disable_reason = "No shared experts are defined in the config."
elif not _is_cuda:
disable_reason = "Shared experts fusion currently requires CUDA devices."
elif _is_cuda and (_device_sm is not None) and (_device_sm < 80):
disable_reason = "Shared experts fusion requires SM80 or newer GPUs."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Shared experts fusion is not supported together with expert parallelism yet."
elif get_moe_a2a_backend().is_deepep():
disable_reason = "Shared experts fusion is not supported when Deepep MoE backend is enabled."

if disable_reason is not None:
get_global_server_args().disable_shared_experts_fusion = True
log_info_on_rank0(
logger,
f"{disable_reason} Shared experts fusion optimization is disabled.",
)
return

self.num_fused_shared_experts = self.config.n_shared_experts
assert (
self.num_fused_shared_experts == 1
), "Only 1 fused shared expert is supported for Glm4vMoeForConditionalGeneration"
log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
Expand Down Expand Up @@ -111,7 +143,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)

if is_nextn:
Expand All @@ -128,6 +160,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
for name, loaded_weight in weights:
weight_names.append(name)

if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
# Shared expert becomes expert ID = n_routed_experts
name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts}",
)

if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
Expand Down Expand Up @@ -163,6 +202,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal
name = name.replace("model.visual.", "visual.")
if "rotary_emb.inv_freq" in name:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
Expand Down
6 changes: 0 additions & 6 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,6 @@ def _handle_model_specific_adjustments(self):
"Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM"
)

elif model_arch in [
"Glm4vMoeForConditionalGeneration",
"Glm4vForConditionalGeneration",
]:
# TODO: fixme - It does not work for GLM4V - https://github.com/sgl-project/sglang/issues/14582
self.disable_shared_experts_fusion = True
elif model_arch in ["GptOssForCausalLM"]:
if (
self.attention_backend is None
Expand Down
Loading