diff --git a/docs/source/en/quantization/gptq.md b/docs/source/en/quantization/gptq.md index dbbc95e7c1c5..350f680456f3 100644 --- a/docs/source/en/quantization/gptq.md +++ b/docs/source/en/quantization/gptq.md @@ -119,7 +119,7 @@ Only 4-bit models are supported, and we recommend deactivating the ExLlama kerne -The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2) or GPTQModel, then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. +The ExLlama kernels are only supported when the entire model is on the GPU. If you're doing inference on a CPU with AutoGPTQ (version > 0.4.2) or GPTQModel (version > 1.4.2), then you'll need to disable the ExLlama kernel. This overwrites the attributes related to the ExLlama kernels in the quantization config of the config.json file. ```py import torch diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py index a6ae314da071..8051461c7372 100644 --- a/src/transformers/quantizers/quantizer_gptq.py +++ b/src/transformers/quantizers/quantizer_gptq.py @@ -70,13 +70,13 @@ def validate_environment(self, *args, **kwargs): "0.4.2" ): raise ImportError( - "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel`. Please notice that auto-gptq will be deprecated in the future." + "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel>=1.4.2`. Please notice that auto-gptq will be deprecated in the future." ) elif is_gptqmodel_available() and ( - version.parse(importlib.metadata.version("gptqmodel")) <= version.parse("1.3.1") + version.parse(importlib.metadata.version("gptqmodel")) < version.parse("1.4.2") or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.99") ): - raise ImportError("The gptqmodel version should be >= 1.3.2, optimum version should >= 1.24.0") + raise ImportError("The gptqmodel version should be >= 1.4.2, optimum version should >= 1.24.0") def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: