diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 6186714e3b..8547532b78 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -231,6 +231,15 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self): # we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02) + def test_choose_qparams_tensor_asym_eps(self): + input = torch.zeros(10, 10) + mapping_type = MappingType.ASYMMETRIC + dtype = torch.int8 + block_size = (10, 10) + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype) + eps = torch.finfo(torch.float32).eps + self.assertEqual(scale, eps) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index bd4bcce1aa..6c80504ec4 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -287,8 +287,9 @@ def choose_qparams_affine( else: raise RuntimeError(f"Unsupported mapping type: {mapping_type}") - if eps is not None: - scale = torch.clamp(scale, min=eps) + if eps is None: + eps = torch.finfo(input).eps + scale = torch.clamp(scale, min=eps) return scale.to(dtype=scale_dtype), zero_point.to(dtype=zero_point_dtype)