Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for int4 weight-only QAT #383

Merged
merged 2 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 175 additions & 18 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,41 @@
fake_quantize_per_channel_group,
fake_quantize_per_token,
)
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
ZeroPointDomain,
)
from torchao.quantization.utils import (
get_group_qparams_symmetric,
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
)
from torchao.utils import TORCH_VERSION_AFTER_2_4


# TODO: put this in a common test utils file
_CUDA_IS_AVAILABLE = torch.cuda.is_available()

class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float)
self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 32).to(torch.float),)
return (torch.randn(1, 256).to(torch.float),)

def forward(self, x):
return self.linear(x)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
self.sub = Sub()
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)
return (torch.randn(1, 512).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
Expand Down Expand Up @@ -111,23 +121,46 @@ def test_fake_quantize_per_token(self):

def _set_ptq_weight(
self,
ptq_linear: "Int8DynActInt4WeightLinear",
fp32_weight: torch.Tensor,
group_size: int,
ptq_linear: torch.nn.Module,
qat_linear: torch.nn.Module,
):
"""
Set the weight to the quantized version of the given fp32 weights,
for making linear outputs comparable with QAT.
"""
from torchao.quantization.GPTQ import (
Int8DynActInt4WeightLinear,
WeightOnlyInt4Linear,
)
from torchao.quantization.prototype.qat import (
Int8DynActInt4WeightQATLinear,
Int4WeightOnlyQATLinear,
)
n_bit = 4
(qmin, qmax) = self._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
fp32_weight = qat_linear.weight
group_size = qat_linear.groupsize
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
)
ptq_linear.weight = q_weight
ptq_linear.scales = s
ptq_linear.zeros = zp
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
qat_linear.weight, n_bit, qat_linear.groupsize,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to("cuda"), qat_linear.inner_k_tiles,
)
ptq_linear.weight = q_weight
ptq_linear.scales_and_zeros = scales_and_zeros
else:
raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear))

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_qat_8da4w_linear(self):
Expand All @@ -144,7 +177,7 @@ def test_qat_8da4w_linear(self):
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
self._set_ptq_weight(ptq_linear, qat_linear)

# Compare linear values
torch.manual_seed(self.SEED)
Expand Down Expand Up @@ -280,7 +313,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
loss_fn1 = torch.nn.CrossEntropyLoss()
loss_fn2 = torch.nn.CrossEntropyLoss()
example_inputs = nn_model.example_inputs()
target = torch.randn(1, 64).float()
target = torch.randn(1, 512).float()
output1 = nn_model(*example_inputs)
output2 = qat_model(*example_inputs)
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
Expand Down Expand Up @@ -322,6 +355,130 @@ def test_qat_generic_fake_quantize(self):
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)

def _assert_close_4w(self, val, ref):
# Note: for int4 weight-only quantization, we do not expect exact match
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
# Here we use the same error bar as PyTorch core to determine closeness:
# https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
mean_err = ((val - ref) / ref).mean().abs()
print(mean_err)
self.assertTrue(mean_err < 0.05)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
inner_k_tiles = 8
scales_precision = torch.bfloat16
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
weight = torch.randn(512, 256, dtype=dtype, device=device)

# PTQ
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
weight, n_bit, group_size, scales_precision,
)
q_weight = torch.ops.aten._convert_weight_to_int4pack(
q_weight.to(device), inner_k_tiles,
)
ptq_out = torch.ops.aten._weight_int4pack_mm(
x, q_weight, group_size, scales_and_zeros
)

# QAT
block_size = (1, group_size)
quant_min = 0
quant_max = 2 ** n_bit - 1
scales, zero_points = get_groupwise_affine_qparams(
weight, n_bit, group_size, scales_precision,
)
w_fq = fake_quantize_affine(
weight,
block_size,
scales,
zero_points,
torch.int32,
quant_min,
quant_max,
zero_point_domain = ZeroPointDomain.FLOAT,
)
qat_out = torch.nn.functional.linear(x, w_fq)

self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear

group_size = 128
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
qat_linear = Int4WeightOnlyQATLinear(
256, 688, bias=False, groupsize=group_size, device=device,
)
ptq_linear = WeightOnlyInt4Linear(
256, 688, bias=False, groupsize=group_size, device=device,
)

# Force the weights to be the same
self._set_ptq_weight(ptq_linear, qat_linear)

# Compare linear values
torch.manual_seed(self.SEED)
x = torch.randn(100, 256, dtype=dtype, device=device)
x2 = copy.deepcopy(x)
qat_out = qat_linear(x)
ptq_out = ptq_linear(x2)
self._assert_close_4w(qat_out, ptq_out)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer

group_size = 32
inner_k_tiles = 8
device = torch.device("cuda")
dtype = torch.bfloat16
torch.manual_seed(self.SEED)
m = M().to(device).to(dtype)
m2 = copy.deepcopy(m)
qat_quantizer = Int4WeightOnlyQATQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
ptq_quantizer = Int4WeightOnlyQuantizer(
groupsize=group_size, inner_k_tiles=inner_k_tiles,
)
qat_model = qat_quantizer.prepare(m)
ptq_model = ptq_quantizer.quantize(m2)

# Compare model values
torch.manual_seed(self.SEED)
x = [i.to(device).to(dtype) for i in m.example_inputs()]
x2 = copy.deepcopy(x)
qat_out = qat_model(*x)
ptq_out = ptq_model(*x2)
self._assert_close_4w(qat_out, ptq_out)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
24 changes: 24 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torchao.quantization.quant_primitives import (
fake_quantize_affine,
fake_quantize_affine_cachemask,
quantize_affine,
dequantize_affine,
choose_qparams_affine,
Expand Down Expand Up @@ -523,5 +524,28 @@ def test_fake_quantize_affine(self):
fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
torch.testing.assert_close(dequantized, fake_quantized)

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
def test_fake_quantize_affine_cachemask(self):
input = torch.randn(10, 10)

mapping_type = MappingType.SYMMETRIC
block_size = list(input.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
quant_min = -127
quant_max = 127
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float)

quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max)
dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max)
(fake_quantized, mask) = fake_quantize_affine_cachemask(
input, block_size, scale, zero_point, dtype, quant_min, quant_max,
)
expected_mask = torch.full(input.shape, True)
torch.testing.assert_close(dequantized, fake_quantized)
torch.testing.assert_close(expected_mask, mask)

if __name__ == "__main__":
unittest.main()
Loading
Loading