Skip to content

Commit

Permalink
Add AffineQuantizedObserver
Browse files Browse the repository at this point in the history
Summary:
In our static_quant flow tutorial we were still using observers from `torch.ao` which we plan to deprecate, this PR adds a more general observer for `AffineQuantizedTensor`, and has shown that we can
replace the old observers (min max observer), there could be futhre work to improve perf, add new types
of observation, e.g. tracking stats other than just min/max, moving average observer, histogram observer.

Test Plan:
python test/quantization/test_observer.py
python tutorials/calibration_flow/static_quant.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Aug 9, 2024
1 parent 433cd14 commit afe8b62
Show file tree
Hide file tree
Showing 4 changed files with 278 additions and 20 deletions.
39 changes: 39 additions & 0 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
)
import unittest
# NOTE: we can copy paste these here if we decide to deprecate them in torch.ao
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver

class TestQuantFlow(TestCase):
def _test_obs_helper(self, obs1, obs2):
example_inputs = [torch.randn(10, 2048), torch.randn(10, 2048), torch.randn(10, 2048)]
for example_input in example_inputs:
obs1(example_input)
obs2(example_input)

scale1, zero_point1 = obs1.calculate_qparams()
scale2, zero_point2 = obs2.calculate_qparams()
self.assertTrue(torch.allclose(scale1, scale2))
self.assertTrue(torch.allclose(zero_point1, zero_point2))

def test_min_max_per_tensor_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine)
self._test_obs_helper(obs, ref_obs)

def test_min_max_per_channel_affine(self):
obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float, zero_point_dtype=torch.int)
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
self._test_obs_helper(obs, ref_obs)


if __name__ == "__main__":
unittest.main()
154 changes: 154 additions & 0 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
from .quant_primitives import (
_get_reduction_params,
choose_qparams_affine_with_min_max,
MappingType,
ZeroPointDomain,
)

from dataclasses import dataclass
from typing import Callable, List, Tuple, Optional
from functools import partial

@dataclass(frozen=True)
class GranularityType:
pass

@dataclass(frozen=True)
class PerTensor(GranularityType):
pass

@dataclass(frozen=True)
class PerAxis(GranularityType):
axis: int

# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
self.p = p

def __call__(self, *args, **keywords):
return self.p(*args, **keywords)

def __repr__(self):
return self.p.__repr__()

def with_args(self, *args, **kwargs):
return _with_args(self, *args, **kwargs)

def _with_args(cls_or_self, *args, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
return r

def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]:
if isinstance(granularity_type, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
block_size = list(input_shape)
block_size[granularity_type.axis] = 1
return tuple(block_size)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")

class AffineQuantizedObserver(torch.nn.Module):
with_args = classmethod(_with_args)

def __init__(self,
update_stats: Callable[[Callable, torch.Tensor], None],
calculate_qparams: Callable[[Callable], Tuple[torch.Tensor, torch.Tensor]],
mapping_type: MappingType,
target_dtype: torch.dtype,
block_size: Optional[Tuple[int, ...]] = None,
granularity_type: Optional[GranularityType] = None,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
):
"""
"""
super().__init__()
assert block_size is not None or granularity_type is not None, "Must specify either block_size or granularity_type"
self._update_stats = update_stats
self._calculate_qparams = calculate_qparams
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.block_size = block_size
self.granularity_type = granularity_type
self.quant_min = quant_min
self.quant_max = quant_max
self.eps = eps
self.scale_dtype = scale_dtype
self.zero_point_dtype = zero_point_dtype
self.preserve_zero = preserve_zero
self.zero_point_domain = zero_point_domain

def forward(self, input: torch.Tensor) -> torch.Tensor:
self._update_stats(self, input)
return input

def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
return self._calculate_qparams(self)

def get_min_max_funcs():

def update_stats_min_max(self, input: torch.Tensor):
if input.numel() == 0:
return

input = input.detach()
if self.block_size is None:
self.block_size = get_block_size(input.shape, self.granularity_type)

shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input.size())
input = input.view(shape_for_reduction)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
self.min_val = min_val
self.max_val = max_val
else:
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)

def calculate_qparams_min_max(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert hasattr(self, "min_val") and hasattr(self, "max_val"), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"

return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
self.block_size,
self.target_dtype,
self.quant_min,
self.quant_max,
self.eps,
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain
)

return update_stats_min_max, calculate_qparams_min_max

_update_stats_min_max, _calculate_qparams_min_max = get_min_max_funcs()
AffineQuantizedMinMaxObserver = AffineQuantizedObserver.with_args(_update_stats_min_max, _calculate_qparams_min_max)
84 changes: 70 additions & 14 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"safe_int_mm",
"int_scaled_matmul",
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"quantize_affine",
"dequantize_affine",
"fake_quantize_affine",
Expand Down Expand Up @@ -570,9 +571,51 @@ def choose_qparams_affine(
zero_point_domain.name
)


def choose_qparams_affine_with_min_max(
min_val: torch.Tensor,
max_val: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
operator that pass in min_val and max_val directly instead of deriving these from a single input.
This is used for observers in static quantization where min_val and max_val may be obtained through
tracking all the data in calibration data set.
Args:
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
and then scale/zero_point, we pass in min_val/max_val directly
"""
return _choose_qparams_affine(
None,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name,
min_val,
max_val,
)


@register_custom_op
def _choose_qparams_affine(
input: torch.Tensor,
input: Optional[torch.Tensor],
mapping_type: str,
block_size: List[int],
target_dtype: torch.dtype,
Expand All @@ -583,23 +626,38 @@ def _choose_qparams_affine(
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: str = "INT",
min_val: Optional[torch.Tensor] = None,
max_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = input.dtype
if input is not None:
if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
zero_point_dtype = input.dtype
if eps is None:
eps = torch.finfo(input.dtype).eps

assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)
assert len(block_size) == input.dim(), f"Got input dim:{input.dim()}, block_size: {block_size}"
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)

min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
else:
assert min_val is not None and max_val is not None, "Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
assert min_val.dtype == max_val.dtype, "Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"

min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
if scale_dtype is None:
scale_dtype = min_val.dtype
if zero_point_dtype is None:
zero_point_dtype = min_val.dtype
if eps is None:
eps = torch.finfo(min_val.dtype).eps

if preserve_zero:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
Expand All @@ -615,10 +673,12 @@ def _choose_qparams_affine(
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
if zero_point_domain != ZeroPointDomain.INT.name:
raise ValueError("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization")
scale = torch.clamp(scale, min=eps)
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
Expand All @@ -627,8 +687,4 @@ def _choose_qparams_affine(
mid_point = (quant_max + quant_min + 1) / 2
zero_point = min_val_neg + scale * mid_point

if eps is None:
eps = torch.finfo(input.dtype).eps
scale = torch.clamp(scale, min=eps)

return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)
21 changes: 15 additions & 6 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
import torch
import copy

# TODO: use the generalized observer for affine qunatization in the future
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
import torch.nn.functional as F
from torch import Tensor
from torchao.dtypes import to_affine_quantized_static
from torchao.quantization.utils import compute_error
from torchao.quantization import quantize_
from torchao.quantization import to_linear_activation_quantized
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter

from torchao.quantization.observer import (
AffineQuantizedMinMaxObserver,
PerTensor,
PerAxis,
)
from torchao.quantization.quant_primitives import (
MappingType,
)


class ObservedLinear(torch.nn.Linear):
Expand Down Expand Up @@ -105,16 +110,20 @@ def forward(self, x):
x = self.linear2(x)
return x

torch.manual_seed(0)

dtype = torch.bfloat16
m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda")

m_for_test = copy.deepcopy(m)

m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")

m_bf16 = torch.compile(m_bf16, mode='max-autotune')

# TODO: use the generalized observer for affine qunatization in the future
act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda")
weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda")
act_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
weight_obs = AffineQuantizedMinMaxObserver(mapping_type=MappingType.ASYMMETRIC, target_dtype=torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)

before_quant = m(*example_inputs)

Expand Down

0 comments on commit afe8b62

Please sign in to comment.