Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions python/sglang/srt/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ class scalar_types:
logger = logging.getLogger(__name__)


def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
# compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
"is_marlin_format", False
)


class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.

Expand Down Expand Up @@ -262,21 +270,23 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":

@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
is_marlin_format = check_marlin_format(hf_quant_cfg)

can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)

is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
)

if can_convert and is_valid_user_quant:
if not is_marlin_format and can_convert and is_valid_user_quant:
msg = (
"The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name())
)
logger.info(msg)
return cls.get_name()

if can_convert and user_quant == "gptq":
if not is_marlin_format and can_convert and user_quant == "gptq":
logger.info(
"Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
Expand Down Expand Up @@ -401,11 +411,7 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":

@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
is_marlin_format = check_marlin_format(hf_quant_cfg)

is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
Expand Down
Loading