Skip to content

Commit

Permalink
deduplicate code for get_group_qparams_symmetric
Browse files Browse the repository at this point in the history
Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together
after we have replaced all implementation with the new blockwise quant code

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 25, 2024
1 parent f05c215 commit d181692
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
6 changes: 4 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_choose_qparams_group_sym(self):
mapping_type = MappingType.SYMMETRIC
dtype = torch.int8
block_size = (1, 2)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
eps = torch.finfo(torch.float32).eps
precision = torch.float32
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)

scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2)
scale_ref, zp_ref = get_group_qparams_symmetric(input, n_bit=8, groupsize=2, precision=precision)

self.assertTrue(torch.equal(scale, scale_ref))
self.assertTrue(torch.equal(zero_point, zp_ref))
Expand Down
33 changes: 14 additions & 19 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
"groupwise_affine_dequantize_tensor_from_qparams",
"groupwise_affine_quantize_tensor",
"groupwise_affine_dequantize_tensor",
"choose_qparams_affine",
"quantize_affine",
"dequantize_affine",
# TODO: need to clean up above functions
] + (_AFTER_TORCH_2_3_ONLY if TORCH_VERSION_AFTER_2_3 else [])

Expand Down Expand Up @@ -728,26 +731,18 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
assert n_bit <= 8, f"unsupported n_bit: {n_bit}"

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))

scales = max_val_abs / (float(max_int - min_int) / 2)
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
w.shape[0], -1
)
block_size = (1, groupsize)
mapping_type = MappingType.SYMMETRIC
eps = torch.finfo(torch.float32).eps
ranges = {}
ranges[1] = (-1, 0)
# generating ranges for bit 2 to 8
for i in range(2, 9):
ranges[i] = (-(2 ** (i - 1)), 2 ** (i - 1) - 1)
quant_min, quant_max = ranges[n_bit]
return choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.int8, quant_min=quant_min, quant_max=quant_max, eps=eps, scale_dtype=precision, zero_point_dtype=precision)


if TORCH_VERSION_AFTER_2_3:
Expand Down

0 comments on commit d181692

Please sign in to comment.