diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 68cd41895d2c..699e6fbe6ae1 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2546,7 +2546,7 @@ def from_pretrained( logger.warning( "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a " "`quantization_config` attribute and has already quantized weights. However, loading attributes" - " (e.g. disable_exllama, use_cuda_fp16) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." + " (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." ) if ( quantization_method_from_args == QuantizationMethod.GPTQ @@ -2556,7 +2556,11 @@ def from_pretrained( raise RuntimeError("GPU is required to quantize or run quantize model.") elif not (is_optimum_available() and is_auto_gptq_available()): raise ImportError( - "Loading GPTQ quantized model requires optimum library : `pip install optimum` and auto-gptq library 'pip install auto-gptq'" + "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)" + ) + elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"): + raise ImportError( + "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`" ) else: # Need to protect the import diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index e302d621baa1..9b698947653d 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -346,6 +346,9 @@ class GPTQConfig(QuantizationConfigMixin): The pad token id. Needed to prepare the dataset when `batch_size` > 1. disable_exllama (`bool`, *optional*, defaults to `False`): Whether to use exllama backend. Only works with `bits` = 4. + max_input_length (`int`, *optional*) + The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input + length. It is specific to the exllama backend with act-order. """ def __init__( @@ -365,6 +368,7 @@ def __init__( batch_size: int = 1, pad_token_id: Optional[int] = None, disable_exllama: bool = False, + max_input_length: Optional[int] = None, **kwargs, ): self.quant_method = QuantizationMethod.GPTQ @@ -383,11 +387,12 @@ def __init__( self.batch_size = batch_size self.pad_token_id = pad_token_id self.disable_exllama = disable_exllama + self.max_input_length = max_input_length self.post_init() def get_loading_attributes(self): attibutes_dict = copy.deepcopy(self.__dict__) - loading_attibutes = ["disable_exllama", "use_cuda_fp16"] + loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"] loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} return loading_attibutes_dict diff --git a/tests/quantization/gptq/test_gptq.py b/tests/quantization/gptq/test_gptq.py index c7530471fa27..d715bd56c02e 100644 --- a/tests/quantization/gptq/test_gptq.py +++ b/tests/quantization/gptq/test_gptq.py @@ -86,6 +86,8 @@ class GPTQTest(unittest.TestCase): EXPECTED_OUTPUTS = set() EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I") + EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I") + EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of") EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.") EXPECTED_OUTPUTS.add("Hello my name is Alyson, I am a student in the") EXPECTED_OUTPUTS.add("Hello my name is Alyson and I am a very sweet,") @@ -236,6 +238,82 @@ class GPTQTestDeviceMapExllama(GPTQTest): disable_exllama = False +@slow +@require_optimum +@require_auto_gptq +@require_torch_gpu +@require_accelerate +class GPTQTestActOrderExllama(unittest.TestCase): + """ + Test GPTQ model with exllama kernel and desc_act=True (also known as act-order). + More information on those arguments here: + https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig + """ + + EXPECTED_OUTPUTS = set() + EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year") + model_name = "hf-internal-testing/Llama-2-7B-GPTQ" + revision = "gptq-4bit-128g-actorder_True" + input_text = "Hello my name is" + + @classmethod + def setUpClass(cls): + """ + Setup quantized model + """ + + cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028) + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + revision=cls.revision, + torch_dtype=torch.float16, + device_map={"": 0}, + quantization_config=cls.quantization_config, + ) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True) + + def check_inference_correctness(self, model): + """ + Test the generation quality of the quantized model and see that we are matching the expected output. + Given that we are operating on small numbers + the testing model is relatively small, we might not get + the same output across GPUs. So we'll generate few tokens (5-10) and check their output. + """ + + # Check that inference pass works on the model + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + + # Check the exactness of the results + output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + # Get the generation + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + + def test_generate_quality(self): + """ + Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens + """ + self.check_inference_correctness(self.quantized_model) + + # this test will fail until the next release of optimum + @pytest.mark.skip + def test_max_input_length(self): + """ + Test if the max_input_length works. It modifies the maximum input length that of the model that runs with exllama backend. + """ + + prompt = "I am in Paris and" * 1000 + inp = self.tokenizer(prompt, return_tensors="pt").to(0) + self.assertTrue(inp["input_ids"].shape[1] > 4028) + with self.assertRaises(RuntimeError) as cm: + self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3) + self.assertTrue("temp_state buffer is too small" in str(cm.exception)) + + prompt = "I am in Paris and" * 500 + inp = self.tokenizer(prompt, return_tensors="pt").to(0) + self.assertTrue(inp["input_ids"].shape[1] < 4028) + self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3) + + # fail when run all together @pytest.mark.skip @require_accelerate