-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: In our static_quant flow tutorial we were still using observers from `torch.ao` which we plan to deprecate, this PR adds a more general observer for `AffineQuantizedTensor`, and has shown that we can replace the old observers (min max observer), there could be futhre work to improve perf, add new types of observation, e.g. tracking stats other than just min/max, moving average observer, histogram observer. Test Plan: python test/quantization/test_observer.py python tutorials/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
1 parent
433cd14
commit afe8b62
Showing
4 changed files
with
278 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import TestCase | ||
from torchao.quantization.observer import ( | ||
AffineQuantizedMinMaxObserver, | ||
PerTensor, | ||
PerAxis, | ||
) | ||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
) | ||
import unittest | ||
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao | ||
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver | ||
|
||
class TestQuantFlow(TestCase): | ||
def _test_obs_helper(self, obs1, obs2): | ||
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] | ||
for example_input in example_inputs: | ||
obs1(example_input) | ||
obs2(example_input) | ||
|
||
scale1, zero_point1 = obs1.calculate_qparams() | ||
scale2, zero_point2 = obs2.calculate_qparams() | ||
self.assertTrue(torch.allclose(scale1, scale2)) | ||
self.assertTrue(torch.allclose(zero_point1, zero_point2)) | ||
|
||
def test_min_max_per_tensor_affine(self): | ||
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) | ||
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) | ||
self._test_obs_helper(obs, ref_obs) | ||
|
||
def test_min_max_per_channel_affine(self): | ||
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) | ||
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) | ||
self._test_obs_helper(obs, ref_obs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
from .quant_primitives import ( | ||
_get_reduction_params, | ||
choose_qparams_affine_with_min_max, | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from dataclasses import dataclass | ||
from typing import Callable, List, Tuple, Optional | ||
from functools import partial | ||
|
||
@dataclass(frozen=True) | ||
class GranularityType: | ||
pass | ||
|
||
@dataclass(frozen=True) | ||
class PerTensor(GranularityType): | ||
pass | ||
|
||
@dataclass(frozen=True) | ||
class PerAxis(GranularityType): | ||
axis: int | ||
|
||
# borrowed from torch.ao.quantization.observer | ||
class _PartialWrapper: | ||
def __init__(self, p): | ||
self.p = p | ||
|
||
def __call__(self, *args, **keywords): | ||
return self.p(*args, **keywords) | ||
|
||
def __repr__(self): | ||
return self.p.__repr__() | ||
|
||
def with_args(self, *args, **kwargs): | ||
return _with_args(self, *args, **kwargs) | ||
|
||
def _with_args(cls_or_self, *args, **kwargs): | ||
r"""Wrapper that allows creation of class factories. | ||
This can be useful when there is a need to create classes with the same | ||
constructor arguments, but different instances. | ||
Example:: | ||
>>> # xdoctest: +SKIP("Undefined vars") | ||
>>> Foo.with_args = classmethod(_with_args) | ||
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) | ||
>>> foo_instance1 = foo_builder() | ||
>>> foo_instance2 = foo_builder() | ||
>>> id(foo_instance1) == id(foo_instance2) | ||
False | ||
""" | ||
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) | ||
return r | ||
|
||
def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: | ||
if isinstance(granularity_type, PerTensor): | ||
return input_shape | ||
elif isinstance(granularity_type, PerAxis): | ||
block_size = list(input_shape) | ||
block_size[granularity_type.axis] = 1 | ||
return tuple(block_size) | ||
raise ValueError(f"Unsupported GranularityType: {granularity_type}") | ||
|
||
class AffineQuantizedObserver(torch.nn.Module): | ||
with_args = classmethod(_with_args) | ||
|
||
def __init__(self, | ||
update_stats: Callable[[Callable, torch.Tensor], None], | ||
calculate_qparams: Callable[[Callable], Tuple[torch.Tensor, torch.Tensor]], | ||
mapping_type: MappingType, | ||
target_dtype: torch.dtype, | ||
block_size: Optional[Tuple[int, ...]] = None, | ||
granularity_type: Optional[GranularityType] = None, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
eps: Optional[float] = None, | ||
scale_dtype: Optional[torch.dtype] = None, | ||
zero_point_dtype: Optional[torch.dtype] = None, | ||
preserve_zero: bool = True, | ||
zero_point_domain = ZeroPointDomain.INT, | ||
): | ||
""" | ||
""" | ||
super().__init__() | ||
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" | ||
self._update_stats = update_stats | ||
self._calculate_qparams = calculate_qparams | ||
self.mapping_type = mapping_type | ||
self.target_dtype = target_dtype | ||
self.block_size = block_size | ||
self.granularity_type = granularity_type | ||
self.quant_min = quant_min | ||
self.quant_max = quant_max | ||
self.eps = eps | ||
self.scale_dtype = scale_dtype | ||
self.zero_point_dtype = zero_point_dtype | ||
self.preserve_zero = preserve_zero | ||
self.zero_point_domain = zero_point_domain | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
self._update_stats(self, input) | ||
return input | ||
|
||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
return self._calculate_qparams(self) | ||
|
||
def get_min_max_funcs(): | ||
|
||
def update_stats_min_max(self, input: torch.Tensor): | ||
if input.numel() == 0: | ||
return | ||
|
||
input = input.detach() | ||
if self.block_size is None: | ||
self.block_size = get_block_size(input.shape, self.granularity_type) | ||
|
||
shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input.size()) | ||
input = input.view(shape_for_reduction) | ||
min_val = torch.amin(input, dim=reduction_dims, keepdim=False) | ||
max_val = torch.amax(input, dim=reduction_dims, keepdim=False) | ||
if not hasattr(self, "min_val") or not hasattr(self, "max_val"): | ||
self.min_val = min_val | ||
self.max_val = max_val | ||
else: | ||
min_val = torch.min(self.min_val, min_val) | ||
max_val = torch.max(self.max_val, max_val) | ||
self.min_val.copy_(min_val) | ||
self.max_val.copy_(max_val) | ||
|
||
def calculate_qparams_min_max(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" | ||
|
||
return choose_qparams_affine_with_min_max( | ||
self.min_val, | ||
self.max_val, | ||
self.mapping_type, | ||
self.block_size, | ||
self.target_dtype, | ||
self.quant_min, | ||
self.quant_max, | ||
self.eps, | ||
self.scale_dtype, | ||
self.zero_point_dtype, | ||
self.preserve_zero, | ||
self.zero_point_domain | ||
) | ||
|
||
return update_stats_min_max, calculate_qparams_min_max | ||
|
||
_update_stats_min_max, _calculate_qparams_min_max = get_min_max_funcs() | ||
AffineQuantizedMinMaxObserver = AffineQuantizedObserver.with_args(_update_stats_min_max, _calculate_qparams_min_max) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters