diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 8297124cc4c0..db09bdb06d52 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -173,12 +173,10 @@ def _modelopt_override_quantization_method( # Check if this is a ModelOpt config quant_algo = hf_quant_config.get("quant_algo", "").upper() - # If user specified generic "modelopt", auto-detect the specific method - if user_quant == "modelopt": - if "FP8" in quant_algo: - return "modelopt_fp8" - elif "NVFP4" in quant_algo or "FP4" in quant_algo: - return "modelopt_fp4" + if "FP8" in quant_algo: + return "modelopt_fp8" + elif "NVFP4" in quant_algo or "FP4" in quant_algo: + return "modelopt_fp4" # The hf_quant_config may be a parsed quant config, so we need to check the # quant_method. diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index b1092aec65ad..fee6c0656338 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -1513,7 +1513,7 @@ def _slice_scale(w): # For other backends, ensure the per-input block dimension is aligned to 16. assert ( weight_scale.shape[assert_dim] % block_size == 0 - ), f"Expected {name}_weight_scale.dim({assert_dim}) to be divisible by {block_size}" + ), f"Expected {name}_weight_scale.dim({assert_dim}) to be divisible by {block_size}, got {weight_scale.shape[assert_dim]}" assert ( weight_scale.dtype == torch.float8_e4m3fn ), f"{name} Weight Blockscale must be represented as FP8-E4M3" diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index be71d0d28167..0737b13c14cb 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -63,6 +63,7 @@ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import RoutingMethodType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.radix_attention import RadixAttention @@ -375,6 +376,7 @@ def __init__( intermediate_size=config.moe_intermediate_size, quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, + routing_method_type=RoutingMethodType.DeepSeekV3, prefix=add_prefix("experts", prefix), ) diff --git a/python/sglang/srt/models/glm4_moe_nextn.py b/python/sglang/srt/models/glm4_moe_nextn.py index 5e0c3ac5992e..6e80a1c75c21 100644 --- a/python/sglang/srt/models/glm4_moe_nextn.py +++ b/python/sglang/srt/models/glm4_moe_nextn.py @@ -21,7 +21,8 @@ from torch import nn from transformers import PretrainedConfig -from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.environ import envs from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm @@ -34,11 +35,14 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM from sglang.srt.server_args import get_global_server_args -from sglang.srt.utils import add_prefix +from sglang.srt.utils import add_prefix, is_cuda logger = logging.getLogger(__name__) +_is_cuda = is_cuda() + + class Glm4MoeModelNextN(nn.Module): def __init__( self, @@ -67,12 +71,19 @@ def __init__( self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + self.alt_stream = ( + torch.cuda.Stream() + if _is_cuda or envs.SGLANG_NPU_USE_MULTI_STREAM.get() + else None + ) + self.decoder = Glm4MoeDecoderLayer( config, 0, quant_config=quant_config, is_nextn=True, prefix=add_prefix("decoder", prefix), + alt_stream=self.alt_stream, ) self.shared_head = nn.Module() @@ -127,6 +138,9 @@ def __init__( self.config = config self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config + self.pp_group = get_pp_group() + self.num_fused_shared_experts = 0 + self.determine_num_fused_shared_experts() self.model = Glm4MoeModelNextN( config, quant_config, prefix=add_prefix("model", prefix) ) @@ -139,10 +153,6 @@ def __init__( ) self.logits_processor = LogitsProcessor(config) - self.num_fused_shared_experts = ( - 0 if get_global_server_args().disable_shared_experts_fusion else 1 - ) - @torch.no_grad() def forward( self,