Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions tests/quantization/autoawq/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class AwqTest(unittest.TestCase):
)

EXPECTED_OUTPUT_BF16 = [
"Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
"Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish",
"Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a",
]

EXPECTED_OUTPUT_EXLLAMA = [
Expand All @@ -127,11 +128,7 @@ def setUpClass(cls):
Setup quantized model
"""
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
# Use GEMM so that test_save_pretrained() writes out the quantized weights.
quantization_config = AwqConfig(backend=AwqBackend.GEMM)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, quantization_config=quantization_config
)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(cls.model_name, device_map=cls.device_map)

def tearDown(self):
gc.collect()
Expand All @@ -142,8 +139,7 @@ def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear
from gptqmodel.nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear
from gptqmodel.nn_modules.qlinear import BaseQuantLinear

from transformers.integrations.awq import replace_with_awq_linear

Expand All @@ -159,10 +155,10 @@ def test_quantized_model_conversion(self):
if isinstance(module, torch.nn.Linear):
nb_linears += 1

model, _ = replace_with_awq_linear(model, quantization_config=quantization_config)
model = replace_with_awq_linear(model, quantization_config=quantization_config)
nb_awq_linear = 0
for module in model.modules():
if isinstance(module, (AwqGEMMQuantLinear, AwqGEMVQuantLinear)):
if isinstance(module, BaseQuantLinear):
nb_awq_linear += 1

self.assertEqual(nb_linears, nb_awq_linear)
Expand All @@ -171,12 +167,12 @@ def test_quantized_model_conversion(self):
with torch.device("meta"):
model = OPTForCausalLM(config)

model, _ = replace_with_awq_linear(
model = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
)
nb_awq_linear = 0
for module in model.modules():
if isinstance(module, (AwqGEMMQuantLinear, AwqGEMVQuantLinear)):
if isinstance(module, BaseQuantLinear):
nb_awq_linear += 1

self.assertEqual(nb_linears - 1, nb_awq_linear)
Expand Down Expand Up @@ -238,8 +234,12 @@ def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
# Load a fresh model for saving — the shared self.quantized_model may have
# already been in-place transformed by a prior generate() call, and saving
# those transformed buffers then re-transforming on reload would corrupt data.
fresh_model = AutoModelForCausalLM.from_pretrained(self.model_name)
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
fresh_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)

input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
Expand Down
Loading