-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AffineQuantizedObserver #650
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from torch.testing._internal.common_utils import TestCase | ||
from torchao.quantization.observer import ( | ||
AffineQuantizedMinMaxObserver, | ||
PerTensor, | ||
PerAxis, | ||
) | ||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
) | ||
import unittest | ||
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao | ||
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver | ||
|
||
class TestQuantFlow(TestCase): | ||
def _test_obs_helper(self, obs1, obs2): | ||
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)] | ||
for example_input in example_inputs: | ||
obs1(example_input) | ||
obs2(example_input) | ||
|
||
scale1, zero_point1 = obs1.calculate_qparams() | ||
scale2, zero_point2 = obs2.calculate_qparams() | ||
self.assertTrue(torch.allclose(scale1, scale2)) | ||
self.assertTrue(torch.allclose(zero_point1, zero_point2)) | ||
|
||
def test_min_max_per_tensor_affine(self): | ||
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) | ||
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine) | ||
self._test_obs_helper(obs, ref_obs) | ||
|
||
def test_min_max_per_channel_affine(self): | ||
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int) | ||
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) | ||
self._test_obs_helper(obs, ref_obs) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import torch | ||
from .quant_primitives import ( | ||
_get_reduction_params, | ||
choose_qparams_affine_with_min_max, | ||
MappingType, | ||
ZeroPointDomain, | ||
) | ||
|
||
from abc import ABCMeta, abstractmethod | ||
from dataclasses import dataclass | ||
from typing import Callable, List, Tuple, Optional, Any | ||
from functools import partial | ||
import logging | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class GranularityType: | ||
pass | ||
|
||
@dataclass(frozen=True) | ||
class PerTensor(GranularityType): | ||
pass | ||
|
||
@dataclass(frozen=True) | ||
class PerAxis(GranularityType): | ||
axis: int | ||
|
||
# borrowed from torch.ao.quantization.observer | ||
class _PartialWrapper: | ||
def __init__(self, p): | ||
self.p = p | ||
|
||
def __call__(self, *args, **keywords): | ||
return self.p(*args, **keywords) | ||
|
||
def __repr__(self): | ||
return self.p.__repr__() | ||
|
||
def with_args(self, *args, **kwargs): | ||
return _with_args(self, *args, **kwargs) | ||
|
||
def _with_args(cls_or_self, *args, **kwargs): | ||
r"""Wrapper that allows creation of class factories. | ||
|
||
This can be useful when there is a need to create classes with the same | ||
constructor arguments, but different instances. | ||
|
||
Example:: | ||
|
||
>>> # xdoctest: +SKIP("Undefined vars") | ||
>>> Foo.with_args = classmethod(_with_args) | ||
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42) | ||
>>> foo_instance1 = foo_builder() | ||
>>> foo_instance2 = foo_builder() | ||
>>> id(foo_instance1) == id(foo_instance2) | ||
False | ||
""" | ||
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) | ||
return r | ||
|
||
def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: | ||
if isinstance(granularity_type, PerTensor): | ||
return input_shape | ||
elif isinstance(granularity_type, PerAxis): | ||
block_size = list(input_shape) | ||
block_size[granularity_type.axis] = 1 | ||
return tuple(block_size) | ||
raise ValueError(f"Unsupported GranularityType: {granularity_type}") | ||
|
||
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3: | ||
|
||
class AffineQuantizedObserverBase(ABC, torch.nn.Module): | ||
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization) | ||
|
||
Args: | ||
`granularity_type` and `block_size`: The granularity of the quantization, | ||
must specify at least one, if both are specified `block_size` takes precedence | ||
Current supported granularity type are `PerTensor` and `PerAxis` | ||
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor` | ||
""" | ||
with_args = classmethod(_with_args) | ||
|
||
def __init__(self, | ||
mapping_type: MappingType, | ||
target_dtype: torch.dtype, | ||
block_size: Optional[Tuple[int, ...]] = None, | ||
granularity_type: Optional[GranularityType] = None, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
eps: Optional[float] = None, | ||
scale_dtype: Optional[torch.dtype] = None, | ||
zero_point_dtype: Optional[torch.dtype] = None, | ||
preserve_zero: bool = True, | ||
zero_point_domain = ZeroPointDomain.INT, | ||
): | ||
super().__init__() | ||
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type" | ||
if block_size is not None and granularity_type is not None: | ||
logger.warning("Both block_size and granularity_type are specified, ignoring granularity_type. block_size: {block_size}, granularity_type: {granularity_type}") | ||
self.mapping_type = mapping_type | ||
self.target_dtype = target_dtype | ||
self.block_size = block_size | ||
self.granularity_type = granularity_type | ||
self.quant_min = quant_min | ||
self.quant_max = quant_max | ||
self.eps = eps | ||
self.scale_dtype = scale_dtype | ||
self.zero_point_dtype = zero_point_dtype | ||
self.preserve_zero = preserve_zero | ||
self.zero_point_domain = zero_point_domain | ||
|
||
@abstractmethod | ||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
""" forward function should take the input tensor | ||
and updates internal stats and return the original input Tensor | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Calculate quantization parameter based on the stats attached to the observer module | ||
and returns a tuple of scale and zero_point Tensor | ||
""" | ||
pass | ||
|
||
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase): | ||
def forward(self, input: torch.Tensor): | ||
if input.numel() == 0: | ||
return input | ||
|
||
input_detached = input.detach() | ||
if self.block_size is None: | ||
self.block_size = get_block_size(input_detached.shape, self.granularity_type) | ||
|
||
shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size()) | ||
input_detached = input_detached.view(shape_for_reduction) | ||
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False) | ||
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False) | ||
if not hasattr(self, "min_val") or not hasattr(self, "max_val"): | ||
self.min_val = min_val | ||
self.max_val = max_val | ||
else: | ||
min_val = torch.min(self.min_val, min_val) | ||
max_val = torch.max(self.max_val, max_val) | ||
self.min_val.copy_(min_val) | ||
self.max_val.copy_(max_val) | ||
# returning original input | ||
return input | ||
|
||
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]: | ||
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams" | ||
return choose_qparams_affine_with_min_max( | ||
self.min_val, | ||
self.max_val, | ||
self.mapping_type, | ||
self.block_size, | ||
self.target_dtype, | ||
self.quant_min, | ||
self.quant_max, | ||
self.eps, | ||
self.scale_dtype, | ||
self.zero_point_dtype, | ||
self.preserve_zero, | ||
self.zero_point_domain | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we merge this with
_get_per_token_block_size
fromtorchao/quantization/utils.py
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel in general
PerTensor
andPerAxis
will be useful for the dynamic / weight only flows as well. We can do that in a future PRThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah we can do this in a future PR, I feel we can merge everything into this function and move this to torchao/quantization/utils.py