diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 3b40a6067d9a..5b67519c4df7 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -257,7 +257,13 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.dtype == loaded_weight.dtype ), "init para dtype and loaded weight dtype should be the same" - assert param.size() == loaded_weight.size() + assert param.size() == loaded_weight.size(), ( + f"ReplicatedLinear weight size mismatch: " + f"param.size()={list(param.size())}, " + f"loaded_weight.size()={list(loaded_weight.size())}, " + f"param.dtype={param.dtype}, " + f"loaded_weight.dtype={loaded_weight.dtype}" + ) param.data.copy_(loaded_weight) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c74415b9dbc0..9dec2fa63b01 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -290,9 +290,13 @@ def _get_quant_method( from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): - if is_layer_skipped( - prefix, self.exclude_modules, self.packed_modules_mapping - ) or self.is_layer_excluded(prefix): + skipped = is_layer_skipped( + prefix, + self.exclude_modules, + self.packed_modules_mapping or {}, + ) + excluded = self.is_layer_excluded(prefix) + if skipped or excluded: return UnquantizedLinearMethod() return Linear(self) elif self.kv_cache_quant_algo and isinstance(layer, RadixAttention): @@ -1024,24 +1028,44 @@ def from_config(cls, config: Dict[str, Any]) -> ModelOptFp4Config: ) def is_layer_excluded(self, prefix: str): - import regex as re + import re fused_patterns = ["q_a_proj", "q_b_proj", "kv_a_proj_with_mqa", "kv_b_proj"] - prefix_split = prefix.split(".") - for pattern in self.exclude_modules: - regex_str = pattern.replace(".", r"\.").replace("*", r".*") - pattern_split = pattern.split(".") - if re.fullmatch(regex_str, prefix): - return True - elif ( - pattern_split[-1] in fused_patterns - and pattern_split[-1] in prefix_split[-1] - ): - # Check if the last part of the excluded pattern is contained in the last part of the prefix - # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa - # e.g., model.layers.{i}.self_attn.{fused_weight_name} - assert len(prefix_split) == 5 and len(pattern_split) == 5 - return True + # Build candidate prefixes to handle naming mismatches between + # SGLang model prefixes and checkpoint ignore patterns. + # E.g., Kimi K2.5 VLM: SGLang prefix is "model.layers.X.self_attn.Y" + # but checkpoint ignore patterns use "language_model.layers.X.self_attn*". + prefixes_to_check = [prefix] + if prefix.startswith("language_model.model."): + # language_model.model.X -> language_model.X (drop inner "model.") + prefixes_to_check.append( + "language_model." + prefix.removeprefix("language_model.model.") + ) + # language_model.model.X -> model.X (drop "language_model.") + prefixes_to_check.append(prefix.removeprefix("language_model.")) + elif prefix.startswith("model."): + # model.X -> language_model.X (replace "model." with "language_model.") + prefixes_to_check.append("language_model." + prefix.removeprefix("model.")) + elif prefix.startswith("language_model."): + prefixes_to_check.append(prefix.removeprefix("language_model.")) + + for check_prefix in prefixes_to_check: + check_prefix_split = check_prefix.split(".") + for pattern in self.exclude_modules: + regex_str = pattern.replace(".", r"\.").replace("*", r".*") + pattern_split = pattern.split(".") + if re.fullmatch(regex_str, check_prefix): + return True + elif ( + pattern_split[-1] in fused_patterns + and pattern_split[-1] in check_prefix_split[-1] + and len(check_prefix_split) == 5 + and len(pattern_split) == 5 + ): + # Check if the last part of the excluded pattern is contained in the last part of the prefix + # This handles fused modules like fused_qkv_a_proj_with_mqa that contain q_a_proj and kv_a_proj_with_mqa + # e.g., model.layers.{i}.self_attn.{fused_weight_name} + return True return False def get_quant_method(self, layer: torch.nn.Module, prefix: str):