|
8 | 8 | # This test takes a long time to run
|
9 | 9 | import unittest
|
10 | 10 | import torch
|
11 |
| -from torchao.quantization.quant_primitives import get_group_qparams_symmetric |
12 |
| -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3 |
| 11 | +from torchao.quantization.quant_primitives import ( |
| 12 | + get_group_qparams_symmetric, |
| 13 | + quantize_affine, |
| 14 | + dequantize_affine, |
| 15 | + choose_qparams_affine, |
| 16 | + MappingType, |
| 17 | +) |
| 18 | + |
| 19 | +from torchao.quantization.utils import ( |
| 20 | + TORCH_VERSION_AFTER_2_3, |
| 21 | + TORCH_VERSION_AFTER_2_4, |
| 22 | +) |
| 23 | + |
| 24 | +_SEED = 1234 |
| 25 | +torch.manual_seed(_SEED) |
13 | 26 |
|
14 | 27 | class TestQuantPrimitives(unittest.TestCase):
|
15 | 28 | SEED = 123
|
@@ -46,5 +59,176 @@ def test_get_group_qparams_symmetric(self):
|
46 | 59 | (scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
|
47 | 60 | torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
|
48 | 61 |
|
| 62 | + def test_choose_qparams_group_sym(self): |
| 63 | + """Note: groupwise asymmetric quant is using a different way of computing zero_points, so |
| 64 | + we don't include it here. We may just replace it with per block quant |
| 65 | + """ |
| 66 | + input = torch.randn(10, 10) |
| 67 | + mapping_type = MappingType.SYMMETRIC |
| 68 | + dtype = torch.int8 |
| 69 | + block_size = (1, 2) |
| 70 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 71 | + |
| 72 | + scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2) |
| 73 | + |
| 74 | + self.assertTrue(torch.equal(scale, scale_ref)) |
| 75 | + self.assertTrue(torch.equal(zero_point, zp_ref)) |
| 76 | + |
| 77 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") |
| 78 | + def test_choose_qparams_token_asym(self): |
| 79 | + input = torch.randn(10, 10) |
| 80 | + mapping_type = MappingType.ASYMMETRIC |
| 81 | + dtype = torch.int8 |
| 82 | + block_size = (1, 10) |
| 83 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 84 | + |
| 85 | + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype) |
| 86 | + scale_ref = scale_ref.squeeze() |
| 87 | + zp_ref = zp_ref.squeeze() |
| 88 | + |
| 89 | + torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3) |
| 90 | + self.assertTrue(torch.equal(zero_point, zp_ref)) |
| 91 | + |
| 92 | + def test_choose_qparams_tensor_asym(self): |
| 93 | + input = torch.randn(10, 10) |
| 94 | + mapping_type = MappingType.ASYMMETRIC |
| 95 | + dtype = torch.int8 |
| 96 | + block_size = (10, 10) |
| 97 | + eps = torch.finfo(torch.float32).eps |
| 98 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) |
| 99 | + |
| 100 | + |
| 101 | + quant_min = -128 |
| 102 | + quant_max = 127 |
| 103 | + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype) |
| 104 | + scale_ref = scale_ref.squeeze() |
| 105 | + zp_ref = zp_ref.squeeze() |
| 106 | + |
| 107 | + self.assertTrue(torch.equal(scale, scale_ref)) |
| 108 | + self.assertTrue(torch.equal(zero_point, zp_ref)) |
| 109 | + |
| 110 | + def test_choose_qparams_tensor_sym(self): |
| 111 | + input = torch.randn(10, 10) |
| 112 | + mapping_type = MappingType.SYMMETRIC |
| 113 | + dtype = torch.int8 |
| 114 | + block_size = (10, 10) |
| 115 | + eps = torch.finfo(torch.float32).eps |
| 116 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps) |
| 117 | + |
| 118 | + quant_min = -128 |
| 119 | + quant_max = 127 |
| 120 | + scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype) |
| 121 | + scale_ref = scale_ref.squeeze() |
| 122 | + zp_ref = zp_ref.squeeze() |
| 123 | + |
| 124 | + self.assertTrue(torch.equal(scale, scale_ref)) |
| 125 | + self.assertTrue(torch.equal(zero_point, zp_ref)) |
| 126 | + |
| 127 | + |
| 128 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") |
| 129 | + def test_quantize_dequantize_group_sym(self): |
| 130 | + input = torch.randn(10, 10) |
| 131 | + mapping_type = MappingType.SYMMETRIC |
| 132 | + dtype = torch.int8 |
| 133 | + block_size = (1, 2) |
| 134 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 135 | + |
| 136 | + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) |
| 137 | + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) |
| 138 | + |
| 139 | + group_size = 2 |
| 140 | + quant_min = -128 |
| 141 | + quant_max = 127 |
| 142 | + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group( |
| 143 | + input, scale, zero_point, quant_min, quant_max, torch.int8, group_size |
| 144 | + ) |
| 145 | + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group( |
| 146 | + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32 |
| 147 | + ) |
| 148 | + |
| 149 | + self.assertTrue(torch.equal(quantized, quantized_ref)) |
| 150 | + self.assertTrue(torch.equal(dequantized, dequantized_ref)) |
| 151 | + |
| 152 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") |
| 153 | + def test_quantize_dequantize_channel_asym(self): |
| 154 | + input = torch.randn(10, 10) |
| 155 | + mapping_type = MappingType.ASYMMETRIC |
| 156 | + dtype = torch.int8 |
| 157 | + block_size = (10, 1) |
| 158 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 159 | + output_dtype = torch.float32 |
| 160 | + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) |
| 161 | + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) |
| 162 | + |
| 163 | + axis = 1 |
| 164 | + quant_min = -128 |
| 165 | + quant_max = 127 |
| 166 | + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel( |
| 167 | + input, scale, zero_point, axis, quant_min, quant_max, torch.int8 |
| 168 | + ) |
| 169 | + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( |
| 170 | + quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype |
| 171 | + ) |
| 172 | + self.assertTrue(torch.equal(quantized, quantized_ref)) |
| 173 | + self.assertTrue(torch.equal(dequantized, dequantized_ref)) |
| 174 | + |
| 175 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") |
| 176 | + def test_quantize_dequantize_tensor_asym(self): |
| 177 | + input = torch.randn(10, 10) |
| 178 | + mapping_type = MappingType.ASYMMETRIC |
| 179 | + dtype = torch.int8 |
| 180 | + block_size = (10, 10) |
| 181 | + output_dtype = torch.float32 |
| 182 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 183 | + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) |
| 184 | + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype) |
| 185 | + |
| 186 | + axis = 1 |
| 187 | + quant_min = -128 |
| 188 | + quant_max = 127 |
| 189 | + quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor( |
| 190 | + input, scale, zero_point, quant_min, quant_max, torch.int8 |
| 191 | + ) |
| 192 | + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor( |
| 193 | + quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype |
| 194 | + ) |
| 195 | + self.assertTrue(torch.equal(quantized, quantized_ref)) |
| 196 | + self.assertTrue(torch.equal(dequantized, dequantized_ref)) |
| 197 | + |
| 198 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower") |
| 199 | + def test_quantize_dequantize_channel_asym_4d(self): |
| 200 | + input = torch.randn(3, 3, 10, 10) |
| 201 | + mapping_type = MappingType.ASYMMETRIC |
| 202 | + dtype = torch.int8 |
| 203 | + block_size = (3, 3, 1, 10) |
| 204 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 205 | + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) |
| 206 | + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) |
| 207 | + |
| 208 | + axis = 2 |
| 209 | + quant_min = -128 |
| 210 | + quant_max = 127 |
| 211 | + quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel( |
| 212 | + input, scale, zero_point, axis, quant_min, quant_max, torch.int8 |
| 213 | + ) |
| 214 | + dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel( |
| 215 | + quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32 |
| 216 | + ) |
| 217 | + self.assertTrue(torch.equal(quantized, quantized_ref)) |
| 218 | + self.assertTrue(torch.equal(dequantized, dequantized_ref)) |
| 219 | + |
| 220 | + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") |
| 221 | + def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): |
| 222 | + input = torch.randn(3, 3, 10, 10) |
| 223 | + mapping_type = MappingType.ASYMMETRIC |
| 224 | + dtype = torch.int8 |
| 225 | + block_size = (3, 3, 2, 2) |
| 226 | + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps) |
| 227 | + quantized = quantize_affine(input, block_size, scale, zero_point, dtype) |
| 228 | + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32) |
| 229 | + # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float |
| 230 | + torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) |
| 231 | + |
| 232 | + |
49 | 233 | if __name__ == "__main__":
|
50 | 234 | unittest.main()
|
0 commit comments