Skip to content

Commit

Permalink
deduplicate code for get_group_qparams_symmetric
Browse files Browse the repository at this point in the history
Summary:
This just removes the implementation, we can have follow up PRs to remove the call all together
after we have replaced all implementation with the new blockwise quant code

Test Plan:
CI

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Apr 24, 2024
1 parent f05c215 commit 28bd36c
Showing 1 changed file with 32 additions and 19 deletions.
51 changes: 32 additions & 19 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,25 +729,38 @@ def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2

to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0

max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

max_val_abs = torch.max(-min_val_neg, max_val_pos)
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))

scales = max_val_abs / (float(max_int - min_int) / 2)
scales = torch.max(scales, torch.full_like(scales, torch.finfo(torch.float32).eps))
# TODO: make sure abs(scales) is not too small?
zeros = torch.full_like(scales, 0)
return scales.to(precision).reshape(w.shape[0], -1), zeros.to(precision).reshape(
w.shape[0], -1
)
block_size = (w.shape[0], groupsize)
mapping_type = MappingType.SYMMETRIC
eps = torch.finfo(torch.float32).eps
if TORCH_VERSION_AFTER_2_3:
bit_to_dtype = {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
8: torch.uint8,
}
assert n_bit in ranges, f"unsupported bit: {n_bit}"
target_dtype = bit_to_dtype[n_bit]
return choose_qparams_affine(w, mapping_type, block_size, target_dtype=target_dtype, eps=eps, scale_dtype=precision, zero_point_dtype=precision)
else:
ranges = {
1: (0, 2**1-1),
2: (0, 2**2-1),
3: (0, 2**3-1),
4: (0, 2**4-1),
5: (0, 2**5-1),
6: (0, 2**6-1),
7: (0, 2**7-1),
8: (0, 2**8-1),
}
assert n_bit in ranges, f"unsupported bit: {n_bit}"
quant_min, quant_max = ranges[n_bit]
# using uint8 to simulate uint4
return choose_qparams_affine(w, mapping_type, block_size, target_dtype=torch.uint8, quant_min=0, quant_max=15, eps=eps, scale_dtype=precision, zero_point_dtype=precision)


if TORCH_VERSION_AFTER_2_3:
Expand Down

0 comments on commit 28bd36c

Please sign in to comment.