diff --git a/tests/quantization/gptq/test_gptq.py b/tests/quantization/gptq/test_gptq.py index 20cfd12d102f..eef0a7d80609 100644 --- a/tests/quantization/gptq/test_gptq.py +++ b/tests/quantization/gptq/test_gptq.py @@ -265,7 +265,16 @@ def test_serialization_big_model_inference(self): """ with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname) - quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) + + quantization_config = GPTQConfig( + bits=self.bits, + group_size=self.group_size, + desc_act=self.desc_act, + sym=self.sym, + use_exllama=self.use_exllama, + ) + quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map, + quantization_config=quantization_config) self.check_inference_correctness(quantized_model_from_saved)