diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index b4fbcb152a..116b474f96 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1244,7 +1244,7 @@ def test_autoquant_manual(self, device, dtype): out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) - + @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ @@ -1376,7 +1376,7 @@ class TestExport(unittest.TestCase): list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)), ) @run_supported_device_dtype - def test_aoti(self, api, test_device, test_dtype): + def test_export(self, api, test_device, test_dtype): if not TORCH_VERSION_AFTER_2_4: self.skipTest("aoti compatibility requires 2.4+.") @@ -1413,9 +1413,20 @@ def forward(self, x): # make sure it compiles example_inputs = (x,) - model = torch.export.export(model, example_inputs).module() + from torch._export import capture_pre_autograd_graph + # TODO: export changes numerics right now, this is because of functionalization according to Zhengxu + # we can re-enable this after non-functional IR is enabled in export + # model = torch.export.export(model, example_inputs).module() + model = capture_pre_autograd_graph(model, example_inputs) after_export = model(x) self.assertTrue(torch.equal(after_export, ref)) + if api is _int8da_int8w_api: + targets = [n.target for n in model.graph.nodes] + self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets) + self.assertTrue(torch.ops.quant.quantize_affine.default in targets) + + + class TestUtils(unittest.TestCase): @parameterized.expand(COMMON_DEVICE_DTYPE) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..39f76928c4 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -4,13 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, auto from typing import List, Optional, Tuple, Dict import torch from torchao.kernel.intmm import int_scaled_matmul from torchao.kernel.intmm import safe_int_mm -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_5, +) +from torchao.utils import _register_custom_op __all__ = [ @@ -34,8 +38,8 @@ class MappingType(Enum): based on this mapping e.g. scale = (10.2 - (-3.5)) / (7 - (-8)) """ - SYMMETRIC = 0 - ASYMMETRIC = 1 + SYMMETRIC = auto() + ASYMMETRIC = auto() class ZeroPointDomain(Enum): """Enum that indicate whether zero_point is in integer domain or floating point domain @@ -43,8 +47,8 @@ class ZeroPointDomain(Enum): integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer) float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale """ - INT = 0 - FLOAT = 1 + INT = auto() + FLOAT = auto() """ Map from dtype to the bound value of integers @@ -69,6 +73,10 @@ class ZeroPointDomain(Enum): }) +quant_lib = torch.library.Library("quant", "FRAGMENT") + +register_custom_op = _register_custom_op(quant_lib) + # TODO: decide on if we want to allow custom quant_min/quant_max here def _get_and_check_qmin_qmax(dtype, quant_min, quant_max): """Get quant_min and quant_max args based on dtype and also @@ -140,7 +148,7 @@ def quantize_affine( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): original float32, float16 or bfloat16 Tensor @@ -174,6 +182,31 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + return _quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain.name, + ) + + +@register_custom_op +def _quantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" @@ -188,12 +221,12 @@ def quantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max ).to(output_dtype) else: - assert zero_point_domain == ZeroPointDomain.FLOAT + assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 min_val = zero_point - scale * mid_point quant = ( @@ -216,7 +249,7 @@ def dequantize_affine( zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, *, output_dtype: torch.dtype = torch.float32, -): +) -> torch.Tensor: """ Args: input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument @@ -238,6 +271,32 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ + return _dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain.name, + output_dtype=output_dtype, + ) + +@register_custom_op +def _dequantize_affine( + input: torch.Tensor, + block_size: List[int], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: str = "INT", + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """op definition that has compatible signatures with custom op library + """ # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size @@ -255,16 +314,16 @@ def dequantize_affine( if zero_point is not None: zero_point = zero_point.view(shape_after_reduction) - if zero_point_domain == ZeroPointDomain.INT: + if zero_point_domain == ZeroPointDomain.INT.name: # Force a copy to avoid input modification due # to upcoming in-place operations. dequant = input.to(torch.int32, copy=True) if zero_point is not None: - dequant -= zero_point.to(torch.int32) + dequant = dequant - zero_point.to(torch.int32) dequant = dequant.to(output_dtype) - dequant *= scale + dequant = dequant * scale else: - assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point @@ -320,8 +379,38 @@ def choose_qparams_affine( Output: Tuple of scales and zero_points Tensor with requested dtype """ + return _choose_qparams_affine( + input, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name + ) + +@register_custom_op +def _choose_qparams_affine( + input: torch.Tensor, + mapping_type: str, + block_size: List[int], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: str = "INT", +) -> Tuple[torch.Tensor, torch.Tensor]: + """op definition that has compatible signatures with custom op library + """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) - assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}" + assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}" if scale_dtype is None: scale_dtype = input.dtype @@ -342,21 +431,22 @@ def choose_qparams_affine( min_val_neg = min_val max_val_pos = max_val - if mapping_type == MappingType.SYMMETRIC: + if mapping_type == MappingType.SYMMETRIC.name: max_val_pos = torch.max(-min_val_neg, max_val_pos) scale = max_val_pos / (float(quant_max - quant_min) / 2) if not preserve_zero: raise ValueError("preserve_zero == False is not supported for symmetric quantization") - if zero_point_domain != ZeroPointDomain.INT: + if zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: + assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) else: - assert zero_point_domain == ZeroPointDomain.FLOAT, "if not preserve_zero, zero_point must be in FLOAT domain" + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point diff --git a/torchao/utils.py b/torchao/utils.py index cc18b5b458..7741328b48 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -13,6 +13,7 @@ "skip_if_compute_capability_less_than", "benchmark_torch_function_in_microseconds", "find_multiple", + "_register_custom_op", "get_model_size_in_bytes", "unwrap_tensor_subclass", "TORCH_VERSION_AFTER_2_2", @@ -65,7 +66,7 @@ def wrapper(*args, **kwargs): def benchmark_torch_function_in_microseconds(f, *args, **kwargs): import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded - + # Manual warmup f(*args, **kwargs) f(*args, **kwargs) @@ -84,6 +85,54 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: return n return n + k - (n % k) +def _register_custom_op(lib): + """This decorator is used to preserve some high level operators for torch.export.export + while still allow them to be decomposed for inductor path + + requirement: make sure `fn.__name__[1:]` is the operator name you want to register + + NOTE: This should be applied at the top, after all other decorators have been applied + NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input, + e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make + sense for downstream system (like executorch) to accept as well + + Example: + lib = torch.library.Library("my_namespace', "FRAGMENT") + + register_custom_op = _register_custom_op(lib) + + @register_custom_op + def _the_op_that_needs_to_be_preserved(...) + ... + + # after this, `_the_op_that_needs_to_be_preserved` will be preserved as + # torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after + # torch.export.export / torch._export.capture_pre_autograd_graph + + """ + from torch._inductor.decomposition import register_decomposition + + def decorator(fn): + if TORCH_VERSION_AFTER_2_5: + from torch._library.infer_schema import infer_schema + + # expecting fn.__name__ starts with `_` and we want to take the rest + # to be the name of the custom op + assert fn.__name__[0] == "_", f"Expecting function name starts with `_`, got {fn.__name__}" + op_name = fn.__name__[1:] + schema = op_name + infer_schema(fn) + lib.define(schema) + lib.impl(op_name, fn, "CompositeImplicitAutograd") + + lib_namespace = lib.ns + op = getattr(getattr(torch.ops, lib_namespace), op_name) + register_decomposition([op])(fn) + return op + else: + return fn + + return decorator + def get_model_size_in_bytes(model, ignore_embeddings=False): """ Returns the model size in bytes. The option to ignore embeddings