Skip to content

Commit

Permalink
Merge branch 'main' into dyn_quant
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryzh168 authored May 16, 2024
2 parents b43bce7 + cae3d82 commit 166353f
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 37 deletions.
2 changes: 1 addition & 1 deletion test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def test_weight_only_quant_force_mixed_mm(self, device, dtype):
m_c = torch.compile(m, mode="max-autotune")
y_wo, (code,) = run_and_get_code(m_c, x)
sqnr = compute_error(y_ref, y_wo)
self.assertGreater(sqnr, 43.0)
self.assertGreaterEqual(sqnr, 42.75)
if device == "cuda":
self.assertTrue("mixed_mm" in code)

Expand Down
47 changes: 35 additions & 12 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torchao.quantization.prototype.qat import (
_choose_qparams_per_token_asymmetric,
_GenericFakeQuantize,
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
Expand Down Expand Up @@ -58,7 +59,7 @@ def _get_qmin_qmax(self, n_bit: int):
qmax = 2 ** (n_bit - 1) - 1
return (qmin, qmax)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_channel_group(self):
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
Expand All @@ -67,6 +68,7 @@ def test_fake_quantize_per_channel_group(self):
torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
(s, zp) = get_group_qparams_symmetric(x, n_bit, group_size)
zp = zp.to(torch.int32)
x2 = copy.deepcopy(x)

# fake quant op
Expand All @@ -84,18 +86,15 @@ def test_fake_quantize_per_channel_group(self):
)
torch.testing.assert_close(out, out_ptq, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_per_token(self):
(qmin, qmax) = self._get_qmin_qmax(8)

torch.manual_seed(self.SEED)
x = torch.randn(100, 256).requires_grad_()
x2 = copy.deepcopy(x)
# TODO: use torch.ops.aten.quantized_decomposed version instead
(s, zp) = _choose_qparams_per_token_asymmetric(
x,
torch.int8, # not used
)
(s, zp) = _choose_qparams_per_token_asymmetric(x, torch.float32, torch.int32)

# fake quant op
out = fake_quantize_per_token(x, s, zp, qmin, qmax)
Expand Down Expand Up @@ -130,7 +129,7 @@ def _set_ptq_weight(
ptq_linear.scales = s
ptq_linear.zeros = zp

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATLinear
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
Expand All @@ -155,7 +154,7 @@ def test_qat_8da4w_linear(self):
ptq_out = ptq_linear(x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightQuantizer
Expand Down Expand Up @@ -189,7 +188,7 @@ def test_qat_8da4w_quantizer(self):
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_meta_weights(self):
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

Expand All @@ -201,7 +200,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
qat_model = qat_quantizer.prepare(m)
self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values()))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
Expand Down Expand Up @@ -254,7 +253,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
qat_out2 = qat_model2(*x2)
torch.testing.assert_close(qat_out, qat_out2, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
"""
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
Expand Down Expand Up @@ -299,6 +298,30 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
torch.testing.assert_close(nn_model.linear2.weight, qat_model.linear2.weight, atol=0, rtol=0)
torch.testing.assert_close(nn_model.sub.linear.weight, qat_model.sub.linear.weight, atol=0, rtol=0)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_generic_fake_quantize(self):
"""
Test that the generic fake quantize used in 8da4w QAT matches
the numerics of existing fake quantize ops in Pytorch in both
the forward and the backward passes.
"""
(qmin, qmax) = self._get_qmin_qmax(4)
py_input = torch.randn(16, 64).float().requires_grad_()
py_s = torch.randn(16).float()
py_zp = torch.randint(qmax, size=(16,), dtype=torch.int32)
py_out = torch.fake_quantize_per_channel_affine(py_input, py_s, py_zp, 0, qmin, qmax)
py_out.sum().backward()

ao_input = copy.deepcopy(py_input)
ao_input.grad.data.zero_()
ao_s = copy.deepcopy(py_s).reshape(-1, 1)
ao_zp = copy.deepcopy(py_zp).reshape(-1, 1)
ao_out = _GenericFakeQuantize.apply(ao_input, ao_s, ao_zp, qmin, qmax)
ao_out.sum().backward()

torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_quantize_activation_per_token_abs_max_zero_input(self):
quantized_ref, scale_ref = quantize_activation_per_token_absmax(input)


@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower")
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.4 or lower")
def test_quantize_dequantize_group_sym(self):
input = torch.randn(10, 10)
mapping_type = MappingType.SYMMETRIC
Expand Down
64 changes: 43 additions & 21 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional, Tuple
from typing import Any, Tuple

import torch
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib
from torch.library import impl

from torchao.quantization.utils import TORCH_VERSION_AFTER_2_3
from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4
from torchao.quantization.quant_primitives import get_group_qparams_symmetric
from torchao.quantization.unified import TwoStepQuantizer


if TORCH_VERSION_AFTER_2_3:
if TORCH_VERSION_AFTER_2_4:
from torchao.quantization.GPTQ import (
_replace_linear_8da4w,
Int8DynActInt4WeightLinear,
Expand Down Expand Up @@ -54,7 +54,7 @@ def prepare(
self.precision,
self.scales_precision,
Int8DynActInt4WeightQATLinear,
copy_weights = True,
copy_weights=True,
)
return model

Expand Down Expand Up @@ -95,7 +95,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module):
quantized_linear.zeros = zp
else:
_convert_qat_linear_8da4w(child)

class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int8 dynamic per token fake
Expand Down Expand Up @@ -131,6 +131,8 @@ def __init__(
self.groupsize = groupsize
self.precision = precision
self.scales_precision = scales_precision
# TODO: make this configurable?
self.zero_points_precision = torch.int32
self._fake_quant_enabled = True

def enable_fake_quant(self, enabled: bool = True):
Expand All @@ -142,8 +144,8 @@ def disable_fake_quant(self):
def forward(self, x: torch.Tensor) -> torch.Tensor:
# activations: int8 dynamic asymmetric quant
if self._fake_quant_enabled:
(act_scales, act_zp) =_choose_qparams_per_token_asymmetric(
x, torch.int8, # dtype not used
(act_scales, act_zp) = _choose_qparams_per_token_asymmetric(
x, self.scales_precision, self.zero_points_precision,
)
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
x_fq = fake_quantize_per_token(
Expand All @@ -157,6 +159,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
(weight_scales, weight_zp) = get_group_qparams_symmetric(
self.weight, 4, self.groupsize, self.scales_precision,
)
# TODO: pass zp dtype to `get_group_qparams_symmetric` instead
weight_zp = weight_zp.to(self.zero_points_precision)
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
w_fq = fake_quantize_per_channel_group(
self.weight,
Expand Down Expand Up @@ -190,6 +194,20 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
if isinstance(mod, Int8DynActInt4WeightQATLinear):
mod.disable_fake_quant()

else: # not TORCH_VERSION_AFTER_2_4

class Int8DynActInt4WeightQATQuantizer:
def __init__(*args, **kwargs):
raise ValueError(
"Int8DynActInt4WeightQATQuantizer is only supported after PyTorch 2.4+"
)

class Int8DynActInt4WeightQATLinear:
def __init__(*args, **kwargs):
raise ValueError(
"Int8DynActInt4WeightQATLinear is only supported after PyTorch 2.4+"
)


# ========================
# | QUANT PRIMITIVES |
Expand All @@ -205,13 +223,14 @@ class _GenericFakeQuantize(torch.autograd.Function):

@staticmethod
def forward(ctx, input, scales, zero_points, quant_min, quant_max):
# Note: this diverges from `torch.fake_quantize_per_channel_affine`,
# which rounds first before adding the zero points. However, this
# is what `quantize_per_channel_group` and `quantize_per_token`
# do and here we try to match that behavior as closely as possible.
q = input.div(scales).add(zero_points).round()
# Note: for bf16 inputs, casting them to fp32 has the unexpected
# side effect of reducing memory footprint significantly, presumably
# because bf16 * fp32 kernels are not as memory efficient
assert input.dtype == torch.float32
assert scales.dtype == torch.float32
assert zero_points.dtype == torch.int32
q = input.mul(1.0 / scales).round().add(zero_points)
dq = q.clamp(quant_min, quant_max).sub(zero_points).mul(scales)
# TODO: do we need this mask?
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
ctx.save_for_backward(mask)
return dq
Expand Down Expand Up @@ -239,14 +258,13 @@ def fake_quantize_per_channel_group(
assert group_size > 1
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
assert torch.isnan(input).sum() == 0
grouped_input = input.reshape(-1, group_size)
grouped_input = input.reshape(-1, group_size).to(torch.float32)
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
fq = _GenericFakeQuantize.apply(
grouped_input, scales, zero_points, quant_min, quant_max,
)
return fq.reshape_as(input)
return fq.reshape_as(input).to(input.dtype)

# TODO: move this to core
quantized_decomposed_lib.define(
Expand All @@ -266,17 +284,20 @@ def fake_quantize_per_token(
from torch.ao.quantization.fx._decomposed import _per_token_quant_qparam_dim_check

_per_token_quant_qparam_dim_check(input, scales, zero_points)
return _GenericFakeQuantize.apply(
input, scales, zero_points, quant_min, quant_max,
fq_input = input.to(torch.float32)
fq = _GenericFakeQuantize.apply(
fq_input, scales, zero_points, quant_min, quant_max,
)
return fq.reshape_as(input).to(input.dtype)

# TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
# The version in pytorch does not have backward support yet so we add
# it here for now until https://github.com/pytorch/pytorch/pull/123452
# is landed.
def _choose_qparams_per_token_asymmetric(
input: torch.Tensor,
dtype: torch.dtype,
scales_precision: torch.dtype = torch.float32,
zero_points_precision: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
Expand All @@ -285,7 +306,8 @@ def _choose_qparams_per_token_asymmetric(
Args:
input (torch.Tensor): original float32/float16 Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
scales_precision (torch.dtype): precision of returned scales
zero_points_precision (torch.dtype): precision of returned zero points
Returns:
scales and zero_points, both float32 Tensors
Expand Down Expand Up @@ -314,4 +336,4 @@ def _choose_qparams_per_token_asymmetric(
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()

return scale.to(torch.float32), zero_point.to(torch.float32)
return scale.to(scales_precision), zero_point.to(zero_points_precision)
4 changes: 2 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def quantize_affine(

if zero_point_domain == ZeroPointDomain.INT:
quant = torch.clamp(
torch.round(input / scale) + zero_point, quant_min, quant_max
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
).to(output_dtype)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT
Expand Down Expand Up @@ -764,7 +764,7 @@ def groupwise_affine_dequantize_tensor(
)


# TODO: replace this with torch.ao.quantization.PerChannelMinMaxObserver
# TODO: separate scale and zero point precision
def get_group_qparams_symmetric(w, n_bit=4, groupsize=128, precision=torch.float32):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
Expand Down

0 comments on commit 166353f

Please sign in to comment.