From 518621414a7720bcb2f3ffa3f418570fa13c732b Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 24 Jun 2024 18:53:25 -0700 Subject: [PATCH] Add decorator for custom op and inductor decomp registration Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. This is a redo for https://github.com/pytorch/ao/pull/408, difference is we can preserve the enums on the python side in this PR Test Plan: regression tests: python test/quantization/test_quant_api.py python test/integration/test_integration.py also need to check performance with python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 7 +- torchao/quantization/quant_primitives.py | 137 ++++++++++++++++++++--- 2 files changed, 123 insertions(+), 21 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..2288427043 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ @@ -1375,8 +1375,8 @@ class TestExport(unittest.TestCase): @parameterized.expand( list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) - @run_supported_device_dtype - def test_aoti(self, api, test_device, test_dtype): + # @run_supported_device_dtype + def test_export(self, api, test_device, test_dtype): if not TORCH_VERSION_AFTER_2_4: self.skipTest("aoti compatibility requires 2.4+.") @@ -1416,6 +1416,7 @@ def forward(self, x): model = torch.export.export(model, example_inputs).module() after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) + print("model:", model) class TestUtils(unittest.TestCase): @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..0a015d9f63 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,13 +4,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, auto from typing import List, Optional, Tuple, Dict import torch from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) __all__ = [ @@ -34,8 +37,8 @@ class MappingType(Enum): based on this mapping e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) """ - SYMMETRIC = 0 - ASYMMETRIC = 1 + SYMMETRIC = auto() + ASYMMETRIC = auto() class ZeroPointDomain(Enum): """Enum that indicate whether zero_point is in integer domain or floating point domain @@ -43,8 +46,8 @@ class ZeroPointDomain(Enum): integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale """ - INT = 0 - FLOAT = 1 + INT = auto() + FLOAT = auto() """ Map from dtype to the bound value of integers @@ -69,6 +72,20 @@ class ZeroPointDomain(Enum): }) +def register_custom_op(name: str): + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + if TORCH_VERSION_AFTER_2_5: + opdef = torch.library.custom_op(name, mutates_args=())(fn) + opdef.register_fake(fn) + register_decomposition([opdef._opoverload])(fn) + return opdef + else: + return fn + + return decorator + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also @@ -140,7 +157,7 @@ def quantize_affine( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): original float32, float16 or bfloat16 Tensor @@ -174,6 +191,31 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name, + ) + + +@register_custom_op("quant::quantize_affine") +def _quantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" @@ -188,12 +230,12 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: - assert zero_point_domain == ZeroPointDomain.FLOAT + assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 min_val = zero_point - scale * mid_point quant = ( @@ -216,7 +258,7 @@ def dequantize_affine( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument @@ -238,6 +280,34 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name, + output_dtype=output_dtype, + ) + + +@register_custom_op("quant::dequantize_affine") +def _dequantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -255,16 +325,16 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: # Force a copy to avoid input modification due # to upcoming in-place operations. dequant = input.to(torch.int32, copy=True) if zero_point is not None: - dequant -= zero_point.to(torch.int32) + dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) - dequant *= scale + dequant = dequant * scale else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point @@ -320,8 +390,38 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + return _choose_qparams_affine( + input, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name + ) + +@register_custom_op("quant::choose_qparams_affine") +def _choose_qparams_affine( + input: torch.Tensor, + mapping_type: str, + block_size: List[int], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: str = "INT", +) -> Tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}" if scale_dtype is None: scale_dtype = input.dtype @@ -342,21 +442,22 @@ def choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val - if mapping_type == MappingType.SYMMETRIC: + if mapping_type == MappingType.SYMMETRIC.name: max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: + if zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: + assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point