Skip to content

Commit

Permalink
quant primitives: always set min val for scale (#201)
Browse files Browse the repository at this point in the history
Summary:
This is to avoid div by 0 in quantize

Test Plan:
python test/quantization/test_quant_primitives.py -k test_choose_qparams_tensor_asym_eps

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored May 3, 2024
1 parent eb03753 commit ec9d9d8
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
9 changes: 9 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)

def test_choose_qparams_tensor_asym_eps(self):
input = torch.zeros(10, 10)
mapping_type = MappingType.ASYMMETRIC
dtype = torch.int8
block_size = (10, 10)
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype)
eps = torch.finfo(torch.float32).eps
self.assertEqual(scale, eps)


if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ def choose_qparams_affine(
else:
raise RuntimeError(f"Unsupported mapping type: {mapping_type}")

if eps is not None:
scale = torch.clamp(scale, min=eps)
if eps is None:
eps = torch.finfo(input.dtype).eps
scale = torch.clamp(scale, min=eps)

return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)

Expand Down

0 comments on commit ec9d9d8

Please sign in to comment.