From d181692c4e8bd93bf16c302df75f9555de607de8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 24 Apr 2024 16:45:17 -0700 Subject: [PATCH] deduplicate code for `get_group_qparams_symmetric` Summary: This just removes the implementation, we can have follow up PRs to remove the call all together after we have replaced all implementation with the new blockwise quant code Test Plan: CI Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_primitives.py | 6 ++-- torchao/quantization/quant_primitives.py | 33 +++++++++------------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 2830e1acfa..6186714e3b 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self): mapping_type = MappingType.SYMMETRIC dtype = torch.int8 block_size = (1, 2) - scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + eps = torch.finfo(torch.float32).eps + precision = torch.float32 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision) - scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2) + scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision) self.assertTrue(torch.equal(scale, scale_ref)) self.assertTrue(torch.equal(zero_point, zp_ref)) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index febe65e124..93ebef5f39 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -40,6 +40,9 @@ "groupwise_affine_dequantize_tensor_from_qparams", "groupwise_affine_quantize_tensor", "groupwise_affine_dequantize_tensor", + "choose_qparams_affine", + "quantize_affine", + "dequantize_affine", # TODO: need to clean up above functions ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) @@ -728,26 +731,18 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float assert groupsize > 1 assert w.shape[-1] % groupsize == 0 assert w.dim() == 2 + assert n_bit <= 8, f"unsupported n_bit: {n_bit}" - to_quant = w.reshape(-1, groupsize) - assert torch.isnan(to_quant).sum() == 0 - - max_val = to_quant.amax(dim=1, keepdim=True) - min_val = to_quant.amin(dim=1, keepdim=True) - min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) - max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) - - max_val_abs = torch.max(-min_val_neg, max_val_pos) - max_int = 2 ** (n_bit - 1) - 1 - min_int = -(2 ** (n_bit - 1)) - - scales = max_val_abs / (float(max_int - min_int) / 2) - scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps)) - # TODO: make sure abs(scales) is not too small? - zeros = torch.full_like(scales, 0) - return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape( - w.shape[0], -1 - ) + block_size = (1, groupsize) + mapping_type = MappingType.SYMMETRIC + eps = torch.finfo(torch.float32).eps + ranges = {} + ranges[1] = (-1, 0) + # generating ranges for bit 2 to 8 + for i in range(2, 9): + ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1) + quant_min, quant_max = ranges[n_bit] + return choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision) if TORCH_VERSION_AFTER_2_3: