diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py new file mode 100644 index 0000000000..0e50760518 --- /dev/null +++ b/test/quantization/test_observer.py @@ -0,0 +1,39 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + PerAxis, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) +import unittest +# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver + +class TestQuantFlow(TestCase): + def _test_obs_helper(self, obs1, obs2): + example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] + for example_input in example_inputs: + obs1(example_input) + obs2(example_input) + + scale1, zero_point1 = obs1.calculate_qparams() + scale2, zero_point2 = obs2.calculate_qparams() + self.assertTrue(torch.allclose(scale1, scale2)) + self.assertTrue(torch.allclose(zero_point1, zero_point2)) + + def test_min_max_per_tensor_affine(self): + obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) + self._test_obs_helper(obs, ref_obs) + + def test_min_max_per_channel_affine(self): + obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) + ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) + self._test_obs_helper(obs, ref_obs) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py new file mode 100644 index 0000000000..9cc3eb2ad6 --- /dev/null +++ b/torchao/quantization/observer.py @@ -0,0 +1,154 @@ +import torch +from .quant_primitives import ( + _get_reduction_params, + choose_qparams_affine_with_min_max, + MappingType, + ZeroPointDomain, +) + +from dataclasses import dataclass +from typing import Callable, List, Tuple, Optional +from functools import partial + +@dataclass(frozen=True) +class GranularityType: + pass + +@dataclass(frozen=True) +class PerTensor(GranularityType): + pass + +@dataclass(frozen=True) +class PerAxis(GranularityType): + axis: int + +# borrowed from torch.ao.quantization.observer +class _PartialWrapper: + def __init__(self, p): + self.p = p + + def __call__(self, *args, **keywords): + return self.p(*args, **keywords) + + def __repr__(self): + return self.p.__repr__() + + def with_args(self, *args, **kwargs): + return _with_args(self, *args, **kwargs) + +def _with_args(cls_or_self, *args, **kwargs): + r"""Wrapper that allows creation of class factories. + + This can be useful when there is a need to create classes with the same + constructor arguments, but different instances. + + Example:: + + >>> # xdoctest: +SKIP("Undefined vars") + >>> Foo.with_args = classmethod(_with_args) + >>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) + >>> foo_instance1 = foo_builder() + >>> foo_instance2 = foo_builder() + >>> id(foo_instance1) == id(foo_instance2) + False + """ + r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) + return r + +def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: + if isinstance(granularity_type, PerTensor): + return input_shape + elif isinstance(granularity_type, PerAxis): + block_size = list(input_shape) + block_size[granularity_type.axis] = 1 + return tuple(block_size) + raise ValueError(f"Unsupported GranularityType: {granularity_type}") + +class AffineQuantizedObserver(torch.nn.Module): + with_args = classmethod(_with_args) + + def __init__(self, + update_stats: Callable[[Callable, torch.Tensor], None], + calculate_qparams: Callable[[Callable], Tuple[torch.Tensor, torch.Tensor]], + mapping_type: MappingType, + target_dtype: torch.dtype, + block_size: Optional[Tuple[int, ...]] = None, + granularity_type: Optional[GranularityType] = None, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain = ZeroPointDomain.INT, + ): + """ + """ + super().__init__() + assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" + self._update_stats = update_stats + self._calculate_qparams = calculate_qparams + self.mapping_type = mapping_type + self.target_dtype = target_dtype + self.block_size = block_size + self.granularity_type = granularity_type + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + + def forward(self, input: torch.Tensor) -> torch.Tensor: + self._update_stats(self, input) + return input + + def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self._calculate_qparams(self) + +def get_min_max_funcs(): + + def update_stats_min_max(self, input: torch.Tensor): + if input.numel() == 0: + return + + input = input.detach() + if self.block_size is None: + self.block_size = get_block_size(input.shape, self.granularity_type) + + shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input.size()) + input = input.view(shape_for_reduction) + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + if not hasattr(self, "min_val") or not hasattr(self, "max_val"): + self.min_val = min_val + self.max_val = max_val + else: + min_val = torch.min(self.min_val, min_val) + max_val = torch.max(self.max_val, max_val) + self.min_val.copy_(min_val) + self.max_val.copy_(max_val) + + def calculate_qparams_min_max(self) -> Tuple[torch.Tensor, torch.Tensor]: + assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" + + return choose_qparams_affine_with_min_max( + self.min_val, + self.max_val, + self.mapping_type, + self.block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.eps, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain + ) + + return update_stats_min_max, calculate_qparams_min_max + +_update_stats_min_max, _calculate_qparams_min_max = get_min_max_funcs() +AffineQuantizedMinMaxObserver = AffineQuantizedObserver.with_args(_update_stats_min_max, _calculate_qparams_min_max) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 1d958be840..a37c17403c 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -21,6 +21,7 @@ "safe_int_mm", "int_scaled_matmul", "choose_qparams_affine", + "choose_qparams_affine_with_min_max", "quantize_affine", "dequantize_affine", "fake_quantize_affine", @@ -570,9 +571,51 @@ def choose_qparams_affine( zero_point_domain.name ) + +def choose_qparams_affine_with_min_max( + min_val: torch.Tensor, + max_val: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine` + operator that pass in min_val and max_val directly instead of deriving these from a single input. + This is used for observers in static quantization where min_val and max_val may be obtained through + tracking all the data in calibration data set. + + Args: + Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one + difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val + and then scale/zero_point, we pass in min_val/max_val directly + """ + return _choose_qparams_affine( + None, + mapping_type.name, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain.name, + min_val, + max_val, + ) + + @register_custom_op def _choose_qparams_affine( - input: torch.Tensor, + input: Optional[torch.Tensor], mapping_type: str, block_size: List[int], target_dtype: torch.dtype, @@ -583,23 +626,38 @@ def _choose_qparams_affine( zero_point_dtype: Optional[torch.dtype] = None, preserve_zero: bool = True, zero_point_domain: str = "INT", + min_val: Optional[torch.Tensor] = None, + max_val: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """op definition that has compatible signatures with custom op library """ quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max) assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}" - if scale_dtype is None: - scale_dtype = input.dtype - if zero_point_dtype is None: - zero_point_dtype = input.dtype + if input is not None: + if scale_dtype is None: + scale_dtype = input.dtype + if zero_point_dtype is None: + zero_point_dtype = input.dtype + if eps is None: + eps = torch.finfo(input.dtype).eps - assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" - shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) - input = input.view(shape_for_reduction) + assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}" + shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) + input = input.view(shape_for_reduction) + + min_val = torch.amin(input, dim=reduction_dims, keepdim=False) + max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + else: + assert min_val is not None and max_val is not None, "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}" + assert min_val.dtype == max_val.dtype, "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}" - min_val = torch.amin(input, dim=reduction_dims, keepdim=False) - max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + if scale_dtype is None: + scale_dtype = min_val.dtype + if zero_point_dtype is None: + zero_point_dtype = min_val.dtype + if eps is None: + eps = torch.finfo(min_val.dtype).eps if preserve_zero: min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) @@ -615,10 +673,12 @@ def _choose_qparams_affine( raise ValueError("preserve_zero == False is not supported for symmetric quantization") if zero_point_domain != ZeroPointDomain.INT.name: raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization") + scale = torch.clamp(scale, min=eps) zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2)) else: assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) + scale = torch.clamp(scale, min=eps) if preserve_zero: zero_point = quant_min - torch.round(min_val_neg / scale) zero_point = torch.clamp(zero_point, quant_min, quant_max) @@ -627,8 +687,4 @@ def _choose_qparams_affine( mid_point = (quant_max + quant_min + 1) / 2 zero_point = min_val_neg + scale * mid_point - if eps is None: - eps = torch.finfo(input.dtype).eps - scale = torch.clamp(scale, min=eps) - return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py index 7911f645e1..2c82ddebb6 100644 --- a/tutorials/calibration_flow/static_quant.py +++ b/tutorials/calibration_flow/static_quant.py @@ -4,8 +4,6 @@ import torch import copy -# TODO: use the generalized observer for affine qunatization in the future -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver import torch.nn.functional as F from torch import Tensor from torchao.dtypes import to_affine_quantized_static @@ -13,7 +11,14 @@ from torchao.quantization import quantize_ from torchao.quantization import to_linear_activation_quantized from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter - +from torchao.quantization.observer import ( + AffineQuantizedMinMaxObserver, + PerTensor, + PerAxis, +) +from torchao.quantization.quant_primitives import ( + MappingType, +) class ObservedLinear(torch.nn.Linear): @@ -105,16 +110,20 @@ def forward(self, x): x = self.linear2(x) return x +torch.manual_seed(0) + dtype = torch.bfloat16 m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") + +m_for_test = copy.deepcopy(m) + m_bf16 = copy.deepcopy(m) example_inputs = m.example_inputs(dtype=dtype, device="cuda") m_bf16 = torch.compile(m_bf16, mode='max-autotune') -# TODO: use the generalized observer for affine qunatization in the future -act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") -weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") +act_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32) +weight_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32) before_quant = m(*example_inputs)