From 05ec481eb5ac9149a40491d1198d1cfe819e3e3c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 24 Apr 2024 16:45:17 -0700 Subject: [PATCH] deduplicate code for `get_group_qparams_symmetric` Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 16 ++- test/quantization/test_quant_primitives.py | 6 +- torchao/quantization/quant_primitives.py | 159 ++++++++------------- torchao/quantization/subclass.py | 6 +- 4 files changed, 75 insertions(+), 112 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2425d341e2..abee0d3979 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -36,6 +36,7 @@ quant_int8_dynamic_per_token_linear, quantize_activation_per_token_absmax, safe_int_mm, + dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -385,11 +386,11 @@ def _test_dynamic_quant_per_tensor_numerics_impl( # to rounding assert torch.max(torch.abs(y_vals - y_ref.int_repr())).item() <= 1 torch.testing.assert_close( - y_scale, torch.tensor([y_ref.q_scale()], device=device, dtype=float_dtype) + y_scale, torch.tensor(y_ref.q_scale(), device=device, dtype=float_dtype) ) if y_zero_point is not None: assert torch.equal( - y_zero_point, torch.tensor([y_ref.q_zero_point()], device=device) + y_zero_point, torch.tensor(y_ref.q_zero_point(), device=device) ) else: self.assertTrue(y_ref.q_zero_point() == 0) @@ -558,8 +559,8 @@ def _test_dynamic_quant_per_channel_numerics_impl( assert torch.max(torch.abs(y_vals - y_ref.int_repr())) <= 1 # dequantize - x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point) - x_ref_dq = y_ref.dequantize() + x_dq = dequantize_per_channel(y_vals, y_scale, y_zero_point, out_dtype=float_dtype) + x_ref_dq = y_ref.dequantize().to(float_dtype) # off-by-one for scale is okay torch.testing.assert_close( x_dq, x_ref_dq, atol=torch.max(y_scale).item() * 1.01, rtol=0.0001 @@ -582,7 +583,8 @@ def test_dynamic_quant_per_channel_numerics_cuda(self): def _test_quantize_per_token_impl(self, device, dtype): x = torch.randn(3, 3, 3, device=device, dtype=dtype) xq, scales = quantize_activation_per_token_absmax(x) - x_dq = dequantize_per_tensor(xq, scales, None).to(x.dtype) + block_size = (1, 1, 3) + x_dq = dequantize_affine(xq, block_size, scales, None, torch.int8, output_dtype=x.dtype) sqnr = compute_error(x, x_dq) self.assertTrue(sqnr >= 45.0) @@ -1173,7 +1175,7 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") ref_q = model_qc(x).detach() - assert SQNR(ref_f, ref_q) > min_sqnr + assert SQNR(ref_f, ref_q) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" # load model structure with torch.device('meta'): @@ -1190,7 +1192,7 @@ def forward(self, x): model_qc = torch.compile(model, mode="max-autotune") test = model_qc(x).detach() - assert SQNR(ref_f, test) > min_sqnr + assert SQNR(ref_f, test) > min_sqnr, f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}" self.assertTrue(torch.equal(ref_q, test)) @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 2830e1acfa..6186714e3b 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self): mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + eps = torch.finfo(torch.float32).eps + precision = torch.float32 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2) + scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index febe65e124..bd4bcce1aa 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -40,6 +40,9 @@ "groupwise_affine_dequantize_tensor_from_qparams", "groupwise_affine_quantize_tensor", "groupwise_affine_dequantize_tensor", + "choose_qparams_affine", + "quantize_affine", + "dequantize_affine", # TODO: need to clean up above functions ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) @@ -219,10 +222,8 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - dequant = input.to(torch.float32) - scale = scale.to(torch.float32) + dequant = input.to(output_dtype) if zero_point is not None: - zero_point = zero_point.to(torch.float32) dequant -= zero_point dequant *= scale dequant = dequant.view(original_shape) @@ -260,9 +261,9 @@ def choose_qparams_affine( """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) if scale_dtype is None: - scale_dtype = torch.float32 + scale_dtype = input.dtype if zero_point_dtype is None: - zero_point_dtype = torch.float32 + zero_point_dtype = input.dtype assert len(block_size) == input.dim() shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) @@ -301,47 +302,18 @@ def dynamically_quantize_per_tensor( target_dtype, qscheme=torch.per_tensor_affine, # for now, reuse existing qscheme enum ): - # assumes affine quantization - - # default setup for affine quantization of activations eps = torch.finfo(torch.float32).eps - - if qscheme == torch.per_tensor_affine: - # get min and max - # TODO(future): make torch.aminmax work on cpu-half - # min_val, max_val = torch.aminmax(x) - min_val = torch.min(x) - max_val = torch.max(x) - - # calculate scale and zero point based on min and max - # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - - scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) - # TODO(future): make torch.clamp with scalar work on cpu-half - scale = torch.clamp(scale, min=eps).reshape(1) - zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int) - zero_point = torch.clamp(zero_point, quant_min, quant_max) - - # quantize based on qmin/qmax/scale/zp - # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 - quant = torch.clamp( - torch.round(x / scale) + zero_point, quant_min, quant_max - ).to(target_dtype) - - else: - assert qscheme == torch.per_tensor_symmetric, f"unsupported qscheme {qscheme}" - # assert quant_min == -1 * quant_max, "unsupported quant_min/quant_max" - amax = torch.max(torch.abs(x)) - scale = amax / (float(quant_max - quant_min) / 2) - scale = torch.clamp(scale, min=eps).reshape(1) - quant = torch.clamp(torch.round(x / scale), quant_min, quant_max).to( - target_dtype - ) - # do not create a tensor for zero_point as this is expensive - zero_point = None - + block_size = x.shape + zero_point_dtype = torch.int32 + + qscheme_to_mapping_type = { + torch.per_tensor_affine: MappingType.ASYMMETRIC, + torch.per_tensor_symmetric: MappingType.SYMMETRIC, + } + assert qscheme in qscheme_to_mapping_type, f"unsupported qscheme {qscheme}" + mapping_type = qscheme_to_mapping_type[qscheme] + scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) + quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) return quant, scale, zero_point @@ -374,33 +346,15 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # assumes dense memory format # TODO(future): relax ^ as needed - # default setup for affine quantization of activations - eps = torch.finfo(torch.float32).eps + assert x.dim() == 2, "only support 2d Tensors" - # get min and max - min_val, max_val = torch.aminmax(x, dim=1) - - # calculate scale and zero point based on min and max - # reference: https://fburl.com/code/srbiybme - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - device = min_val_neg.device - - # reference: https://fburl.com/code/4wll53rk - max_val_pos = torch.max(-min_val_neg, max_val_pos) - scale = max_val_pos / (float(quant_max - quant_min) / 2) - # ensure scale is the same dtype as the original tensor - scale = torch.clamp(scale, min=eps).to(x.dtype) - zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) - - # quantize based on qmin/qmax/scale/zp - # reference: torch/ao/quantization/fx/_decomposed.py?lines=63 - x_div = x.transpose(0, 1) / scale - x_round = torch.round(x_div) - x_zp = x_round + zero_point - x_zp = x_zp.transpose(0, 1) - quant = torch.clamp(x_zp, quant_min, quant_max).to(target_dtype) + eps = torch.finfo(torch.float32).eps + block_size = (1, x.shape[1]) + zero_point_dtype = torch.int64 + mapping_type = MappingType.SYMMETRIC + scale, zero_point = choose_qparams_affine(x, mapping_type, block_size, target_dtype=target_dtype, quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype) + quant = quantize_affine(x, block_size, scale, zero_point, target_dtype, quant_min, quant_max) return quant, scale, zero_point @@ -408,23 +362,29 @@ def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): def dequantize_per_tensor(int_repr, scale, zero_point, out_dtype=torch.float32): - y = int_repr.to(out_dtype) - if zero_point is not None: - y -= zero_point - return y * scale + block_size = int_repr.shape + input_dtype = int_repr.dtype + assert scale.numel() == 1, f"scale size: {scale.numel()}" + dequantized = dequantize_affine(int_repr, block_size, scale, zero_point, input_dtype, output_dtype=out_dtype) + return dequantized # reference: https://fburl.com/code/org0fmi3 def dequantize_per_channel(int_repr, scales, zero_points, out_dtype=torch.float32): - # assumes axis is 0 - y = int_repr.transpose(0, 1) - y = y.to(out_dtype) - y = y - zero_points - y = y * scales - y = y.transpose(0, 1) - return y + assert int_repr.dim() == 2, "only support 2d Tensors" + # channel axis == 0 + # block_size before transpose should be (1, int_repr.shape[1]) for axis == 0 per channel quant + + # TODO: transpose is for perf reasons for torch.compile, we should separate this to lowering step + int_repr = int_repr.t() + # transpose for block_size as well + block_size = (int_repr.shape[0], 1) + input_dtype = int_repr.dtype + dequantized = dequantize_affine(int_repr, block_size, scales, zero_points, input_dtype, output_dtype=out_dtype) + dequantized = dequantized.t() + return dequantized def quant_int8_dynamic_linear( @@ -595,7 +555,7 @@ def quant_int8_per_token_matmul( def get_groupwise_affine_qparams(w, n_bit=4, groupsize=128): - """ """ + """This is tinygemm specific, we'll keep this for now""" if groupsize > w.shape[-1]: groupsize = w.shape[-1] assert groupsize > 1 @@ -644,6 +604,7 @@ def groupwise_affine_quantize_tensor_from_qparams( n_bit=4, groupsize=128, ): + """This is tinygemm specific, we'll keep this for now""" assert groupsize > 1 # needed for GPTQ single column quantize if groupsize > w.shape[-1] and scales.shape[-1] == 1: @@ -679,6 +640,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( n_bit=4, groupsize=128, ): + """This is tinygemm specific, we'll keep this for now""" assert groupsize > 1 # needed for GPTQ single column dequantize if groupsize > w_int4x8.shape[-1] and scales.shape[-1] == 1: @@ -728,26 +690,19 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float assert groupsize > 1 assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 + assert n_bit <= 8, f"unsupported n_bit: {n_bit}" - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - - max_val_abs = torch.max(-min_val_neg, max_val_pos) - max_int = 2 ** (n_bit - 1) - 1 - min_int = -(2 ** (n_bit - 1)) - - scales = max_val_abs / (float(max_int - min_int) / 2) - scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps)) - # TODO: make sure abs(scales) is not too small? - zeros = torch.full_like(scales, 0) - return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape( - w.shape[0], -1 - ) + mapping_type = MappingType.SYMMETRIC + block_size = (1, groupsize) + eps = torch.finfo(torch.float32).eps + ranges = {} + ranges[1] = (-1, 0) + # generating ranges for bit 2 to 8 + for i in range(2, 9): + ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1) + quant_min, quant_max = ranges[n_bit] + scale, zero_point = choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision) + return scale.reshape(w.shape[0], -1), zero_point.reshape(w.shape[0], -1) if TORCH_VERSION_AFTER_2_3: @@ -796,7 +751,7 @@ def pack_int4_from_int8(int8_data: torch.Tensor) -> torch.Tensor: @impl(quantized_decomposed_lib, "unpack_int4_to_int8", "CompositeExplicitAutograd") def unpack_int4_to_int8(int8_data: torch.Tensor) -> torch.Tensor: - """Get the original weight from the normalized float weight format""" + """ Get the original weight from the normalized float weight format""" # since we are using int8 we will decode 2 entries per byte # Shift elements down 4 and select out the bottom 4 bits shape = int8_data.shape diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 64689b8d95..e9b532d6d6 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -218,8 +218,11 @@ def dequantize(self, dtype=None): """ Obtain the dequantized version of the quantized tensor subclass """ + zero_points = torch.zeros(self.q_scales.shape, device=self.q_scales.device, dtype=self.q_scales.dtype) + # zero_points = 0 + # TODO: fix dtype here? `to(self.dtype)` is not overwritten by `dtype` arg? dq_t = dequantize_per_channel( - self.int_data.t(), self.q_scales, 0, self.dtype if dtype is None else dtype + self.int_data.t(), self.q_scales, zero_points, self.dtype if dtype is None else dtype ).to(self.dtype) # data was transposed to dequantize so make sure shape is correct return dq_t if not self.transposed else dq_t.t() @@ -292,6 +295,7 @@ def from_float(cls, input_float, qmin=-128, qmax=127): Int8DynamicallyQuantizedLinearWeight.from_float(model.lin_mod.weight) ) """ + # because we call transpose in dequantization w_int_repr, w_scales, _ = dynamically_quantize_per_channel( input_float, qmin, qmax, torch.int8 )