From 79f2c7f0ce69069ca14983c7b9955ef48e0a064e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 10 Jun 2024 13:59:55 -0700 Subject: [PATCH] Fix dimension issues for int4 weight only quant path (#330) Summary: Currently the accepted dimension of _quantized_linear is not clear, this PR fixes the issue. Currently the "tensor_core_tiled" layout tensor does not do repacking in view operation, which is incorrect, this PR removes the view support (which is not needed right now), and restrict the use case to transpose op, and records the transpose status of the tensor instead of doing repacking for performance. Test Plan: python test/quantization/test_quant_api.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_aq.py | 29 ++++++++++ test/integration/test_integration.py | 6 +- torchao/dtypes/aqt.py | 84 ++++++++++++++-------------- 3 files changed, 74 insertions(+), 45 deletions(-) create mode 100644 test/dtypes/test_aq.py diff --git a/test/dtypes/test_aq.py b/test/dtypes/test_aq.py new file mode 100644 index 0000000000..6967e5f310 --- /dev/null +++ b/test/dtypes/test_aq.py @@ -0,0 +1,29 @@ +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) +from torchao.quantization.quant_api import get_apply_int4wo_quant +import torch +import unittest + + +class TestAQ(TestCase): + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_tensor_core_layout_transpose(self): + t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") + shape = t.shape + apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32) + aqt = apply_int4wo_quant(t) + aqt_shape = aqt.shape + self.assertEqual(aqt_shape, shape) + + # transpose shape test + for _ in range(10): + t = t.t() + aqt = aqt.t() + shape = t.shape + aqt_shape = aqt.shape + self.assertEqual(aqt_shape, shape) + +if __name__ == "__main__": + run_tests() diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1aced2f69e..4461355fb3 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -217,9 +217,9 @@ def _test_smooth_linear_impl(self, x_shape, lin_shape, device): # rtol=0.00001), \ # 'y_smooth_fq_only not close to y_dynamic_q' - self.assertTrue(sqnr_smooth_fq.item() >= 40.0) - self.assertTrue(sqnr_dynamic_q.item() >= 40.0) - self.assertTrue(sqnr_fq.item() >= 40.0) + self.assertTrue(sqnr_smooth_fq.item() >= 40.0, f"got: {sqnr_smooth_fq.item()}") + self.assertTrue(sqnr_dynamic_q.item() >= 40.0, f"got: {sqnr_dynamic_q.item()}") + self.assertTrue(sqnr_fq.item() >= 40.0, f"got: {sqnr_fq.item()}") # Restore backend torch.backends.quantized.engine = orig_backend diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index b1ac739da4..ae05720b05 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -188,11 +188,6 @@ def _apply_fn_to_data(self, fn): fn(self.zero_point), ) - def _change_shape(self, shape): - return self.__class__( - self.int_data.view(shape), self.scale, self.zero_point - ) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -202,9 +197,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) - if func is aten.view.default: - assert len(args) == 2 - new = args[0]._change_shape(args[1]) + if func is aten.t.default: + tensor = args[0] + new = tensor.__class__( + tensor.int_data.view(tenssor.shape[::-1]), tenssor.scale, tenssor.zero_point + ) return return_and_correct_aliasing(func, args, kwargs, new) raise NotImplementedError( @@ -241,6 +238,7 @@ def __new__( cls, packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, + transposed: bool, ): kwargs = {} kwargs["device"] = packed_weight.device @@ -256,19 +254,22 @@ def __init__( self, packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, + transposed: bool, ): self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero + self.transposed = False def __tensor_flatten__(self): - return ["packed_weight", "scale_and_zero"], [] + return ["packed_weight", "scale_and_zero"], [self.transposed] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - return cls(packed_weight, scale_and_zero) + transposed, = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed) @classmethod def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8): @@ -276,7 +277,7 @@ def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8): scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) - return cls(packed_weight, scale_and_zero) + return cls(packed_weight, scale_and_zero, False) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) @@ -285,7 +286,8 @@ def to(self, *args, **kwargs): raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device") return self.__class__( self.packed_weight.to(kwargs["device"]), - self.scale_and_zero.to(kwargs["device"]) + self.scale_and_zero.to(kwargs["device"]), + self.transposed ) def _apply_fn_to_data(self, fn): @@ -293,14 +295,6 @@ def _apply_fn_to_data(self, fn): self.scale_and_zero = fn(self.scale_and_zero) return self - def _change_shape(self, shape): - # int_data, scale, zero = self.get_plain() - # int_data = int_data.view(shape) - # changed = self.from_plain(int_data, scale, zero) - # return changed - # TODO: changing shape is no-op for int4 packed weight right now - return self - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -310,10 +304,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) - if func is aten.view.default: - assert len(args) == 2 - new = args[0]._change_shape(args[1]) - return return_and_correct_aliasing(func, args, kwargs, new) + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + args[0].transposed = not args[0].transposed + return return_and_correct_aliasing(func, args, kwargs, args[0]) raise NotImplementedError( f"TensorCoreTiledAQTLayout dispatch: attempting to run {func}, this is not supported" @@ -329,8 +325,7 @@ def get_plain(self): ) cur_shape = self.shape assert len(cur_shape) == 4 - # TODO: expose the arg - inner_k_tiles = self.cur_shape[-1] * 2 + inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) eye_shape = original_shape[1] block_size = (1, 32) @@ -557,11 +552,6 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) - def _change_shape(self, shape, block_size): - return self.__class__( - self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride() - ) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): # Note: we only added cpu path here for 8da4w, this is for executorch, in the future @@ -583,13 +573,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True): - # TODO: the old tensor subclass can use the single implementation for both F.linear dispatch - # and aten.addmm/aten.mm dispatch because `_change_shape` is not implmeneted correctly (got ignored - # for the int_data), this makes the dimension for weight_qtensor indeterministic, we need to fix - # the issue and make sure we have a clear accepted dimension for `_quantized_linear_op` - # after that we can remove _from_linear flag +def _quantized_linear_op(input_tensor, weight_qtensor, bias): + """ + Quantized version of F.linear operator + Args: + input_tensor: dimension is (batch_size, in_features) + weight_tensor: dimension is (out_features, in_features) + bias: dimension is (out_features,) + """ is_cuda = weight_qtensor.is_cuda is_cpu = weight_qtensor.device == torch.device("cpu") if isinstance(weight_qtensor, AffineQuantizedTensor): @@ -649,9 +641,11 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True) weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and weight_qtensor.layout == "tensor_core_tiled" ): - if not _from_flinear: - weight_qtensor = weight_qtensor.t() assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" + assert input_tensor.shape[-1] == weight_qtensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_qtensor.shape} second dim " + ) # TODO: check groupsize quantization # avoid circular dep, TODO: move this to a common util.py @@ -747,7 +741,8 @@ def aten_mm(func, *args, **kwargs): args[0], ) try: - return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False) + weight_tensor = weight_tensor.t() + return _quantized_linear_op(input_tensor, weight_tensor, bias) except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() @@ -761,7 +756,8 @@ def aten_mm(func, *args, **kwargs): None ) try: - return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False) + weight_tensor = weight_tensor.t() + return _quantized_linear_op(input_tensor, weight_tensor, bias) except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() @@ -797,7 +793,11 @@ def t(func, *args, **kwargs): block_size = args[0].block_size assert len(block_size) == 2 transposed_block_size = (block_size[1], block_size[0]) - new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size) + tensor = args[0] + shape = tensor.shape[::-1] + new = tensor.__class__( + tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + ) return return_and_correct_aliasing(func, args, kwargs, new) to_aq = AffineQuantizedTensor.from_float