From e31b5757c5068caa019c283722835bb29816781e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 17 Jul 2024 13:02:14 -0700 Subject: [PATCH] Add static quantization as an example for calibration flow (#487) Summary: So far quantization flow API that we provided (`quantize_`) does not require calibration (calibrate a model with sample data), this PR added a static quantization example that serves as an example for calibration flow * 1. first prepare the model for calibration * 2. calibrate the prepared model with sample data * 3. convert the calibrated model to quantized model Test Plan: python torchao/prototype/calibration_flow/static_quant.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 6 +- torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor.py | 36 ++++- torchao/prototype/quant_llm/quant_llm.py | 5 +- torchao/quantization/quant_api.py | 21 +-- tutorials/calibration_flow/static_quant.py | 145 +++++++++++++++++++++ 6 files changed, 202 insertions(+), 13 deletions(-) create mode 100644 tutorials/calibration_flow/static_quant.py diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 7d37af3e04..a938e46e2e 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -14,10 +14,12 @@ class TestAffineQuantized(TestCase): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now") def test_tensor_core_layout_transpose(self): - t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda") + l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") + t = l.weight shape = t.shape apply_int4_weight_only_quant = int4_weight_only(group_size=32) - aqt = apply_int4_weight_only_quant(t) + ql = apply_int4_weight_only_quant(l) + aqt = ql.weight aqt_shape = aqt.shape self.assertEqual(aqt_shape, shape) diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index e72b89156f..39372fe27f 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -4,6 +4,7 @@ from .affine_quantized_tensor import ( AffineQuantizedTensor, to_affine_quantized, + to_affine_quantized_static, LayoutType, PlainLayoutType, TensorCoreTiledLayoutType, @@ -15,6 +16,7 @@ "UInt4Tensor" "AffineQuantizedTensor", "to_affine_quantized", + "to_affine_quantized_static", "LayoutType", "PlainLayoutType", "TensorCoreTiledLayoutType", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index f0a0affbcb..513ba2a8ca 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -232,6 +232,7 @@ def from_float( scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = layout_type.post_process(int_data) layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) @@ -246,8 +247,40 @@ def from_float( dtype=input_float.dtype ) + @classmethod + def from_float_static( + cls, + input_float: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + layout_type: LayoutType = PlainLayoutType(), + ): + original_shape = input_float.shape + input_float = layout_type.pre_process(input_float) + + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + + int_data = layout_type.post_process(int_data) + + layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) + layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type) + return cls( + layout_tensor, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + @property - def layout_type(self) -> str: + def layout_type(self) -> LayoutType: return self.layout_tensor.layout_type @classmethod @@ -809,3 +842,4 @@ def t(func, *args, **kwargs): return return_and_correct_aliasing(func, args, kwargs, new) to_affine_quantized = AffineQuantizedTensor.from_float +to_affine_quantized_static = AffineQuantizedTensor.from_float_static diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py index 33cf1b47b7..3a5dafb52a 100644 --- a/torchao/prototype/quant_llm/quant_llm.py +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -7,6 +7,7 @@ from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones from torchao.ops import quant_llm_linear from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE +from torchao.quantization.quant_api import _get_linear_subclass_inserter _ONES_TABLE = [_n_ones(i) for i in range(8)] @@ -456,8 +457,8 @@ def apply_quant_llm(weight: Tensor) -> Tensor: if (in_dim % 64 != 0) or (out_dim % 256 != 0): return weight return QuantLlmLinearWeight.from_float(weight, ebits, mbits) - return apply_quant_llm + return _get_linear_subclass_inserter(apply_quant_llm) def fp6_llm_weight_only(): - return quant_llm_fpx_weight_only(3, 2) + return _get_linear_subclass_inserter(quant_llm_fpx_weight_only(3, 2)) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 095dbde0b0..f45baaf8a5 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -259,12 +259,12 @@ def insert_subclass(lin): return insert_subclass -def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Tensor], torch.Tensor], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): +def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module], filter_fn: Optional[Callable[[torch.nn.Module, str], bool]]=None, set_inductor_config: bool=True): """Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace Args: model (torch.nn.Module): input model - apply_tensor_subclass (Callable[[torch.Tensor], torch.Tensor]): function that convert a floating point Tensor to a (quantized) tensor subclass instance (e.g. affine quantized tensor instance) + apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor) filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on the weight of the module set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True) @@ -300,19 +300,24 @@ def quantize_(model: torch.nn.Module, apply_tensor_subclass: Callable[[torch.Ten x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6, zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float") + def apply_weight_quant_to_linear(linear): + linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False) + return linear + # apply to modules under block0 submodule def filter_fn(module: nn.Module, fqn: str) -> bool: return isinstance(module, nn.Linear) m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32)) - quantize_(m, apply_weight_quant, filter_fn) + quantize_(m, apply_weight_quant_to_linear, filter_fn) """ if set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() + _replace_with_custom_fn_if_matches_filter( model, - _get_linear_subclass_inserter(apply_tensor_subclass), + apply_tensor_subclass, _is_linear if filter_fn is None else filter_fn, ) @@ -356,7 +361,7 @@ def get_per_token_block_size(x): weight = to_linear_act_quantized(weight, input_quant_func) return weight - return apply_int8_dynamic_activation_int4_weight_quant + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant) def int4_weight_only(group_size=128, inner_k_tiles=8): @@ -394,7 +399,7 @@ def apply_int4_weight_only_quant(weight): layout_type = TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles) return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain, layout_type=layout_type) - return apply_int4_weight_only_quant + return _get_linear_subclass_inserter(apply_int4_weight_only_quant) def int8_weight_only(): @@ -412,7 +417,7 @@ def apply_int8wo_quant(weight): block_size = (1, weight.shape[1]) return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) - return apply_int8wo_quant + return _get_linear_subclass_inserter(apply_int8wo_quant) def int8_dynamic_activation_int8_weight(): """ @@ -454,4 +459,4 @@ def get_per_token_block_size(x): weight = to_linear_act_quantized(weight, input_quant_func) return weight - return apply_int8_dynamic_activation_int8_weight_quant + return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int8_weight_quant) diff --git a/tutorials/calibration_flow/static_quant.py b/tutorials/calibration_flow/static_quant.py new file mode 100644 index 0000000000..a546c5ab89 --- /dev/null +++ b/tutorials/calibration_flow/static_quant.py @@ -0,0 +1,145 @@ +""" +Demo for static quantization flow +""" +import torch +import copy + +# TODO: use the generalized observer for affine qunatization in the future +from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver +import torch.nn.functional as F +from torch import Tensor +from torchao.dtypes import to_affine_quantized_static +from torchao.quantization.utils import compute_error +from torchao.quantization import quantize_ +from torchao.quantization.subclass import to_linear_act_quantized +from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + + +class ObservedLinear(torch.nn.Linear): + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, bias: bool = True, device=None, dtype=None): + super().__init__(in_features, out_features, bias, device, dtype) + self.act_obs = act_obs + self.weight_obs = weight_obs + + def forward(self, input: Tensor): + observed_input = self.act_obs(input) + observed_weight = self.weight_obs(self.weight) + return F.linear(observed_input, observed_weight, self.bias) + + @classmethod + def from_float(cls, float_linear, act_obs, weight_obs): + observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, weight_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype) + observed_linear.weight = float_linear.weight + observed_linear.bias = float_linear.bias + return observed_linear + +def insert_observers_(model, act_obs, weight_obs): + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) + replacement_fn = lambda m: ObservedLinear.from_float(m, act_obs, weight_obs) + act_obs = copy.deepcopy(act_obs) + weight_obs = copy.deepcopy(weight_obs) + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) + +# converting observed linear module to linear module with quantzied weights (and quantized activations) +# with tensor subclasses +def apply_static_quant(observed_linear): + target_dtype = torch.uint8 + + # weight quantization + weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams() + def weight_quant_func(weight): + block_size = (1, weight.shape[1]) + return to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) + linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype) + linear.weight = observed_linear.weight + linear.bias = observed_linear.bias + + linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False) + + # activation quantization + act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams() + input_quant_func = lambda x: to_affine_quantized_static(x, act_scale, act_zero_point, x.shape, target_dtype) + linear.weight = torch.nn.Parameter(to_linear_act_quantized(linear.weight, input_quant_func), requires_grad=False) + + return linear + + +# alternative for converting observed linear module to quantized linear module +class QuantizedLinear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor): + super().__init__() + self.act_scale, self.act_zero_point = act_obs.calculate_qparams() + weight_scale, weight_zero_point = weight_obs.calculate_qparams() + assert weight.dim() == 2 + block_size = (1, weight.shape[1]) + target_dtype = torch.uint8 + self.qweight = to_affine_quantized_static(weight, weight_scale, weight_zero_point, block_size, target_dtype) + self.bias = bias + + def forward(self, input: Tensor): + block_size = input.shape + target_dtype = torch.uint8 + qinput = to_affine_quantized_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype) + return F.linear(qinput, self.qweight, self.bias) + + @classmethod + def from_observed(cls, observed_linear): + quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias) + return quantized_linear + +def apply_static_quant2(observed_linear): + return QuantizedLinear.from_observed(observed_linear) + +class ToyLinearModel(torch.nn.Module): + def __init__(self, m=64, n=32, k=64): + super().__init__() + self.linear1 = torch.nn.Linear(m, n, bias=False) + self.linear2 = torch.nn.Linear(n, k, bias=False) + + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): + return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + +dtype = torch.bfloat16 +m = ToyLinearModel(1024, 1024, 1024).eval().to(dtype).to("cuda") +m_bf16 = copy.deepcopy(m) +example_inputs = m.example_inputs(dtype=dtype, device="cuda") + +m_bf16 = torch.compile(m_bf16, mode='max-autotune') + +# TODO: use the generalized observer for affine qunatization in the future +act_obs = MinMaxObserver(dtype=torch.uint8, qscheme=torch.per_tensor_affine).to("cuda") +weight_obs = PerChannelMinMaxObserver(dtype=torch.uint8, qscheme=torch.per_channel_affine).to("cuda") + +before_quant = m(*example_inputs) + +insert_observers_(m, act_obs, weight_obs) +# calibrating / training +for _ in range(10): + m(*example_inputs) + +after_obs = m(*example_inputs) + +m2 = copy.deepcopy(m) + +is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) + +# quantized linear represented as an nn.Linear with modified tensor subclass weights +# for both activation and weight quantization +quantize_(m, apply_static_quant, is_observed_linear) +print("quantized model (applying tensor subclass to weight):", m) +after_quant = m(*example_inputs) +assert compute_error(before_quant, after_quant) > 30 +print("test passed") + +# quantized linear as a standalone module +quantize_(m2, apply_static_quant2, is_observed_linear) +print("quantized model (quantized module):", m2) +after_quant = m2(*example_inputs) +assert compute_error(before_quant, after_quant) > 30 +print("test passed")