From 38ac7c47a9679a23a44d8ce5d9d377d183c73a87 Mon Sep 17 00:00:00 2001 From: Minglei Zhu Date: Mon, 8 Dec 2025 03:28:10 +0000 Subject: [PATCH] fix load_weights for glm4v_moe with shared_experts fusion --- python/sglang/srt/models/glm4v_moe.py | 56 +++++++++++++++++++++++---- python/sglang/srt/server_args.py | 6 --- 2 files changed, 48 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 4adefc7304bb..324de18b49b7 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -6,10 +6,14 @@ import torch.nn as nn from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, + get_tensor_model_parallel_world_size, +) from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -19,10 +23,11 @@ from sglang.srt.models.glm4_moe import Glm4MoeModel from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix, is_cuda +from sglang.srt.utils import add_prefix, get_device_sm, is_cuda, log_info_on_rank0 from sglang.srt.utils.hf_transformers_utils import get_processor _is_cuda = is_cuda() +_device_sm = get_device_sm() logger = logging.getLogger(__name__) @@ -44,11 +49,8 @@ def __init__( vision_utils.update_vit_attn_dummy_heads_config(self.config) self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config - self.num_fused_shared_experts = ( - 0 - if get_global_server_args().disable_shared_experts_fusion - else config.n_shared_experts - ) + self.num_fused_shared_experts = 0 + self.determine_num_fused_shared_experts() self.model = Glm4MoeModel( config, @@ -84,6 +86,36 @@ def __init__( # For EAGLE3 support self.capture_aux_hidden_states = False + def determine_num_fused_shared_experts(self): + if get_global_server_args().disable_shared_experts_fusion: + return + + disable_reason = None + if not getattr(self.config, "n_shared_experts", None): + disable_reason = "No shared experts are defined in the config." + elif not _is_cuda: + disable_reason = "Shared experts fusion currently requires CUDA devices." + elif _is_cuda and (_device_sm is not None) and (_device_sm < 80): + disable_reason = "Shared experts fusion requires SM80 or newer GPUs." + elif get_moe_expert_parallel_world_size() > 1: + disable_reason = "Shared experts fusion is not supported together with expert parallelism yet." + elif get_moe_a2a_backend().is_deepep(): + disable_reason = "Shared experts fusion is not supported when Deepep MoE backend is enabled." + + if disable_reason is not None: + get_global_server_args().disable_shared_experts_fusion = True + log_info_on_rank0( + logger, + f"{disable_reason} Shared experts fusion optimization is disabled.", + ) + return + + self.num_fused_shared_experts = self.config.n_shared_experts + assert ( + self.num_fused_shared_experts == 1 + ), "Only 1 fused shared expert is supported for Glm4vMoeForConditionalGeneration" + log_info_on_rank0(logger, "Shared experts fusion optimization enabled.") + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): if is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): @@ -111,7 +143,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts, + num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, ) if is_nextn: @@ -128,6 +160,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal for name, loaded_weight in weights: weight_names.append(name) + if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name: + # Shared expert becomes expert ID = n_routed_experts + name = name.replace( + "mlp.shared_experts", + f"mlp.experts.{self.config.n_routed_experts}", + ) + if not is_nextn: if hasattr(self.config, "num_nextn_predict_layers"): num_nextn_layers = self.config.num_nextn_predict_layers @@ -163,6 +202,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal name = name.replace("model.visual.", "visual.") if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 7845381c9227..8e7753dab60a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1083,12 +1083,6 @@ def _handle_model_specific_adjustments(self): "Use flashinfer_trtllm as MoE runner backend on sm100 for DeepseekV3ForCausalLM" ) - elif model_arch in [ - "Glm4vMoeForConditionalGeneration", - "Glm4vForConditionalGeneration", - ]: - # TODO: fixme - It does not work for GLM4V - https://github.com/sgl-project/sglang/issues/14582 - self.disable_shared_experts_fusion = True elif model_arch in ["GptOssForCausalLM"]: if ( self.attention_backend is None