Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from packaging import version

from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, logging
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, is_torchao_available, is_gptqmodel_available, logging


if is_torch_available():
Expand Down Expand Up @@ -618,6 +618,7 @@ def __init__(
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
backend: Optional[str] = None,
checkpoint_format: Optional[str] = "gptq",
use_cuda_fp16: bool = False,
model_seqlen: Optional[int] = None,
Expand All @@ -641,6 +642,8 @@ def __init__(
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
self.backend = backend
self.checkpoint_format = checkpoint_format
self.use_cuda_fp16 = use_cuda_fp16
self.model_seqlen = model_seqlen
self.block_name_to_quantize = block_name_to_quantize
Expand All @@ -653,7 +656,6 @@ def __init__(
self.disable_exllama = kwargs.pop("disable_exllama", None)
self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.checkpoint_format = checkpoint_format
self.post_init()

def get_loading_attributes(self):
Expand Down Expand Up @@ -690,6 +692,10 @@ def post_init(self):
['wikitext2','c4','c4-new'], but we found {self.dataset}"""
)

if self.backend is not None and not is_gptqmodel_available():
if self.backend == "auto_trainable":
self.use_exllama = False

if self.disable_exllama is None and self.use_exllama is None:
# New default behaviour
self.use_exllama = True
Expand Down Expand Up @@ -736,6 +742,12 @@ def post_init(self):
"You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ."
)

if is_gptqmodel_available() and self.backend is None:
if self.exllama_config["version"] == ExllamaVersion.ONE and not self.use_exllama:
self.backend = "auto_trainable"
else:
self.backend = "auto"

def to_dict(self):
config_dict = super().to_dict()
config_dict.pop("disable_exllama", None)
Expand Down