Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HQQ support #605

Merged
merged 21 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import unittest
import torch
from torchao.prototype.hqq.core import HQQQuantizer
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
ZeroPointDomain,
PlainAQTLayout,
PlainLayoutType,
TensorCoreTiledAQTLayout,
TensorCoreTiledLayoutType,
MappingType,
)

torch.random.manual_seed(100)

#Parameters
device = 'cuda:0'
compute_dtype = torch.bfloat16
group_size = 64
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size) #axis=1
preserve_zero = False
zero_point_domain = ZeroPointDomain.FLOAT
zero_point_dtype = compute_dtype
inner_k_tiles = 8


in_features, out_features = 4096, 11800
linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device)
x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20.
y_ref = linear_layer(x)
W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype)

def _eval_hqq(nbits, W, y_ref, layout_type):
q_tensor_hqq = AffineQuantizedTensor.from_float(
input_float=W,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int32 if isinstance(layout_type, TensorCoreTiledLayoutType) else torch.uint8,
quant_min=0,
quant_max=2**nbits - 1,
zero_point_domain=zero_point_domain,
preserve_zero=preserve_zero,
layout_type=layout_type,
use_hqq=True,
)

quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device)
del quant_linear_layer.weight
quant_linear_layer.weight = q_tensor_hqq
dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item()
dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item()

return dequantize_error, dot_product_error

class TestHQQ(unittest.TestCase):
def test_hqq_plain_8bit(self):
dequantize_error, dot_product_error = _eval_hqq(8, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 5e-5)
self.assertTrue(dot_product_error < 0.00013)

def test_hqq_plain_7bit(self):
dequantize_error, dot_product_error = _eval_hqq(7, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 6e-05)
self.assertTrue(dot_product_error < 0.000193)

def test_hqq_plain_6bit(self):
dequantize_error, dot_product_error = _eval_hqq(6, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 0.0001131)
self.assertTrue(dot_product_error < 0.000353)

def test_hqq_plain_5bit(self):
dequantize_error, dot_product_error = _eval_hqq(5, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 0.00023)
self.assertTrue(dot_product_error < 0.000704)

def test_hqq_plain_4bit(self):
dequantize_error, dot_product_error = _eval_hqq(4, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 0.000487)
self.assertTrue(dot_product_error < 0.001472)

def test_hqq_tensorcore_4bit(self):
dequantize_error, dot_product_error = _eval_hqq(4, W, y_ref, TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles))
self.assertTrue(dequantize_error < 0.000487)
self.assertTrue(dot_product_error < 0.00147)

def test_hqq_plain_3bit(self):
dequantize_error, dot_product_error = _eval_hqq(3, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 0.00101)
self.assertTrue(dot_product_error < 0.003047)

def test_hqq_plain_2bit(self):
dequantize_error, dot_product_error = _eval_hqq(2, W, y_ref, PlainLayoutType())
self.assertTrue(dequantize_error < 0.002366)
self.assertTrue(dot_product_error < 0.007255)

if __name__ == "__main__":
unittest.main()
23 changes: 18 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from typing import Dict, Callable, Any, Tuple, Optional
from collections import defaultdict
import functools
import math
from torchao.quantization.quant_primitives import (
choose_qparams_affine,
quantize_affine,
dequantize_affine,
ZeroPointDomain,
MappingType,
int_scaled_matmul,
quantize_affine_hqq,
)
from torchao.quantization.utils import (
pack_tinygemm_scales_and_zeros,
Expand All @@ -25,6 +27,7 @@
PlainLayoutType,
is_device,
)

from dataclasses import dataclass
from torchao.utils import TORCH_VERSION_AFTER_2_5

Expand Down Expand Up @@ -75,7 +78,6 @@ def _get_to_kwargs(self, *args, **kwargs):
##############################
# Tensor Subclass Definition #
##############################

class AffineQuantizedTensor(torch.Tensor):
"""
Affine quantized tensor subclass. Affine quantization means we quantize the floating point tensor with an affine transformation:
Expand Down Expand Up @@ -194,14 +196,25 @@ def from_float(
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
use_hqq: bool = False,
mobicham marked this conversation as resolved.
Show resolved Hide resolved
):
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)
if(use_hqq):
assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization."
nbits = int(math.log2(quant_max + 1))
axis = 1 if (block_size[0]==1) else 0
group_size = max(block_size)
compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype
device = input_float.device
int_data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False)

else:
input_float = layout_type.pre_process(input_float)
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
int_data = layout_type.post_process(int_data)

layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
layout_tensor = layout_tensor_ctr(int_data, scale, zero_point, layout_type)
return cls(
Expand Down
184 changes: 184 additions & 0 deletions torchao/prototype/hqq/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import torch
import math
from torch import Tensor, float16, float32
from typing import Union


# Shrinking operator (proximal operator for the lp norm)
def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)
)


# Proximal solver || W - dequantize(quantize(W))||_p^p
@torch.inference_mode()
def optimize_weights_proximal_legacy(
tensor: Tensor,
scale: Tensor,
zero: Tensor,
min_max: list,
axis: int = 0,
dtype: Union[torch.dtype, None] = None,
device: Union[str, None] = None,
verbose: bool = False,
opt_params: dict = {
"lp_norm": 0.7,
"beta": 1e1,
"kappa": 1.01,
"iters": 20,
"early_stop": True,
},
) -> tuple:
lp_norm, beta, kappa, iters, early_stop = (
opt_params["lp_norm"],
opt_params["beta"],
opt_params["kappa"],
opt_params["iters"],
opt_params["early_stop"],
)

device = tensor.device if (device is None) else torch.device(device)

if dtype is None:
dtype = float16 if (device.type == "cuda") else float32

W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)

best_error = 1e4
for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa

current_error = float(torch.abs(W_f - W_r).mean())
if verbose:
print("Iter " + str(i + 1), " | Error: " + str(current_error))
if early_stop:
if current_error < best_error:
best_error = current_error
else:
break

scale = scale.to(tensor.device)
zero = zero.to(tensor.device)
del W_f, W_q, W_r, W_e
torch.cuda.empty_cache()

W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
return W_q, scale, zero


# Default: fast with early stopping
optimize_weights_proximal = optimize_weights_proximal_legacy


# Mainly used to check if the group-size is divisible by numel()
def is_divisible(val1: int, val2: int) -> bool:
return int(val2 * math.ceil(val1 / val2)) == val1


# Converts hqq format W_dequant = (W_q - zero)*scale into affinequantized format: (W_q - mid_point)*scale_ao + zero_ao
def convert_to_affinequantized_format(W_q, scale, zero, nbits, shape):
mobicham marked this conversation as resolved.
Show resolved Hide resolved
quant_min = 0
quant_max = 2**nbits - 1
mid_point = (quant_max + quant_min + 1) / 2
zero_ao = ((mid_point - zero.float()) * scale.float()).to(zero.dtype)
scale_ao = scale
W_q_ao = W_q.view(shape)
return W_q_ao, scale_ao, zero_ao


# Main HQQ Quantizer - simplified, no bitpacking.
class HQQQuantizer:
optimize_weights = optimize_weights_proximal

@classmethod
def quantize(
cls,
tensor: Tensor,
nbits: float = 4,
group_size: int = 64,
optimize: bool = True,
axis: int = 1,
compute_dtype: torch.dtype = float16,
device: str = "cuda",
verbose: bool = False, # to check the optimizer error
raw_output: bool = False, # If True, it will return the quant params in hqq lib format
) -> tuple:
assert axis in [0, 1], "axis should be either 0 or 1"
if group_size is not None:
assert is_divisible(tensor.numel(), group_size), (
"group_size should be divisble by the total tensor dimensions. shape: "
+ str(tensor.shape)
+ ", group_size: "
+ str(group_size)
)

W = tensor.to(device=device, dtype=torch.float32)
shape = W.shape

# Reshape for grouping
if group_size is not None:
W = (
W.reshape([-1, group_size])
if (axis == 1)
else W.reshape([group_size, -1])
)

# Get min/max values
_min = W.min(axis=axis, keepdim=True)[0]
_max = W.max(axis=axis, keepdim=True)[0]

max_v = round(2**nbits - 1)
min_v = 0
min_max = [min_v, max_v]

# Clamp to avoid fp16 issues
scale = (max_v / (_max - _min)).clamp(max=2e4)
zero = -_min * scale

# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
if nbits in [4]:
zero = torch.round(zero)

# Fine-tune weights
if optimize:
W_q, scale, zero = HQQQuantizer.optimize_weights(
tensor=W,
scale=scale,
zero=zero,
min_max=min_max,
axis=axis,
device=device,
verbose=verbose,
)
else:
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])

# Store meta-data (we invert the scale for dequantization)
scale = 1.0 / scale

# Convert to affienquantized format
if raw_output is False:
W_q, scale, zero = convert_to_affinequantized_format(
W_q, scale, zero, nbits, shape
)

# Make sure all the weights are in the right compute_dtype/device
W_q = W_q.to(dtype=torch.uint8, device=device)
scale = scale.to(dtype=compute_dtype, device=device)
zero = zero.to(dtype=compute_dtype, device=device)

# cleanup
del W, _min, _max
torch.cuda.empty_cache()

return W_q, scale, zero, shape
Loading