Skip to content

Commit

Permalink
Add general fake_quantize_affine op (pytorch#492)
Browse files Browse the repository at this point in the history
Summary: Add a general `fake_quantize_affine` op that simulates
`quantize_affine` + `dequantize_affine` but without casting the
intermediate quantized values to lower bit-widths, intended for
quantization-aware training (QAT).

Test Plan:
python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine
  • Loading branch information
andrewor14 authored Jul 11, 2024
1 parent 2526b17 commit caafb88
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 5 deletions.
20 changes: 20 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest
import torch
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -503,5 +504,24 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):

self.assertTrue(torch.equal(w_bf16, w_bf16_ref))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_affine(self):
input = torch.randn(10, 10)

mapping_type = MappingType.SYMMETRIC
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
torch.testing.assert_close(dequantized, fake_quantized)

if __name__ == "__main__":
unittest.main()
110 changes: 105 additions & 5 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"choose_qparams_affine",
"quantize_affine",
"dequantize_affine",
"fake_quantize_affine",
]

class MappingType(Enum):
Expand Down Expand Up @@ -203,14 +204,34 @@ def _quantize_affine(
output_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
zero_point_domain: str = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
"""
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
return _quantize_affine_no_dtype_cast(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
).to(output_dtype)


def _quantize_affine_no_dtype_cast(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: int,
quant_max: int,
zero_point_domain: str = ZeroPointDomain.INT.name,
) -> torch.Tensor:
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
original_shape = input.shape
input = input.view(shape_for_reduction)
Expand All @@ -224,7 +245,7 @@ def _quantize_affine(
if zero_point_domain == ZeroPointDomain.INT.name:
quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
Expand All @@ -233,11 +254,12 @@ def _quantize_affine(
torch.clamp(
torch.round((input - min_val) / scale),
quant_min, quant_max)
).to(output_dtype)
)
quant = quant.view(original_shape)

return quant


def dequantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
Expand Down Expand Up @@ -283,6 +305,7 @@ def dequantize_affine(
output_dtype=output_dtype,
)


@register_custom_op
def _dequantize_affine(
input: torch.Tensor,
Expand All @@ -292,7 +315,7 @@ def _dequantize_affine(
input_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: str = "INT",
zero_point_domain: str = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
Expand All @@ -303,7 +326,28 @@ def _dequantize_affine(
assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}"
assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
output_dtype,
)


def _dequantize_affine_no_dtype_check(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: int,
quant_max: int,
zero_point_domain: str = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
original_shape = input.shape
input = input.view(shape_for_reduction)
Expand Down Expand Up @@ -335,6 +379,62 @@ def _dequantize_affine(

return dequant.view(original_shape).to(output_dtype)


def fake_quantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
"""
General fake quantize op for quantization-aware training (QAT).
This is equivalent to calling `quantize_affine` + `dequantize_affine`
but without the dtype casts.
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
"""
input_dtype = input.dtype
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
q = _quantize_affine_no_dtype_cast(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain.name,
)
dq = _dequantize_affine_no_dtype_check(
q,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain.name,
output_dtype=input_dtype,
)
return dq


def choose_qparams_affine(
input: torch.Tensor,
mapping_type: MappingType,
Expand Down

0 comments on commit caafb88

Please sign in to comment.