Skip to content

Commit

Permalink
Fix wrong scale eps applied
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic committed Feb 28, 2025
1 parent 79e3366 commit 9a43a80
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
60 changes: 60 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,66 @@ def test_float8_quant_primitives(self, hp_dtype, float8_dtype):
torch.testing.assert_close(expected_quantized, quantized)
torch.testing.assert_close(expected_dequantized, dequantized)

@parameterized.expand(
[
torch.float64,
torch.float32,
torch.bfloat16,
torch.float16,
]
)
def test_choose_qparams_affine_for_inf_scale_reciprocal(self, hp_dtype):
# Fixed by #1770, the test will fail for all the variants
# before that fix, and will pass afterwards.
#
# The scale value must be forcefully clamped, within
# _choose_qparams_affine() function, (that
# choose_qparams_affine() and others call into) to a large
# enough number so that its reciprocal does not become Inf.
# Otherwise during the quantization, by multiplying with scale
# reciprocal, all the values will be quantized to Inf value,
# except from zero value that would produce NaN (0*Inf) as
# quantized value.
#
# The minimal normalized value for given floating point data
# type is given by torch.finfo(hp_dtype).tiny - let's call
# this value "tiny". It could be seen by checking, that for
# all of torch.float64, torch.float32, torch.float16 and
# torch.floatb16, denormalized number that is equal to tiny/4
# will produce Inf as its reciprocal.
#
# Thus, to reproduce the problem, one would create a tensor
# with such values that their absolute maximum, after being
# divided with the range of quantized data (that is 57344 for
# torch.float8_e5m2), would produce scale smaller than tiny/4.
# Also, eps parameter should be set to value no greater than
# tiny/4, as scale is clamped from below to that value. With
# such inpujts, choose_qparams_affine() will produce Inf as
# scale value.
#
# Note that this may seem as contrieved reproduces. However,
# there are cases with existing code that would pass
# torch.finfo(torch.float32).eps as eps value, no matters of
# scale_dtype. The float16 has rather small range, so this
# value is well bellow torch.finfo(torch.float32).eps, and for
# such eps value, the code bellow would produce Inf scale even
# for float16 tensor that has 0.5 as its maximum value.
float8_dtype = torch.float8_e5m2
tiny = torch.finfo(hp_dtype).tiny
x = torch.tensor([[0, 100 * tiny]], dtype=hp_dtype)
scale, _ = choose_qparams_affine(
input=x,
mapping_type=MappingType.SYMMETRIC,
block_size=[1, 2],
target_dtype=float8_dtype,
eps=tiny / 4,
scale_dtype=hp_dtype,
preserve_zero=True,
zero_point_domain=ZeroPointDomain.NONE,
)
scale_reciprocal = scale.reciprocal()
assert not torch.any(torch.isinf(scale_reciprocal)).item()


if __name__ == "__main__":
unittest.main()
16 changes: 15 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,7 @@ def _choose_qparams_affine(
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
and `zero_point_domain`
"""

quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [
MappingType.SYMMETRIC.name,
Expand Down Expand Up @@ -907,6 +908,13 @@ def _choose_qparams_affine(
min_val_neg = min_val
max_val_pos = max_val

# Prevent reciprocal of scale, calculated below, to become Inf.
if torch.is_floating_point(max_val):
eps = max(eps, torch.finfo(max_val.dtype).tiny)
else:
# Here, scale will be calculated below as torch.float32
eps = max(eps, torch.finfo(torch.float32).tiny)

if (
mapping_type == MappingType.SYMMETRIC.name
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
Expand Down Expand Up @@ -966,7 +974,13 @@ def _choose_qparams_affine(

if zero_point is not None:
zero_point = zero_point.to(dtype=zero_point_dtype)
return scale.to(dtype=scale_dtype), zero_point
scale = scale.to(dtype=scale_dtype)
if torch.is_floating_point(scale):
# Again, prevent scale reciprocal to become Inf.
scale = scale.clamp(
min=torch.finfo(scale_dtype).tiny, max=torch.finfo(scale_dtype).max
)
return scale, zero_point


def choose_qparams_and_quantize_affine_qqq(
Expand Down

0 comments on commit 9a43a80

Please sign in to comment.