diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 9f0d55ff1a24..4a526ee6f717 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -25,7 +25,7 @@ from packaging import version -from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging +from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, is_gptqmodel_available, logging if is_torch_available(): @@ -618,6 +618,7 @@ def __init__( desc_act: bool = False, sym: bool = True, true_sequential: bool = True, + backend: Optional[str] = None, checkpoint_format: Optional[str] = "gptq", use_cuda_fp16: bool = False, model_seqlen: Optional[int] = None, @@ -641,6 +642,8 @@ def __init__( self.desc_act = desc_act self.sym = sym self.true_sequential = true_sequential + self.backend = backend + self.checkpoint_format = checkpoint_format self.use_cuda_fp16 = use_cuda_fp16 self.model_seqlen = model_seqlen self.block_name_to_quantize = block_name_to_quantize @@ -653,7 +656,6 @@ def __init__( self.disable_exllama = kwargs.pop("disable_exllama", None) self.cache_block_outputs = cache_block_outputs self.modules_in_block_to_quantize = modules_in_block_to_quantize - self.checkpoint_format = checkpoint_format self.post_init() def get_loading_attributes(self): @@ -690,6 +692,10 @@ def post_init(self): ['wikitext2','c4','c4-new'], but we found {self.dataset}""" ) + if self.backend is not None and not is_gptqmodel_available(): + if self.backend == "auto_trainable": + self.use_exllama = False + if self.disable_exllama is None and self.use_exllama is None: # New default behaviour self.use_exllama = True @@ -736,6 +742,12 @@ def post_init(self): "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ." ) + if is_gptqmodel_available() and self.backend is None: + if self.exllama_config["version"] == ExllamaVersion.ONE and not self.use_exllama: + self.backend = "auto_trainable" + else: + self.backend = "auto" + def to_dict(self): config_dict = super().to_dict() config_dict.pop("disable_exllama", None)