diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6f6ec68d8eb3..7ebc73377e1a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -60,7 +60,10 @@ apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) -from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) from sglang.srt.layers.quantization.utils import ( all_close_1d, convert_to_channelwise, @@ -117,8 +120,10 @@ def __init__( activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, weight_block_size: List[int] = None, + packed_modules_mapping: Optional[Dict[str, List[str]]] = None, use_mxfp8: bool = False, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: log_info_on_rank0(logger, "Detected fp8 checkpoint.") @@ -126,6 +131,7 @@ def __init__( raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + self.packed_modules_mapping = packed_modules_mapping or {} self.use_mxfp8 = use_mxfp8 if weight_block_size is not None: if not is_checkpoint_fp8_serialized: @@ -167,15 +173,20 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: use_mxfp8 = "mxfp8" in quant_method is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + packed_modules_mapping = ( + cls.get_from_keys_or(config, ["packed_modules_mapping"], {}) or {} + ) ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None ) if ignored_layers: - if "mistral3" in config.get("model_type", ""): - # hack for ministral - ignored_layers = [ - layer.replace("model.", "") for layer in ignored_layers - ] + # Keep both "model." and non-"model." variants for robust prefix matching. + normalized = [] + for layer in ignored_layers: + base = layer.removeprefix("model.") + normalized.append(base) + normalized.append(f"model.{base}") + ignored_layers = normalized weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if use_mxfp8 and weight_block_size is not None: logger.warning( @@ -187,6 +198,7 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, + packed_modules_mapping=packed_modules_mapping, use_mxfp8=use_mxfp8, ) @@ -198,10 +210,18 @@ def get_quant_method( from sglang.srt.layers.radix_attention import RadixAttention if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignored_layers): + if is_layer_skipped( + prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping + ): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): + if is_layer_skipped( + prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping + ): + return UnquantizedFusedMoEMethod( + layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe + ) return Fp8MoEMethod(self) elif isinstance(layer, RadixAttention): return Fp8KVCacheMethod(self) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index a3346d6ae836..198b201de159 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -84,12 +84,12 @@ def is_layer_skipped( if prefix_gate in ignored_layers and prefix_up in ignored_layers: is_skipped = True elif "experts" in prefix: - is_skipped = any( - [ - prefix in layer_name - for layer_name in ignored_layers - if "experts" in layer_name - ] + # Expert names can include full module paths; keep coarse prefix matches + # (e.g., "model.layers.{i}.") while also checking expert-specific entries. + is_skipped = is_skipped or any( + prefix in layer_name + for layer_name in ignored_layers + if "experts" in layer_name ) assert is_skipped is not None