Skip to content

Commit d5efc49

Browse files
committed
Add test for choose_qparams for tinygemm ops
Summary: This is in preparation for replacing tinygemm q/dq ops with the unified quant primitive ops Test Plan: python test/quantization/test_quant_primitives.py -k test_tinygemm_get_groupwise_affine_qparams Reviewers: Subscribers: Tasks: Tags:
1 parent 63c5ac5 commit d5efc49

File tree

2 files changed

+95
-10
lines changed

2 files changed

+95
-10
lines changed

test/quantization/test_quant_primitives.py

+69-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torchao.quantization.quant_primitives import (
1212
get_group_qparams_symmetric,
13+
get_groupwise_affine_qparams,
1314
quantize_affine,
1415
dequantize_affine,
1516
choose_qparams_affine,
@@ -56,8 +57,8 @@ def test_get_group_qparams_symmetric(self):
5657
scale_obs = scale_obs.reshape(weight.shape[0], -1)
5758

5859
# assert that scales are identical
59-
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
60-
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
60+
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16)
61+
torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0)
6162

6263
def test_choose_qparams_group_sym(self):
6364
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
@@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self):
8889
scale_ref = scale_ref.squeeze()
8990
zp_ref = zp_ref.squeeze()
9091

91-
torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
92+
torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3)
9293
self.assertTrue(torch.equal(zero_point, zp_ref))
9394

9495
def test_choose_qparams_tensor_asym(self):
@@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
257258
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
258259
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
259260
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
260-
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
261+
torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02)
261262

262263
def test_choose_qparams_tensor_asym_eps(self):
263264
input = torch.zeros(10, 10)
@@ -298,5 +299,69 @@ def test_raises(self):
298299
with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
299300
_ = quantize_affine(input, block_size, scale, zero_point, dtype)
300301

302+
def test_not_preserve_zero_not_supported(self):
303+
"""Making sure preserve_zero == False is not supported for symmetric quant"""
304+
input = torch.randn(10, 256)
305+
n_bit = 4
306+
mapping_type = MappingType.SYMMETRIC
307+
dtype = torch.int8
308+
block_size = (1, 128)
309+
quant_min = 0
310+
quant_max = 2**n_bit - 1
311+
eps = 1e-6
312+
scale_dtype = torch.bfloat16
313+
zero_point_dtype = torch.bfloat16
314+
with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"):
315+
choose_qparams_affine(
316+
input,
317+
mapping_type,
318+
block_size,
319+
dtype,
320+
quant_min,
321+
quant_max,
322+
eps,
323+
scale_dtype=scale_dtype,
324+
zero_point_dtype=zero_point_dtype,
325+
preserve_zero=False,
326+
)
327+
328+
329+
def test_tinygemm_get_groupwise_affine_qparams(self):
330+
input = torch.randn(10, 256)
331+
n_bit = 4
332+
scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16)
333+
334+
mapping_type = MappingType.ASYMMETRIC
335+
dtype = torch.int8
336+
block_size = (1, 128)
337+
quant_min = 0
338+
quant_max = 2**n_bit - 1
339+
eps = 1e-6
340+
scale_dtype = torch.bfloat16
341+
zero_point_dtype = torch.bfloat16
342+
scale, zero_point = \
343+
choose_qparams_affine(
344+
input,
345+
mapping_type,
346+
block_size,
347+
dtype,
348+
quant_min,
349+
quant_max,
350+
eps,
351+
scale_dtype=scale_dtype,
352+
zero_point_dtype=zero_point_dtype,
353+
preserve_zero=False,
354+
)
355+
356+
def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point):
357+
return (quant_min - zero_point + mid_point) * scale
358+
359+
mid_point = 2 ** (n_bit - 1)
360+
zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point)
361+
362+
self.assertTrue(torch.equal(scale, scale_ref))
363+
torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03)
364+
365+
301366
if __name__ == "__main__":
302367
unittest.main()

torchao/quantization/quant_primitives.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -256,19 +256,29 @@ def choose_qparams_affine(
256256
eps: Optional[float] = None,
257257
scale_dtype: Optional[torch.dtype] = None,
258258
zero_point_dtype: Optional[torch.dtype] = None,
259+
preserve_zero = True,
259260
) -> Tuple[torch.Tensor, torch.Tensor]:
260261
"""
261262
Args:
262263
input (torch.Tensor): fp32, bf16, fp16 input Tensor
263264
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
264-
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
265-
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
265+
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
266+
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
266267
target_dtype (torch.dtype): dtype for target quantized Tensor
267268
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
268269
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
269270
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
270271
scale_dtype (torch.dtype): dtype for scale Tensor
271272
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
273+
preserve_zero (bool): a flag to indicate whether we need zero to be exactly
274+
representable or not, this is typically required for ops that needs zero padding, like convolution
275+
it's less important for ops that doesn't have zero padding in the op itself, like linear.
276+
277+
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
278+
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
279+
gurantee.
280+
281+
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
272282
273283
Output:
274284
Tuple of scales and zero_points Tensor with requested dtype
@@ -288,17 +298,27 @@ def choose_qparams_affine(
288298
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
289299
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
290300

291-
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
292-
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
301+
if preserve_zero:
302+
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
303+
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
304+
else:
305+
min_val_neg = min_val
306+
max_val_pos = max_val
293307

294308
if mapping_type == MappingType.SYMMETRIC:
295309
max_val_pos = torch.max(-min_val_neg, max_val_pos)
296310
scale = max_val_pos / (float(quant_max - quant_min) / 2)
311+
if not preserve_zero:
312+
raise ValueError("preserve_zero == False is not supported for symmetric quantization")
297313
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
298314
else:
299315
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
300-
zero_point = quant_min - torch.round(min_val_neg / scale)
301-
zero_point = torch.clamp(zero_point, quant_min, quant_max)
316+
if preserve_zero:
317+
zero_point = quant_min - torch.round(min_val_neg / scale)
318+
zero_point = torch.clamp(zero_point, quant_min, quant_max)
319+
else:
320+
zero_point = quant_min - min_val_neg / scale
321+
302322

303323
if eps is None:
304324
eps = torch.finfo(input.dtype).eps

0 commit comments

Comments
 (0)