From 5ef4e8babaa43c38738971f0ea2ac814aade4e38 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Feb 2026 12:00:58 -0800 Subject: [PATCH 1/7] Initial attempt --- python/sglang/srt/layers/quantization/fp8.py | 25 +++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6f6ec68d8eb3..fa7c8b854a6d 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -60,7 +60,10 @@ apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin, ) -from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.quantization.unquant import ( + UnquantizedFusedMoEMethod, + UnquantizedLinearMethod, +) from sglang.srt.layers.quantization.utils import ( all_close_1d, convert_to_channelwise, @@ -119,6 +122,7 @@ def __init__( weight_block_size: List[int] = None, use_mxfp8: bool = False, ) -> None: + super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: log_info_on_rank0(logger, "Detected fp8 checkpoint.") @@ -167,6 +171,9 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: use_mxfp8 = "mxfp8" in quant_method is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + packed_modules_mapping = cls.get_from_keys_or( + config, ["packed_modules_mapping"], {} + ) ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None ) @@ -182,13 +189,15 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." ) weight_block_size = [1, 32] - return cls( + quant_config = cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, use_mxfp8=use_mxfp8, ) + quant_config.packed_modules_mapping = packed_modules_mapping or {} + return quant_config def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -198,10 +207,20 @@ def get_quant_method( from sglang.srt.layers.radix_attention import RadixAttention if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignored_layers): + if is_layer_skipped( + prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping + ): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): + # FusedMoE prefixes include ".experts", where layer-prefix rules + # (e.g., "model.layers.{i}.") should also force BF16 fallback. + if is_layer_skipped( + prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping + ) or any(ignored in prefix for ignored in self.ignored_layers): + return UnquantizedFusedMoEMethod( + layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe + ) return Fp8MoEMethod(self) elif isinstance(layer, RadixAttention): return Fp8KVCacheMethod(self) From 8184c0f9dd376dfbad4a53390074d901c8fdd556 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Feb 2026 12:53:19 -0800 Subject: [PATCH 2/7] Try fix --- python/sglang/srt/layers/quantization/fp8.py | 25 ++++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index fa7c8b854a6d..65e2a61e1ee0 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -178,11 +178,26 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: config, ["ignored_layers", "modules_to_not_convert"], None ) if ignored_layers: - if "mistral3" in config.get("model_type", ""): - # hack for ministral - ignored_layers = [ - layer.replace("model.", "") for layer in ignored_layers - ] + # Keep both "model." and non-"model." variants for robust prefix matching. + normalized: List[str] = [] + seen: set[str] = set() + for layer in ignored_layers: + candidates = [layer] + if layer.startswith("model."): + candidates.append(layer[len("model.") :]) + else: + candidates.append(f"model.{layer}") + for candidate in candidates: + if candidate not in seen: + normalized.append(candidate) + seen.add(candidate) + if len(normalized) != len(ignored_layers): + log_info_on_rank0( + logger, + "Fp8Config expanded ignored_layers to include both " + "'model.' and non-'model.' variants for matching.", + ) + ignored_layers = normalized weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if use_mxfp8 and weight_block_size is not None: logger.warning( From fcca4a7df61e491e34b8f81a5ab2641997c7789e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Feb 2026 13:37:03 -0800 Subject: [PATCH 3/7] Simplify prefix processing fix --- python/sglang/srt/layers/quantization/fp8.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 65e2a61e1ee0..e53d01af2e48 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -179,24 +179,10 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: ) if ignored_layers: # Keep both "model." and non-"model." variants for robust prefix matching. - normalized: List[str] = [] - seen: set[str] = set() + normalized = [] for layer in ignored_layers: - candidates = [layer] - if layer.startswith("model."): - candidates.append(layer[len("model.") :]) - else: - candidates.append(f"model.{layer}") - for candidate in candidates: - if candidate not in seen: - normalized.append(candidate) - seen.add(candidate) - if len(normalized) != len(ignored_layers): - log_info_on_rank0( - logger, - "Fp8Config expanded ignored_layers to include both " - "'model.' and non-'model.' variants for matching.", - ) + base = layer.removeprefix("model.") + normalized.extend((base, f"model.{base}")) ignored_layers = normalized weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if use_mxfp8 and weight_block_size is not None: From 3ea6e6ad0e43208f150bc83d8fcd2de924fde0eb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 12 Feb 2026 14:27:10 -0800 Subject: [PATCH 4/7] Clean up and refactor --- python/sglang/srt/layers/quantization/fp8.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index e53d01af2e48..cf74416e2e40 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -120,6 +120,7 @@ def __init__( activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, weight_block_size: List[int] = None, + packed_modules_mapping: Optional[Dict[str, List[str]]] = None, use_mxfp8: bool = False, ) -> None: super().__init__() @@ -130,6 +131,7 @@ def __init__( raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + self.packed_modules_mapping = packed_modules_mapping or {} self.use_mxfp8 = use_mxfp8 if weight_block_size is not None: if not is_checkpoint_fp8_serialized: @@ -171,8 +173,8 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: use_mxfp8 = "mxfp8" in quant_method is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8 activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - packed_modules_mapping = cls.get_from_keys_or( - config, ["packed_modules_mapping"], {} + packed_modules_mapping = ( + cls.get_from_keys_or(config, ["packed_modules_mapping"], {}) or {} ) ignored_layers = cls.get_from_keys_or( config, ["ignored_layers", "modules_to_not_convert"], None @@ -190,15 +192,14 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: "MXFP8 ignoring incoming weight_block_size in config.json; it is fixed to [1, 32]." ) weight_block_size = [1, 32] - quant_config = cls( + return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, + packed_modules_mapping=packed_modules_mapping, use_mxfp8=use_mxfp8, ) - quant_config.packed_modules_mapping = packed_modules_mapping or {} - return quant_config def get_quant_method( self, layer: torch.nn.Module, prefix: str @@ -214,11 +215,12 @@ def get_quant_method( return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - # FusedMoE prefixes include ".experts", where layer-prefix rules - # (e.g., "model.layers.{i}.") should also force BF16 fallback. - if is_layer_skipped( + # FusedMoE prefixes include ".experts"; allow coarse layer-prefix + # rules (e.g., "model.layers.{i}.") to force BF16 fallback. + should_skip = is_layer_skipped( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping - ) or any(ignored in prefix for ignored in self.ignored_layers): + ) or any(ignored in prefix for ignored in self.ignored_layers) + if should_skip: return UnquantizedFusedMoEMethod( layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe ) From e1efddf48d6e1637df76b40958b8bd77bdd69e94 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 21:04:22 -0800 Subject: [PATCH 5/7] Preserve coarse prefix matches when is_layer_skipped handles MoE experts --- python/sglang/srt/layers/quantization/fp8.py | 5 ++--- python/sglang/srt/layers/quantization/utils.py | 10 ++++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index cf74416e2e40..6f211d2b23f7 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -217,10 +217,9 @@ def get_quant_method( elif isinstance(layer, FusedMoE): # FusedMoE prefixes include ".experts"; allow coarse layer-prefix # rules (e.g., "model.layers.{i}.") to force BF16 fallback. - should_skip = is_layer_skipped( + if is_layer_skipped( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping - ) or any(ignored in prefix for ignored in self.ignored_layers) - if should_skip: + ): return UnquantizedFusedMoEMethod( layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe ) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index a3346d6ae836..98ecedf41ef9 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -84,12 +84,10 @@ def is_layer_skipped( if prefix_gate in ignored_layers and prefix_up in ignored_layers: is_skipped = True elif "experts" in prefix: - is_skipped = any( - [ - prefix in layer_name - for layer_name in ignored_layers - if "experts" in layer_name - ] + is_skipped = is_skipped or any( + prefix in layer_name + for layer_name in ignored_layers + if "experts" in layer_name ) assert is_skipped is not None From d9f2f47f691a185c85b6b3842ddc46f5aacc7c87 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 21:09:08 -0800 Subject: [PATCH 6/7] Minor refactor --- python/sglang/srt/layers/quantization/fp8.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 6f211d2b23f7..ddbcac53c9d5 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -184,7 +184,8 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config: normalized = [] for layer in ignored_layers: base = layer.removeprefix("model.") - normalized.extend((base, f"model.{base}")) + normalized.append(base) + normalized.append(f"model.{base}") ignored_layers = normalized weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if use_mxfp8 and weight_block_size is not None: From 604e66454a55f569355a4d8eee2a4ceda910db08 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Sat, 21 Feb 2026 21:13:04 -0800 Subject: [PATCH 7/7] Minor clean up --- python/sglang/srt/layers/quantization/fp8.py | 2 -- python/sglang/srt/layers/quantization/utils.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index ddbcac53c9d5..7ebc73377e1a 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -216,8 +216,6 @@ def get_quant_method( return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): - # FusedMoE prefixes include ".experts"; allow coarse layer-prefix - # rules (e.g., "model.layers.{i}.") to force BF16 fallback. if is_layer_skipped( prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping ): diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 98ecedf41ef9..198b201de159 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -84,6 +84,8 @@ def is_layer_skipped( if prefix_gate in ignored_layers and prefix_up in ignored_layers: is_skipped = True elif "experts" in prefix: + # Expert names can include full module paths; keep coarse prefix matches + # (e.g., "model.layers.{i}.") while also checking expert-specific entries. is_skipped = is_skipped or any( prefix in layer_name for layer_name in ignored_layers