diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index f18f477b3944..56d411fde60c 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -110,7 +110,8 @@ class AwqTest(unittest.TestCase): ) EXPECTED_OUTPUT_BF16 = [ - "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish" + "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish", + "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a", ] EXPECTED_OUTPUT_EXLLAMA = [ @@ -127,11 +128,7 @@ def setUpClass(cls): Setup quantized model """ cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) - # Use GEMM so that test_save_pretrained() writes out the quantized weights. - quantization_config = AwqConfig(backend=AwqBackend.GEMM) - cls.quantized_model = AutoModelForCausalLM.from_pretrained( - cls.model_name, device_map=cls.device_map, quantization_config=quantization_config - ) + cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device_map) def tearDown(self): gc.collect() @@ -142,8 +139,7 @@ def test_quantized_model_conversion(self): """ Simple test that checks if the quantized model has been converted properly """ - from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear - from gptqmodel.nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear + from gptqmodel.nn_modules.qlinear import BaseQuantLinear from transformers.integrations.awq import replace_with_awq_linear @@ -159,10 +155,10 @@ def test_quantized_model_conversion(self): if isinstance(module, torch.nn.Linear): nb_linears += 1 - model, _ = replace_with_awq_linear(model, quantization_config=quantization_config) + model = replace_with_awq_linear(model, quantization_config=quantization_config) nb_awq_linear = 0 for module in model.modules(): - if isinstance(module, (AwqGEMMQuantLinear, AwqGEMVQuantLinear)): + if isinstance(module, BaseQuantLinear): nb_awq_linear += 1 self.assertEqual(nb_linears, nb_awq_linear) @@ -171,12 +167,12 @@ def test_quantized_model_conversion(self): with torch.device("meta"): model = OPTForCausalLM(config) - model, _ = replace_with_awq_linear( + model = replace_with_awq_linear( model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"] ) nb_awq_linear = 0 for module in model.modules(): - if isinstance(module, (AwqGEMMQuantLinear, AwqGEMVQuantLinear)): + if isinstance(module, BaseQuantLinear): nb_awq_linear += 1 self.assertEqual(nb_linears - 1, nb_awq_linear) @@ -238,8 +234,12 @@ def test_save_pretrained(self): """ Simple test that checks if the quantized model is working properly after being saved and loaded """ + # Load a fresh model for saving — the shared self.quantized_model may have + # already been in-place transformed by a prior generate() call, and saving + # those transformed buffers then re-transforming on reload would corrupt data. + fresh_model = AutoModelForCausalLM.from_pretrained(self.model_name) with tempfile.TemporaryDirectory() as tmpdirname: - self.quantized_model.save_pretrained(tmpdirname) + fresh_model.save_pretrained(tmpdirname) model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)