From 4ecc35de476e4d4b1a52e86624638b68a60604f6 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 15 May 2024 22:21:48 +0000 Subject: [PATCH 01/19] small fixes --- src/compressed_tensors/compressors/pack_quantized.py | 9 +++++---- .../quantization/lifecycle/initialize.py | 4 +++- tests/test_compressors/test_pack_quant.py | 4 ++-- tests/test_quantization/test_configs/test_strategies.py | 1 - 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index 16b6f2a13..3fe9c4ddb 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -41,7 +41,7 @@ class PackedQuantizationCompressor(Compressor): """ COMPRESSION_PARAM_NAMES = [ - "weight", + "weight_packed", "weight_scale", "weight_zero_point", "weight_shape", @@ -73,7 +73,6 @@ def compress( zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) shape = torch.tensor(value.shape) if scale is not None and zp is not None: - # weight is quantized, compress it # weight is quantized, compress it quant_args = model_quant_args[prefix] if can_quantize(value, quant_args): @@ -85,8 +84,10 @@ def compress( args=quant_args, dtype=torch.int8, ) - value = pack_4bit_ints(value.cpu()) + value = pack_4bit_ints(value.cpu()) compressed_dict[merge_names(prefix, "weight_shape")] = shape + compressed_dict[merge_names(prefix, "weight_packed")] = value + continue compressed_dict[name] = value.to("cpu") @@ -116,7 +117,7 @@ def decompress( weight_data[param_name] = f.get_tensor(full_name) if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES): - weight = weight_data["weight"] + weight = weight_data["weight_packed"] original_shape = torch.Size(weight_data["weight_shape"]) unpacked = unpack_4bit_ints(weight, original_shape) decompressed = dequantize( diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 45ad86479..db62d94e9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -90,7 +90,9 @@ def _initialize_scale_zero_point_observer( device = next(module.parameters()).device # initializes empty scale and zero point parameters for the module - init_scale = Parameter(torch.empty(0, device=device), requires_grad=False) + init_scale = Parameter( + torch.empty(0, dtype=torch.float16, device=device), requires_grad=False + ) module.register_parameter(f"{base_name}_scale", init_scale) init_zero_point = Parameter( diff --git a/tests/test_compressors/test_pack_quant.py b/tests/test_compressors/test_pack_quant.py index 262451651..a20ccd720 100644 --- a/tests/test_compressors/test_pack_quant.py +++ b/tests/test_compressors/test_pack_quant.py @@ -74,10 +74,10 @@ def test_quant_format(shape): assert len(dense_state_dict) + 1 == len(compressed_state_dict) # check compressed and packed - assert compressed_state_dict["dummy.weight"].dtype == torch.int32 + assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32 expected_rows = shape[0] expected_columns = math.ceil(shape[1] / 8) # round each row up to nearest int32 - assert compressed_state_dict["dummy.weight"].shape == ( + assert compressed_state_dict["dummy.weight_packed"].shape == ( expected_rows, expected_columns, ) diff --git a/tests/test_quantization/test_configs/test_strategies.py b/tests/test_quantization/test_configs/test_strategies.py index 4e8a2ca5f..fc48ba607 100644 --- a/tests/test_quantization/test_configs/test_strategies.py +++ b/tests/test_quantization/test_configs/test_strategies.py @@ -22,7 +22,6 @@ QuantizationStrategy, apply_quantization_config, ) -from compressed_tensors.quantization.lifecycle.forward import fake_quantize from torch.nn import Linear From d53eac1da8812cc4bd58c28359da837151c58d4b Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 17 May 2024 15:24:01 +0000 Subject: [PATCH 02/19] initial commit --- .../compressors/__init__.py | 1 + .../compressors/fp8_quantized.py | 32 +++++++++++++++++ .../compressors/int_quantized.py | 3 +- .../compressors/model_compressor.py | 18 +++++++++- src/compressed_tensors/config/base.py | 1 + .../quantization/lifecycle/forward.py | 6 ++-- .../quantization/observers/helpers.py | 36 +++++++++++++++---- 7 files changed, 86 insertions(+), 11 deletions(-) create mode 100644 src/compressed_tensors/compressors/fp8_quantized.py diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index b6f2c7d61..59beeb5b1 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -21,3 +21,4 @@ from .model_compressor import ModelCompressor from .pack_quantized import PackedQuantizationCompressor from .sparse_bitmask import BitmaskCompressor, BitmaskTensor +from .fp8_quantized import FloatQuantizationCompressor diff --git a/src/compressed_tensors/compressors/fp8_quantized.py b/src/compressed_tensors/compressors/fp8_quantized.py new file mode 100644 index 000000000..926d2233e --- /dev/null +++ b/src/compressed_tensors/compressors/fp8_quantized.py @@ -0,0 +1,32 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from compressed_tensors.compressors import Compressor, IntQuantizationCompressor +from compressed_tensors.config import CompressionFormat + + +__all__ = ["FloatQuantizationCompressor"] + + + +@Compressor.register(name=CompressionFormat.float_quantized.value) +class FloatQuantizationCompressor(IntQuantizationCompressor): + """ + Compression for quantized FP8 models. Weight of each quantized layer is + converted from its original float type to the format specified by the layer's + quantization scheme. + """ + COMPRESSED_DTYPE = torch.float8_e4m3fn diff --git a/src/compressed_tensors/compressors/int_quantized.py b/src/compressed_tensors/compressors/int_quantized.py index 07d78a4f4..2521a14d9 100644 --- a/src/compressed_tensors/compressors/int_quantized.py +++ b/src/compressed_tensors/compressors/int_quantized.py @@ -41,6 +41,7 @@ class IntQuantizationCompressor(Compressor): """ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"] + COMPRESSED_DTYPE = torch.int8 def compress( self, @@ -76,7 +77,7 @@ def compress( scale=scale, zero_point=zp, args=quant_args, - dtype=torch.int8, + dtype=self.COMPRESSED_DTYPE, ) compressed_dict[name] = value.to("cpu") diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 9d1cf6dfa..f416f104d 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -17,7 +17,9 @@ import operator import os from typing import Dict, Optional, Union - +import transformers +import torch +import re from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, QUANTIZATION_CONFIG_NAME, @@ -185,6 +187,9 @@ def compress( compressed_state_dict ) + + transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size + return compressed_state_dict def decompress(self, model_path: str, model: Module): @@ -262,3 +267,14 @@ def _get_weight_arg_mappings(model: Module) -> Dict: quantized_modules_to_args[name] = submodule.quantization_scheme.weights return quantized_modules_to_args + +# HACK: Override the dtype_byte_size function in transformers to support float8 types +# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 +def new_dtype_byte_size(dtype): + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 \ No newline at end of file diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index e2b0a97e0..f4a6ef6c2 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,7 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" int_quantized = "int-quantized" + float_quantized = "float-quantized" pack_quantized = "pack-quantized" diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e4c627c45..327b91f7d 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -23,6 +23,7 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.observers.helpers import calculate_range from torch.nn import Module @@ -145,9 +146,8 @@ def _process_quantization( do_quantize: bool = True, do_dequantize: bool = True, ) -> torch.Tensor: - bit_range = 2**args.num_bits - q_max = torch.tensor(bit_range / 2 - 1, device=x.device) - q_min = torch.tensor(-bit_range / 2, device=x.device) + + q_min, q_max = calculate_range(args, x.device) group_size = args.group_size # group diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 95c038e4c..70e806703 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -15,11 +15,11 @@ from typing import Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import QuantizationArgs, QuantizationType from torch import FloatTensor, IntTensor, Tensor -__all__ = ["calculate_qparams"] +__all__ = ["calculate_qparams", "calculate_range"] def calculate_qparams( @@ -37,10 +37,15 @@ def calculate_qparams( max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) device = min_vals.device - bit_range = 2**quantization_args.num_bits - 1 - bit_min = -(bit_range + 1) / 2 - bit_max = bit_min + bit_range - if quantization_args.symmetric: + bit_min, bit_max = calculate_range(quantization_args, device) + bit_range = bit_max - bit_min + + if quantization_args.type == QuantizationType.FLOAT: + #TODO: don't assume symmetric + max_val_pos = torch.max(-min_vals, max_vals) + scales = (bit_max / max_val_pos.clamp(min=1e-12)).float().reciprocal() + zero_points = torch.zeros(scales.shape, device=device, dtype=torch.float8_e4m3fn) + elif quantization_args.symmetric: max_val_pos = torch.max(-min_vals, max_vals) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) @@ -52,3 +57,22 @@ def calculate_qparams( zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) return scales, zero_points + +def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: + """ + """ + if quantization_args.type == QuantizationType.INT: + bit_range = 2**quantization_args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=device) + q_min = torch.tensor(-bit_range / 2, device=device) + else: # QuantizationType.FLOAT + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + fp_range_info = torch.finfo(torch.float8_e4m3fn) + q_max = torch.tensor(fp_range_info.max, device=device) + q_min = torch.tensor(fp_range_info.min, device=device) + + return q_min, q_max From 816a0e1e7cbd5e731b17e0db0216dc34ea041aee Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 21 May 2024 20:44:15 +0000 Subject: [PATCH 03/19] bug fixes --- .../compressors/fp8_quantized.py | 2 +- .../compressors/model_compressor.py | 9 +- .../quantization/lifecycle/forward.py | 23 ++- .../quantization/lifecycle/initialize.py | 14 +- .../quantization/observers/helpers.py | 18 ++- tests/test_compressors/test_fp8_quant.py | 139 ++++++++++++++++++ .../test_configs/test_bit_depths.py | 71 ++++++++- 7 files changed, 252 insertions(+), 24 deletions(-) create mode 100644 tests/test_compressors/test_fp8_quant.py diff --git a/src/compressed_tensors/compressors/fp8_quantized.py b/src/compressed_tensors/compressors/fp8_quantized.py index 926d2233e..543f98360 100644 --- a/src/compressed_tensors/compressors/fp8_quantized.py +++ b/src/compressed_tensors/compressors/fp8_quantized.py @@ -21,7 +21,6 @@ __all__ = ["FloatQuantizationCompressor"] - @Compressor.register(name=CompressionFormat.float_quantized.value) class FloatQuantizationCompressor(IntQuantizationCompressor): """ @@ -29,4 +28,5 @@ class FloatQuantizationCompressor(IntQuantizationCompressor): converted from its original float type to the format specified by the layer's quantization scheme. """ + COMPRESSED_DTYPE = torch.float8_e4m3fn diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index f416f104d..0329be1ec 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -16,10 +16,11 @@ import logging import operator import os +import re from typing import Dict, Optional, Union -import transformers + import torch -import re +import transformers from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, QUANTIZATION_CONFIG_NAME, @@ -187,7 +188,6 @@ def compress( compressed_state_dict ) - transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size return compressed_state_dict @@ -268,6 +268,7 @@ def _get_weight_arg_mappings(model: Module) -> Dict: return quantized_modules_to_args + # HACK: Override the dtype_byte_size function in transformers to support float8 types # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 def new_dtype_byte_size(dtype): @@ -277,4 +278,4 @@ def new_dtype_byte_size(dtype): if bit_search is None: raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") bit_size = int(bit_search.groups()[0]) - return bit_size // 8 \ No newline at end of file + return bit_size // 8 diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 04f8e21ea..29085d434 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -17,13 +17,13 @@ from typing import Optional import torch +from compressed_tensors.quantization.observers.helpers import calculate_range from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.observers.helpers import calculate_range from torch.nn import Module @@ -146,8 +146,8 @@ def _process_quantization( do_quantize: bool = True, do_dequantize: bool = True, ) -> torch.Tensor: - - q_min, q_max = calculate_range(args, x.device) + + q_min, q_max = calculate_range(args, x.device) group_size = args.group_size if args.strategy == QuantizationStrategy.GROUP: @@ -156,7 +156,11 @@ def _process_quantization( output = torch.zeros_like(x, dtype=scale.dtype) else: output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x, dtype=output_dtype) + if output_dtype is torch.float8_e4m3fn: + output = torch.zeros_like(x, dtype=torch.float32) + output = output.to(torch.float8_e4m3fn) + else: + output = torch.zeros_like(x, dtype=output_dtype) # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group @@ -292,8 +296,13 @@ def _quantize( q_max: torch.Tensor, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + + if zero_point.dtype.is_floating_point: + rounded = (x / scale + zero_point).to(torch.float8_e4m3fn).to(x.dtype) + else: + rounded = torch.round(x / scale + zero_point) quantized_value = torch.clamp( - torch.round(x / scale + zero_point), + rounded, q_min, q_max, ) @@ -310,4 +319,8 @@ def _dequantize( scale: torch.Tensor, zero_point: torch.Tensor, ) -> torch.Tensor: + if x_q.dtype is torch.float8_e4m3fn: + return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale.to( + scale.dtype + ) return (x_q - zero_point) * scale diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index db62d94e9..2b3e4878d 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -20,7 +20,10 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationType, +) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter @@ -91,11 +94,16 @@ def _initialize_scale_zero_point_observer( # initializes empty scale and zero point parameters for the module init_scale = Parameter( - torch.empty(0, dtype=torch.float16, device=device), requires_grad=False + torch.empty(0, dtype=module.weight.dtype, device=device), requires_grad=False ) module.register_parameter(f"{base_name}_scale", init_scale) + zp_dtype = ( + torch.int8 + if quantization_args.type is QuantizationType.INT + else module.weight.dtype + ) init_zero_point = Parameter( - torch.empty(0, device=device, dtype=int), requires_grad=False + torch.empty(0, device=device, dtype=zp_dtype), requires_grad=False ) module.register_parameter(f"{base_name}_zero_point", init_zero_point) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 70e806703..6df300dc7 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -15,7 +15,10 @@ from typing import Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs, QuantizationType +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationType, +) from torch import FloatTensor, IntTensor, Tensor @@ -41,10 +44,13 @@ def calculate_qparams( bit_range = bit_max - bit_min if quantization_args.type == QuantizationType.FLOAT: - #TODO: don't assume symmetric + # TODO: don't assume symmetric max_val_pos = torch.max(-min_vals, max_vals) scales = (bit_max / max_val_pos.clamp(min=1e-12)).float().reciprocal() - zero_points = torch.zeros(scales.shape, device=device, dtype=torch.float8_e4m3fn) + zero_points = torch.zeros( + scales.shape, device=device, dtype=torch.float8_e4m3fn + ) + zero_points = zero_points.to(min_vals.dtype) elif quantization_args.symmetric: max_val_pos = torch.max(-min_vals, max_vals) scales = max_val_pos / (float(bit_range) / 2) @@ -58,14 +64,14 @@ def calculate_qparams( return scales, zero_points + def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: - """ - """ + """ """ if quantization_args.type == QuantizationType.INT: bit_range = 2**quantization_args.num_bits q_max = torch.tensor(bit_range / 2 - 1, device=device) q_min = torch.tensor(-bit_range / 2, device=device) - else: # QuantizationType.FLOAT + else: # QuantizationType.FLOAT if quantization_args.num_bits != 8: raise ValueError( "Floating point quantization is only supported for 8 bits," diff --git a/tests/test_compressors/test_fp8_quant.py b/tests/test_compressors/test_fp8_quant.py new file mode 100644 index 000000000..29a19dac8 --- /dev/null +++ b/tests/test_compressors/test_fp8_quant.py @@ -0,0 +1,139 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +from collections import OrderedDict + +import pytest +import torch +from compressed_tensors import FloatQuantizationCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, + apply_quantization_config, + apply_quantization_status, +) +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from safetensors.torch import save_file +from torch.nn.modules import Linear, Sequential + + +def get_dummy_quant_config(strategy, group_size=None): + config_groups = { + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=strategy, type="float", group_size=group_size + ), + ), + } + ignore = ["lm_head"] + quant_config = QuantizationConfig( + config_groups=config_groups, + ignore=ignore, + ) + + return quant_config + + +@pytest.mark.parametrize( + "strategy,group_size,sc,zp", + [ + [QuantizationStrategy.TENSOR, None, 0.01, 0], + [ + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8, 1)) * 0.01, + torch.zeros((512, 8, 1), dtype=torch.int8), + ], + [ + QuantizationStrategy.CHANNEL, + 128, + torch.rand((512, 1)) * 0.01, + torch.zeros((512, 1), dtype=torch.int8), + ], + ], +) +def test_quant_format(strategy, group_size, sc, zp): + dense_state_dict = { + "dummy.weight": torch.rand((512, 1024)), + "dummy.weight_scale": torch.tensor(sc, dtype=torch.float32), + "dummy.weight_zero_point": torch.tensor(zp, dtype=torch.float32), + } + quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) + + compressor = FloatQuantizationCompressor(config=quant_config) + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + compressed_state_dict = compressor.compress( + dense_state_dict, model_quant_args=quantized_modules_to_args + ) + + # state_dict params should be the same + assert len(dense_state_dict) == len(compressed_state_dict) + + # check compressed to int8 + assert compressed_state_dict["dummy.weight"].dtype == torch.float8_e4m3fn + assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32 + assert compressed_state_dict["dummy.weight_zero_point"].dtype == torch.float32 + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + [QuantizationStrategy.TENSOR, None], + # [QuantizationStrategy.GROUP, 128], + [QuantizationStrategy.CHANNEL, None], + ], +) +def test_reload_match(strategy, group_size, tmp_path): + model = Sequential( + OrderedDict( + [ + ("dummy", Linear(512, 1024, bias=None)), + ] + ) + ) + quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) + apply_quantization_config(model, quant_config) + apply_quantization_status(model, QuantizationStatus.CALIBRATION) + + for _ in range(1): + inputs = torch.rand((512, 512)) + _ = model(inputs) + + compressor = FloatQuantizationCompressor(config=quant_config) + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, + } + compressed_state_dict = compressor.compress( + model.state_dict(), model_quant_args=quantized_modules_to_args + ) + save_file(compressed_state_dict, tmp_path / "model.safetensors") + reconstructed_dense_gen = compressor.decompress(tmp_path) + reconstructed_dense = {} + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value + + fake_quant_dummy = fake_quantize( + model.dummy.weight, + scale=model.dummy.weight_scale, + zero_point=model.dummy.weight_zero_point, + args=quantized_modules_to_args["dummy"], + ) + assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"]) + + shutil.rmtree(tmp_path) diff --git a/tests/test_quantization/test_configs/test_bit_depths.py b/tests/test_quantization/test_configs/test_bit_depths.py index 225509047..5b82da4a5 100644 --- a/tests/test_quantization/test_configs/test_bit_depths.py +++ b/tests/test_quantization/test_configs/test_bit_depths.py @@ -25,10 +25,15 @@ from torch.nn import Linear -def create_config(bit_depth, input_symmetry, weight_symmetry): - weights = QuantizationArgs(num_bits=bit_depth, symmetric=weight_symmetry) +def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): + print(quant_type) + weights = QuantizationArgs( + num_bits=bit_depth, type=quant_type, symmetric=weight_symmetry + ) if input_symmetry is not None: - inputs = QuantizationArgs(num_bits=bit_depth, symmetric=input_symmetry) + inputs = QuantizationArgs( + num_bits=bit_depth, type=quant_type, symmetric=input_symmetry + ) else: inputs = None @@ -45,11 +50,12 @@ def create_config(bit_depth, input_symmetry, weight_symmetry): @torch.no_grad @pytest.mark.parametrize("bit_depth", [4, 8]) +@pytest.mark.parametrize("quant_type", ["int"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_bit_depths(bit_depth, input_symmetry, weight_symmetry): +def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): model = Linear(64, 64) - quant_config = create_config(bit_depth, input_symmetry, weight_symmetry) + quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) min = -1 * int(2**bit_depth / 2) @@ -92,3 +98,58 @@ def test_bit_depths(bit_depth, input_symmetry, weight_symmetry): ) assert not torch.any(quantized_weight < min).item() assert not torch.any(quantized_weight > max).item() + + +@torch.no_grad +@pytest.mark.parametrize("bit_depth", [8]) +@pytest.mark.parametrize("quant_type", ["float"]) +@pytest.mark.parametrize("input_symmetry", [True, False, None]) +@pytest.mark.parametrize("weight_symmetry", [True, False]) +def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): + model = Linear(64, 64) + quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) + apply_quantization_config(model, quant_config) + + dtype_info = torch.finfo(torch.float8_e4m3fn) + min = dtype_info.min + max = dtype_info.max + + inputs = torch.randn(32, 64) + model(inputs) + assert model.weight_zero_point.dtype == model.weight.dtype == torch.float32 + if input_symmetry is not None: + assert model.input_zero_point.dtype == inputs.dtype == torch.float32 + assert model.input_zero_point >= min + assert model.input_zero_point <= max + + input_max = torch.max(inputs) + input_min = torch.min(inputs) + diff_from_max = abs( + abs(model.input_scale * (max - model.input_zero_point)) - abs(input_max) + ) + diff_from_min = abs( + abs(model.input_scale * abs(min - model.input_zero_point)) - abs(input_min) + ) + assert diff_from_max < model.input_scale or diff_from_min < model.input_scale + + assert model.weight_zero_point >= min + assert model.weight_zero_point <= max + + weight_max = torch.max(model.weight) + weight_min = torch.min(model.weight) + diff_from_max = abs( + abs(model.weight_scale * (max - model.weight_zero_point)) - abs(weight_max) + ) + diff_from_min = abs( + abs(model.weight_scale * abs(min - model.weight_zero_point)) - abs(weight_min) + ) + assert diff_from_max < model.weight_scale or diff_from_min < model.weight_scale + + quantized_weight = fake_quantize( + model.weight, + model.weight_scale, + model.weight_zero_point, + model.quantization_scheme.weights, + ) + assert not torch.any(quantized_weight < min).item() + assert not torch.any(quantized_weight > max).item() From 6ca7843412e848402130e6e666f3eede0dac0d13 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 21 May 2024 21:03:33 +0000 Subject: [PATCH 04/19] cleanup --- .../compressors/fp8_quantized.py | 4 ++-- .../compressors/pack_quantized.py | 9 ++++----- .../quantization/lifecycle/forward.py | 16 ++++++++++------ .../quantization/observers/helpers.py | 7 +++---- .../quantization/quant_args.py | 5 ++++- .../quantization/utils/helpers.py | 15 +++++++++++++++ tests/test_compressors/test_pack_quant.py | 4 ++-- 7 files changed, 40 insertions(+), 20 deletions(-) diff --git a/src/compressed_tensors/compressors/fp8_quantized.py b/src/compressed_tensors/compressors/fp8_quantized.py index 543f98360..e654a8bc8 100644 --- a/src/compressed_tensors/compressors/fp8_quantized.py +++ b/src/compressed_tensors/compressors/fp8_quantized.py @@ -13,9 +13,9 @@ # limitations under the License. -import torch from compressed_tensors.compressors import Compressor, IntQuantizationCompressor from compressed_tensors.config import CompressionFormat +from compressed_tensors.quantization import FP8_DTYPE __all__ = ["FloatQuantizationCompressor"] @@ -29,4 +29,4 @@ class FloatQuantizationCompressor(IntQuantizationCompressor): quantization scheme. """ - COMPRESSED_DTYPE = torch.float8_e4m3fn + COMPRESSED_DTYPE = FP8_DTYPE diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index 3fe9c4ddb..16b6f2a13 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -41,7 +41,7 @@ class PackedQuantizationCompressor(Compressor): """ COMPRESSION_PARAM_NAMES = [ - "weight_packed", + "weight", "weight_scale", "weight_zero_point", "weight_shape", @@ -73,6 +73,7 @@ def compress( zp = model_state.get(merge_names(prefix, "weight_zero_point"), None) shape = torch.tensor(value.shape) if scale is not None and zp is not None: + # weight is quantized, compress it # weight is quantized, compress it quant_args = model_quant_args[prefix] if can_quantize(value, quant_args): @@ -84,10 +85,8 @@ def compress( args=quant_args, dtype=torch.int8, ) - value = pack_4bit_ints(value.cpu()) + value = pack_4bit_ints(value.cpu()) compressed_dict[merge_names(prefix, "weight_shape")] = shape - compressed_dict[merge_names(prefix, "weight_packed")] = value - continue compressed_dict[name] = value.to("cpu") @@ -117,7 +116,7 @@ def decompress( weight_data[param_name] = f.get_tensor(full_name) if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES): - weight = weight_data["weight_packed"] + weight = weight_data["weight"] original_shape = torch.Size(weight_data["weight_shape"]) unpacked = unpack_4bit_ints(weight, original_shape) decompressed = dequantize( diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 29085d434..1cae7de95 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -19,11 +19,13 @@ import torch from compressed_tensors.quantization.observers.helpers import calculate_range from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, QuantizationArgs, QuantizationStrategy, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils import is_float_quantization from torch.nn import Module @@ -156,9 +158,9 @@ def _process_quantization( output = torch.zeros_like(x, dtype=scale.dtype) else: output_dtype = dtype if dtype is not None else x.dtype - if output_dtype is torch.float8_e4m3fn: - output = torch.zeros_like(x, dtype=torch.float32) - output = output.to(torch.float8_e4m3fn) + if output_dtype is FP8_DTYPE: + output = torch.zeros_like(x) + output = output.to(FP8_DTYPE) else: output = torch.zeros_like(x, dtype=output_dtype) @@ -297,10 +299,11 @@ def _quantize( dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + scaled = x / scale + zero_point if zero_point.dtype.is_floating_point: - rounded = (x / scale + zero_point).to(torch.float8_e4m3fn).to(x.dtype) + rounded = scaled.to(FP8_DTYPE).to(x.dtype) else: - rounded = torch.round(x / scale + zero_point) + rounded = torch.round(scaled) quantized_value = torch.clamp( rounded, q_min, @@ -319,7 +322,8 @@ def _dequantize( scale: torch.Tensor, zero_point: torch.Tensor, ) -> torch.Tensor: - if x_q.dtype is torch.float8_e4m3fn: + if is_float_quantization(x_q): + # can't perform arithmetic in fp8 types, need to convert first return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale.to( scale.dtype ) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 6df300dc7..e01056246 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -16,6 +16,7 @@ import torch from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, QuantizationArgs, QuantizationType, ) @@ -47,9 +48,7 @@ def calculate_qparams( # TODO: don't assume symmetric max_val_pos = torch.max(-min_vals, max_vals) scales = (bit_max / max_val_pos.clamp(min=1e-12)).float().reciprocal() - zero_points = torch.zeros( - scales.shape, device=device, dtype=torch.float8_e4m3fn - ) + zero_points = torch.zeros(scales.shape, device=device, dtype=FP8_DTYPE) zero_points = zero_points.to(min_vals.dtype) elif quantization_args.symmetric: max_val_pos = torch.max(-min_vals, max_vals) @@ -77,7 +76,7 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: "Floating point quantization is only supported for 8 bits," f"got {quantization_args.num_bits}" ) - fp_range_info = torch.finfo(torch.float8_e4m3fn) + fp_range_info = torch.finfo(FP8_DTYPE) q_max = torch.tensor(fp_range_info.max, device=device) q_min = torch.tensor(fp_range_info.min, device=device) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index f8c82d8af..bad7f85eb 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -15,10 +15,13 @@ from enum import Enum from typing import Any, Dict, Optional +import torch from pydantic import BaseModel, Field, validator -__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] +__all__ = ["FP8_DTYPE", "QuantizationType", "QuantizationStrategy", "QuantizationArgs"] + +FP8_DTYPE = torch.float8_e4m3fn class QuantizationType(str, Enum): diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 074fccc9d..e74d0552d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -30,6 +30,7 @@ "calculate_compression_ratio", "get_torch_bit_depth", "can_quantize", + "is_float_quantization", ] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -182,3 +183,17 @@ def calculate_compression_ratio(model: Module) -> float: total_uncompressed += uncompressed_bits * num_weights return total_uncompressed / total_compressed + + +def is_float_quantization(tensor: torch.Tensor) -> bool: + """ + :param tensor: tensor to check for quantization type + :return: True if a float quantization dtype, false otherwise + """ + if tensor.dtype is torch.float8_e5m2: + return True + + if tensor.dtype is torch.float8_e4m3fn: + return True + + return False diff --git a/tests/test_compressors/test_pack_quant.py b/tests/test_compressors/test_pack_quant.py index a20ccd720..262451651 100644 --- a/tests/test_compressors/test_pack_quant.py +++ b/tests/test_compressors/test_pack_quant.py @@ -74,10 +74,10 @@ def test_quant_format(shape): assert len(dense_state_dict) + 1 == len(compressed_state_dict) # check compressed and packed - assert compressed_state_dict["dummy.weight_packed"].dtype == torch.int32 + assert compressed_state_dict["dummy.weight"].dtype == torch.int32 expected_rows = shape[0] expected_columns = math.ceil(shape[1] / 8) # round each row up to nearest int32 - assert compressed_state_dict["dummy.weight_packed"].shape == ( + assert compressed_state_dict["dummy.weight"].shape == ( expected_rows, expected_columns, ) From 9fdb7648c46a45882e25e10cc41693e346e3ba63 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 21 May 2024 21:11:33 +0000 Subject: [PATCH 05/19] clarity comments --- src/compressed_tensors/compressors/fp8_quantized.py | 3 +-- src/compressed_tensors/compressors/int_quantized.py | 3 +-- src/compressed_tensors/quantization/lifecycle/forward.py | 7 ++++++- src/compressed_tensors/quantization/observers/helpers.py | 8 +++++++- 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/compressors/fp8_quantized.py b/src/compressed_tensors/compressors/fp8_quantized.py index e654a8bc8..9bf352aa5 100644 --- a/src/compressed_tensors/compressors/fp8_quantized.py +++ b/src/compressed_tensors/compressors/fp8_quantized.py @@ -25,8 +25,7 @@ class FloatQuantizationCompressor(IntQuantizationCompressor): """ Compression for quantized FP8 models. Weight of each quantized layer is - converted from its original float type to the format specified by the layer's - quantization scheme. + converted from its original float type to float8_e4m3fn """ COMPRESSED_DTYPE = FP8_DTYPE diff --git a/src/compressed_tensors/compressors/int_quantized.py b/src/compressed_tensors/compressors/int_quantized.py index 2521a14d9..f4c1fe69b 100644 --- a/src/compressed_tensors/compressors/int_quantized.py +++ b/src/compressed_tensors/compressors/int_quantized.py @@ -36,8 +36,7 @@ class IntQuantizationCompressor(Compressor): """ Integer compression for quantized models. Weight of each quantized layer is - converted from its original float type to the format specified by the layer's - quantization scheme. + converted from its original float type to an int8 """ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"] diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 1cae7de95..1de856682 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -154,11 +154,16 @@ def _process_quantization( if args.strategy == QuantizationStrategy.GROUP: - if do_dequantize: # if dequantizing the output should be a fp type + if do_dequantize: + # if dequantizing the output should match the original weight dtype, + # which is the same as the scale's output = torch.zeros_like(x, dtype=scale.dtype) else: + # outputting a quantized output, use the dtype passed in as a kwarg if its + # specified, otherwise default to the input type output_dtype = dtype if dtype is not None else x.dtype if output_dtype is FP8_DTYPE: + # zeros_like doesn't support fp8 types directly, workaround output = torch.zeros_like(x) output = output.to(FP8_DTYPE) else: diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index e01056246..234b3d94a 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -65,7 +65,13 @@ def calculate_qparams( def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: - """ """ + """ + Calculated the effective quantization range for the given Quantization Args + + :param quantization_args: quantization args to get range of + :param device: device to store the range to + :return: tuple endpoints for the given quantization range + """ if quantization_args.type == QuantizationType.INT: bit_range = 2**quantization_args.num_bits q_max = torch.tensor(bit_range / 2 - 1, device=device) From a2cdba6b2463e1df724dda54fb02b760bcbb20b2 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 22 May 2024 15:15:39 +0000 Subject: [PATCH 06/19] clean up compression classes --- .../compressors/__init__.py | 7 ++- .../compressors/fp8_quantized.py | 31 ----------- .../{int_quantized.py => naive_quantized.py} | 53 ++++++++++++++++--- src/compressed_tensors/config/base.py | 1 + .../quantization/lifecycle/forward.py | 41 ++++++++++---- .../quantization/observers/helpers.py | 3 +- tests/test_compressors/test_fp8_quant.py | 3 +- 7 files changed, 85 insertions(+), 54 deletions(-) delete mode 100644 src/compressed_tensors/compressors/fp8_quantized.py rename src/compressed_tensors/compressors/{int_quantized.py => naive_quantized.py} (76%) diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index 59beeb5b1..faeea8e3b 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -17,8 +17,11 @@ from .base import Compressor from .dense import DenseCompressor from .helpers import load_compressed, save_compressed, save_compressed_model -from .int_quantized import IntQuantizationCompressor from .model_compressor import ModelCompressor +from .naive_quantized import ( + FloatQuantizationCompressor, + IntQuantizationCompressor, + QuantizationCompressor, +) from .pack_quantized import PackedQuantizationCompressor from .sparse_bitmask import BitmaskCompressor, BitmaskTensor -from .fp8_quantized import FloatQuantizationCompressor diff --git a/src/compressed_tensors/compressors/fp8_quantized.py b/src/compressed_tensors/compressors/fp8_quantized.py deleted file mode 100644 index 9bf352aa5..000000000 --- a/src/compressed_tensors/compressors/fp8_quantized.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from compressed_tensors.compressors import Compressor, IntQuantizationCompressor -from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import FP8_DTYPE - - -__all__ = ["FloatQuantizationCompressor"] - - -@Compressor.register(name=CompressionFormat.float_quantized.value) -class FloatQuantizationCompressor(IntQuantizationCompressor): - """ - Compression for quantized FP8 models. Weight of each quantized layer is - converted from its original float type to float8_e4m3fn - """ - - COMPRESSED_DTYPE = FP8_DTYPE diff --git a/src/compressed_tensors/compressors/int_quantized.py b/src/compressed_tensors/compressors/naive_quantized.py similarity index 76% rename from src/compressed_tensors/compressors/int_quantized.py rename to src/compressed_tensors/compressors/naive_quantized.py index f4c1fe69b..9280122fc 100644 --- a/src/compressed_tensors/compressors/int_quantized.py +++ b/src/compressed_tensors/compressors/naive_quantized.py @@ -18,7 +18,11 @@ import torch from compressed_tensors.compressors import Compressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization import ( + FP8_DTYPE, + QuantizationArgs, + QuantizationType, +) from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize from compressed_tensors.quantization.utils import can_quantize from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -27,20 +31,24 @@ from tqdm import tqdm -__all__ = ["IntQuantizationCompressor"] +__all__ = [ + "QuantizationCompressor", + "IntQuantizationCompressor", + "FloatQuantizationCompressor", +] _LOGGER: logging.Logger = logging.getLogger(__name__) -@Compressor.register(name=CompressionFormat.int_quantized.value) -class IntQuantizationCompressor(Compressor): +@Compressor.register(name=CompressionFormat.naive_quantized.value) +class QuantizationCompressor(Compressor): """ - Integer compression for quantized models. Weight of each quantized layer is - converted from its original float type to an int8 + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. """ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"] - COMPRESSED_DTYPE = torch.int8 def compress( self, @@ -76,7 +84,7 @@ def compress( scale=scale, zero_point=zp, args=quant_args, - dtype=self.COMPRESSED_DTYPE, + dtype=self._parse_compression_dtype(quant_args), ) compressed_dict[name] = value.to("cpu") @@ -113,3 +121,32 @@ def decompress( zero_point=weight_data["weight_zero_point"], ) yield merge_names(weight_name, "weight"), decompressed + + def _parse_compression_dtype(self, args: QuantizationArgs) -> torch.dtype: + if args.type is QuantizationType.FLOAT: + return FP8_DTYPE + else: # QuantizationType.INT + if args.num_bits <= 8: + return torch.int8 + elif args.num_bits <= 16: + return torch.int16 + else: + return torch.int32 + + +@Compressor.register(name=CompressionFormat.int_quantized.value) +class IntQuantizationCompressor(QuantizationCompressor): + """ + Alias for integer quantized models + """ + + pass + + +@Compressor.register(name=CompressionFormat.float_quantized.value) +class FloatQuantizationCompressor(QuantizationCompressor): + """ + Alias for fp quantized models + """ + + pass diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index f4a6ef6c2..052130732 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -27,6 +27,7 @@ class CompressionFormat(Enum): sparse_bitmask = "sparse-bitmask" int_quantized = "int-quantized" float_quantized = "float-quantized" + naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 1de856682..20b21e188 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -22,6 +22,7 @@ FP8_DTYPE, QuantizationArgs, QuantizationStrategy, + QuantizationType, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -195,7 +196,13 @@ def _process_quantization( idx = i * group_size if do_quantize: output[:, idx : (idx + group_size)] = _quantize( - x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype + x[:, idx : (idx + group_size)], + sc, + zp, + q_min, + q_max, + quantization_type=args.type, + dtype=dtype, ) if do_dequantize: input = ( @@ -207,7 +214,15 @@ def _process_quantization( else: # covers channel, token and tensor strategies if do_quantize: - output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype) + output = _quantize( + x, + scale, + zero_point, + q_min, + q_max, + quantization_type=args.type, + dtype=dtype, + ) if do_dequantize: output = _dequantize(output if do_quantize else x, scale, zero_point) @@ -301,19 +316,25 @@ def _quantize( zero_point: torch.Tensor, q_min: torch.Tensor, q_max: torch.Tensor, + quantization_type: Optional[QuantizationType] = QuantizationType.INT, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: scaled = x / scale + zero_point - if zero_point.dtype.is_floating_point: - rounded = scaled.to(FP8_DTYPE).to(x.dtype) + if quantization_type is QuantizationType.FLOAT: + # clamp first because cast isn't saturated + quantized_value = torch.clamp( + scaled, + q_min, + q_max, + ) + quantized_value = scaled.to(FP8_DTYPE).to(x.dtype) else: - rounded = torch.round(scaled) - quantized_value = torch.clamp( - rounded, - q_min, - q_max, - ) + quantized_value = torch.clamp( + torch.round(scaled), + q_min, + q_max, + ) if dtype is not None: quantized_value = quantized_value.to(dtype) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 234b3d94a..c4727b848 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -47,7 +47,8 @@ def calculate_qparams( if quantization_args.type == QuantizationType.FLOAT: # TODO: don't assume symmetric max_val_pos = torch.max(-min_vals, max_vals) - scales = (bit_max / max_val_pos.clamp(min=1e-12)).float().reciprocal() + scales = max_val_pos / (float(bit_range) / 2) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=FP8_DTYPE) zero_points = zero_points.to(min_vals.dtype) elif quantization_args.symmetric: diff --git a/tests/test_compressors/test_fp8_quant.py b/tests/test_compressors/test_fp8_quant.py index 29a19dac8..48359856d 100644 --- a/tests/test_compressors/test_fp8_quant.py +++ b/tests/test_compressors/test_fp8_quant.py @@ -95,7 +95,6 @@ def test_quant_format(strategy, group_size, sc, zp): "strategy,group_size", [ [QuantizationStrategy.TENSOR, None], - # [QuantizationStrategy.GROUP, 128], [QuantizationStrategy.CHANNEL, None], ], ) @@ -111,7 +110,7 @@ def test_reload_match(strategy, group_size, tmp_path): apply_quantization_config(model, quant_config) apply_quantization_status(model, QuantizationStatus.CALIBRATION) - for _ in range(1): + for _ in range(16): inputs = torch.rand((512, 512)) _ = model(inputs) From 1854bb5dd364dbb6e5b5877b0628c47c6ef182b9 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 22 May 2024 19:15:45 +0000 Subject: [PATCH 07/19] fixing zero point issues --- .../quantization/lifecycle/forward.py | 7 ++--- .../quantization/observers/helpers.py | 31 ++++++++++++------- .../quantization/quant_args.py | 21 ++++++++++++- 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 20b21e188..3d77ad84b 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -23,6 +23,7 @@ QuantizationArgs, QuantizationStrategy, QuantizationType, + round_fp8, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -328,7 +329,7 @@ def _quantize( q_min, q_max, ) - quantized_value = scaled.to(FP8_DTYPE).to(x.dtype) + quantized_value = round_fp8(quantized_value, FP8_DTYPE) else: quantized_value = torch.clamp( torch.round(scaled), @@ -350,7 +351,5 @@ def _dequantize( ) -> torch.Tensor: if is_float_quantization(x_q): # can't perform arithmetic in fp8 types, need to convert first - return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale.to( - scale.dtype - ) + return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale return (x_q - zero_point) * scale diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index c4727b848..2058da51b 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -19,6 +19,7 @@ FP8_DTYPE, QuantizationArgs, QuantizationType, + round_fp8, ) from torch import FloatTensor, IntTensor, Tensor @@ -44,23 +45,29 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min - if quantization_args.type == QuantizationType.FLOAT: - # TODO: don't assume symmetric - max_val_pos = torch.max(-min_vals, max_vals) + if quantization_args.symmetric: + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = torch.zeros(scales.shape, device=device, dtype=FP8_DTYPE) - zero_points = zero_points.to(min_vals.dtype) - elif quantization_args.symmetric: - max_val_pos = torch.max(-min_vals, max_vals) - scales = max_val_pos / (float(bit_range) / 2) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) + + # set zero_points to correct types + if quantization_args.type == QuantizationType.FLOAT: + zero_points = round_fp8(zero_points, FP8_DTYPE) + else: # QuantizationType.INT + zero_points = zero_points.to(torch.int8) else: scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = bit_min - torch.round(min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) + + if quantization_args.type == QuantizationType.FLOAT: + zero_points = bit_min - (min_vals / scales) + zero_points = round_fp8( + torch.clamp(zero_points, bit_min, bit_max), FP8_DTYPE + ) + else: # QuantizationType.INT + zero_points = bit_min - torch.round(min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) return scales, zero_points diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index bad7f85eb..da5c02e7b 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -19,7 +19,13 @@ from pydantic import BaseModel, Field, validator -__all__ = ["FP8_DTYPE", "QuantizationType", "QuantizationStrategy", "QuantizationArgs"] +__all__ = [ + "FP8_DTYPE", + "QuantizationType", + "QuantizationStrategy", + "QuantizationArgs", + "round_fp8", +] FP8_DTYPE = torch.float8_e4m3fn @@ -126,3 +132,16 @@ def validate_strategy(cls, value, values): return QuantizationStrategy.TENSOR return value + + +def round_fp8(tensor: torch.Tensor, fp8_type: torch.dtype) -> torch.Tensor: + """ + Rounds each element of the input tensor to the nearest fp8 representation, + keeping to original dtype + + :param tensor: tensor to round + :param fp8_type: fp8 dtype to round to + :return: tensor rounded to fp8 + """ + original_dtype = tensor.dtype + return tensor.to(fp8_type).to(original_dtype) From 6a8420ac43ee48d8f36d5fa90e81885a469c9c5d Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 28 May 2024 15:40:33 +0000 Subject: [PATCH 08/19] comment for hack --- src/compressed_tensors/compressors/model_compressor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 0329be1ec..b7c54344f 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -188,6 +188,9 @@ def compress( compressed_state_dict ) + # HACK: Override the dtype_byte_size function in transformers to + # support float8 types. Fix is posted upstream + # https://github.com/huggingface/transformers/pull/30488 transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size return compressed_state_dict From ef57cf4201ff71a51a9cc978e90fe653178b3fb6 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 19:46:56 +0000 Subject: [PATCH 09/19] update quant check --- src/compressed_tensors/quantization/utils/helpers.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index e74d0552d..404e1385e 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -188,11 +188,8 @@ def calculate_compression_ratio(model: Module) -> float: def is_float_quantization(tensor: torch.Tensor) -> bool: """ :param tensor: tensor to check for quantization type - :return: True if a float quantization dtype, false otherwise + :return: True if a supported float quantization dtype, false otherwise """ - if tensor.dtype is torch.float8_e5m2: - return True - if tensor.dtype is torch.float8_e4m3fn: return True From 3baefc572429ab71b549d1af174f031d10babe0c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 20:40:54 +0000 Subject: [PATCH 10/19] cleanup fp8 dtypes --- .../compressors/naive_quantized.py | 23 +--------- .../compressors/pack_quantized.py | 4 -- .../quantization/lifecycle/apply.py | 7 +-- .../quantization/lifecycle/forward.py | 45 ++++++++++--------- .../quantization/lifecycle/initialize.py | 13 ++---- .../quantization/quant_args.py | 11 +++++ 6 files changed, 43 insertions(+), 60 deletions(-) diff --git a/src/compressed_tensors/compressors/naive_quantized.py b/src/compressed_tensors/compressors/naive_quantized.py index 4ef443754..a9bdff7b0 100644 --- a/src/compressed_tensors/compressors/naive_quantized.py +++ b/src/compressed_tensors/compressors/naive_quantized.py @@ -18,11 +18,7 @@ import torch from compressed_tensors.compressors import Compressor from compressed_tensors.config import CompressionFormat -from compressed_tensors.quantization import ( - FP8_DTYPE, - QuantizationArgs, - QuantizationType, -) +from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize from compressed_tensors.quantization.utils import can_quantize from compressed_tensors.utils import get_nested_weight_mappings, merge_names @@ -84,7 +80,7 @@ def compress( scale=scale, zero_point=zp, args=quant_args, - dtype=self._parse_compression_dtype(quant_args), + dtype=quant_args.pytorch_dtype(), ) elif name.endswith("zero_point"): if torch.all(value == 0): @@ -121,10 +117,6 @@ def decompress( if "weight_scale" in weight_data: zero_point = weight_data.get("weight_zero_point", None) scale = weight_data["weight_scale"] - if zero_point is None: - # zero_point assumed to be 0 if not included in state_dict - zero_point = torch.zeros_like(scale) - decompressed = dequantize( x_q=weight_data["weight"], scale=scale, @@ -132,17 +124,6 @@ def decompress( ) yield merge_names(weight_name, "weight"), decompressed - def _parse_compression_dtype(self, args: QuantizationArgs) -> torch.dtype: - if args.type is QuantizationType.FLOAT: - return FP8_DTYPE - else: # QuantizationType.INT - if args.num_bits <= 8: - return torch.int8 - elif args.num_bits <= 16: - return torch.int16 - else: - return torch.int32 - @Compressor.register(name=CompressionFormat.int_quantized.value) class IntQuantizationCompressor(QuantizationCompressor): diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index a2f7cd13e..305db56b7 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -125,10 +125,6 @@ def decompress( if "weight_scale" in weight_data: zero_point = weight_data.get("weight_zero_point", None) scale = weight_data["weight_scale"] - if zero_point is None: - # zero_point assumed to be 0 if not included in state_dict - zero_point = torch.zeros_like(scale) - weight = weight_data["weight_packed"] original_shape = torch.Size(weight_data["weight_shape"]) unpacked = unpack_4bit_ints(weight, original_shape) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index e9f16481b..36ff71a17 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -197,10 +197,11 @@ def _load_quant_args_from_state_dict( scale = getattr(module, scale_name, None) zp = getattr(module, zp_name, None) if scale is not None: - scale.data = state_dict[f"{module_name}.{scale_name}"].to(device) + state_dict_scale = state_dict[f"{module_name}.{scale_name}"] + scale.data = state_dict_scale.to(device).to(scale.dtype) if zp is not None: zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None) if zp_from_state is not None: # load the non-zero zero points - zp.data = state_dict[f"{module_name}.{zp_name}"].to(device) + zp.data = zp_from_state.to(device).to(scale.dtype) else: # fill with zeros matching scale shape - zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device) + zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 715d6c625..c70d55ef7 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -27,7 +27,6 @@ ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme -from compressed_tensors.quantization.utils import is_float_quantization from torch.nn import Module @@ -77,8 +76,9 @@ def quantize( def dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: torch.Tensor = None, args: QuantizationArgs = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Dequantize a quantized input tensor x_q based on the strategy specified in args. If @@ -106,6 +106,10 @@ def dequantize( f"Could not infer a quantization strategy from scale with {scale.ndim} " "dimmensions. Expected 0 or 2 dimmensions." ) + + if dtype is None: + dtype = scale.dtype + return _process_quantization( x=x_q, scale=scale, @@ -113,6 +117,7 @@ def dequantize( args=args, do_quantize=False, do_dequantize=True, + dtype=dtype, ) @@ -161,19 +166,8 @@ def _process_quantization( group_size = args.group_size if args.strategy == QuantizationStrategy.GROUP: - if do_dequantize and not do_quantize: - # if dequantizing a quantized type infer the output type from the scale - output = torch.zeros_like(x, dtype=scale.dtype) - else: - # use the dtype passed in as a kwarg if its specified, otherwise default - # to the input type - output_dtype = dtype if dtype is not None else x.dtype - if output_dtype is FP8_DTYPE: - # zeros_like doesn't support fp8 types directly, workaround - output = torch.zeros_like(x) - output = output.to(FP8_DTYPE) - else: - output = torch.zeros_like(x, dtype=output_dtype) + output_dtype = dtype if dtype is not None else x.dtype + output = torch.zeros_like(x).to(output_dtype) # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group @@ -183,7 +177,7 @@ def _process_quantization( while scale.ndim < 2: # pad scale and zero point dims for slicing scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) if zero_point is not None else None columns = x.shape[1] if columns >= group_size: @@ -196,7 +190,7 @@ def _process_quantization( # scale.shape should be [nchan, ndim] # sc.shape should be [nchan, 1] after unsqueeze sc = scale[:, i].view(-1, 1) - zp = zero_point[:, i].view(-1, 1) + zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None idx = i * group_size if do_quantize: @@ -351,9 +345,16 @@ def _quantize( def _dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: torch.Tensor = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - if is_float_quantization(x_q): - # can't perform arithmetic in fp8 types, need to convert first - return (x_q.to(scale.dtype) - zero_point.to(scale.dtype)) * scale - return (x_q - zero_point) * scale + + dequant_value = x_q + if zero_point is not None: + dequant_value -= zero_point + dequant_value = dequant_value.to(scale.dtype) * scale + + if dtype is not None: + dequant_value = dequant_value.to(dtype) + + return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4524b5b04..868bba7f5 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -20,10 +20,7 @@ from compressed_tensors.quantization.lifecycle.forward import ( wrap_module_forward_quantized, ) -from compressed_tensors.quantization.quant_args import ( - QuantizationArgs, - QuantizationType, -) +from compressed_tensors.quantization.quant_args import QuantizationArgs from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme from torch.nn import Module, Parameter @@ -94,15 +91,11 @@ def _initialize_scale_zero_point_observer( # initializes empty scale and zero point parameters for the module init_scale = Parameter( - torch.empty(0, dtype=torch.float16, device=device), requires_grad=False + torch.empty(0, dtype=module.weight.dtype, device=device), requires_grad=False ) module.register_parameter(f"{base_name}_scale", init_scale) - zp_dtype = ( - torch.int8 - if quantization_args.type is QuantizationType.INT - else module.weight.dtype - ) + zp_dtype = quantization_args.pytorch_dtype() init_zero_point = Parameter( torch.empty(0, device=device, dtype=zp_dtype), requires_grad=False ) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index da5c02e7b..ced2cb4b9 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -133,6 +133,17 @@ def validate_strategy(cls, value, values): return value + def pytorch_dtype(self) -> torch.dtype: + if self.type is QuantizationType.FLOAT: + return FP8_DTYPE + else: # QuantizationType.INT + if self.num_bits <= 8: + return torch.int8 + elif self.num_bits <= 16: + return torch.int16 + else: + return torch.int32 + def round_fp8(tensor: torch.Tensor, fp8_type: torch.dtype) -> torch.Tensor: """ From 6443bb86f831f0cb70d149f4f223041249eab0c5 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 21:21:15 +0000 Subject: [PATCH 11/19] cleanup --- .../quantization/lifecycle/forward.py | 36 +++++++------------ .../quantization/observers/helpers.py | 7 ++-- .../quantization/quant_args.py | 19 ++++++---- .../quantization/utils/helpers.py | 12 ------- tests/test_compressors/test_pack_quant.py | 4 +-- .../test_configs/test_bit_depths.py | 6 ++-- 6 files changed, 34 insertions(+), 50 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index c70d55ef7..3b9dfed3e 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -19,11 +19,9 @@ import torch from compressed_tensors.quantization.observers.helpers import calculate_range from compressed_tensors.quantization.quant_args import ( - FP8_DTYPE, QuantizationArgs, QuantizationStrategy, - QuantizationType, - round_fp8, + round_to_quantized_type, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -200,7 +198,7 @@ def _process_quantization( zp, q_min, q_max, - quantization_type=args.type, + args=args, dtype=dtype, ) if do_dequantize: @@ -219,7 +217,7 @@ def _process_quantization( zero_point, q_min, q_max, - quantization_type=args.type, + args, dtype=dtype, ) if do_dequantize: @@ -315,26 +313,18 @@ def _quantize( zero_point: torch.Tensor, q_min: torch.Tensor, q_max: torch.Tensor, - quantization_type: Optional[QuantizationType] = QuantizationType.INT, + args: QuantizationArgs, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - scaled = x / scale + zero_point - if quantization_type is QuantizationType.FLOAT: - # clamp first because cast isn't saturated - quantized_value = torch.clamp( - scaled, - q_min, - q_max, - ) - quantized_value = round_fp8(quantized_value, FP8_DTYPE) - else: - quantized_value = torch.clamp( - torch.round(scaled), - q_min, - q_max, - ) - + scaled = x / scale + zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, + q_min, + q_max, + ) + quantized_value = round_to_quantized_type(clamped_value, args) if dtype is not None: quantized_value = quantized_value.to(dtype) @@ -351,7 +341,7 @@ def _dequantize( dequant_value = x_q if zero_point is not None: - dequant_value -= zero_point + dequant_value = dequant_value - zero_point.to(scale.dtype) dequant_value = dequant_value.to(scale.dtype) * scale if dtype is not None: diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 2058da51b..f1236410e 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -19,7 +19,6 @@ FP8_DTYPE, QuantizationArgs, QuantizationType, - round_fp8, ) from torch import FloatTensor, IntTensor, Tensor @@ -53,7 +52,7 @@ def calculate_qparams( # set zero_points to correct types if quantization_args.type == QuantizationType.FLOAT: - zero_points = round_fp8(zero_points, FP8_DTYPE) + zero_points = zero_points.to(FP8_DTYPE) else: # QuantizationType.INT zero_points = zero_points.to(torch.int8) else: @@ -62,9 +61,7 @@ def calculate_qparams( if quantization_args.type == QuantizationType.FLOAT: zero_points = bit_min - (min_vals / scales) - zero_points = round_fp8( - torch.clamp(zero_points, bit_min, bit_max), FP8_DTYPE - ) + zero_points = torch.clamp(zero_points, bit_min, bit_max).to(FP8_DTYPE) else: # QuantizationType.INT zero_points = bit_min - torch.round(min_vals / scales) zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index ced2cb4b9..38ff26d71 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -24,7 +24,7 @@ "QuantizationType", "QuantizationStrategy", "QuantizationArgs", - "round_fp8", + "round_to_quantized_type", ] FP8_DTYPE = torch.float8_e4m3fn @@ -145,14 +145,21 @@ def pytorch_dtype(self) -> torch.dtype: return torch.int32 -def round_fp8(tensor: torch.Tensor, fp8_type: torch.dtype) -> torch.Tensor: +def round_to_quantized_type( + tensor: torch.Tensor, args: QuantizationArgs +) -> torch.Tensor: """ - Rounds each element of the input tensor to the nearest fp8 representation, + Rounds each element of the input tensor to the nearest quantized representation, keeping to original dtype :param tensor: tensor to round - :param fp8_type: fp8 dtype to round to - :return: tensor rounded to fp8 + :param args: QuantizationArgs to pull appropriate dtype from + :return: rounded tensor """ original_dtype = tensor.dtype - return tensor.to(fp8_type).to(original_dtype) + if args.type is QuantizationType.FLOAT: + rounded = tensor.to(FP8_DTYPE) + else: # QuantizationType.INT + rounded = torch.round(tensor) + + return rounded.to(original_dtype) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 404e1385e..074fccc9d 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -30,7 +30,6 @@ "calculate_compression_ratio", "get_torch_bit_depth", "can_quantize", - "is_float_quantization", ] _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -183,14 +182,3 @@ def calculate_compression_ratio(model: Module) -> float: total_uncompressed += uncompressed_bits * num_weights return total_uncompressed / total_compressed - - -def is_float_quantization(tensor: torch.Tensor) -> bool: - """ - :param tensor: tensor to check for quantization type - :return: True if a supported float quantization dtype, false otherwise - """ - if tensor.dtype is torch.float8_e4m3fn: - return True - - return False diff --git a/tests/test_compressors/test_pack_quant.py b/tests/test_compressors/test_pack_quant.py index b111d596b..46bf091e5 100644 --- a/tests/test_compressors/test_pack_quant.py +++ b/tests/test_compressors/test_pack_quant.py @@ -110,10 +110,10 @@ def test_reload_match(tmp_path): dense_state_dict = { "dummy.weight": torch.rand((511, 350)), "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), - "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32), + "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int8), "dummy2.weight": torch.rand((128, 280)), "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), - "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32), + "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8), } quant_config = get_dummy_quant_config() diff --git a/tests/test_quantization/test_configs/test_bit_depths.py b/tests/test_quantization/test_configs/test_bit_depths.py index 5b82da4a5..cf0d16fa6 100644 --- a/tests/test_quantization/test_configs/test_bit_depths.py +++ b/tests/test_quantization/test_configs/test_bit_depths.py @@ -116,9 +116,11 @@ def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): inputs = torch.randn(32, 64) model(inputs) - assert model.weight_zero_point.dtype == model.weight.dtype == torch.float32 + assert model.weight_zero_point.dtype == torch.float8_e4m3fn + model.weight_zero_point.data = model.weight_zero_point.to(model.weight.dtype) if input_symmetry is not None: - assert model.input_zero_point.dtype == inputs.dtype == torch.float32 + assert model.input_zero_point.dtype == torch.float8_e4m3fn + model.input_zero_point.data = model.input_zero_point.to(model.weight.dtype) assert model.input_zero_point >= min assert model.input_zero_point <= max From 3a42557760eef1fdbe6e69d135f3fc65f29811e5 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 21:26:34 +0000 Subject: [PATCH 12/19] clean up observer --- .../quantization/observers/helpers.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index f1236410e..330859cfd 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -43,28 +43,21 @@ def calculate_qparams( bit_min, bit_max = calculate_range(quantization_args, device) bit_range = bit_max - bit_min + zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) - - # set zero_points to correct types - if quantization_args.type == QuantizationType.FLOAT: - zero_points = zero_points.to(FP8_DTYPE) - else: # QuantizationType.INT - zero_points = zero_points.to(torch.int8) else: scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) - if quantization_args.type == QuantizationType.FLOAT: - zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max).to(FP8_DTYPE) - else: # QuantizationType.INT - zero_points = bit_min - torch.round(min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) + # match zero-points to quantized type + zero_points = zero_points.to(zp_dtype) return scales, zero_points From 3a4be1357cccd46241f5a6ce58b38e6d2944809f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 21:37:09 +0000 Subject: [PATCH 13/19] dtype fix --- src/compressed_tensors/quantization/lifecycle/apply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 36ff71a17..58d6f3fe9 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -202,6 +202,6 @@ def _load_quant_args_from_state_dict( if zp is not None: zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None) if zp_from_state is not None: # load the non-zero zero points - zp.data = zp_from_state.to(device).to(scale.dtype) + zp.data = zp_from_state.to(device).to(zp.dtype) else: # fill with zeros matching scale shape zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device) From bc98eee87cae131fd9d0acbae2244eeebcb872a6 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 29 May 2024 21:54:26 +0000 Subject: [PATCH 14/19] docstrings --- src/compressed_tensors/quantization/lifecycle/forward.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 3b9dfed3e..d75afd671 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -86,6 +86,7 @@ def dequantize( :param scale: scale tensor :param zero_point: zero point tensor :param args: quantization args used to quantize x_q + :param dtype: optional dtype to cast the dequantized output to :return: dequantized float tensor """ if args is None: @@ -198,7 +199,7 @@ def _process_quantization( zp, q_min, q_max, - args=args, + args, dtype=dtype, ) if do_dequantize: From 18846ec338185ea7806f68bc6ceac2802aaa0b36 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 14 Jun 2024 18:35:31 +0000 Subject: [PATCH 15/19] fixes after rebase --- .../quantization/lifecycle/initialize.py | 2 +- .../quantization/quant_args.py | 2 +- .../quantization/quant_scheme.py | 25 +++++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 4beb7e45e..6f1c457e9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -122,7 +122,7 @@ def _initialize_scale_zero_point_observer( zp_dtype = quantization_args.pytorch_dtype() init_zero_point = Parameter( - torch.empty(expected_shape, device=device, dtype=int), + torch.empty(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, ) module.register_parameter(f"{base_name}_zero_point", init_zero_point) diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 11c9bf33b..2eb3c1685 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -134,7 +134,7 @@ def validate_strategy(cls, value, values): return value def pytorch_dtype(self) -> torch.dtype: - if self.type is QuantizationType.FLOAT: + if self.type is QuantizationType.FLOAT.value: return FP8_DTYPE else: # QuantizationType.INT if self.num_bits <= 8: diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index d2786617c..a529751eb 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -15,7 +15,11 @@ from copy import deepcopy from typing import List, Optional -from compressed_tensors.quantization.quant_args import QuantizationArgs, QuantizationStrategy, QuantizationType +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from pydantic import BaseModel @@ -107,16 +111,15 @@ def is_preset_scheme(name: str) -> bool: return name.upper() in PRESET_SCHEMES -W8A8 = dict( - weights=QuantizationArgs(), input_activations=QuantizationArgs() -) +W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs()) -W4A16 = dict(weights=QuantizationArgs(num_bits=4, strategy=QuantizationStrategy.CHANNEL)) +W4A16 = dict( + weights=QuantizationArgs(num_bits=4, strategy=QuantizationStrategy.CHANNEL) +) -FP8 = dict(weights=QuantizationArgs(type=QuantizationType.FLOAT), input_activations=QuantizationArgs(type=QuantizationType.FLOAT)) +FP8 = dict( + weights=QuantizationArgs(type=QuantizationType.FLOAT), + input_activations=QuantizationArgs(type=QuantizationType.FLOAT), +) -PRESET_SCHEMES = { - "W8A8": W8A8, - "W4A16": W4A16, - "FP8": FP8 -} +PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8} From 7b08b9de2e8b1de4725204fdb62925ef33d13ba3 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 14 Jun 2024 19:42:38 +0000 Subject: [PATCH 16/19] test fixes --- .../quantization/observers/helpers.py | 4 +++- src/compressed_tensors/quantization/quant_args.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 4eebb97f5..4b1537bb1 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -78,7 +78,7 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: bit_range = 2**quantization_args.num_bits q_max = torch.tensor(bit_range / 2 - 1, device=device) q_min = torch.tensor(-bit_range / 2, device=device) - else: # QuantizationType.FLOAT + elif quantization_args.type == QuantizationType.FLOAT: if quantization_args.num_bits != 8: raise ValueError( "Floating point quantization is only supported for 8 bits," @@ -87,5 +87,7 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: fp_range_info = torch.finfo(FP8_DTYPE) q_max = torch.tensor(fp_range_info.max, device=device) q_min = torch.tensor(fp_range_info.min, device=device) + else: + raise ValueError(f"Invalid quantization type {quantization_args.type}") return q_min, q_max diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 2eb3c1685..6ff2b90d3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -134,15 +134,17 @@ def validate_strategy(cls, value, values): return value def pytorch_dtype(self) -> torch.dtype: - if self.type is QuantizationType.FLOAT.value: + if self.type == QuantizationType.FLOAT: return FP8_DTYPE - else: # QuantizationType.INT + elif self.type == QuantizationType.INT: if self.num_bits <= 8: return torch.int8 elif self.num_bits <= 16: return torch.int16 else: return torch.int32 + else: + raise ValueError(f"Invalid quantization type {self.type}") def round_to_quantized_type( @@ -157,9 +159,11 @@ def round_to_quantized_type( :return: rounded tensor """ original_dtype = tensor.dtype - if args.type is QuantizationType.FLOAT: + if args.type == QuantizationType.FLOAT: rounded = tensor.to(FP8_DTYPE) - else: # QuantizationType.INT + elif args.type == QuantizationType.INT: rounded = torch.round(tensor) + else: + raise ValueError(f"Invalid quantization type {args.type}") return rounded.to(original_dtype) From 755fee5cbf571a285d5dbddbad4f3df92f0d605c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 19 Jun 2024 17:47:53 +0000 Subject: [PATCH 17/19] style --- src/compressed_tensors/quantization/lifecycle/forward.py | 2 +- src/compressed_tensors/quantization/quant_scheme.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index e9da6a7a4..155616dd9 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -104,7 +104,7 @@ def dequantize( scale = scale.to(x_q.device) if x_q.device != zero_point.device: zero_point = zero_point.to(x_q.device) - + if args is None: if scale.ndim == 0 or scale.ndim == 1: args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 386d5968e..a78148697 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -115,7 +115,11 @@ def is_preset_scheme(name: str) -> bool: weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True) ) -W4A16 = dict(weights=QuantizationArgs(num_bits=4, symmetric=True)) +W4A16 = dict( + weights=QuantizationArgs( + num_bits=4, strategy=QuantizationStrategy.CHANNEL, symmetric=True + ) +) FP8 = dict( weights=QuantizationArgs(type=QuantizationType.FLOAT), From 6836c80a4f51ddbbad13a55f721288d33755f0cc Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 19 Jun 2024 17:57:02 +0000 Subject: [PATCH 18/19] get rid of broken segment --- src/compressed_tensors/quantization/quant_scheme.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a78148697..ff2f7eb1d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -111,14 +111,10 @@ def is_preset_scheme(name: str) -> bool: return name.upper() in PRESET_SCHEMES -W8A8 = dict( - weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True) -) +W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs()) W4A16 = dict( - weights=QuantizationArgs( - num_bits=4, strategy=QuantizationStrategy.CHANNEL, symmetric=True - ) + weights=QuantizationArgs(num_bits=4, strategy=QuantizationStrategy.CHANNEL) ) FP8 = dict( @@ -127,3 +123,5 @@ def is_preset_scheme(name: str) -> bool: ) PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8} + +PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8} From b6d2470e0f4f1c99f4389cdc843dd4aa5d039816 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 19 Jun 2024 17:58:53 +0000 Subject: [PATCH 19/19] fix broken code --- src/compressed_tensors/quantization/lifecycle/forward.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index 155616dd9..263796eb2 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -97,14 +97,6 @@ def dequantize( :param dtype: optional dtype to cast the dequantized output to :return: dequantized float tensor """ - # ensure all tensors are on the same device - # assumes that the target device is the input - # tensor's device - if x_q.device != scale.device: - scale = scale.to(x_q.device) - if x_q.device != zero_point.device: - zero_point = zero_point.to(x_q.device) - if args is None: if scale.ndim == 0 or scale.ndim == 1: args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)