diff --git a/src/compressed_tensors/compressors/__init__.py b/src/compressed_tensors/compressors/__init__.py index d3bb61f59..6cffc6d7e 100644 --- a/src/compressed_tensors/compressors/__init__.py +++ b/src/compressed_tensors/compressors/__init__.py @@ -17,8 +17,12 @@ from .base import Compressor from .dense import DenseCompressor from .helpers import load_compressed, save_compressed, save_compressed_model -from .int_quantized import IntQuantizationCompressor from .marlin_24 import Marlin24Compressor from .model_compressor import ModelCompressor, map_modules_to_quant_args +from .naive_quantized import ( + FloatQuantizationCompressor, + IntQuantizationCompressor, + QuantizationCompressor, +) from .pack_quantized import PackedQuantizationCompressor from .sparse_bitmask import BitmaskCompressor, BitmaskTensor diff --git a/src/compressed_tensors/compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressor.py index 269fe870e..e57dbff6d 100644 --- a/src/compressed_tensors/compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressor.py @@ -16,9 +16,12 @@ import logging import operator import os +import re from copy import deepcopy from typing import Any, Dict, Optional, Union +import torch +import transformers from compressed_tensors.base import ( COMPRESSION_CONFIG_NAME, QUANTIZATION_CONFIG_NAME, @@ -236,6 +239,11 @@ def compress( compressed_state_dict ) + # HACK: Override the dtype_byte_size function in transformers to + # support float8 types. Fix is posted upstream + # https://github.com/huggingface/transformers/pull/30488 + transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size + return compressed_state_dict def decompress(self, model_path: str, model: Module): @@ -313,3 +321,15 @@ def map_modules_to_quant_args(model: Module) -> Dict: quantized_modules_to_args[name] = submodule.quantization_scheme.weights return quantized_modules_to_args + + +# HACK: Override the dtype_byte_size function in transformers to support float8 types +# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488 +def new_dtype_byte_size(dtype): + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 diff --git a/src/compressed_tensors/compressors/int_quantized.py b/src/compressed_tensors/compressors/naive_quantized.py similarity index 85% rename from src/compressed_tensors/compressors/int_quantized.py rename to src/compressed_tensors/compressors/naive_quantized.py index 69bd4e63c..e3a9a42fd 100644 --- a/src/compressed_tensors/compressors/int_quantized.py +++ b/src/compressed_tensors/compressors/naive_quantized.py @@ -27,17 +27,21 @@ from tqdm import tqdm -__all__ = ["IntQuantizationCompressor"] +__all__ = [ + "QuantizationCompressor", + "IntQuantizationCompressor", + "FloatQuantizationCompressor", +] _LOGGER: logging.Logger = logging.getLogger(__name__) -@Compressor.register(name=CompressionFormat.int_quantized.value) -class IntQuantizationCompressor(Compressor): +@Compressor.register(name=CompressionFormat.naive_quantized.value) +class QuantizationCompressor(Compressor): """ - Integer compression for quantized models. Weight of each quantized layer is - converted from its original float type to the format specified by the layer's - quantization scheme. + Implements naive compression for quantized models. Weight of each + quantized layer is converted from its original float type to the closest Pytorch + type to the type specified by the layer's QuantizationArgs. """ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"] @@ -77,7 +81,7 @@ def compress( scale=scale, zero_point=zp, args=quant_args, - dtype=torch.int8, + dtype=quant_args.pytorch_dtype(), ) elif name.endswith("zero_point"): if torch.all(value == 0): @@ -114,13 +118,27 @@ def decompress( if "weight_scale" in weight_data: zero_point = weight_data.get("weight_zero_point", None) scale = weight_data["weight_scale"] - if zero_point is None: - # zero_point assumed to be 0 if not included in state_dict - zero_point = torch.zeros_like(scale) - decompressed = dequantize( x_q=weight_data["weight"], scale=scale, zero_point=zero_point, ) yield merge_names(weight_name, "weight"), decompressed + + +@Compressor.register(name=CompressionFormat.int_quantized.value) +class IntQuantizationCompressor(QuantizationCompressor): + """ + Alias for integer quantized models + """ + + pass + + +@Compressor.register(name=CompressionFormat.float_quantized.value) +class FloatQuantizationCompressor(QuantizationCompressor): + """ + Alias for fp quantized models + """ + + pass diff --git a/src/compressed_tensors/compressors/pack_quantized.py b/src/compressed_tensors/compressors/pack_quantized.py index 9d997204b..74b78132d 100644 --- a/src/compressed_tensors/compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/pack_quantized.py @@ -126,10 +126,6 @@ def decompress( if "weight_scale" in weight_data: zero_point = weight_data.get("weight_zero_point", None) scale = weight_data["weight_scale"] - if zero_point is None: - # zero_point assumed to be 0 if not included in state_dict - zero_point = torch.zeros_like(scale) - weight = weight_data["weight_packed"] original_shape = torch.Size(weight_data["weight_shape"]) unpacked = unpack_4bit_ints(weight, original_shape) diff --git a/src/compressed_tensors/config/base.py b/src/compressed_tensors/config/base.py index b9ecab8ef..6f7c2f53d 100644 --- a/src/compressed_tensors/config/base.py +++ b/src/compressed_tensors/config/base.py @@ -26,6 +26,8 @@ class CompressionFormat(Enum): dense = "dense" sparse_bitmask = "sparse-bitmask" int_quantized = "int-quantized" + float_quantized = "float-quantized" + naive_quantized = "naive-quantized" pack_quantized = "pack-quantized" marlin_24 = "marlin-24" diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 8806861c3..f31bdc78b 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -215,15 +215,11 @@ def _load_quant_args_from_state_dict( scale = getattr(module, scale_name, None) zp = getattr(module, zp_name, None) if scale is not None: - state_dict_scale = state_dict.get(f"{module_name}.{scale_name}") - if state_dict_scale is not None: - scale.data = state_dict_scale.to(device).to(scale.dtype) - else: - scale.data = scale.data.to(device) - + state_dict_scale = state_dict[f"{module_name}.{scale_name}"] + scale.data = state_dict_scale.to(device).to(scale.dtype) if zp is not None: zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None) if zp_from_state is not None: # load the non-zero zero points - zp.data = state_dict[f"{module_name}.{zp_name}"].to(device) + zp.data = zp_from_state.to(device).to(zp.dtype) else: # fill with zeros matching scale shape - zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device) + zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device) diff --git a/src/compressed_tensors/quantization/lifecycle/forward.py b/src/compressed_tensors/quantization/lifecycle/forward.py index b0b952c03..263796eb2 100644 --- a/src/compressed_tensors/quantization/lifecycle/forward.py +++ b/src/compressed_tensors/quantization/lifecycle/forward.py @@ -17,9 +17,11 @@ from typing import Optional import torch +from compressed_tensors.quantization.observers.helpers import calculate_range from compressed_tensors.quantization.quant_args import ( QuantizationArgs, QuantizationStrategy, + round_to_quantized_type, ) from compressed_tensors.quantization.quant_config import QuantizationStatus from compressed_tensors.quantization.quant_scheme import QuantizationScheme @@ -80,8 +82,9 @@ def quantize( def dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: torch.Tensor = None, args: QuantizationArgs = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: """ Dequantize a quantized input tensor x_q based on the strategy specified in args. If @@ -91,16 +94,9 @@ def dequantize( :param scale: scale tensor :param zero_point: zero point tensor :param args: quantization args used to quantize x_q + :param dtype: optional dtype to cast the dequantized output to :return: dequantized float tensor """ - # ensure all tensors are on the same device - # assumes that the target device is the input - # tensor's device - if x_q.device != scale.device: - scale = scale.to(x_q.device) - if x_q.device != zero_point.device: - zero_point = zero_point.to(x_q.device) - if args is None: if scale.ndim == 0 or scale.ndim == 1: args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR) @@ -115,8 +111,12 @@ def dequantize( else: raise ValueError( f"Could not infer a quantization strategy from scale with {scale.ndim} " - "dimmensions. Expected 0-2 dimmensions." + "dimmensions. Expected 0 or 2 dimmensions." ) + + if dtype is None: + dtype = scale.dtype + return _process_quantization( x=x_q, scale=scale, @@ -124,6 +124,7 @@ def dequantize( args=args, do_quantize=False, do_dequantize=True, + dtype=dtype, ) @@ -167,19 +168,13 @@ def _process_quantization( do_quantize: bool = True, do_dequantize: bool = True, ) -> torch.Tensor: - bit_range = 2**args.num_bits - q_max = torch.tensor(bit_range / 2 - 1, device=x.device) - q_min = torch.tensor(-bit_range / 2, device=x.device) + + q_min, q_max = calculate_range(args, x.device) group_size = args.group_size if args.strategy == QuantizationStrategy.GROUP: - - if do_dequantize and not do_quantize: - # if dequantizing a quantized type infer the output type from the scale - output = torch.zeros_like(x, dtype=scale.dtype) - else: - output_dtype = dtype if dtype is not None else x.dtype - output = torch.zeros_like(x, dtype=output_dtype) + output_dtype = dtype if dtype is not None else x.dtype + output = torch.zeros_like(x).to(output_dtype) # TODO: vectorize the for loop # TODO: fix genetric assumption about the tensor size for computing group @@ -189,7 +184,7 @@ def _process_quantization( while scale.ndim < 2: # pad scale and zero point dims for slicing scale = scale.unsqueeze(1) - zero_point = zero_point.unsqueeze(1) + zero_point = zero_point.unsqueeze(1) if zero_point is not None else None columns = x.shape[1] if columns >= group_size: @@ -202,12 +197,18 @@ def _process_quantization( # scale.shape should be [nchan, ndim] # sc.shape should be [nchan, 1] after unsqueeze sc = scale[:, i].view(-1, 1) - zp = zero_point[:, i].view(-1, 1) + zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None idx = i * group_size if do_quantize: output[:, idx : (idx + group_size)] = _quantize( - x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype + x[:, idx : (idx + group_size)], + sc, + zp, + q_min, + q_max, + args, + dtype=dtype, ) if do_dequantize: input = ( @@ -219,7 +220,15 @@ def _process_quantization( else: # covers channel, token and tensor strategies if do_quantize: - output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype) + output = _quantize( + x, + scale, + zero_point, + q_min, + q_max, + args, + dtype=dtype, + ) if do_dequantize: output = _dequantize(output if do_quantize else x, scale, zero_point) @@ -313,14 +322,18 @@ def _quantize( zero_point: torch.Tensor, q_min: torch.Tensor, q_max: torch.Tensor, + args: QuantizationArgs, dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - quantized_value = torch.clamp( - torch.round(x / scale + zero_point), + + scaled = x / scale + zero_point.to(x.dtype) + # clamp first because cast isn't guaranteed to be saturated (ie for fp8) + clamped_value = torch.clamp( + scaled, q_min, q_max, ) - + quantized_value = round_to_quantized_type(clamped_value, args) if dtype is not None: quantized_value = quantized_value.to(dtype) @@ -331,6 +344,16 @@ def _quantize( def _dequantize( x_q: torch.Tensor, scale: torch.Tensor, - zero_point: torch.Tensor, + zero_point: torch.Tensor = None, + dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: - return (x_q - zero_point) * scale + + dequant_value = x_q + if zero_point is not None: + dequant_value = dequant_value - zero_point.to(scale.dtype) + dequant_value = dequant_value.to(scale.dtype) * scale + + if dtype is not None: + dequant_value = dequant_value.to(dtype) + + return dequant_value diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 3b68f05a2..6f1c457e9 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -120,8 +120,9 @@ def _initialize_scale_zero_point_observer( ) module.register_parameter(f"{base_name}_scale", init_scale) + zp_dtype = quantization_args.pytorch_dtype() init_zero_point = Parameter( - torch.empty(expected_shape, device=device, dtype=int), + torch.empty(expected_shape, device=device, dtype=zp_dtype), requires_grad=False, ) module.register_parameter(f"{base_name}_zero_point", init_zero_point) diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 1454d7693..4b1537bb1 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -15,11 +15,15 @@ from typing import Tuple import torch -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + FP8_DTYPE, + QuantizationArgs, + QuantizationType, +) from torch import FloatTensor, IntTensor, Tensor -__all__ = ["calculate_qparams"] +__all__ = ["calculate_qparams", "calculate_range"] def calculate_qparams( @@ -37,22 +41,53 @@ def calculate_qparams( max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) device = min_vals.device - bit_range = 2**quantization_args.num_bits - 1 - bit_min = -(bit_range + 1) / 2 - bit_max = bit_min + bit_range + bit_min, bit_max = calculate_range(quantization_args, device) + bit_range = bit_max - bit_min + zp_dtype = quantization_args.pytorch_dtype() + if quantization_args.symmetric: - max_val_pos = torch.max(-min_vals, max_vals) + max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8) + zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) - zero_points = bit_min - torch.round(min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8) + zero_points = bit_min - (min_vals / scales) + zero_points = torch.clamp(zero_points, bit_min, bit_max) + + # match zero-points to quantized type + zero_points = zero_points.to(zp_dtype) if scales.ndim == 0: scales = scales.reshape(1) zero_points = zero_points.reshape(1) return scales, zero_points + + +def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple: + """ + Calculated the effective quantization range for the given Quantization Args + + :param quantization_args: quantization args to get range of + :param device: device to store the range to + :return: tuple endpoints for the given quantization range + """ + if quantization_args.type == QuantizationType.INT: + bit_range = 2**quantization_args.num_bits + q_max = torch.tensor(bit_range / 2 - 1, device=device) + q_min = torch.tensor(-bit_range / 2, device=device) + elif quantization_args.type == QuantizationType.FLOAT: + if quantization_args.num_bits != 8: + raise ValueError( + "Floating point quantization is only supported for 8 bits," + f"got {quantization_args.num_bits}" + ) + fp_range_info = torch.finfo(FP8_DTYPE) + q_max = torch.tensor(fp_range_info.max, device=device) + q_min = torch.tensor(fp_range_info.min, device=device) + else: + raise ValueError(f"Invalid quantization type {quantization_args.type}") + + return q_min, q_max diff --git a/src/compressed_tensors/quantization/quant_args.py b/src/compressed_tensors/quantization/quant_args.py index 0c84fdda0..6ff2b90d3 100644 --- a/src/compressed_tensors/quantization/quant_args.py +++ b/src/compressed_tensors/quantization/quant_args.py @@ -15,10 +15,19 @@ from enum import Enum from typing import Any, Dict, Optional +import torch from pydantic import BaseModel, Field, validator -__all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"] +__all__ = [ + "FP8_DTYPE", + "QuantizationType", + "QuantizationStrategy", + "QuantizationArgs", + "round_to_quantized_type", +] + +FP8_DTYPE = torch.float8_e4m3fn class QuantizationType(str, Enum): @@ -123,3 +132,38 @@ def validate_strategy(cls, value, values): return QuantizationStrategy.TENSOR return value + + def pytorch_dtype(self) -> torch.dtype: + if self.type == QuantizationType.FLOAT: + return FP8_DTYPE + elif self.type == QuantizationType.INT: + if self.num_bits <= 8: + return torch.int8 + elif self.num_bits <= 16: + return torch.int16 + else: + return torch.int32 + else: + raise ValueError(f"Invalid quantization type {self.type}") + + +def round_to_quantized_type( + tensor: torch.Tensor, args: QuantizationArgs +) -> torch.Tensor: + """ + Rounds each element of the input tensor to the nearest quantized representation, + keeping to original dtype + + :param tensor: tensor to round + :param args: QuantizationArgs to pull appropriate dtype from + :return: rounded tensor + """ + original_dtype = tensor.dtype + if args.type == QuantizationType.FLOAT: + rounded = tensor.to(FP8_DTYPE) + elif args.type == QuantizationType.INT: + rounded = torch.round(tensor) + else: + raise ValueError(f"Invalid quantization type {args.type}") + + return rounded.to(original_dtype) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 2b228acb8..ff2f7eb1d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -15,7 +15,11 @@ from copy import deepcopy from typing import List, Optional -from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from pydantic import BaseModel @@ -107,13 +111,17 @@ def is_preset_scheme(name: str) -> bool: return name.upper() in PRESET_SCHEMES -W8A8 = dict( - weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True) +W8A8 = dict(weights=QuantizationArgs(), input_activations=QuantizationArgs()) + +W4A16 = dict( + weights=QuantizationArgs(num_bits=4, strategy=QuantizationStrategy.CHANNEL) +) + +FP8 = dict( + weights=QuantizationArgs(type=QuantizationType.FLOAT), + input_activations=QuantizationArgs(type=QuantizationType.FLOAT), ) -W4A16 = dict(weights=QuantizationArgs(num_bits=4, symmetric=True)) +PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8} -PRESET_SCHEMES = { - "W8A8": W8A8, - "W4A16": W4A16, -} +PRESET_SCHEMES = {"W8A8": W8A8, "W4A16": W4A16, "FP8": FP8} diff --git a/tests/test_compressors/test_fp8_quant.py b/tests/test_compressors/test_fp8_quant.py new file mode 100644 index 000000000..eb21b575c --- /dev/null +++ b/tests/test_compressors/test_fp8_quant.py @@ -0,0 +1,137 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +from collections import OrderedDict + +import pytest +import torch +from compressed_tensors import FloatQuantizationCompressor +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationConfig, + QuantizationScheme, + QuantizationStatus, + QuantizationStrategy, + apply_quantization_config, + apply_quantization_status, +) +from compressed_tensors.quantization.lifecycle.forward import fake_quantize +from safetensors.torch import save_file +from torch.nn.modules import Linear, Sequential + + +def get_dummy_quant_config(strategy, group_size=None): + config_groups = { + "group_1": QuantizationScheme( + targets=["Linear"], + weights=QuantizationArgs( + strategy=strategy, type="float", group_size=group_size + ), + ), + } + ignore = ["lm_head"] + quant_config = QuantizationConfig( + config_groups=config_groups, + ignore=ignore, + ) + + return quant_config + + +@pytest.mark.parametrize( + "strategy,group_size,sc,zp", + [ + [QuantizationStrategy.TENSOR, None, 0.01, 0], + [ + QuantizationStrategy.GROUP, + 128, + torch.rand((512, 8, 1)) * 0.01, + torch.zeros((512, 8, 1), dtype=torch.int8), + ], + [ + QuantizationStrategy.CHANNEL, + 128, + torch.rand((512, 1)) * 0.01, + torch.zeros((512, 1), dtype=torch.int8), + ], + ], +) +def test_quant_format(strategy, group_size, sc, zp): + dense_state_dict = { + "dummy.weight": torch.rand((512, 1024)), + "dummy.weight_scale": torch.tensor(sc, dtype=torch.float32), + "dummy.weight_zero_point": torch.tensor(zp, dtype=torch.float32), + } + quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) + + compressor = FloatQuantizationCompressor(config=quant_config) + quantized_modules_to_args = {"dummy": quant_config.config_groups["group_1"].weights} + compressed_state_dict = compressor.compress( + dense_state_dict, model_quant_args=quantized_modules_to_args + ) + + # state_dict params should be the same, minus the zero_point if symmetric + assert len(dense_state_dict) == len(compressed_state_dict) + 1 + + # check compressed to int8 + assert compressed_state_dict["dummy.weight"].dtype == torch.float8_e4m3fn + assert compressed_state_dict["dummy.weight_scale"].dtype == torch.float32 + + +@pytest.mark.parametrize( + "strategy,group_size", + [ + [QuantizationStrategy.TENSOR, None], + [QuantizationStrategy.CHANNEL, None], + ], +) +def test_reload_match(strategy, group_size, tmp_path): + model = Sequential( + OrderedDict( + [ + ("dummy", Linear(512, 1024, bias=None)), + ] + ) + ) + quant_config = get_dummy_quant_config(strategy=strategy, group_size=group_size) + apply_quantization_config(model, quant_config) + apply_quantization_status(model, QuantizationStatus.CALIBRATION) + + for _ in range(16): + inputs = torch.rand((512, 512)) + _ = model(inputs) + + compressor = FloatQuantizationCompressor(config=quant_config) + quantized_modules_to_args = { + "dummy": quant_config.config_groups["group_1"].weights, + } + compressed_state_dict = compressor.compress( + model.state_dict(), model_quant_args=quantized_modules_to_args + ) + save_file(compressed_state_dict, tmp_path / "model.safetensors") + reconstructed_dense_gen = compressor.decompress(tmp_path) + reconstructed_dense = {} + for name, value in reconstructed_dense_gen: + reconstructed_dense[name] = value + + fake_quant_dummy = fake_quantize( + model.dummy.weight, + scale=model.dummy.weight_scale, + zero_point=model.dummy.weight_zero_point, + args=quantized_modules_to_args["dummy"], + ) + assert torch.equal(fake_quant_dummy, reconstructed_dense["dummy.weight"]) + + shutil.rmtree(tmp_path) diff --git a/tests/test_compressors/test_pack_quant.py b/tests/test_compressors/test_pack_quant.py index b111d596b..46bf091e5 100644 --- a/tests/test_compressors/test_pack_quant.py +++ b/tests/test_compressors/test_pack_quant.py @@ -110,10 +110,10 @@ def test_reload_match(tmp_path): dense_state_dict = { "dummy.weight": torch.rand((511, 350)), "dummy.weight_scale": torch.tensor(0.01, dtype=torch.float32), - "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int32), + "dummy.weight_zero_point": torch.tensor(0, dtype=torch.int8), "dummy2.weight": torch.rand((128, 280)), "dummy2.weight_scale": torch.tensor(0.02, dtype=torch.float32), - "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int32), + "dummy2.weight_zero_point": torch.tensor(15, dtype=torch.int8), } quant_config = get_dummy_quant_config() diff --git a/tests/test_quantization/test_configs/test_bit_depths.py b/tests/test_quantization/test_configs/test_bit_depths.py index 225509047..cf0d16fa6 100644 --- a/tests/test_quantization/test_configs/test_bit_depths.py +++ b/tests/test_quantization/test_configs/test_bit_depths.py @@ -25,10 +25,15 @@ from torch.nn import Linear -def create_config(bit_depth, input_symmetry, weight_symmetry): - weights = QuantizationArgs(num_bits=bit_depth, symmetric=weight_symmetry) +def create_config(bit_depth, quant_type, input_symmetry, weight_symmetry): + print(quant_type) + weights = QuantizationArgs( + num_bits=bit_depth, type=quant_type, symmetric=weight_symmetry + ) if input_symmetry is not None: - inputs = QuantizationArgs(num_bits=bit_depth, symmetric=input_symmetry) + inputs = QuantizationArgs( + num_bits=bit_depth, type=quant_type, symmetric=input_symmetry + ) else: inputs = None @@ -45,11 +50,12 @@ def create_config(bit_depth, input_symmetry, weight_symmetry): @torch.no_grad @pytest.mark.parametrize("bit_depth", [4, 8]) +@pytest.mark.parametrize("quant_type", ["int"]) @pytest.mark.parametrize("input_symmetry", [True, False, None]) @pytest.mark.parametrize("weight_symmetry", [True, False]) -def test_bit_depths(bit_depth, input_symmetry, weight_symmetry): +def test_bit_depths(bit_depth, quant_type, input_symmetry, weight_symmetry): model = Linear(64, 64) - quant_config = create_config(bit_depth, input_symmetry, weight_symmetry) + quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) apply_quantization_config(model, quant_config) min = -1 * int(2**bit_depth / 2) @@ -92,3 +98,60 @@ def test_bit_depths(bit_depth, input_symmetry, weight_symmetry): ) assert not torch.any(quantized_weight < min).item() assert not torch.any(quantized_weight > max).item() + + +@torch.no_grad +@pytest.mark.parametrize("bit_depth", [8]) +@pytest.mark.parametrize("quant_type", ["float"]) +@pytest.mark.parametrize("input_symmetry", [True, False, None]) +@pytest.mark.parametrize("weight_symmetry", [True, False]) +def test_fp8(bit_depth, quant_type, input_symmetry, weight_symmetry): + model = Linear(64, 64) + quant_config = create_config(bit_depth, quant_type, input_symmetry, weight_symmetry) + apply_quantization_config(model, quant_config) + + dtype_info = torch.finfo(torch.float8_e4m3fn) + min = dtype_info.min + max = dtype_info.max + + inputs = torch.randn(32, 64) + model(inputs) + assert model.weight_zero_point.dtype == torch.float8_e4m3fn + model.weight_zero_point.data = model.weight_zero_point.to(model.weight.dtype) + if input_symmetry is not None: + assert model.input_zero_point.dtype == torch.float8_e4m3fn + model.input_zero_point.data = model.input_zero_point.to(model.weight.dtype) + assert model.input_zero_point >= min + assert model.input_zero_point <= max + + input_max = torch.max(inputs) + input_min = torch.min(inputs) + diff_from_max = abs( + abs(model.input_scale * (max - model.input_zero_point)) - abs(input_max) + ) + diff_from_min = abs( + abs(model.input_scale * abs(min - model.input_zero_point)) - abs(input_min) + ) + assert diff_from_max < model.input_scale or diff_from_min < model.input_scale + + assert model.weight_zero_point >= min + assert model.weight_zero_point <= max + + weight_max = torch.max(model.weight) + weight_min = torch.min(model.weight) + diff_from_max = abs( + abs(model.weight_scale * (max - model.weight_zero_point)) - abs(weight_max) + ) + diff_from_min = abs( + abs(model.weight_scale * abs(min - model.weight_zero_point)) - abs(weight_min) + ) + assert diff_from_max < model.weight_scale or diff_from_min < model.weight_scale + + quantized_weight = fake_quantize( + model.weight, + model.weight_scale, + model.weight_zero_point, + model.quantization_scheme.weights, + ) + assert not torch.any(quantized_weight < min).item() + assert not torch.any(quantized_weight > max).item()