From 591472cff46f9bb5c9c0ecd054ad527e7b5fb9eb Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 4 Sep 2024 21:37:58 -0700 Subject: [PATCH] [StaticQuant] Update how block_size is calculated with Observers stack-info: PR: https://github.com/pytorch/ao/pull/815, branch: drisspg/stack/10 --- test/quantization/test_observer.py | 74 ++++++++++++++++++++++++++++++ torchao/quantization/observer.py | 51 +++++++++++++++++++- 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_observer.py b/test/quantization/test_observer.py index 0e50760518..e0c9257a96 100644 --- a/test/quantization/test_observer.py +++ b/test/quantization/test_observer.py @@ -1,3 +1,4 @@ +import re import torch from torch.testing._internal.common_utils import TestCase from torchao.quantization.observer import ( @@ -34,6 +35,79 @@ def test_min_max_per_channel_affine(self): ref_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine) self._test_obs_helper(obs, ref_obs) + def test_block_size_calc_success(self): + obs = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerTensor(), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + example_inputs = [ + torch.randn(10, 2048), + torch.randn(9, 2048), + torch.randn(7, 2048), + ] + for example_input in example_inputs: + obs(example_input) + + obs.calculate_qparams() + + obs = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerAxis(1), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + for example_input in example_inputs: + obs(example_input) + + obs.calculate_qparams() + + def test_block_size_row_errors(self): + obs = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerAxis(0), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + example_inputs = [ + torch.randn(10, 2048), + torch.randn(9, 2048), + ] + expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([10]) != min_val:torch.Size([9])" + escaped_error_msg = re.escape(expected_error_msg) + with self.assertRaisesRegex(AssertionError, escaped_error_msg): + for example_input in example_inputs: + obs(example_input) + + obs = AffineQuantizedMinMaxObserver( + MappingType.SYMMETRIC, + torch.float8_e4m3fn, + granularity_type=PerAxis(1), + eps=torch.finfo(torch.float32).eps, + scale_dtype=torch.float, + zero_point_dtype=torch.int, + zero_point_domain=None, + ) + example_inputs = [ + torch.randn(10, 2048), + torch.randn(9, 2047), + ] + expected_error_msg = "Can't update existing min_val - shape mismatch, self.min_val:torch.Size([2048]) != min_val:torch.Size([2047])" + escaped_error_msg = re.escape(expected_error_msg) + with self.assertRaisesRegex(AssertionError, escaped_error_msg): + for example_input in example_inputs: + obs(example_input) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index a8d10f73f3..984f2a765e 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -16,14 +16,41 @@ @dataclass(frozen=True) class GranularityType: + """ + Base class for representing the granularity of quantization. + +<<<<<<< Updated upstream +======= + This class serves as a parent for specific granularity types used in + quantization operations, such as per-tensor or per-axis quantization. + """ pass +>>>>>>> Stashed changes @dataclass(frozen=True) class PerTensor(GranularityType): + """ + Represents per-tensor granularity in quantization. + +<<<<<<< Updated upstream +======= + This granularity type calcualtes the quantization parameters + based off the entire tensor. + """ pass +>>>>>>> Stashed changes @dataclass(frozen=True) class PerAxis(GranularityType): + """ + Represents per-axis granularity in quantization. + + This granularity type calcualtes different quantization parameters + along a specified axis of the tensor. + + Attributes: + axis (int): The axis along which reduction is performed. + """ axis: int # borrowed from torch.ao.quantization.observer @@ -59,7 +86,20 @@ def _with_args(cls_or_self, *args, **kwargs): r = _PartialWrapper(partial(cls_or_self, *args, **kwargs)) return r +<<<<<<< Updated upstream def get_block_size(input_shape: Tuple[int, ...], granularity_type: GranularityType) -> Tuple[int, ...]: +======= + +def get_block_size( + input_shape: Tuple[int, ...], granularity_type: GranularityType +) -> Tuple[int, ...]: + """Get the block size based on the input shape and granularity type. + + Args: + input_shape: The input tensor shape possibly more than 2 dimensions + granularity_type: The granularity type of the quantization + """ +>>>>>>> Stashed changes if isinstance(granularity_type, PerTensor): return input_shape elif isinstance(granularity_type, PerAxis): @@ -130,8 +170,13 @@ def forward(self, input: torch.Tensor): return input input_detached = input.detach() - if self.block_size is None: - self.block_size = get_block_size(input_detached.shape, self.granularity_type) + assert self.granularity_type is not None, "granularity_type is None" + block_size = get_block_size(input_detached.shape, self.granularity_type) + + # If we are doing PerTensor quantization then we do not need to reduce along a given + # Dimension. The input tensor can vary sizes and we dont want to cache a block_size + if self.block_size is None and not isinstance(self.granularity_type, PerTensor): + self.block_size = block_size shape_for_reduction, reduction_dims = _get_reduction_params(self.block_size, input_detached.size()) input_detached = input_detached.view(shape_for_reduction) @@ -141,6 +186,8 @@ def forward(self, input: torch.Tensor): self.min_val = min_val self.max_val = max_val else: + assert self.min_val.shape == min_val.shape, f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}" + assert self.max_val.shape == max_val.shape, f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}" min_val = torch.min(self.min_val, min_val) max_val = torch.max(self.max_val, max_val) self.min_val.copy_(min_val)