Skip to content

Commit

Permalink
Refactor QAT into its own module
Browse files Browse the repository at this point in the history
Summary: Refactor QAT into its own module so future QAT features
can live under the same folder without making qat.py longer,
and a separate QAT README can be added in the future.

Test Plan:
python test/quantization/test_qat.py
  • Loading branch information
andrewor14 committed Jul 30, 2024
1 parent ec317fc commit f470f49
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 158 deletions.
26 changes: 13 additions & 13 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
_GenericFakeQuantize,
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_fake_quantize_per_channel_group(self):
x2 = copy.deepcopy(x)

# fake quant op
out = fake_quantize_per_channel_group(
out = _fake_quantize_per_channel_group(
x, s, zp, qmin, qmax, group_size,
)
out.sum().backward()
Expand All @@ -110,7 +110,7 @@ def test_fake_quantize_per_token(self):
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
out = _fake_quantize_per_token(x, s, zp, qmin, qmax)
out.sum().backward()

# compare against PTQ ops
Expand All @@ -135,7 +135,7 @@ def _set_ptq_weight(
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.api import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
Expand Down Expand Up @@ -167,7 +167,7 @@ def _set_ptq_weight(

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

group_size = 128
Expand All @@ -192,7 +192,7 @@ def test_qat_8da4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer

group_size = 16
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_qat_8da4w_quantizer(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat.api import Int8DynActInt4WeightQATQuantizer

with torch.device("meta"):
m = M()
Expand All @@ -241,7 +241,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.api import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
enable_8da4w_fake_quant,
Expand Down Expand Up @@ -294,7 +294,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
"""
from torchao.quantization.prototype.qat import (
from torchao.quantization.prototype.qat.api import (
Int8DynActInt4WeightQATQuantizer,
disable_8da4w_fake_quant,
)
Expand Down Expand Up @@ -425,7 +425,7 @@ def test_qat_4w_primitives(self):
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
Expand Down Expand Up @@ -455,7 +455,7 @@ def test_qat_4w_linear(self):
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.prototype.qat.api import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
Expand Down
17 changes: 17 additions & 0 deletions torchao/quantization/prototype/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from .api import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)

__all__ = [
"disable_4w_fake_quant",
"disable_8da4w_fake_quant",
"enable_4w_fake_quant",
"enable_8da4w_fake_quant",
"Int4WeightOnlyQATQuantizer",
"Int8DynActInt4WeightQATQuantizer",
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List, Optional, Tuple
from typing import Any, Optional

import torch
import torch.nn.functional as F
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.library import impl

from torchao.quantization.GPTQ import (
_check_linear_int4_k,
Expand All @@ -20,14 +18,13 @@
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.quant_primitives import (
fake_quantize_affine_cachemask,
ZeroPointDomain,
)
from torchao.quantization.quant_primitives import ZeroPointDomain
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import (
_get_per_token_block_size,
get_group_qparams_symmetric,
from torchao.quantization.utils import get_group_qparams_symmetric
from .utils import (
_choose_qparams_per_token_asymmetric,
_fake_quantize_per_channel_group,
_fake_quantize_per_token,
)


Expand Down Expand Up @@ -163,7 +160,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
x_fq = _fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)
else:
Expand All @@ -177,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
w_fq = fake_quantize_per_channel_group(
w_fq = _fake_quantize_per_channel_group(
self.weight,
weight_scales,
weight_zp,
Expand Down Expand Up @@ -349,7 +346,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
scales, zero_points = get_groupwise_affine_qparams(
self.weight, n_bit, self.groupsize, self.scales_precision,
)
w_fq = fake_quantize_per_channel_group(
w_fq = _fake_quantize_per_channel_group(
self.weight,
scales,
zero_points,
Expand All @@ -373,135 +370,3 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
"""
if isinstance(mod, Int4WeightOnlyQATLinear):
mod.disable_fake_quant()


# ========================
# | QUANT PRIMITIVES |
# ========================

class _GenericFakeQuantize(torch.autograd.Function):
"""
Implementation of generic fake quantize with backward STE.
With the appropriate input tensor shape, this can be used to express
grouped per channel fake quantize or per token fake quantize.
"""

@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
block_size: List[int],
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
# Note: for bf16 inputs, casting them to fp32 has the unexpected
# side effect of reducing memory footprint significantly, presumably
# because bf16 * fp32 kernels are not as memory efficient
assert input.dtype == torch.float32
assert scales.dtype == torch.float32
assert zero_points.dtype == torch.int32

(fq, mask) = fake_quantize_affine_cachemask(
input,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain,
)

ctx.save_for_backward(mask)
return fq

@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None, None

def fake_quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
group_size: int,
zero_point_domain: ZeroPointDomain=ZeroPointDomain.INT,
) -> torch.Tensor:
assert group_size > 1
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
block_size = (1, group_size)
return _GenericFakeQuantize.apply(
input, scales, zero_points, quant_min, quant_max, block_size, zero_point_domain,
)

def fake_quantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
) -> torch.Tensor:
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check

_per_token_quant_qparam_dim_check(input, scales, zero_points)
block_size = _get_per_token_block_size(input)
fq_input = input.to(torch.float32)
fq = _GenericFakeQuantize.apply(
fq_input, scales, zero_points, quant_min, quant_max, block_size,
)
return fq.reshape_as(input).to(input.dtype)

# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
# The version in pytorch does not have backward support yet so we add
# it here for now until https://github.com/pytorch/pytorch/pull/123452
# is landed.
def _choose_qparams_per_token_asymmetric(
input: torch.Tensor,
scales_precision: torch.dtype = torch.float32,
zero_points_precision: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32/float16 Tensor
scales_precision (torch.dtype): precision of returned scales
zero_points_precision (torch.dtype): precision of returned zero points
Returns:
scales and zero_points, both float32 Tensors
"""
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
qmin, qmax = -128, 127
min_val = torch.amin(input, dim=-1, keepdim=True)
max_val = torch.amax(input, dim=-1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
eps = torch.finfo(torch.float32).eps # use xnnpack eps?

# scale
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = scale.clamp(min=eps)

# zero point
descaled_min = min_val_neg / scale
descaled_max = max_val_pos / scale
zero_point_from_min_error = qmin + descaled_min
zero_point_from_max_error = qmax + descaled_max
zero_point = torch.where(
zero_point_from_min_error + zero_point_from_max_error > 0,
qmin - descaled_min,
qmax - descaled_max,
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()

return scale.to(scales_precision), zero_point.to(zero_points_precision)
Loading

0 comments on commit f470f49

Please sign in to comment.