From dba8e1028e821abc9cabfc3f63207d9c719b37cb Mon Sep 17 00:00:00 2001 From: cpuhrsch Date: Wed, 15 May 2024 12:58:55 -0700 Subject: [PATCH] =?UTF-8?q?Revert=20"Enable=20dispatch=20to=20tinygemm=20i?= =?UTF-8?q?nt4=20and=20int8=20kernels=20for=20quantized=20tenso=E2=80=A6"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 10da375e52eaea8240b963f90845569c5e735c2b. --- test/quantization/test_quant_api.py | 105 ++------------------- test/quantization/test_quant_primitives.py | 11 ++- torchao/quantization/autoquant.py | 1 - torchao/quantization/quant_primitives.py | 85 ++++------------- torchao/quantization/subclass.py | 84 ++--------------- 5 files changed, 43 insertions(+), 243 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cea659e61d..10d36f0c1b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -9,6 +9,7 @@ import unittest import torch import os +from torch._export import capture_pre_autograd_graph from torch.ao.quantization.quantize_pt2e import ( prepare_pt2e, convert_pt2e, @@ -35,7 +36,7 @@ def dynamic_quant(model, example_inputs): - m = torch.export.export(model, example_inputs).module() + m = capture_pre_autograd_graph(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) m = convert_pt2e(m) @@ -49,14 +50,14 @@ def _apply_dynamic_quant(model): """ _replace_with_custom_fn_if_matches_filter( model, - lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features),)), + lambda linear_mod: dynamic_quant(linear_mod, (torch.randn(1, linear_mod.in_features))), lambda mod, fqn: isinstance(mod, torch.nn.Linear), ) return model def capture_and_prepare(model, example_inputs): - m = torch.export.export(model, example_inputs) + m = capture_pre_autograd_graph(model, example_inputs) quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config(is_dynamic=True)) m = prepare_pt2e(m, quantizer) # TODO: we can run the weight observer in convert_pt2e so that user don't need to run this @@ -87,13 +88,13 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: return model class ToyLinearModel(torch.nn.Module): - def __init__(self, m=64, n=32, k=64): + def __init__(self): super().__init__() - self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float) - self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, self.linear1.in_features).to(torch.float),) + return (torch.randn(1, 64).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -103,9 +104,8 @@ def forward(self, x): class TestQuantFlow(unittest.TestCase): def test_dynamic_quant_gpu_singleline(self): m = ToyLinearModel().eval() - example_inputs = m.example_inputs() m = _apply_dynamic_quant(m) - quantized = m(*example_inputs) + quantized = m(*m.example_inputs()) # AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 # While executing %choose_qparams_tensor_1 : [num_users=2] = call_function[target=torch.ops.quantized_decomposed.choose_qparams.tensor](args = (%arg0_3, -128, 127, 0.000244140625, torch.int8), kwargs = {}) # m = torch.compile(m, mode="max-autotune") @@ -442,94 +442,7 @@ def get_per_token_block_size(x): ref = m_copy(*example_inputs) self.assertTrue(torch.equal(res, ref)) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_tensor_subclass_int4(self): - from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.quant_primitives import MappingType - from torchao.quantization.quant_primitives import ZeroPointDomain - import copy - - # weight settings - groupsize = 32 - mapping_type = MappingType.ASYMMETRIC - block_size = (1, groupsize) - target_dtype = torch.int32 - quant_min = 0 - quant_max = 15 - eps = 1e-6 - preserve_zero = False - zero_point_dtype = torch.bfloat16 - - # weight only quantization - input_quant_func = None - - # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") - m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) - - def to_quantized(weight): - return AffineQuantizedTensor.from_float( - weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, - zero_point_dtype=zero_point_dtype, - preserve_zero=preserve_zero, - zero_point_domain=ZeroPointDomain.FLOAT, - input_quant_func=input_quant_func, - ) - - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors - change_linear_weights_to_int4_woqtensors(m_copy, groupsize=groupsize) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - - self.assertTrue(torch.equal(res, ref)) - - - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantized_tensor_subclass_int8(self): - from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.quant_primitives import MappingType - import copy - - # weight settings - mapping_type = MappingType.SYMMETRIC - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 - - # weight only quantization - input_quant_func = None - - m = ToyLinearModel().eval().to(torch.bfloat16) - m_copy = copy.deepcopy(m) - example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) - - def to_quantized(weight): - block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func) - - m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) - - # reference - from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors - change_linear_weights_to_int8_woqtensors(m_copy) - - res = m(*example_inputs) - ref = m_copy(*example_inputs) - torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) if __name__ == "__main__": diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index a64439a25e..291039e42a 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -327,8 +327,6 @@ def test_not_preserve_zero_not_supported(self): def test_tinygemm_get_groupwise_affine_qparams(self): - from torchao.quantization.quant_primitives import ZeroPointDomain - input = torch.randn(10, 256) n_bit = 4 scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) @@ -353,11 +351,16 @@ def test_tinygemm_get_groupwise_affine_qparams(self): scale_dtype=scale_dtype, zero_point_dtype=zero_point_dtype, preserve_zero=False, - zero_point_domain=ZeroPointDomain.FLOAT, ) + def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point): + return (quant_min - zero_point + mid_point) * scale + + mid_point = 2 ** (n_bit - 1) + zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point) + self.assertTrue(torch.equal(scale, scale_ref)) - self.assertTrue(torch.equal(zero_point, zero_point_ref)) + torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03) if __name__ == "__main__": diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 4c0ae53ce8..4331d9b042 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -9,7 +9,6 @@ quantize_activation_per_token_absmax, safe_int_mm, ) -from .utils import TORCH_VERSION_AFTER_2_4 import torch.nn.functional as F try: from torch._inductor.utils import do_bench diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 4f39a6055d..3975284b61 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -72,14 +72,6 @@ def guard_dtype_size(tensor_arg, arg_name, dtype=None, size=None): torch.uint7: (0, 2**7-1), }) -class MappingType(Enum): - SYMMETRIC = 0 - ASYMMETRIC = 1 - -class ZeroPointDomain(Enum): - INT = 0 - FLOAT = 1 - # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also @@ -149,8 +141,7 @@ def quantize_affine( zero_point: Optional[torch.Tensor], output_dtype: torch.dtype, quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + quant_max: Optional[int] = None ): """ Args: @@ -162,12 +153,6 @@ def quantize_affine( output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT Note: How can block_size represent different granularities? @@ -199,19 +184,9 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: - quant = torch.clamp( - torch.round(input / scale) + zero_point, quant_min, quant_max - ).to(output_dtype) - else: - assert zero_point_domain == ZeroPointDomain.FLOAT - mid_point = (quant_max + quant_min + 1) / 2 - min_val = zero_point - scale * mid_point - quant = ( - torch.clamp( - torch.round((input - min_val) / scale), - quant_min, quant_max) - ).to(output_dtype) + quant = torch.clamp( + torch.round(input / scale) + zero_point, quant_min, quant_max + ).to(output_dtype) quant = quant.view(original_shape) return quant @@ -224,7 +199,6 @@ def dequantize_affine( input_dtype: torch.dtype, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, ): @@ -239,12 +213,6 @@ def dequantize_affine( quant_min (Optional[int]): minimum quantized value for input Tensor quant_max (Optional[int]): maximum quantized value for input Tensor output_dtype (torch.dtype): dtype for output Tensor, default is fp32 - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT Output: dequantized Tensor, with requested dtype or fp32 @@ -265,22 +233,18 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: - dequant = input.to(torch.int32) - if zero_point is not None: - dequant -= zero_point.to(torch.int32) - dequant = dequant.to(output_dtype) - dequant *= scale - else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" - mid_point = (quant_max + quant_min + 1) / 2 - dequant = input - mid_point - dequant = dequant.to(output_dtype) - dequant *= scale - if zero_point is not None: - dequant += zero_point + dequant = input.to(torch.int32) + if zero_point is not None: + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant *= scale + dequant = dequant.view(original_shape) + return dequant.to(output_dtype) - return dequant.view(original_shape).to(output_dtype) + +class MappingType(Enum): + SYMMETRIC = 0 + ASYMMETRIC = 1 def choose_qparams_affine( input: torch.Tensor, @@ -292,8 +256,7 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero: bool = True, - zero_point_domain = ZeroPointDomain.INT, + preserve_zero = True, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -317,13 +280,6 @@ def choose_qparams_affine( If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - Output: Tuple of scales and zero_points Tensor with requested dtype """ @@ -354,18 +310,15 @@ def choose_qparams_affine( scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: - raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") - zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) + zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2)) else: scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" - mid_point = (quant_max + quant_min + 1) / 2 - zero_point = min_val_neg + scale * mid_point + zero_point = quant_min - min_val_neg / scale + if eps is None: eps = torch.finfo(input.dtype).eps diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 607cb77766..6128720d4d 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,13 +14,10 @@ dynamically_quantize_per_channel, groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, - pack_tinygemm_scales_and_zeros, unpack_tinygemm_scales_and_zeros, - groupwise_affine_quantize_tensor_from_qparams, choose_qparams_affine, quantize_affine, dequantize_affine, - ZeroPointDomain, ) from .utils import find_multiple from typing import Tuple, Optional, Callable @@ -622,13 +619,7 @@ class AffineQuantizedTensor(torch.Tensor): shape (torch.Size): the shape for the Tensor quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes input Tensor as input and outputs an AffineQuantizedTensor object dtype: dtype for external representation of the tensor, e.g. torch.float32 """ @@ -642,10 +633,8 @@ def __new__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, input_quant_func: Optional[Callable] = None, dtype=None, - # TODO: remove args and kwargs *args, **kwargs ): @@ -669,7 +658,6 @@ def __init__( shape: torch.Size, quant_min: Optional[int] = None, quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, input_quant_func: Optional[Callable] = None, dtype=None, *args, @@ -681,7 +669,6 @@ def __init__( self.block_size = block_size self.quant_min = quant_min self.quant_max = quant_max - self.zero_point_domain = zero_point_domain self.input_quant_func = input_quant_func def __repr__(self): @@ -690,20 +677,18 @@ def __repr__(self): f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" ) - def dequantize(self, output_dtype=None): - if output_dtype is None: - output_dtype = self.dtype - return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + def dequantize(self, output_dtype=torch.float32): + return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.input_quant_func, self.dtype] + return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.input_quant_func, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - block_size, shape, quant_min, quant_max, zero_point_domain, input_quant_func, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, input_quant_func, dtype = tensor_attributes return cls( int_data, scale, @@ -712,7 +697,6 @@ def __tensor_unflatten__( shape if outer_size is None else outer_size, quant_min, quant_max, - zero_point_domain, input_quant_func=input_quant_func, dtype=dtype, strides=outer_stride, @@ -731,11 +715,9 @@ def from_float( scale_dtype = None, zero_point_dtype = None, input_quant_func = None, - preserve_zero = True, - zero_point_domain = ZeroPointDomain.INT, ): - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max) return cls( int_data, scale, @@ -744,7 +726,6 @@ def from_float( input_float.shape, quant_min, quant_max, - zero_point_domain, input_quant_func=input_quant_func, dtype=input_float.dtype ) @@ -759,54 +740,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if weight_qtensor.input_quant_func is None: - is_cuda = args[0].is_cuda - is_cpu = args[0].device == torch.device("cpu") - # weight only quantization - is_int8 = ( - weight_qtensor.int_data.dtype == torch.int8 and - weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and - weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 - ) - is_uint4 = ( - weight_qtensor.int_data.dtype == torch.int32 and - weight_qtensor.quant_min == 0 and - weight_qtensor.quant_max == 15 - ) - - # TODO: enable cpu and mps path as well - # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor - if ( - is_cuda and - is_uint4 and - weight_qtensor.dtype == torch.bfloat16 and - len(weight_qtensor.shape) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT - ): - # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) - groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) - elif ( - is_cpu and - is_int8 and - len(weight_qtensor.shape) == 2 and - len(weight_qtensor.block_size) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] - ): - # TODO: enable mps path as well - # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - else: - # dynamic quantization + if weight_qtensor.input_quant_func is not None: input_tensor = weight_qtensor.input_quant_func(input_tensor) input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() @@ -843,7 +777,6 @@ def to(self, *args, **kwargs): self.shape, self.quant_min, self.quant_max, - self.zero_point_domain, self.input_quant_func, **kwargs, ) @@ -857,7 +790,6 @@ def _apply_fn_to_data(self, fn): self.shape, self.quant_min, self.quant_max, - self.zero_point_domain, self.input_quant_func, dtype=self.dtype, )