Skip to content

Commit

Permalink
[StaticQuant] Update how block_size is calculated with Observers
Browse files Browse the repository at this point in the history
stack-info: PR: #815, branch: drisspg/stack/10
  • Loading branch information
drisspg committed Sep 5, 2024
1 parent 848e123 commit 155e41c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 3 deletions.
74 changes: 74 additions & 0 deletions test/quantization/test_observer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import torch
from torch.testing._internal.common_utils import TestCase
from torchao.quantization.observer import (
Expand Down Expand Up @@ -34,6 +35,79 @@ def test_min_max_per_channel_affine(self):
ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine)
self._test_obs_helper(obs, ref_obs)

def test_block_size_calc_success(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerTensor(),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2048),
torch.randn(7, 2048),
]
for example_input in example_inputs:
obs(example_input)

obs.calculate_qparams()

obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
for example_input in example_inputs:
obs(example_input)

obs.calculate_qparams()

def test_block_size_row_errors(self):
obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(0),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2048),
]
expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([10]) != min_val:torch.Size([9])"
escaped_error_msg = re.escape(expected_error_msg)
with self.assertRaisesRegex(AssertionError, escaped_error_msg):
for example_input in example_inputs:
obs(example_input)

obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.float8_e4m3fn,
granularity_type=PerAxis(1),
eps=torch.finfo(torch.float32).eps,
scale_dtype=torch.float,
zero_point_dtype=torch.int,
zero_point_domain=None,
)
example_inputs = [
torch.randn(10, 2048),
torch.randn(9, 2047),
]
expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([2048]) != min_val:torch.Size([2047])"
escaped_error_msg = re.escape(expected_error_msg)
with self.assertRaisesRegex(AssertionError, escaped_error_msg):
for example_input in example_inputs:
obs(example_input)


if __name__ == "__main__":
unittest.main()
43 changes: 40 additions & 3 deletions torchao/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,35 @@

@dataclass(frozen=True)
class GranularityType:
"""
Base class for representing the granularity of quantization.
This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass

@dataclass(frozen=True)
class PerTensor(GranularityType):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass

@dataclass(frozen=True)
class PerAxis(GranularityType):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int

# borrowed from torch.ao.quantization.observer
Expand Down Expand Up @@ -59,7 +80,16 @@ def _with_args(cls_or_self, *args, **kwargs):
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
return r

def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]:

def get_block_size(
input_shape: Tuple[int, ...], granularity_type: GranularityType
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.
Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity_type: The granularity type of the quantization
"""
if isinstance(granularity_type, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
Expand Down Expand Up @@ -130,8 +160,13 @@ def forward(self, input: torch.Tensor):
return input

input_detached = input.detach()
if self.block_size is None:
self.block_size = get_block_size(input_detached.shape, self.granularity_type)
assert self.granularity_type is not None, "granularity_type is None"
block_size = get_block_size(input_detached.shape, self.granularity_type)

# If we are doing PerTensor quantization then we do not need to reduce along a given
# Dimension. The input tensor can vary sizes and we dont want to cache a block_size
if self.block_size is None and not isinstance(self.granularity_type, PerTensor):
self.block_size = block_size

shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size())
input_detached = input_detached.view(shape_for_reduction)
Expand All @@ -141,6 +176,8 @@ def forward(self, input: torch.Tensor):
self.min_val = min_val
self.max_val = max_val
else:
assert self.min_val.shape == min_val.shape, f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
assert self.max_val.shape == max_val.shape, f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
Expand Down

0 comments on commit 155e41c

Please sign in to comment.