From f2ca7f20215c66c35059a53a579f1121b18ff161 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 24 Jun 2024 18:53:25 -0700 Subject: [PATCH] Add decorator for custom op and inductor decomp registration Summary: This PR adds a decorator to register custom op and also an inductor dcomposition. The goal is for torch.export path to be able to see high level ops like quantize_affine instead of breaking down the op, this is because some backends like xnnpack wants to work with these higher level ops. This is a redo for https://github.com/pytorch/ao/pull/408, difference is we can preserve the enums on the python side in this PR Test Plan: regression tests: python test/quantization/test_quant_api.py python test/integration/test_integration.py also need to check performance with python tutorials/quantize_vit/run_vit_b_quant.py Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 17 +- torchao/quantization/quant_primitives.py | 208 +++++++++++++++++++++-- 2 files changed, 204 insertions(+), 21 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..116b474f96 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ @@ -1376,7 +1376,7 @@ class TestExport(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) @run_supported_device_dtype - def test_aoti(self, api, test_device, test_dtype): + def test_export(self, api, test_device, test_dtype): if not TORCH_VERSION_AFTER_2_4: self.skipTest("aoti compatibility requires 2.4+.") @@ -1413,9 +1413,20 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - model = torch.export.export(model, example_inputs).module() + from torch._export import capture_pre_autograd_graph + # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu + # we can re-enable this after non-functional IR is enabled in export + # model = torch.export.export(model, example_inputs).module() + model = capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) + if api is _int8da_int8w_api: + targets = [n.target for n in model.graph.nodes] + self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets) + self.assertTrue(torch.ops.quant.quantize_affine.default in targets) + + + class TestUtils(unittest.TestCase): @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..1c14125d19 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,13 +4,16 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, auto from typing import List, Optional, Tuple, Dict import torch from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) __all__ = [ @@ -34,8 +37,8 @@ class MappingType(Enum): based on this mapping e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) """ - SYMMETRIC = 0 - ASYMMETRIC = 1 + SYMMETRIC = auto() + ASYMMETRIC = auto() class ZeroPointDomain(Enum): """Enum that indicate whether zero_point is in integer domain or floating point domain @@ -43,8 +46,8 @@ class ZeroPointDomain(Enum): integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale """ - INT = 0 - FLOAT = 1 + INT = auto() + FLOAT = auto() """ Map from dtype to the bound value of integers @@ -69,6 +72,90 @@ class ZeroPointDomain(Enum): }) +# def register_custom_op(name: str): +# from torch._inductor.decomposition import register_decomposition + +# def decorator(fn): +# if TORCH_VERSION_AFTER_2_5: +# opdef = torch.library.custom_op(name, mutates_args=())(fn) +# opdef.register_fake(fn) +# register_decomposition([opdef._opoverload])(fn) +# return opdef +# else: +# return fn + +# return decorator + +quant_lib = torch.library.Library("quant", "FRAGMENT") + +# def register_custom_op(lib, schema: str): +# """This decorator is used to preserve some high level operators for torch.export.export +# while still allow them to be decomposed for inductor path + +# NOTE: This should be applied at the top, after all other decorators have been applied +# """ +# from torch._inductor.decomposition import register_decomposition + +# def decorator(fn): +# if TORCH_VERSION_AFTER_2_5: +# # TODO: change order +# lib_namespace = lib.ns +# op_name = schema.split("(")[0] +# lib.define(schema) +# lib.impl(op_name, fn, "CompositeImplicitAutograd") +# op = getattr(getattr(torch.ops, lib_namespace), op_name) +# register_decomposition([op])(fn) +# return op +# else: +# return fn + +# return decorator + +def register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + @register_custom_op(lib) + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + if TORCH_VERSION_AFTER_2_5: + from torch._library.infer_schema import infer_schema + + # assuming fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + else: + return fn + + return decorator + + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also @@ -140,7 +227,7 @@ def quantize_affine( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): original float32, float16 or bfloat16 Tensor @@ -174,6 +261,31 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name, + ) + + +@register_custom_op(quant_lib) +def _quantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" @@ -188,12 +300,12 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: - assert zero_point_domain == ZeroPointDomain.FLOAT + assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 min_val = zero_point - scale * mid_point quant = ( @@ -216,7 +328,7 @@ def dequantize_affine( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument @@ -238,6 +350,34 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name, + output_dtype=output_dtype, + ) + + +# @register_custom_op(quant_lib, 'dequantize_affine(Tensor input, int[] block_size, Tensor scale, Tensor zero_point, ScalarType input_dtype, int? quant_min=None, int? quant_max=None, str zero_point_domain="INT", ScalarType output_dtype=float) -> Tensor') +@register_custom_op(quant_lib) +def _dequantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -255,16 +395,16 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: # Force a copy to avoid input modification due # to upcoming in-place operations. dequant = input.to(torch.int32, copy=True) if zero_point is not None: - dequant -= zero_point.to(torch.int32) + dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) - dequant *= scale + dequant = dequant * scale else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point @@ -320,8 +460,39 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + return _choose_qparams_affine( + input, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name + ) + +# @register_custom_op(quant_lib, 'choose_qparams_affine(Tensor input, str mapping_type, int[] block_size, ScalarType target_dtype, int? quant_min=None, int? quant_max=None, float? eps=None, ScalarType? scale_dtype=None, ScalarType? zero_point_dtype=None, bool preserve_zero=True, str zero_point_domain="INT") -> (Tensor, Tensor)') +@register_custom_op(quant_lib) +def _choose_qparams_affine( + input: torch.Tensor, + mapping_type: str, + block_size: List[int], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: str = "INT", +) -> Tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}" if scale_dtype is None: scale_dtype = input.dtype @@ -342,21 +513,22 @@ def choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val - if mapping_type == MappingType.SYMMETRIC: + if mapping_type == MappingType.SYMMETRIC.name: max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: + if zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: + assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point