|
10 | 10 | import torch
|
11 | 11 | from torchao.quantization.quant_primitives import (
|
12 | 12 | get_group_qparams_symmetric,
|
| 13 | + get_groupwise_affine_qparams, |
13 | 14 | quantize_affine,
|
14 | 15 | dequantize_affine,
|
15 | 16 | choose_qparams_affine,
|
@@ -56,8 +57,8 @@ def test_get_group_qparams_symmetric(self):
|
56 | 57 | scale_obs = scale_obs.reshape(weight.shape[0], -1)
|
57 | 58 |
|
58 | 59 | # assert that scales are identical
|
59 |
| - (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize) |
60 |
| - torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0) |
| 60 | + (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize, precision=torch.float16) |
| 61 | + torch.testing.assert_close(scale_obs, scale_ao, rtol=0, atol=0) |
61 | 62 |
|
62 | 63 | def test_choose_qparams_group_sym(self):
|
63 | 64 | """Note: groupwise asymmetric quant is using a different way of computing zero_points, so
|
@@ -88,7 +89,7 @@ def test_choose_qparams_token_asym(self):
|
88 | 89 | scale_ref = scale_ref.squeeze()
|
89 | 90 | zp_ref = zp_ref.squeeze()
|
90 | 91 |
|
91 |
| - torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3) |
| 92 | + torch.testing.assert_close(scale, scale_ref, atol=10e-3, rtol=10e-3) |
92 | 93 | self.assertTrue(torch.equal(zero_point, zp_ref))
|
93 | 94 |
|
94 | 95 | def test_choose_qparams_tensor_asym(self):
|
@@ -257,7 +258,7 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
|
257 | 258 | quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
|
258 | 259 | dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
|
259 | 260 | # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
|
260 |
| - torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) |
| 261 | + torch.testing.assert_close(dequantized, input, rtol=2, atol=0.02) |
261 | 262 |
|
262 | 263 | def test_choose_qparams_tensor_asym_eps(self):
|
263 | 264 | input = torch.zeros(10, 10)
|
@@ -298,5 +299,69 @@ def test_raises(self):
|
298 | 299 | with self.assertRaisesRegex(RuntimeError, "is invalid for input of size 1"):
|
299 | 300 | _ = quantize_affine(input, block_size, scale, zero_point, dtype)
|
300 | 301 |
|
| 302 | + def test_not_preserve_zero_not_supported(self): |
| 303 | + """Making sure preserve_zero == False is not supported for symmetric quant""" |
| 304 | + input = torch.randn(10, 256) |
| 305 | + n_bit = 4 |
| 306 | + mapping_type = MappingType.SYMMETRIC |
| 307 | + dtype = torch.int8 |
| 308 | + block_size = (1, 128) |
| 309 | + quant_min = 0 |
| 310 | + quant_max = 2**n_bit - 1 |
| 311 | + eps = 1e-6 |
| 312 | + scale_dtype = torch.bfloat16 |
| 313 | + zero_point_dtype = torch.bfloat16 |
| 314 | + with self.assertRaisesRegex(ValueError, "preserve_zero == False is not supported for symmetric quantization"): |
| 315 | + choose_qparams_affine( |
| 316 | + input, |
| 317 | + mapping_type, |
| 318 | + block_size, |
| 319 | + dtype, |
| 320 | + quant_min, |
| 321 | + quant_max, |
| 322 | + eps, |
| 323 | + scale_dtype=scale_dtype, |
| 324 | + zero_point_dtype=zero_point_dtype, |
| 325 | + preserve_zero=False, |
| 326 | + ) |
| 327 | + |
| 328 | + |
| 329 | + def test_tinygemm_get_groupwise_affine_qparams(self): |
| 330 | + input = torch.randn(10, 256) |
| 331 | + n_bit = 4 |
| 332 | + scale_ref, zero_point_ref = get_groupwise_affine_qparams(input, n_bit=n_bit, groupsize=128, dtype=torch.bfloat16) |
| 333 | + |
| 334 | + mapping_type = MappingType.ASYMMETRIC |
| 335 | + dtype = torch.int8 |
| 336 | + block_size = (1, 128) |
| 337 | + quant_min = 0 |
| 338 | + quant_max = 2**n_bit - 1 |
| 339 | + eps = 1e-6 |
| 340 | + scale_dtype = torch.bfloat16 |
| 341 | + zero_point_dtype = torch.bfloat16 |
| 342 | + scale, zero_point = \ |
| 343 | + choose_qparams_affine( |
| 344 | + input, |
| 345 | + mapping_type, |
| 346 | + block_size, |
| 347 | + dtype, |
| 348 | + quant_min, |
| 349 | + quant_max, |
| 350 | + eps, |
| 351 | + scale_dtype=scale_dtype, |
| 352 | + zero_point_dtype=zero_point_dtype, |
| 353 | + preserve_zero=False, |
| 354 | + ) |
| 355 | + |
| 356 | + def int_zero_point_to_float(zero_point, scale, qaunt_min, mid_point): |
| 357 | + return (quant_min - zero_point + mid_point) * scale |
| 358 | + |
| 359 | + mid_point = 2 ** (n_bit - 1) |
| 360 | + zero_point_float = int_zero_point_to_float(zero_point, scale, quant_min, mid_point) |
| 361 | + |
| 362 | + self.assertTrue(torch.equal(scale, scale_ref)) |
| 363 | + torch.testing.assert_close(zero_point_float, zero_point_ref, rtol=0.00001, atol=torch.max(scale)*0.03) |
| 364 | + |
| 365 | + |
301 | 366 | if __name__ == "__main__":
|
302 | 367 | unittest.main()
|
0 commit comments