From a02d06181ac41f2cd09c00e35542957e5818b55c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 22 Apr 2024 16:45:05 -0700 Subject: [PATCH] [quant] Add per block quantization primitives Summary: We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py Test Plan: python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_primitives.py | 140 ++++++++++++- torchao/quantization/autoquant.py | 7 +- torchao/quantization/quant_primitives.py | 227 +++++++++++++++++++++ 3 files changed, 371 insertions(+), 3 deletions(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 82533d6e47..9a9649c00b 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -8,7 +8,14 @@ # This test takes a long time to run import unittest import torch -from torchao.quantization.quant_primitives import get_group_qparams_symmetric +from torchao.quantization.quant_primitives import ( + get_group_qparams_symmetric, + quantize_affine_per_block, + dequantize_affine_per_block, + choose_qparams_affine_per_block, + MappingType, +) + from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 class TestQuantPrimitives(unittest.TestCase): @@ -46,5 +53,136 @@ def test_get_group_qparams_symmetric(self): (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) + def test_choose_qparams_group_sym(self): + """Note: groupwise asymmetric quant is using a different way of computing zero_points, so + we don't include it here. We may just replace it with per block quant + """ + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 2) + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2) + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + def test_choose_qparams_token_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (1, 10) + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + def test_choose_qparams_tensor_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + eps = torch.finfo(torch.float32).eps + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=eps) + + quant_min = -128 + quant_max = 127 + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + def test_choose_qparams_tensor_sym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + eps = torch.finfo(torch.float32).eps + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=eps) + + quant_min = -128 + quant_max = 127 + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) + scale_ref = scale_ref.squeeze() + zp_ref = zp_ref.squeeze() + + self.assertTrue(torch.equal(scale, scale_ref)) + self.assertTrue(torch.equal(zero_point, zp_ref)) + + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.4 or lower") + def test_quantize_dequantize_group_sym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.SYMMETRIC + dtype = torch.int8 + block_size = (1, 2) + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + + quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + + group_size = 2 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group( + input, scale, zero_point, quant_min, quant_max, torch.int8, group_size + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 + ) + + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + + def test_quantize_dequantize_channel_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 1) + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + + axis = 1 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel( + input, scale, zero_point, axis, quant_min, quant_max, torch.int8 + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( + quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + ) + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + + def test_quantize_dequantize_tensor_asym(self): + input = torch.randn(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) + quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype) + dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) + + axis = 1 + quant_min = -128 + quant_max = 127 + quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( + input, scale, zero_point, quant_min, quant_max, torch.int8 + ) + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=torch.float32 + ) + self.assertTrue(torch.equal(quantized, quantized_ref)) + self.assertTrue(torch.equal(dequantized, dequantized_ref)) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index f1f387d7b5..9f2b59f20a 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -10,7 +10,11 @@ safe_int_mm, ) import torch.nn.functional as F -from torch._inductor.utils import do_bench +try: + from torch._inductor.utils import do_bench +except: + from torch._inductor.runtime.runtime_utils import do_bench + aten = torch.ops.aten AUTOQUANT_CACHE = {} @@ -387,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte model(*example_input) change_autoquantizable_to_quantized(model, **kwargs) return model - diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 88eafd4b2a..4410bef867 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from enum import Enum +from typing import List, Optional, Tuple import torch from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype @@ -41,6 +43,231 @@ # TODO: need to clean up above functions ] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else []) + +_DTYPE_TO_QVALUE_BOUNDS = { + torch.uint8: (0, 255), + torch.int8: (-128, 127), + torch.int16: (-(2**15), 2**15 - 1), + torch.int32: (-(2**31), 2**31 - 1), +} + +if TORCH_VERSION_AFTER_2_3: + _DTYPE_TO_QVALUE_BOUNDS.update({ + torch.uint1: (0, 2**1-1), + torch.uint2: (0, 2**2-1), + torch.uint3: (0, 2**3-1), + torch.uint4: (0, 2**4-1), + torch.uint5: (0, 2**5-1), + torch.uint6: (0, 2**6-1), + torch.uint7: (0, 2**7-1), + }) + +def _get_qmin_qmax(dtype, quant_min, quant_max): + assert dtype in _DTYPE_TO_QVALUE_BOUNDS + if quant_min is None: + quant_min = _DTYPE_TO_QVALUE_BOUNDS[dtype][0] + if quant_max is None: + quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype][1] + return quant_min, quant_max + +def _get_reduction_params(block_size, input_size): + """Given block_size and input size find the parameters for reduction: + + Output: + shape_for_reduction: the shape we use to `view` input to prepare it for reduction + reduction_dims: the dims we'll do reduction over + """ + assert len(block_size) == len(input_size) + shape_for_reduction = [] + reduction_dims = [] + cur_dim = 0 + for i in range(len(block_size)): + if block_size[i] != input_size[i] and block_size[i] > 1: + assert input_size[i] % block_size[i] == 0, f"Expecting input size at {i} dimension: {input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}" + shape_for_reduction.append(input_size[i] // block_size[i]) + shape_for_reduction.append(block_size[i]) + # reduce over the block_size[i] dim + reduction_dims.append(cur_dim + 1) + cur_dim += 2 + elif block_size[i] == 1: + # shape does not change if block_size[i] == 1, and we don't reduce over this dim + shape_for_reduction.append(input_size[i]) + cur_dim += 1 + else: + shape_for_reduction.append(block_size[i]) + reduction_dims.append(cur_dim) + cur_dim += 1 + return shape_for_reduction, reduction_dims + + +def quantize_affine_per_block( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None +): + """ + Args: + input (torch.Tensor): original float32 or bfloat16 Tensor + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (float): quantization parameter for affine quantization + zero_point (int): quantization parameter for affine quantization + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype + quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype + + Note: + How can block_size represent different granularities? + let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different + granularities: + + granularity type | block_size + per_tensor | (3, 3, 10, 10) + per_axis (axis=0) | (1, 3, 10, 10) + per_axis (axis=1) | (3, 1, 10, 10) + per_group (groupsize=2) | (3, 3, 10, 2) + per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10) + + + Output: + quantized tensor with requested dtype + """ + # TODO: validations + quant_min, quant_max = _get_qmin_qmax(dtype, quant_min, quant_max) + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + quant = torch.clamp( + torch.round(input / scale) + zero_point, quant_min, quant_max + ).to(dtype) + quant = quant.view(original_shape) + + return quant + +def dequantize_affine_per_block( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + *, + output_dtype: Optional[torch.dtype] = None, +): + """ + Args: + input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument + block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + scale (Tensor): quantization parameter for affine quantization + zero_point (Tensor): quantization parameter for affine quantization + dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor + quant_min (Optional[int]): minimum quantized value for output Tensor + quant_max (Optional[int]): maximum quantized value for output Tensor + output_dtype (torch.dtype?): optional dtype for output Tensor, default is fp32 + + Output: + dequantized Tensor, with requested dtype or fp32 + """ + # TODO: validations + assert input.dtype == dtype + quant_min, quant_max = _get_qmin_qmax(dtype, quant_min, quant_max) + + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + original_shape = input.shape + input = input.view(shape_for_reduction) + shape_after_reduction = shape_for_reduction + for i in reduction_dims: + shape_after_reduction[i] = 1 + scale = scale.view(shape_after_reduction) + if zero_point is not None: + zero_point = zero_point.view(shape_after_reduction) + + dequant = input.to(torch.float32) + scale = scale.to(torch.float32) + if zero_point is not None: + zero_point = zero_point.to(torch.float32) + dequant -= zero_point + dequant *= scale + dequant = dequant.view(original_shape) + return dequant.to(output_dtype) + + +class MappingType(Enum): + SYMMETRIC = 0 + ASYMMETRIC = 1 + +def choose_qparams_affine_per_block( + input: torch.Tensor, + mapping_type: MappingType, + block_size: List[int], + dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + input (torch.Tensor): fp32, bf16, fp16 input Tensor + mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric + dtype (torch.dtype): dtype for target quantized Tensor + quant_min (Optional[int]): minimum quantized value for target quantized Tensor + quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor + eps (Optional[flaot]: minimum scale + scale_dtype (torch.dtype): dtype for scales + zero_point_dtype (torch.dtype): dtype for zero_points + + Output: + Tuple of scales and zero_points Tensor with requested dtype + """ + quant_min, quant_max = _get_qmin_qmax(dtype, quant_min, quant_max) + if scale_dtype is None: + scale_dtype = torch.float32 + if zero_point_dtype is None: + zero_point_dtype = torch.float32 + + assert len(block_size) == input.dim() + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + input = input.view(shape_for_reduction) + + if mapping_type == MappingType.SYMMETRIC: + amax = torch.amax(torch.abs(input), dim=reduction_dims, keepdim=False) + scale = amax / (float(quant_max - quant_min) / 2) + zero_point = torch.ones_like(scale) + zero_point *= int((quant_min + quant_max + 1) / 2) + elif mapping_type == MappingType.ASYMMETRIC: + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + + min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) + max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) + + scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + raise RuntimeError(f"Unsupported mapping type: {mapping_type}") + + if eps is not None: + scale = torch.clamp(scale, min=eps) + + return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) + + # copy-pasta of https://www.internalfb.com/intern/anp/view/?id=3350736 def dynamically_quantize_per_tensor(