From 877d333196c09ea0a40bb2099df042e342a45859 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 7 Mar 2025 01:52:40 +0000 Subject: [PATCH 1/5] Some fixes for AWQ --- src/compressed_tensors/quantization/quant_scheme.py | 13 +++++++++++++ .../quantization/utils/helpers.py | 9 +++++---- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 36b886044..460858a17 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -199,6 +199,18 @@ def is_preset_scheme(name: str) -> bool: ), ) +# AWQ quantization +AWQ = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + symmetric=False, + dynamic=False, + group_size=128, + ), +) + PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, @@ -212,4 +224,5 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, + "AWQ": AWQ, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 9f65ee330..ad7237715 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -64,8 +64,7 @@ def calculate_qparams( :param quantization_args: settings to quantization :return: tuple of the calculated scale(s) and zero point(s) """ - min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) - max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) + device = min_vals.device bit_min, bit_max = calculate_range(quantization_args, device) @@ -73,15 +72,17 @@ def calculate_qparams( zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) - scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) + scales = torch.clamp(scales, min=1e-5) zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(zero_points, bit_min, bit_max) + zero_points = torch.clamp(torch.round(zero_points), bit_min, bit_max) # match zero-points to quantized type zero_points = zero_points.to(zp_dtype) From a129ae8bbd6662086b6d810de4e408d0d78b5ff6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Mar 2025 20:18:51 +0000 Subject: [PATCH 2/5] revert clamp to 1e-5 Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index ad7237715..a52f0340a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -80,7 +80,7 @@ def calculate_qparams( zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: scales = (max_vals - min_vals) / float(bit_range) - scales = torch.clamp(scales, min=1e-5) + scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = bit_min - (min_vals / scales) zero_points = torch.clamp(torch.round(zero_points), bit_min, bit_max) From 571265811edb9fa5ba7e0eceeb23f417abac5b3d Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 11 Mar 2025 20:24:16 +0000 Subject: [PATCH 3/5] rename awq quant preset to W4A16_ASYM Signed-off-by: Brian Dellabetta --- .../quantization/quant_scheme.py | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index 460858a17..9fcc0d55d 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -142,6 +142,18 @@ def is_preset_scheme(name: str) -> bool: ), ) +# 4 bit integer weights only asymmetric quantization +W4A16_ASYM = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.INT, + strategy=QuantizationStrategy.GROUP, + group_size=128, + symmetric=False, + dynamic=False, + ), +) + # 4 bit integer weights and 8 bit activations quantization INT8_W4A8 = dict( weights=QuantizationArgs( @@ -199,24 +211,13 @@ def is_preset_scheme(name: str) -> bool: ), ) -# AWQ quantization -AWQ = dict( - weights=QuantizationArgs( - num_bits=4, - type=QuantizationType.INT, - strategy=QuantizationStrategy.GROUP, - symmetric=False, - dynamic=False, - group_size=128, - ), -) - PRESET_SCHEMES = { # Unquantized (no-op) "UNQUANTIZED": UNQUANTIZED, # Integer weight only schemes "W8A16": W8A16, "W4A16": W4A16, + "W4A16_ASYM": W4A16_ASYM, # Integer weight and activation schemes "W8A8": INT8_W8A8, "INT8": INT8_W8A8, # alias for W8A8 @@ -224,5 +225,4 @@ def is_preset_scheme(name: str) -> bool: # Float weight and activation schemes "FP8": FP8, "FP8_DYNAMIC": FP8_DYNAMIC, - "AWQ": AWQ, } From 399a1d059c7f89ab31ddecad2ffb4b902a603778 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Mar 2025 20:09:40 +0000 Subject: [PATCH 4/5] revert changes to min_vals/max_vals Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/utils/helpers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index a52f0340a..55a16f443 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -64,6 +64,10 @@ def calculate_qparams( :param quantization_args: settings to quantization :return: tuple of the calculated scale(s) and zero point(s) """ + # based on the implementations for consuming quantized values, + # 0.0 must always be representable within the quantized range + min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) + max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) device = min_vals.device @@ -72,8 +76,6 @@ def calculate_qparams( zp_dtype = quantization_args.pytorch_dtype() if quantization_args.symmetric: - min_vals = torch.min(min_vals, torch.zeros_like(min_vals)) - max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) scales = max_val_pos / (float(bit_range) / 2) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) @@ -97,7 +99,7 @@ def calculate_qparams( def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs): """ Returns the computed scales and zero points for dynamic activation - qunatization. + quantization. :param value: tensor to calculate quantization parameters for :param args: quantization args From f4769706ef6927c83fc92538cacff661806f88a6 Mon Sep 17 00:00:00 2001 From: Brian Dellabetta Date: Tue, 25 Mar 2025 20:38:37 +0000 Subject: [PATCH 5/5] only round if casting to int type Signed-off-by: Brian Dellabetta --- src/compressed_tensors/quantization/utils/helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 55a16f443..d7e6d5f81 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -84,9 +84,12 @@ def calculate_qparams( scales = (max_vals - min_vals) / float(bit_range) scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps) zero_points = bit_min - (min_vals / scales) - zero_points = torch.clamp(torch.round(zero_points), bit_min, bit_max) + zero_points = torch.clamp(zero_points, bit_min, bit_max) # match zero-points to quantized type + # if casting to int, use round instead of truncate + if quantization_args.type == QuantizationType.INT: + zero_points = torch.round(zero_points) zero_points = zero_points.to(zp_dtype) if scales.ndim == 0: