Skip to content

Commit febc832

Browse files
authored
[quant] Add per block quantization primitives (pytorch#159)
Summary: We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py Test Plan: python test/quantization/test_quant_primitives.py Reviewers: Subscribers: Tasks: Tags:
1 parent 9e5d9cb commit febc832

File tree

3 files changed

+437
-3
lines changed

3 files changed

+437
-3
lines changed

test/quantization/test_quant_primitives.py

+186-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,21 @@
88
# This test takes a long time to run
99
import unittest
1010
import torch
11-
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
12-
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
11+
from torchao.quantization.quant_primitives import (
12+
get_group_qparams_symmetric,
13+
quantize_affine,
14+
dequantize_affine,
15+
choose_qparams_affine,
16+
MappingType,
17+
)
18+
19+
from torchao.quantization.utils import (
20+
TORCH_VERSION_AFTER_2_3,
21+
TORCH_VERSION_AFTER_2_4,
22+
)
23+
24+
_SEED = 1234
25+
torch.manual_seed(_SEED)
1326

1427
class TestQuantPrimitives(unittest.TestCase):
1528
SEED = 123
@@ -46,5 +59,176 @@ def test_get_group_qparams_symmetric(self):
4659
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
4760
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)
4861

62+
def test_choose_qparams_group_sym(self):
63+
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
64+
we don't include it here. We may just replace it with per block quant
65+
"""
66+
input = torch.randn(10, 10)
67+
mapping_type = MappingType.SYMMETRIC
68+
dtype = torch.int8
69+
block_size = (1, 2)
70+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
71+
72+
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)
73+
74+
self.assertTrue(torch.equal(scale, scale_ref))
75+
self.assertTrue(torch.equal(zero_point, zp_ref))
76+
77+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
78+
def test_choose_qparams_token_asym(self):
79+
input = torch.randn(10, 10)
80+
mapping_type = MappingType.ASYMMETRIC
81+
dtype = torch.int8
82+
block_size = (1, 10)
83+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
84+
85+
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
86+
scale_ref = scale_ref.squeeze()
87+
zp_ref = zp_ref.squeeze()
88+
89+
torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
90+
self.assertTrue(torch.equal(zero_point, zp_ref))
91+
92+
def test_choose_qparams_tensor_asym(self):
93+
input = torch.randn(10, 10)
94+
mapping_type = MappingType.ASYMMETRIC
95+
dtype = torch.int8
96+
block_size = (10, 10)
97+
eps = torch.finfo(torch.float32).eps
98+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)
99+
100+
101+
quant_min = -128
102+
quant_max = 127
103+
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype)
104+
scale_ref = scale_ref.squeeze()
105+
zp_ref = zp_ref.squeeze()
106+
107+
self.assertTrue(torch.equal(scale, scale_ref))
108+
self.assertTrue(torch.equal(zero_point, zp_ref))
109+
110+
def test_choose_qparams_tensor_sym(self):
111+
input = torch.randn(10, 10)
112+
mapping_type = MappingType.SYMMETRIC
113+
dtype = torch.int8
114+
block_size = (10, 10)
115+
eps = torch.finfo(torch.float32).eps
116+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps)
117+
118+
quant_min = -128
119+
quant_max = 127
120+
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype)
121+
scale_ref = scale_ref.squeeze()
122+
zp_ref = zp_ref.squeeze()
123+
124+
self.assertTrue(torch.equal(scale, scale_ref))
125+
self.assertTrue(torch.equal(zero_point, zp_ref))
126+
127+
128+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
129+
def test_quantize_dequantize_group_sym(self):
130+
input = torch.randn(10, 10)
131+
mapping_type = MappingType.SYMMETRIC
132+
dtype = torch.int8
133+
block_size = (1, 2)
134+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
135+
136+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
137+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
138+
139+
group_size = 2
140+
quant_min = -128
141+
quant_max = 127
142+
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group(
143+
input, scale, zero_point, quant_min, quant_max, torch.int8, group_size
144+
)
145+
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group(
146+
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32
147+
)
148+
149+
self.assertTrue(torch.equal(quantized, quantized_ref))
150+
self.assertTrue(torch.equal(dequantized, dequantized_ref))
151+
152+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
153+
def test_quantize_dequantize_channel_asym(self):
154+
input = torch.randn(10, 10)
155+
mapping_type = MappingType.ASYMMETRIC
156+
dtype = torch.int8
157+
block_size = (10, 1)
158+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
159+
output_dtype = torch.float32
160+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
161+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
162+
163+
axis = 1
164+
quant_min = -128
165+
quant_max = 127
166+
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
167+
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
168+
)
169+
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
170+
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=output_dtype
171+
)
172+
self.assertTrue(torch.equal(quantized, quantized_ref))
173+
self.assertTrue(torch.equal(dequantized, dequantized_ref))
174+
175+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
176+
def test_quantize_dequantize_tensor_asym(self):
177+
input = torch.randn(10, 10)
178+
mapping_type = MappingType.ASYMMETRIC
179+
dtype = torch.int8
180+
block_size = (10, 10)
181+
output_dtype = torch.float32
182+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
183+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
184+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=output_dtype)
185+
186+
axis = 1
187+
quant_min = -128
188+
quant_max = 127
189+
quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor(
190+
input, scale, zero_point, quant_min, quant_max, torch.int8
191+
)
192+
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor(
193+
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=output_dtype
194+
)
195+
self.assertTrue(torch.equal(quantized, quantized_ref))
196+
self.assertTrue(torch.equal(dequantized, dequantized_ref))
197+
198+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
199+
def test_quantize_dequantize_channel_asym_4d(self):
200+
input = torch.randn(3, 3, 10, 10)
201+
mapping_type = MappingType.ASYMMETRIC
202+
dtype = torch.int8
203+
block_size = (3, 3, 1, 10)
204+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
205+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
206+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
207+
208+
axis = 2
209+
quant_min = -128
210+
quant_max = 127
211+
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
212+
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
213+
)
214+
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
215+
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32
216+
)
217+
self.assertTrue(torch.equal(quantized, quantized_ref))
218+
self.assertTrue(torch.equal(dequantized, dequantized_ref))
219+
220+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
221+
def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
222+
input = torch.randn(3, 3, 10, 10)
223+
mapping_type = MappingType.ASYMMETRIC
224+
dtype = torch.int8
225+
block_size = (3, 3, 2, 2)
226+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
227+
quantized = quantize_affine(input, block_size, scale, zero_point, dtype)
228+
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)
229+
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
230+
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)
231+
232+
49233
if __name__ == "__main__":
50234
unittest.main()

torchao/quantization/autoquant.py

-1
Original file line numberDiff line numberDiff line change
@@ -391,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte
391391
model(*example_input)
392392
change_autoquantizable_to_quantized(model, **kwargs)
393393
return model
394-

0 commit comments

Comments
 (0)