From 0aa84eca76281e90e78f5b208966506b14bda9da Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 22 Apr 2024 16:45:05 -0700 Subject: [PATCH] [quant] Add per block quantization primitives Summary: We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py Test Plan: python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_primitives.py | 188 ++++++++++++++- torchao/quantization/autoquant.py | 7 +- torchao/quantization/quant_primitives.py | 251 +++++++++++++++++++++ 3 files changed, 442 insertions(+), 4 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 82533d6e47..2830e1acfa 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -8,8 +8,21 @@ # This test takes a long time to run import unittest import torch -from torchao.quantization.quant_primitives import get_group_qparams_symmetric -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 +from torchao.quantization.quant_primitives import ( + get_group_qparams_symmetric, + quantize_affine, + dequantize_affine, + choose_qparams_affine, + MappingType, +) + +from torchao.quantization.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, +) + +_SEED = 1234 +torch.manual_seed(_SEED) class TestQuantPrimitives(unittest.TestCase): SEED = 123 @@ -46,5 +59,176 @@ def test_get_group_qparams_symmetric(self): (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) + def test_choose_qparams_group_sym(self): + """Note: groupwise asymmetric quant is using a different way of computing zero_points, so + we don't include it here. We may just replace it with per block quant + """ + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 2) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2) + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_choose_qparams_token_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (1, 10) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + def test_choose_qparams_tensor_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + eps = torch.finfo(torch.float32).eps + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + + + quant_min = -128 + quant_max = 127 + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + def test_choose_qparams_tensor_sym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + eps = torch.finfo(torch.float32).eps + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) + + quant_min = -128 + quant_max = 127 + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_quantize_dequantize_group_sym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 2) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + + group_size = 2 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group( + input, scale, zero_point, quant_min, quant_max, torch.int8, group_size + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + ) + + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_quantize_dequantize_channel_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 1) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + output_dtype = torch.float32 + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + + axis = 1 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel( + input, scale, zero_point, axis, quant_min, quant_max, torch.int8 + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( + quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype + ) + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_quantize_dequantize_tensor_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + output_dtype = torch.float32 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) + + axis = 1 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( + input, scale, zero_point, quant_min, quant_max, torch.int8 + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype + ) + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") + def test_quantize_dequantize_channel_asym_4d(self): + input = torch.randn(3, 3, 10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (3, 3, 1, 10) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + + axis = 2 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel( + input, scale, zero_point, axis, quant_min, quant_max, torch.int8 + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( + quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + ) + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): + input = torch.randn(3, 3, 10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (3, 3, 2, 2) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float + torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index f1f387d7b5..9f2b59f20a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -10,7 +10,11 @@ safe_int_mm, ) import torch.nn.functional as F -from torch._inductor.utils import do_bench +try: + from torch._inductor.utils import do_bench +except: + from torch._inductor.runtime.runtime_utils import do_bench + aten = torch.ops.aten AUTOQUANT_CACHE = {} @@ -387,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte model(*example_input) change_autoquantizable_to_quantized(model, **kwargs) return model - diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 88eafd4b2a..febe65e124 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum +from typing import List, Optional, Tuple import torch from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype @@ -41,6 +43,255 @@ # TODO: need to clean up above functions ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) + +_DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} + +if TORCH_VERSION_AFTER_2_3: + _DTYPE_TO_QVALUE_BOUNDS.update({ + torch.uint1: (0, 2**1-1), + torch.uint2: (0, 2**2-1), + torch.uint3: (0, 2**3-1), + torch.uint4: (0, 2**4-1), + torch.uint5: (0, 2**5-1), + torch.uint6: (0, 2**6-1), + torch.uint7: (0, 2**7-1), + }) + +# TODO: decide on if we want to allow custom quant_min/quant_max here +def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): + """Get quant_min and quant_max args based on dtype and also + verify that they are within the range of possible quant_min/quant_max + for dtype + """ + if dtype not in _DTYPE_TO_QVALUE_BOUNDS: + raise ValueError(f"Unsupported dtype: {dtype}") + quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] + if quant_min is None: + quant_min = quant_min_lower_bound + if quant_max is None: + quant_max = quant_max_upper_bound + + assert quant_min >= quant_min_lower_bound, \ + "quant_min out of bound for dtype, " \ + f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" + + assert quant_max <= quant_max_upper_bound, \ + "quant_max out of bound for dtype, " \ + f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" + return quant_min, quant_max + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + + Example:: + Input: + block_size: (3, 3, 2, 10) + input_size: (3, 3, 10, 10) + + Output: + shape_for_reduction: (3, 3, 5, 2, 10) + reduction_dim: [0, 1, 3, 4] + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, f"Expecting input size at {i} dimension: {input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + else: + # block_size[i] == input_size[i] or block_size[i] == 1 + shape_for_reduction.append(input_size[i]) + # we only need to reduce over the dimension if block_size is greater than 1 + # otherwise it's already the same as reduced dimension + if block_size[i] != 1: + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def quantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None +): + """ + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + # TODO: validations + quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max) + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + quant = torch.clamp( + torch.round(input / scale) + zero_point, quant_min, quant_max + ).to(output_dtype) + quant = quant.view(original_shape) + + return quant + +def dequantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + *, + output_dtype: Optional[torch.dtype] = None, +): + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for input Tensor + quant_max (Optional[int]): maximum quantized value for input Tensor + output_dtype (torch.dtype?): optional dtype for output Tensor, default is fp32 + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + # TODO: validations + assert input.dtype == input_dtype + quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) + + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + dequant = input.to(torch.float32) + scale = scale.to(torch.float32) + if zero_point is not None: + zero_point = zero_point.to(torch.float32) + dequant -= zero_point + dequant *= scale + dequant = dequant.view(original_shape) + return dequant.to(output_dtype) + + +class MappingType(Enum): + SYMMETRIC = 0 + ASYMMETRIC = 1 + +def choose_qparams_affine( + input: torch.Tensor, + mapping_type: MappingType, + block_size: List[int], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input (torch.Tensor): fp32, bf16, fp16 input Tensor + mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric + target_dtype (torch.dtype): dtype for target quantized Tensor + quant_min (Optional[int]): minimum quantized value for target quantized Tensor + quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor + eps (Optional[float]: minimum scale + scale_dtype (torch.dtype): dtype for scales + zero_point_dtype (torch.dtype): dtype for zero_points + + Output: + Tuple of scales and zero_points Tensor with requested dtype + """ + quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) + if scale_dtype is None: + scale_dtype = torch.float32 + if zero_point_dtype is None: + zero_point_dtype = torch.float32 + + assert len(block_size) == input.dim() + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + input = input.view(shape_for_reduction) + + if mapping_type == MappingType.SYMMETRIC: + amax = torch.amax(torch.abs(input), dim=reduction_dims, keepdim=False) + scale = amax / (float(quant_max - quant_min) / 2) + zero_point = torch.ones_like(scale) + zero_point *= int((quant_min + quant_max + 1) / 2) + elif mapping_type == MappingType.ASYMMETRIC: + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + raise RuntimeError(f"Unsupported mapping type: {mapping_type}") + + if eps is not None: + scale = torch.clamp(scale, min=eps) + + return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) + + # copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 def dynamically_quantize_per_tensor(