From 41a40cbb3d46b532bc1fa1fa82785c4b221aacf4 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 2 Oct 2024 15:09:55 -0700 Subject: [PATCH] Subclass API (#966) Summary: Adds new int8_dynamic_activation_intx_weight quantization with subclass API Differential Revision: D62464487 --- ...8bit_act_xbit_weight_subclass_quantizer.py | 408 ++++++++++++++++++ torchao/experimental/quant_api.py | 49 +++ ...8bit_act_xbit_weight_subclass_quantizer.py | 153 +++++++ torchao/quantization/quant_primitives.py | 17 +- 4 files changed, 620 insertions(+), 7 deletions(-) create mode 100644 torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py create mode 100644 torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py diff --git a/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py b/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py new file mode 100644 index 0000000000..003cf687b8 --- /dev/null +++ b/torchao/experimental/_linear_8bit_act_xbit_weight_subclass_quantizer.py @@ -0,0 +1,408 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import List, Optional, Tuple + +import torch +from torch.ao.quantization.fx._decomposed import ( + dequantize_per_channel_group, + quantize_per_channel_group, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.dtypes.affine_quantized_tensor import ( + AQTLayout, + register_aqt_quantized_linear_dispatch, + register_layout_cls, +) +from torchao.dtypes.utils import LayoutType +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + MappingType, + ZeroPointDomain, +) +from torchao.utils import TorchAOBaseTensor + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +import sys + +handler = logging.StreamHandler(sys.stdout) +formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setFormatter(formatter) +logger.addHandler(handler) + + +_VALID_TARGETS: List[str] = ["native", "fallback"] + + +def _target_equal(target, other): + assert other in _VALID_TARGETS + assert target in _VALID_TARGETS + return target == other + + +# This format is intended for use with int8 dynamic quantization +class IntxWeightLayoutType(LayoutType): + nbit: int + group_size: int + + # The target platform for the layout, either 'native' or 'fallback'. + target: str + + def __init__( + self, + nbit: int, + group_size: int, + target: str, + ): + assert nbit <= 7 + self.nbit = nbit + self.group_size = group_size + assert target in _VALID_TARGETS + self.target = target + + def extra_repr(self): + return f"nbit={self.nbit}, group_size={self.group_size}, target={self.target}" + + +def _pack_weights_native( + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, +): + assert isinstance(layout_type, IntxWeightLayoutType) + assert _target_equal(layout_type.target, "native") + nbit = layout_type.nbit + group_size = layout_type.group_size + has_weight_zeros = zero_point is not None + + if has_weight_zeros: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + zero_point.reshape(-1).to(torch.int8), + torch.empty(0, group_size, dtype=torch.int8), + ] + else: + args = [ + int_data.to(torch.int8), + scale.reshape(-1), + torch.empty(0, group_size, dtype=torch.int8), + ] + + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr(torch.ops.torchao, f"_pack_8bit_act_{nbit}bit{wzp_suffix}_weight")( + *args + ) + + +@register_layout_cls(IntxWeightLayoutType) +class IntxWeightAQTLayout(AQTLayout): + def __new__( + cls, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["dtype"] = packed_weight.dtype + assert not packed_weight.requires_grad + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale: Optional[torch.Tensor], + zero_point: Optional[torch.Tensor], + layout_type: LayoutType, + ): + assert isinstance(layout_type, IntxWeightLayoutType) + + # In the native case, scale and zero_point information is inside + # the packed_weight + if _target_equal(layout_type.target, "native"): + assert scale is None + assert zero_point is None + + self.packed_weight = packed_weight + self.scale = scale + self.zero_point = zero_point + self.layout_type = layout_type + + def __repr__(self): + layout_type = self.get_layout_type() + return f"{self.__class__.__name__}(packed_weight={str(self.packed_weight)}, scale={str(self.scale)}, zero_point={str(self.zero_point)}, layout_type={layout_type})" + + def get_layout_type(self) -> LayoutType: + return self.layout_type + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + layout_type: LayoutType, + ): + assert isinstance(layout_type, IntxWeightLayoutType) + + try: + if _target_equal(layout_type.target, "native"): + packed_weight = _pack_weights_native( + int_data, scale, zero_point, layout_type + ) + scale = None + zero_point = None + return cls(packed_weight, scale, zero_point, layout_type) + except Exception as e: + logger.warning( + f"A failure occurred when packing weights with IntxWeightLayoutType.target={layout_type.target}: {e}\n" + + "Falling back to **slow** implementation IntxWeightLayoutType.target=fallback." + ) + layout_type.target = "fallback" + + # Fallback + assert _target_equal(layout_type.target, "fallback") + packed_weight = int_data.to(torch.int8) + return cls(packed_weight, scale, zero_point, layout_type) + + def _apply_fn_to_data(self, fn): + self.packed_weight = fn(self.packed_weight) + if self.scale is not None: + self.scale = fn(self.scale) + + if self.zero_point is not None: + self.zero_point = fn(self.zero_point) + return self + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is torch.ops.aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + if func is torch.ops.aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"IntxWeightAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + if _target_equal(self.get_layout_type().target, "native"): + return ["packed_weight"], [self.get_layout_type()] + + # fallback + assert _target_equal(self.get_layout_type().target, "fallback") + if self.zero_point is None: + return ["packed_weight", "scale"], [self.get_layout_type()] + return ["packed_weight", "scale", "zero"], [self.get_layout_type()] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale, zero_point = ( + tensor_data_dict["packed_weight"], + tensor_data_dict.get("scale", None), + tensor_data_dict.get("zero_point", None), + ) + (layout_type,) = tensor_attributes + return cls(packed_weight, scale, zero_point, layout_type) + + +def _linear_int8_dynamic_activation_intx_weight_check( + input_tensor, weight_tensor, bias +): + layout_type = weight_tensor.layout_tensor.layout_type + return isinstance(layout_type, IntxWeightLayoutType) and bias is None + + +def _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.layout_tensor.layout_type.target == "fallback" + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + weight_qvals = weight_tensor.layout_tensor.packed_weight.to(torch.int32) + weight_scales = weight_tensor.layout_tensor.scale + weight_zeros = weight_tensor.layout_tensor.zero_point + group_size = weight_tensor.layout_tensor.layout_type.group_size + has_weight_zeros = weight_zeros is not None + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + + weights_dequantized = dequantize_per_channel_group( + w_int8=weight_qvals, + scales=weight_scales, + zero_points=( + weight_zeros if has_weight_zeros else torch.zeros_like(weight_scales) + ), + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=group_size, + output_dtype=torch.float32, + ) + + # Quantize activations + activation_scales, activation_zeros = choose_qparams_affine( + input=input_tensor, + mapping_type=MappingType.ASYMMETRIC, + block_size=(1, k), + target_dtype=torch.int32, + quant_min=-128, + quant_max=127, + eps=0.0, + scale_dtype=torch.float32, + zero_point_dtype=torch.int32, + preserve_zero=True, + zero_point_domain=ZeroPointDomain.INT, + ) + activation_qvals = quantize_per_channel_group( + input=input_tensor, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=-128, + quant_max=127, + dtype=torch.int8, + group_size=k, + ) + activations_dequantized = dequantize_per_channel_group( + w_int8=activation_qvals, + scales=activation_scales, + zero_points=activation_zeros, + quant_min=None, # TODO: why is this an arg for this function + quant_max=None, # TODO: why is this an arg for this function + dtype=None, # TODO: why is this an arg for this function + group_size=k, + output_dtype=torch.float32, + ) + + return torch.matmul( + activations_dequantized, weights_dequantized.transpose(1, 0) + ) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + input_tensor = input_tensor.reshape(-1, m, k) + + res = [ + _impl_2d(input_tensor[i, :, :], weight_tensor) + for i in range(input_tensor.shape[0]) + ] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +def _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias +): + assert weight_tensor.layout_tensor.layout_type.target == "native" + assert bias is None + + def _impl_2d(input_tensor, weight_tensor): + assert input_tensor.dim() == 2 + assert weight_tensor.dim() == 2 + + m, k = input_tensor.shape + n, k_ = weight_tensor.shape + assert k_ == k + group_size = weight_tensor.layout_tensor.layout_type.group_size + packed_weight = weight_tensor.layout_tensor.packed_weight + + # TODO(T200095131): convert self.n, self.k, self.group_size to + # int when supported by AOTI + args = ( + input_tensor, + packed_weight, + torch.empty(0, group_size, dtype=torch.int8), + torch.empty(0, n, dtype=torch.int8), + torch.empty(0, k, dtype=torch.int8), + ) + + has_weight_zeros = weight_tensor.zero_point_domain is not None + + assert len(weight_tensor.block_size) == 2 + assert weight_tensor.block_size[0] == 1 + group_size = weight_tensor.block_size[1] + assert group_size == weight_tensor.layout_tensor.layout_type.group_size + nbit = weight_tensor.layout_tensor.layout_type.nbit + + n, k = weight_tensor.shape + m, k_ = input_tensor.shape + assert k_ == k + + packed_weight = weight_tensor.layout_tensor.packed_weight + wzp_suffix = "" if has_weight_zeros else "0zp" + return getattr( + torch.ops.torchao, f"_linear_8bit_act_{nbit}bit{wzp_suffix}_weight" + )(*args) + + if input_tensor.dim() == 2: + return _impl_2d(input_tensor, weight_tensor) + + assert input_tensor.dim() >= 3 + lead_shape = input_tensor.shape[0:-2] + m, k = input_tensor.shape[-2], input_tensor.shape[-1] + n, k_ = weight_tensor.shape + assert k_ == k + + input_tensor = input_tensor.reshape(-1, m, k) + + res = [ + _impl_2d(input_tensor[i, :, :], weight_tensor) + for i in range(input_tensor.shape[0]) + ] + res = torch.stack(res) + res = res.reshape(*lead_shape, m, n) + return res + + +def _linear_int8_dynamic_activation_intx_weight_impl(input_tensor, weight_tensor, bias): + target = weight_tensor.layout_tensor.layout_type.target + if _target_equal(target, "native"): + return _linear_int8_dynamic_activation_intx_weight_native_impl( + input_tensor, weight_tensor, bias + ) + + if _target_equal(target, "fallback"): + return _linear_int8_dynamic_activation_intx_weight_fallback_impl( + input_tensor, weight_tensor, bias + ) + + assert False, f"Unknown target {target}" + + +register_aqt_quantized_linear_dispatch( + _linear_int8_dynamic_activation_intx_weight_check, + _linear_int8_dynamic_activation_intx_weight_impl, +) diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index ac21c75221..c5bdf36d46 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -325,3 +325,52 @@ def quantize(self, model: nn.Module) -> nn.Module: }, ) return model + + +from _linear_8bit_act_xbit_weight_subclass_quantizer import IntxWeightLayoutType +from torchao.quantization.quant_api import ( + _get_linear_subclass_inserter, + MappingType, + to_affine_quantized_intx, + ZeroPointDomain, +) + + +def int8_dyn_act_intx_weight( + group_size: int = 128, + nbit: int = 4, + has_weight_zeros: bool = False, + target: str = "native", +): + + def apply(weight): + assert weight.shape[-1] % group_size == 0 + use_hqq = False + layout_type = IntxWeightLayoutType( + nbit=nbit, group_size=group_size, target=target + ) + mapping_type = MappingType.ASYMMETRIC + eps = torch.finfo(torch.float32).eps + block_size = (1, group_size) + target_dtype = torch.int32 + quant_min = -(1 << (nbit - 1)) + quant_max = (1 << (nbit - 1)) - 1 + zero_point_dtype = torch.int8 + preserve_zero = has_weight_zeros + zero_point_domain = ZeroPointDomain.INT if has_weight_zeros else None + return to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + zero_point_dtype=zero_point_dtype, + preserve_zero=preserve_zero, + zero_point_domain=zero_point_domain, + layout_type=layout_type, + use_hqq=use_hqq, + ) + + return _get_linear_subclass_inserter(apply) diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py new file mode 100644 index 0000000000..0a2181b8f7 --- /dev/null +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_subclass_quantizer.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import glob +import os + +import sys +import tempfile +import unittest + +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +from quant_api import int8_dyn_act_intx_weight +from torchao.quantization.quant_api import quantize_ + +from torchao.utils import unwrap_tensor_subclass +from quant_api import ( + _Int8DynActIntxWeightQuantizedLinearFallback, +) + +libs = glob.glob("/tmp/cmake-out/torchao/lib/libtorchao_ops_aten.*") +libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) +if len(libs) == 0: + print( + "Could not find library libtorchao_ops_aten; please run `sh build_torchao_ops.sh aten` in torchao/experimental to build the op library. A slow fallback kernel will be used instaed." + ) +else: + torch.ops.load_library(libs[0]) + + +class TestInt8DynamicActivationIntxWeight(unittest.TestCase): + def test_accuracy(self): + group_size = 128 + m = 1 + n = 1071 + k = 4096 + activations = torch.randn(m, k, dtype=torch.float32) + model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) + + for nbit in [1, 2, 3, 4, 5, 6, 7]: + for has_weight_zeros in [True, False]: + print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") + quantized_model = copy.deepcopy(model) + quantize_( + quantized_model, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + ), + ) + + quantized_model_reference = copy.deepcopy(model) + quantize_( + quantized_model_reference, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="fallback", + ), + ) + + with torch.no_grad(): + result = quantized_model(activations) + expected_result = quantized_model_reference(activations) + + #TODO: remove expected_result2 checks when we deprecate non-subclass API + reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() + reference_impl.quantize_and_pack_weights( + model[0].weight, nbit, group_size, has_weight_zeros + ) + expected_result2 = reference_impl(activations) + + num_mismatch_at_low_tol = 0 + num_mismatch_at_low_tol2 = 0 + num_total = result.reshape(-1).shape[0] + for i in range(num_total): + actual_val = result.reshape(-1)[i] + expected_val = expected_result.reshape(-1)[i] + expected_val2 = expected_result2.reshape(-1)[i] + self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) + if not torch.allclose(actual_val, expected_val): + num_mismatch_at_low_tol += 1 + + self.assertTrue(torch.allclose(expected_val, expected_val2, atol=1e-2, rtol=1e-1)) + if not torch.allclose(expected_val, expected_val2): + num_mismatch_at_low_tol2 += 1 + + # Assert at most 5% of entries are not close at a low tolerance + self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) + self.assertTrue(num_mismatch_at_low_tol2 / num_total <= 0.01) + + def test_export_compile_aoti(self): + group_size = 32 + m = 3 + k0 = 512 + k1 = 256 + k2 = 128 + k3 = 1024 + nbit = 4 + has_weight_zeros = False + layers = [ + torch.nn.Linear(k0, k1, bias=False), + torch.nn.Linear(k1, k2, bias=False), + torch.nn.Linear(k2, k3, bias=False), + ] + model = torch.nn.Sequential(*layers) + activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + + print("Quantizing model") + quantize_( + model, + int8_dyn_act_intx_weight( + group_size=group_size, + nbit=nbit, + has_weight_zeros=has_weight_zeros, + target="native", + ), + ) + + unwrapped_model = copy.deepcopy(model) + unwrap_tensor_subclass(model) + + print("Exporting quantized model") + exported = torch.export.export(model, (activations,)) + + print("Compiling quantized model") + compiled = torch.compile(unwrapped_model) + with torch.no_grad(): + compiled(activations) + + with tempfile.TemporaryDirectory() as tmpdirname: + print("Exporting quantized model with AOTI") + torch._export.aot_compile( + model, + (activations,), + options={"aot_inductor.output_path": f"{tmpdirname}/model.so"}, + ) + + print("Running quantized model in AOTI") + fn = torch._export.aot_load(f"{tmpdirname}/model.so", "cpu") + fn(activations) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index b1561e4cff..cffb7c3a95 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -300,7 +300,7 @@ def _quantize_affine_no_dtype_cast( elif zero_point_domain is None: # This case handles quantization for float8 we expect no zero point and no zero point domain assert zero_point is None, "zero_point should be None when zero_point_domain is None" - quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max) + quant = torch.clamp(torch.round(input * (1.0 / scale)), quant_min, quant_max) else: assert zero_point_domain == ZeroPointDomain.FLOAT.name mid_point = (quant_max + quant_min + 1) / 2 @@ -766,13 +766,16 @@ def _choose_qparams_affine( assert mapping_type == MappingType.ASYMMETRIC.name scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min) scale = torch.clamp(scale, min=eps) - if preserve_zero: - zero_point = quant_min - torch.round(min_val_neg / scale) - zero_point = torch.clamp(zero_point, quant_min, quant_max) + if zero_point_domain is None: + zero_point = torch.zeros_like(scale) else: - assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" - mid_point = (quant_max + quant_min + 1) / 2 - zero_point = min_val_neg + scale * mid_point + if preserve_zero: + zero_point = quant_min - torch.round(min_val_neg / scale) + zero_point = torch.clamp(zero_point, quant_min, quant_max) + else: + assert zero_point_domain == ZeroPointDomain.FLOAT.name, "if not preserve_zero, zero_point must be in FLOAT domain" + mid_point = (quant_max + quant_min + 1) / 2 + zero_point = min_val_neg + scale * mid_point return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)