Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add int4 gptq and eval #116

Merged
merged 6 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions test/quantization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torchao.quantization.utils import find_multiple

def prepare_inputs_for_model(inps):
def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
if input.dim() != 2:
raise ValueError(f"Expected input to be of dim 2, but got {input.dim()}")
if inps.dim() != 2:
raise ValueError(f"Expected input to be of dim 2, but got {inps.dim()}")

inps = inps.squeeze(0)
# setup inputs in correct format
max_new_tokens = 1
T = inps.size(0)
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=inps.dtype, device=inps.device)
Expand All @@ -27,11 +27,6 @@ def prepare_inputs_for_model(inps):
x = seq.index_select(0, input_pos).view(1, -1)
return (x, input_pos)

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)

@dataclass
class ModelArgs:
block_size: int = 2048
Expand Down
149 changes: 144 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def test_8da4w_quantizer(self):
m(*example_inputs)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder
def test_8da4w_gptq_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cpu"
Expand All @@ -161,6 +161,7 @@ def test_gptq_quantizer(self):
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
Expand Down Expand Up @@ -190,12 +191,60 @@ def test_gptq_quantizer(self):
blocksize,
percdamp,
groupsize,
precision=precision,
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs)
compiled = torch.compile(model, mode="max-autotune")
with torch.no_grad():
compiled(inputs[0].values[0], inputs[1].values[0])
result=TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)

assert result['results']['wikitext']['word_perplexity,none'] < 7.88, (
f"accuracy regressed from 7.87 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_8da4w_quantizer_eval(self):
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import TransformerEvalWrapper

precision = torch.bfloat16
device = "cpu"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)

quantizer = Int8DynActInt4WeightQuantizer(groupsize=128, precision=precision)
q_model = quantizer.quantize(model)
result=TransformerEvalWrapper(
q_model,
tokenizer,
q_model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_gpt_fast(self):
Expand Down Expand Up @@ -248,5 +297,95 @@ def test_gptq_quantizer_gpt_fast(self):
with torch.no_grad():
compiled(inputs[0].values[0], inputs[1].values[0])

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device="cpu")
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
blocksize = 128
percdamp = 0.01
groupsize = 128
calibration_tasks = ["wikitext"]
calibration_limit = 1
calibration_seq_length = 100
input_prep_func = prepare_inputs_for_model
pad_calibration_inputs = False

inputs = InputRecorder(
tokenizer,
calibration_seq_length,
input_prep_func,
pad_calibration_inputs,
model.config.vocab_size,
device="cpu",
).record_inputs(
calibration_tasks,
calibration_limit,
).get_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer(
blocksize,
percdamp,
groupsize,
)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)

model = quantizer.quantize(model, inputs).cuda()
result = TransformerEvalWrapper(
model.cuda(),
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 7.77, (
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
result=TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none']<7.77, (
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

if __name__ == "__main__":
unittest.main()
Loading
Loading