diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 5479c7aca76..3d9dbf64f28 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -330,6 +330,12 @@ def __init__( self.tp_rank = get_tensor_model_parallel_rank() self.num_experts = num_experts self.expert_map = None + + if enable_flashinfer_moe and quant_config is None: + logger.warning("Disable flashinfer MoE when quantization config is None.") + enable_flashinfer_moe = False + enable_ep_moe = False + self.enable_flashinfer_moe = enable_flashinfer_moe if enable_ep_moe: assert ( diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index 7ca945d3a98..270b1143624 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -44,6 +44,12 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + logger.warning( + "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model." + ) + quant_config = None + self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c73200400e8..a7749037ca2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2201,7 +2201,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal q_a_proj_weight = cached_a_proj[q_a_proj_name] kv_a_proj_weight = cached_a_proj[kv_a_proj_name] cat_dim = 0 - if ( + if self.quant_config is not None and ( self.quant_config.get_name() == "awq" or self.quant_config.get_name() == "moe_wna16" ): @@ -2232,6 +2232,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=Fal for scale in ["k_scale", "v_scale"]: if scale in name: name = name.replace(f"{scale[0]}_proj", "attn_mqa") + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader