diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index b52c0575686..af917a26df3 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -37,6 +37,14 @@ class scalar_types: logger = logging.getLogger(__name__) +def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: + # compat: gptqmodel and autogptq (eol) main use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get( + "is_marlin_format", False + ) + + class GPTQConfig(QuantizationConfig): """Config class for GPTQ. @@ -262,13 +270,15 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: + is_marlin_format = check_marlin_format(hf_quant_cfg) + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" ) - if can_convert and is_valid_user_quant: + if not is_marlin_format and can_convert and is_valid_user_quant: msg = ( "The model is convertible to {} during runtime." " Using {} kernel.".format(cls.get_name(), cls.get_name()) @@ -276,7 +286,7 @@ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str] logger.info(msg) return cls.get_name() - if can_convert and user_quant == "gptq": + if not is_marlin_format and can_convert and user_quant == "gptq": logger.info( "Detected that the model can run with gptq_marlin" ", however you specified quantization=gptq explicitly," @@ -401,11 +411,7 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: - # compat: autogptq >=0.8.0 use checkpoint_format: str - # compat: autogptq <=0.7.1 is_marlin_format: bool - is_marlin_format = hf_quant_cfg.get( - "checkpoint_format" - ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) + is_marlin_format = check_marlin_format(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "gptq" or user_quant == "marlin"