Skip to content

Commit

Permalink
[quant] Add per block quantization primitives
Browse files Browse the repository at this point in the history
Summary:
We want to use this to replace all q/dq/choose_qparams ops in https://github.com/pytorch-labs/ao/blob/main/torchao/quantization/quant_primitives.py and https://github.com/pytorch/pytorch/blob/main/torch/ao/quantization/fx/_decomposed.py

Test Plan:
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 23, 2024
1 parent 3124382 commit a02d061
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 3 deletions.
140 changes: 139 additions & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
# This test takes a long time to run
import unittest
import torch
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.quant_primitives import (
get_group_qparams_symmetric,
quantize_affine_per_block,
dequantize_affine_per_block,
choose_qparams_affine_per_block,
MappingType,
)

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3

class TestQuantPrimitives(unittest.TestCase):
Expand Down Expand Up @@ -46,5 +53,136 @@ def test_get_group_qparams_symmetric(self):
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
torch.testing.assert_allclose(scale_obs, scale_ao, rtol=0, atol=0)

def test_choose_qparams_group_sym(self):
"""Note: groupwise asymmetric quant is using a different way of computing zero_points, so
we don't include it here. We may just replace it with per block quant
"""
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_token_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (1, 10)
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

torch.testing.assert_allclose(scale, scale_ref, atol=10e-3, rtol=10e-3)
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=eps)

quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))

def test_choose_qparams_tensor_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (10, 10)
eps = torch.finfo(torch.float32).eps
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=eps)

quant_min = -128
quant_max = 127
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_symmetric(input, quant_min, quant_max, eps, dtype)
scale_ref = scale_ref.squeeze()
zp_ref = zp_ref.squeeze()

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)

quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

group_size = 2
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel_group(
input, scale, zero_point, quant_min, quant_max, torch.int8, group_size
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel_group(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, group_size, output_dtype=torch.float32
)

self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))


def test_quantize_dequantize_channel_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 1)
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_channel(
input, scale, zero_point, axis, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_channel(
quantized_ref, scale, zero_point, axis, quant_min, quant_max, torch.int8, out_dtype=torch.float32
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

def test_quantize_dequantize_tensor_asym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine_per_block(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
quantized = quantize_affine_per_block(input, block_size, scale, zero_point, dtype)
dequantized = dequantize_affine_per_block(quantized, block_size, scale, zero_point, dtype, output_dtype=torch.float32)

axis = 1
quant_min = -128
quant_max = 127
quantized_ref = torch.ops.quantized_decomposed.quantize_per_tensor(
input, scale, zero_point, quant_min, quant_max, torch.int8
)
dequantized_ref = torch.ops.quantized_decomposed.dequantize_per_tensor(
quantized_ref, scale, zero_point, quant_min, quant_max, torch.int8, out_dtype=torch.float32
)
self.assertTrue(torch.equal(quantized, quantized_ref))
self.assertTrue(torch.equal(dequantized, dequantized_ref))

if __name__ == "__main__":
unittest.main()
7 changes: 5 additions & 2 deletions torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
safe_int_mm,
)
import torch.nn.functional as F
from torch._inductor.utils import do_bench
try:
from torch._inductor.utils import do_bench
except:
from torch._inductor.runtime.runtime_utils import do_bench

aten = torch.ops.aten

AUTOQUANT_CACHE = {}
Expand Down Expand Up @@ -387,4 +391,3 @@ def autoquant(model, example_input, qtensor_class_list=DEFAULT_CLASS_LIST, filte
model(*example_input)
change_autoquantizable_to_quantized(model, **kwargs)
return model

Loading

0 comments on commit a02d061

Please sign in to comment.