Skip to content

Commit

Permalink
Add test for choose_qparams for tinygemm ops
Browse files Browse the repository at this point in the history
Summary:
This is in preparation for replacing tinygemm q/dq ops with the unified quant primitive ops

Test Plan:
python test/quantization/test_quant_primitives.py -k test_tinygemm_get_groupwise_affine_qparams

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 8, 2024
1 parent 63c5ac5 commit 7e53a5e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 10 deletions.
46 changes: 42 additions & 4 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -56,8 +57,8 @@ def test_get_group_qparams_symmetric(self):
scale_obs = scale_obs.reshape(weight.shape[0], -1)

# assert that scales are identical
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16)
torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
Expand Down Expand Up @@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self):
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
Expand Down Expand Up @@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
Expand Down Expand Up @@ -298,5 +299,42 @@ def test_raises(self):
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
_ = quantize_affine(input, block_size, scale, zero_point, dtype)

def test_tinygemm_get_groupwise_affine_qparams(self):
input = torch.randn(10, 256)
n_bit = 4
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)

mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 128)
quant_min = 0
quant_max = 2**n_bit - 1
eps = 1e-6
scale_dtype = torch.bfloat16
zero_point_dtype = torch.bfloat16
scale, zero_point = \
choose_qparams_affine(
input,
mapping_type,
block_size,
dtype,
quant_min,
quant_max,
eps,
scale_dtype=scale_dtype,
zero_point_dtype=zero_point_dtype,
preserve_zero=False,
)

def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
return (quant_min - zero_point + mid_point) * scale

mid_point = 2 ** (n_bit - 1)
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)

self.assertTrue(torch.equal(scale, scale_ref))
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)


if __name__ == "__main__":
unittest.main()
31 changes: 25 additions & 6 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,19 +256,29 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): fp32, bf16, fp16 input Tensor
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
target_dtype (torch.dtype): dtype for target quantized Tensor
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
Output:
Tuple of scales and zero_points Tensor with requested dtype
Expand All @@ -288,17 +298,26 @@ def choose_qparams_affine(
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
if preserve_zero:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
else:
min_val_neg = min_val
max_val_pos = max_val

if mapping_type == MappingType.SYMMETRIC:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
assert preserve_zero, "non-representable zero path is not implemented for symmetric quantization"
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if preserve_zero:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
zero_point = quant_min - min_val_neg / scale


if eps is None:
eps = torch.finfo(input.dtype).eps
Expand Down

0 comments on commit 7e53a5e

Please sign in to comment.