From 07a3b9e966ad3ff363814dd95777c4e8c7158072 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 20 Mar 2024 19:04:44 -0700 Subject: [PATCH] [reland] Add support for Int8DynActInt4WeightQuantizer (#66) Summary: att Test Plan: python test/quantization/test_quant_api.py -k test_8da4w_quantizer Reviewed By: cpuhrsch Differential Revision: D55101038 Pulled By: jerryzh168 ghstack-source-id: 4b83757687d2241dd811cc3127d9e8dd6b02a2dc Pull Request resolved: https://github.com/pytorch-labs/ao/pull/74 --- test/quantization/test_quant_api.py | 27 +++++-- torchao/quantization/GPTQ.py | 3 +- torchao/quantization/quant_api.py | 106 ++++++++++++++++++++++++---- 3 files changed, 115 insertions(+), 21 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index ea3012f475..f02a3154a4 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -24,6 +24,8 @@ Quantizer, TwoStepQuantizer, Int8DynActInt4WeightGPTQQuantizer, + Int8DynActInt4WeightQuantizer, + Int8DynActInt4WeightLinear, ) from pathlib import Path from sentencepiece import SentencePieceProcessor @@ -85,8 +87,11 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear1 = torch.nn.Linear(5, 5).to(torch.float) - self.linear2 = torch.nn.Linear(5, 5).to(torch.float) + self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 64).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -97,8 +102,7 @@ class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = M().eval() m = _apply_dynamic_quant(m) - example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),) - quantized = m(*example_inputs) + quantized = m(*m.example_inputs()) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -110,9 +114,9 @@ def test_dynamic_quant_gpu_singleline(self): def test_dynamic_quant_gpu_unified_api_unified_impl(self): quantizer = XNNPackDynamicQuantizer() m = M().eval() + example_inputs = m.example_inputs() m = quantizer.prepare(m) m = quantizer.convert(m) - example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),) quantized = m(*example_inputs) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) @@ -125,15 +129,24 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): quantizer = TorchCompileDynamicQuantizer() m = M().eval() + example_inputs = m.example_inputs() m = quantizer.quantize(m) - example_inputs = (torch.randn(1, 5).to(dtype=torch.float32),) quantized = m(*example_inputs) m = torch.compile(m, mode="max-autotune") compiled = m(*example_inputs) torch.testing.assert_close(quantized, compiled, atol=0, rtol=0) + def test_8da4w_quantizer(self): + quantizer = Int8DynActInt4WeightQuantizer(group_size=32) + m = M().eval() + example_inputs = m.example_inputs() + m = quantizer.quantize(m) + assert isinstance(m.linear1, Int8DynActInt4WeightLinear) + assert isinstance(m.linear2, Int8DynActInt4WeightLinear) + m(*example_inputs) + @unittest.skip("skipping until we get checkpoints for gpt-fast") - def test_gptq(self): + def test_gptq_quantizer(self): # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cpu" diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 52f36553a7..b642ebda3f 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -19,6 +19,7 @@ # from model import Transformer # pyre-ignore[21] from torch.utils._pytree import tree_flatten, tree_unflatten +import logging # pyre-fixme[5]: Global expression must be annotated. aten = torch.ops.aten @@ -63,7 +64,7 @@ def model_forward(model, x, input_pos): get_task_dict = tasks.get_task_dict evaluate = evaluator.evaluate else: - print("lm_eval is not installed, GPTQ may not be usable") + logging.info("lm_eval is not installed, GPTQ may not be usable") # pyre-fixme[3]: Return type must be annotated. def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2997fef125..16271e68d7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -36,8 +36,10 @@ from .quant_primitives import ( get_group_qparams_symmetric, per_token_dynamic_quant, + group_quantize_tensor_symmetric, ) -from typing import Dict, Tuple +from typing import Dict, Tuple, Any +import logging __all__ = [ "apply_weight_only_int8_quant", @@ -54,21 +56,18 @@ ############################# Unified Quantization APIs ############################## # API 1, single quantize call to create a quantized model with quantized state_dict class Quantizer: - # pyre-fixme[2]: Parameter must be annotated. - def quantize(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`. pass # API 2, flow that needs calibration or training class TwoStepQuantizer: - # pyre-fixme[2]: Parameter must be annotated. - def prepare(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: + def prepare(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`. pass - # pyre-fixme[2]: Parameter must be annotated. - def convert(self, model: torch.nn.Module, *args, **kwargs) -> torch.nn.Module: + def convert(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: # pyre-fixme[7]: Expected `Module` but got implicit return value of `None`. pass @@ -260,7 +259,7 @@ def replace_conv2d_1x1(conv): MultiInput, ) else: - print("lm_eval not available, skip defining GPTQQuantizer") + logging.info("lm_eval not available, skip defining GPTQQuantizer") class GPTQQuantizer(Quantizer): @@ -442,11 +441,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": @torch.no_grad() # pyre-fixme[14]: `quantize` overrides method defined in `Quantizer` inconsistently. - def quantize( - self, - # pyre-fixme[2]: Parameter must be annotated. - model, - ) -> torch.nn.Module: + def quantize(self, model: torch.nn.Module, **kwargs: Any) -> torch.nn.Module: state_dict = self._create_quantized_state_dict( model, # pyre-fixme[16]: `GPTQQuantizer` has no attribute `tokenizer`. @@ -686,6 +681,91 @@ def replace_linear_8da4w( ) +class Int8DynActInt4WeightQuantizer(Quantizer): + def __init__( + self, + group_size: int = 256, + padding_allowed: bool = False, + precision: torch.dtype = torch.float32, + scales_precision: torch.dtype = torch.float32, + ) -> None: + self.group_size: int = group_size + self.padding_allowed: bool = padding_allowed + self.precision: torch.dtype = precision + self.scales_precision: torch.dtype = scales_precision + # assert group_size in [32, 64, 128, 256] + + @torch.no_grad() + def _create_quantized_state_dict(self, model: torch.nn.Module) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.group_size == 0 + ), f"require in_features:{in_features} % self.group_size:{self.group_size} == 0" + + weight = mod.weight.data + """ + if not _check_linear_int4_k( + in_features, self.group_size + ): + if self.padding_allowed: + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) + padded_in_features = _calc_padded_size_linear_int4( + in_features, self.group_size + ) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) + else: + raise RuntimeError( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that group_size" + ) + """ + ( + weight_int8, + scales, + zeros, + ) = group_quantize_tensor_symmetric( + weight.to(self.precision), + 4, # n_bit + self.group_size, + self.scales_precision, + ) + cur_state_dict[f"{fqn}.weight"] = weight_int8.to("cpu") + cur_state_dict[f"{fqn}.scales"] = scales.to("cpu") + cur_state_dict[f"{fqn}.zeros"] = zeros.to("cpu") + # TODO: support bias? + + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + replace_linear_8da4w( + model, + self.group_size, + self.padding_allowed, + self.precision, + self.scales_precision, + ) + return model + + def quantize(self, model: torch.nn.Module, *args: Any, **kwargs: Any) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + + class Int8DynActInt4WeightGPTQQuantizer(GPTQQuantizer): # pyre-fixme[3]: Return type must be annotated. def __init__(