Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.unquant import (
UnquantizedFusedMoEMethod,
UnquantizedLinearMethod,
)
from sglang.srt.layers.quantization.utils import (
all_close_1d,
convert_to_channelwise,
Expand Down Expand Up @@ -117,15 +120,18 @@ def __init__(
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
weight_block_size: List[int] = None,
packed_modules_mapping: Optional[Dict[str, List[str]]] = None,
use_mxfp8: bool = False,
) -> None:
super().__init__()
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
log_info_on_rank0(logger, "Detected fp8 checkpoint.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []
self.packed_modules_mapping = packed_modules_mapping or {}
self.use_mxfp8 = use_mxfp8
if weight_block_size is not None:
if not is_checkpoint_fp8_serialized:
Expand Down Expand Up @@ -167,15 +173,20 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
use_mxfp8 = "mxfp8" in quant_method
is_checkpoint_fp8_serialized = ("fp8" in quant_method) or use_mxfp8
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
packed_modules_mapping = (
cls.get_from_keys_or(config, ["packed_modules_mapping"], {}) or {}
)
ignored_layers = cls.get_from_keys_or(
config, ["ignored_layers", "modules_to_not_convert"], None
)
if ignored_layers:
if "mistral3" in config.get("model_type", ""):
# hack for ministral
ignored_layers = [
layer.replace("model.", "") for layer in ignored_layers
]
# Keep both "model." and non-"model." variants for robust prefix matching.
normalized = []
for layer in ignored_layers:
base = layer.removeprefix("model.")
normalized.append(base)
normalized.append(f"model.{base}")
ignored_layers = normalized
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if use_mxfp8 and weight_block_size is not None:
logger.warning(
Expand All @@ -187,6 +198,7 @@ def from_config(cls, config: Dict[str, Any]) -> Fp8Config:
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
packed_modules_mapping=packed_modules_mapping,
use_mxfp8=use_mxfp8,
)

Expand All @@ -198,10 +210,18 @@ def get_quant_method(
from sglang.srt.layers.radix_attention import RadixAttention

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
if is_layer_skipped(
prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
return Fp8LinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix, self.ignored_layers, fused_mapping=self.packed_modules_mapping
):
return UnquantizedFusedMoEMethod(
layer.use_triton_kernels, layer.use_flashinfer_trtllm_moe
)
return Fp8MoEMethod(self)
elif isinstance(layer, RadixAttention):
return Fp8KVCacheMethod(self)
Expand Down
12 changes: 6 additions & 6 deletions python/sglang/srt/layers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def is_layer_skipped(
if prefix_gate in ignored_layers and prefix_up in ignored_layers:
is_skipped = True
elif "experts" in prefix:
is_skipped = any(
[
prefix in layer_name
for layer_name in ignored_layers
if "experts" in layer_name
]
# Expert names can include full module paths; keep coarse prefix matches
# (e.g., "model.layers.{i}.") while also checking expert-specific entries.
is_skipped = is_skipped or any(
prefix in layer_name
for layer_name in ignored_layers
if "experts" in layer_name
)

assert is_skipped is not None
Expand Down
Loading