diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 24882b8418..70c2562bb3 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -186,9 +186,14 @@ def test_8da4w_quantizer(self): assert isinstance(m.linear2, Int8DynActInt4WeightLinear) m(*example_inputs) + # TODO: save model weights as artifacts and re-enable in CI + # For now, to run this test, you will need to download the weights from HF + # and run this script to convert them: + # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer, InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer + from torchao._eval import InputRecorder, TransformerEvalWrapper # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" @@ -250,7 +255,7 @@ def test_8da4w_gptq_quantizer(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") def test_8da4w_quantizer_eval(self): from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer - from torchao.quantization.GPTQ import TransformerEvalWrapper + from torchao._eval import TransformerEvalWrapper precision = torch.bfloat16 device = "cpu" @@ -284,7 +289,8 @@ def test_8da4w_quantizer_eval(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4wo(self): - from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer + from torchao._eval import InputRecorder, TransformerEvalWrapper precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -343,7 +349,8 @@ def test_gptq_quantizer_int4wo(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4wo(self): - from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + from torchao._eval import TransformerEvalWrapper precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -378,7 +385,7 @@ def test_quantizer_int4wo(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): - from torchao.quantization.GPTQ import TransformerEvalWrapper + from torchao._eval import TransformerEvalWrapper precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") diff --git a/torchao/_eval.py b/torchao/_eval.py new file mode 100644 index 0000000000..c7e6ce8381 --- /dev/null +++ b/torchao/_eval.py @@ -0,0 +1,228 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .utils import _lm_eval_available, _MultiInput + +if _lm_eval_available: + try: # lm_eval version 0.4 + from lm_eval.evaluator import evaluate # pyre-ignore[21] + from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21] + from lm_eval.tasks import get_task_dict # pyre-ignore[21] + except: # lm_eval version 0.3 + from lm_eval import base, evaluator, tasks + + eval_wrapper = base.BaseLM + get_task_dict = tasks.get_task_dict + evaluate = evaluator.evaluate + + class InputRecorder(eval_wrapper): + """ + This is a fake evaluation wrapper from the lm_eval library that just records the inputs + so that they can be used in calibration. + + If pad_calibration_inputs is enabled, the input recorder will take + each input and pad/truncate it down to the calibration_seq_length. + (if using padding you should set the embeddings for the pad_token to 0 + in the model) + + Note: after padding/truncation, input_prep_function is called to bring + it to the proper form to be inserted into a given model. + + If not, it will only truncate inputs to the desired length. + """ + + def __init__( + self, + tokenizer, + calibration_seq_length, + input_prep_func=None, + pad_calibration_inputs=False, + vocab_size=32000, + pad_token=0, + device="cpu", + ): + super().__init__() + self._tokenizer = tokenizer + self._device = torch.device(device) + self.vocab_size = vocab_size + self._max_seq_length = calibration_seq_length + self.calibration_seq_length = calibration_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + self.pad_calibration_inputs = pad_calibration_inputs + self.pad_token = pad_token + + self.inputs = None + + @property + def eot_token_id(self): + try: + return self._tokenizer.eos_id() + except: + return self._tokenizer.eos_id + + @property + def max_length(self): + return self._max_seq_length + + @property + def max_gen_toks(self): + return 50 + + @property + def batch_size(self): + return 1 + + @property + def device(self): + return self._device + + def tok_encode(self, string: str, **kwargs): + # TODO: verify this for multi-batch as well + tokens = self._tokenizer.encode(string) + if hasattr(self._tokenizer, "bos_id"): + try: + tokens = [self._tokenizer.bos_id()] + tokens + except: + tokens = [self._tokenizer.bos_id] + tokens + return tokens + + def tok_decode(self, tokens): + decoded = self._tokenizer.decode(tokens) + return decoded + + def add_input(self, args): + if self.inputs is None: + self.inputs = [_MultiInput([arg]) for arg in args] + else: + self.inputs = [ + multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) + ] + + def record_inputs( + self, + calibration_tasks, + calibration_limit, + ): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(calibration_tasks) + print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) + + evaluate( + self, + task_dict, + limit=calibration_limit, + ) + return self + + def get_inputs(self): + return self.inputs + + def _model_call(self, inps): + inps = inps.squeeze(0) + T = len(inps) + if ( + # can't use inputs that are too short when padding disabled + (T < self.calibration_seq_length and not self.pad_calibration_inputs) + or + # can't use inputs that actually use token we use for padding + (self.pad_calibration_inputs and self.pad_token in inps) + ): + # give random output + return torch.randn( + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device + ) + + # pad or truncate to the right size + if T >= self.calibration_seq_length: + inps = inps[: self.calibration_seq_length] + else: + inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) + + inps = inps.unsqueeze(0) + model_in = self.input_prep_func(inps) + + self.add_input(model_in) + + # output `something` with correct shape to keep eval going + return torch.randn( + (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device + ) + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception("unimplemented") + + class TransformerEvalWrapper(InputRecorder): + """ + A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. + """ + def __init__( + self, + model, + tokenizer, + max_seq_length, + input_prep_func=None, + device="cuda" + ): + super().__init__(None, None) + self._model = model + self._tokenizer = tokenizer + self._device = torch.device(device) + self._max_seq_length = max_seq_length + + # need to take inps and convert to corrent input + # for model + self.input_prep_func = ( + input_prep_func if input_prep_func is not None + else lambda x: (x,) + ) + + def _model_call(self, inps): + # TODO: make batches work + input = self.input_prep_func(inps) + + max_seq_length = min(inps.size(1), self.max_length) + with torch.device(self._device): + self._model.setup_caches(self.batch_size, max_seq_length) + logits = self._model(*input) + return logits + + def _model_generate(self, context, max_length, eos_token_id): + raise Exception('unimplemented') + + def run_eval(self, tasks, limit): + try: + lm_eval.tasks.initialize_tasks() + except: + pass + + task_dict = get_task_dict(tasks) + print("Evaluating Model On: ", task_dict) + with torch.no_grad(): + result = evaluate( + self, + task_dict, + limit=limit, + ) + for task, res in result["results"].items(): + print(f"{task}: {res}") + return result diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index e7176b4fd2..f0c16f86d9 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -19,7 +19,12 @@ from torch.utils._pytree import tree_flatten, tree_unflatten -from .utils import TORCH_VERSION_AFTER_2_3, find_multiple +from .utils import ( + _lm_eval_available, + _MultiInput, + TORCH_VERSION_AFTER_2_3, + find_multiple, +) from typing import Any, Dict, Optional from .unified import Quantizer @@ -32,266 +37,20 @@ ) aten = torch.ops.aten -## eval.py ## - -try: - import lm_eval # pyre-ignore[21] # noqa: F401 - - lm_eval_available = True -except: - lm_eval_available = False - -if lm_eval_available: - try: # lm_eval version 0.4 - from lm_eval.evaluator import evaluate # pyre-ignore[21] - from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21] - from lm_eval.tasks import get_task_dict # pyre-ignore[21] - except: # lm_eval version 0.3 - from lm_eval import base, evaluator, tasks - - eval_wrapper = base.BaseLM - get_task_dict = tasks.get_task_dict - evaluate = evaluator.evaluate -else: +if not _lm_eval_available: logging.info("lm_eval is not installed, GPTQ may not be usable") add_ons = [] -if lm_eval_available: - add_ons += ["InputRecorder", "TransformerEvalWrapper"] - if TORCH_VERSION_AFTER_2_3: add_ons += ["Int8DynActInt4WeightQuantizer", "Int8DynActInt4WeightGPTQQuantizer"] __all__ = [ - "MultiInput", "Int4WeightOnlyGPTQQuantizer", "Int4WeightOnlyQuantizer", ] + add_ons -if lm_eval_available: - class InputRecorder(eval_wrapper): - """ - This is a fake evaluation wrapper from the lm_eval library that just records the inputs - so that they can be used in calibration. - - If pad_calibration_inputs is enabled, the input recorder will take - each input and pad/truncate it down to the calibration_seq_length. - (if using padding you should set the embeddings for the pad_token to 0 - in the model) - - Note: after padding/truncation, input_prep_function is called to bring - it to the proper form to be inserted into a given model. - - If not, it will only truncate inputs to the desired length. - """ - - def __init__( - self, - tokenizer, - calibration_seq_length, - input_prep_func=None, - pad_calibration_inputs=False, - vocab_size=32000, - pad_token=0, - device="cpu", - ): - super().__init__() - self._tokenizer = tokenizer - self._device = torch.device(device) - self.vocab_size = vocab_size - self._max_seq_length = calibration_seq_length - self.calibration_seq_length = calibration_seq_length - - # need to take inps and convert to corrent input - # for model - self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) - ) - - self.pad_calibration_inputs = pad_calibration_inputs - self.pad_token = pad_token - - self.inputs = None - - @property - def eot_token_id(self): - try: - return self._tokenizer.eos_id() - except: - return self._tokenizer.eos_id - - @property - def max_length(self): - return self._max_seq_length - - @property - def max_gen_toks(self): - return 50 - - @property - def batch_size(self): - return 1 - - @property - def device(self): - return self._device - - def tok_encode(self, string: str, **kwargs): - # TODO: verify this for multi-batch as well - tokens = self._tokenizer.encode(string) - if hasattr(self._tokenizer, "bos_id"): - try: - tokens = [self._tokenizer.bos_id()] + tokens - except: - tokens = [self._tokenizer.bos_id] + tokens - return tokens - - def tok_decode(self, tokens): - decoded = self._tokenizer.decode(tokens) - return decoded - - def add_input(self, args): - if self.inputs is None: - self.inputs = [MultiInput([arg]) for arg in args] - else: - self.inputs = [ - multi.add_input(arg) for (multi, arg) in zip(self.inputs, args) - ] - - def record_inputs( - self, - calibration_tasks, - calibration_limit, - ): - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(calibration_tasks) - print("Obtaining GPTQ calibration inputs on: ", calibration_tasks) - - evaluate( - self, - task_dict, - limit=calibration_limit, - ) - return self - - def get_inputs(self): - return self.inputs - - def _model_call(self, inps): - inps = inps.squeeze(0) - T = len(inps) - if ( - # can't use inputs that are too short when padding disabled - (T < self.calibration_seq_length and not self.pad_calibration_inputs) - or - # can't use inputs that actually use token we use for padding - (self.pad_calibration_inputs and self.pad_token in inps) - ): - # give random output - return torch.randn( - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device - ) - - # pad or truncate to the right size - if T >= self.calibration_seq_length: - inps = inps[: self.calibration_seq_length] - else: - inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T)) - - inps = inps.unsqueeze(0) - model_in = self.input_prep_func(inps) - - self.add_input(model_in) - - # output `something` with correct shape to keep eval going - return torch.randn( - (1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device - ) - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception("unimplemented") - - class TransformerEvalWrapper(InputRecorder): - """ - A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. - """ - def __init__( - self, - model, - tokenizer, - max_seq_length, - input_prep_func=None, - device="cuda" - ): - super().__init__(None, None) - self._model = model - self._tokenizer = tokenizer - self._device = torch.device(device) - self._max_seq_length = max_seq_length - - # need to take inps and convert to corrent input - # for model - self.input_prep_func = ( - input_prep_func if input_prep_func is not None - else lambda x: (x,) - ) - - def _model_call(self, inps): - # TODO: make batches work - input = self.input_prep_func(inps) - - max_seq_length = min(inps.size(1), self.max_length) - with torch.device(self._device): - self._model.setup_caches(self.batch_size, max_seq_length) - logits = self._model(*input) - return logits - - def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') - - def run_eval(self, tasks, limit): - try: - lm_eval.tasks.initialize_tasks() - except: - pass - - task_dict = get_task_dict(tasks) - print("Evaluating Model On: ", task_dict) - with torch.no_grad(): - result = evaluate( - self, - task_dict, - limit=limit, - ) - for task, res in result["results"].items(): - print(f"{task}: {res}") - return result - -class MultiInput: - - def __init__(self, inputs): - - self.values = list(inputs) - - def add_input(self, input): - self.values.append(input) - return self - - def __getitem__(self, slice): - return MultiInput(self.values[slice]) - - def cuda(self): - self.values = [ - val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values - ] - class GenericGPTQRunner(fx.Interpreter): """ @@ -308,7 +67,7 @@ class GenericGPTQRunner(fx.Interpreter): def __init__( self, model, - inputs: MultiInput, + inputs: _MultiInput, blocksize=128, percdamp=0.01, groupsize=128, @@ -407,22 +166,22 @@ def tensors_to_cuda(args): # flatten args and kwargs together flat_args, spec = tree_flatten((args, kwargs)) - # move all single tensors to cuda, will move MultiInputs to cuda one at a time + # move all single tensors to cuda, will move _MultiInputs to cuda one at a time flat_args = tensors_to_cuda(flat_args) - has_multi_input = MultiInput in [type(x) for x in flat_args] + has_multi_input = _MultiInput in [type(x) for x in flat_args] if has_multi_input: # Just some trickery to convert - # [MultiInput[a, a, a], MultiInput(b, b, b)] => [a, b], [a, b], [a, b] + # [_MultiInput[a, a, a], _MultiInput(b, b, b)] => [a, b], [a, b], [a, b] multi_input_count = max( - [len(x.values) if isinstance(x, MultiInput) else 1 for x in flat_args] + [len(x.values) if isinstance(x, _MultiInput) else 1 for x in flat_args] ) transposed_args = list( zip( *[ ( x.values - if isinstance(x, MultiInput) + if isinstance(x, _MultiInput) else [x] * multi_input_count ) for x in flat_args @@ -551,7 +310,7 @@ def SQNR(x, y): ) return new_out - return MultiInput(outputs) if has_multi_input else outputs[0] + return _MultiInput(outputs) if has_multi_input else outputs[0] def faster_quant(self, H, W): percdamp = self.percdamp @@ -751,7 +510,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": raise NotImplementedError("_convert_for_runtime not implemented") @torch.no_grad() - def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + def quantize(self, model: torch.nn.Module, inputs: List[_MultiInput], **kwargs: Any) -> torch.nn.Module: pass def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): @@ -981,7 +740,7 @@ def _convert_for_runtime(self, model): ) return model - def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + def quantize(self, model: torch.nn.Module, inputs: List[_MultiInput], **kwargs: Any) -> torch.nn.Module: state_dict = self._create_quantized_state_dict( model, inputs, @@ -1327,7 +1086,7 @@ def _convert_for_runtime(self, model): ) return model - def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: + def quantize(self, model: torch.nn.Module, inputs: List[_MultiInput], **kwargs: Any) -> torch.nn.Module: state_dict = self._create_quantized_state_dict( model, inputs, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74cb7deb20..78a76863f3 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -22,6 +22,13 @@ "TORCH_VERSION_AFTER_2_3", ] +try: + import lm_eval # pyre-ignore[21] # noqa: F401 + + _lm_eval_available = True +except: + _lm_eval_available = False + def find_multiple(n: int, *args: Tuple[int]) -> int: k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9] @@ -146,6 +153,26 @@ def get_model_size_in_bytes(model): s += b.nelement() * b.element_size() return s + +class _MultiInput: + + def __init__(self, inputs): + + self.values = list(inputs) + + def add_input(self, input): + self.values.append(input) + return self + + def __getitem__(self, slice): + return _MultiInput(self.values[slice]) + + def cuda(self): + self.values = [ + val.cuda() if isinstance(val, torch.Tensor) else val for val in self.values + ] + + # TODO: quantization namespace is not the right place ot have this if version.parse(torch.__version__) >= version.parse("2.4.0.dev"): TORCH_VERSION_AFTER_2_4 = True