Skip to content

Commit 5fe9967

Browse files
authored
Add support for int4 weight-only QAT (pytorch#383)
* Add cachemask variant for fake_quantize_affine Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask * Add support for int4 weight-only QAT Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar
1 parent aaf209d commit 5fe9967

File tree

3 files changed

+477
-41
lines changed

3 files changed

+477
-41
lines changed

test/quantization/test_qat.py

+175-18
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,41 @@
1818
fake_quantize_per_channel_group,
1919
fake_quantize_per_token,
2020
)
21-
from torchao.quantization.utils import get_group_qparams_symmetric
21+
from torchao.quantization.quant_primitives import (
22+
fake_quantize_affine,
23+
ZeroPointDomain,
24+
)
25+
from torchao.quantization.utils import (
26+
get_group_qparams_symmetric,
27+
get_groupwise_affine_qparams,
28+
groupwise_affine_quantize_tensor,
29+
)
2230
from torchao.utils import TORCH_VERSION_AFTER_2_4
2331

2432

2533
# TODO: put this in a common test utils file
34+
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
35+
2636
class Sub(torch.nn.Module):
2737
def __init__(self):
2838
super().__init__()
29-
self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float)
39+
self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float)
3040

3141
def example_inputs(self):
32-
return (torch.randn(1, 32).to(torch.float),)
42+
return (torch.randn(1, 256).to(torch.float),)
3343

3444
def forward(self, x):
3545
return self.linear(x)
3646

3747
class M(torch.nn.Module):
3848
def __init__(self):
3949
super().__init__()
40-
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
50+
self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float)
4151
self.sub = Sub()
42-
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)
52+
self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float)
4353

4454
def example_inputs(self):
45-
return (torch.randn(1, 64).to(torch.float),)
55+
return (torch.randn(1, 512).to(torch.float),)
4656

4757
def forward(self, x):
4858
x = self.linear1(x)
@@ -111,23 +121,46 @@ def test_fake_quantize_per_token(self):
111121

112122
def _set_ptq_weight(
113123
self,
114-
ptq_linear: "Int8DynActInt4WeightLinear",
115-
fp32_weight: torch.Tensor,
116-
group_size: int,
124+
ptq_linear: torch.nn.Module,
125+
qat_linear: torch.nn.Module,
117126
):
118127
"""
119128
Set the weight to the quantized version of the given fp32 weights,
120129
for making linear outputs comparable with QAT.
121130
"""
131+
from torchao.quantization.GPTQ import (
132+
Int8DynActInt4WeightLinear,
133+
WeightOnlyInt4Linear,
134+
)
135+
from torchao.quantization.prototype.qat import (
136+
Int8DynActInt4WeightQATLinear,
137+
Int4WeightOnlyQATLinear,
138+
)
122139
n_bit = 4
123140
(qmin, qmax) = self._get_qmin_qmax(n_bit)
124-
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
125-
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
126-
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
127-
)
128-
ptq_linear.weight = q_weight
129-
ptq_linear.scales = s
130-
ptq_linear.zeros = zp
141+
if isinstance(ptq_linear, Int8DynActInt4WeightLinear):
142+
assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear)
143+
fp32_weight = qat_linear.weight
144+
group_size = qat_linear.groupsize
145+
(s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size)
146+
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
147+
fp32_weight, s, zp, qmin, qmax, torch.int8, group_size,
148+
)
149+
ptq_linear.weight = q_weight
150+
ptq_linear.scales = s
151+
ptq_linear.zeros = zp
152+
elif isinstance(ptq_linear, WeightOnlyInt4Linear):
153+
assert isinstance(qat_linear, Int4WeightOnlyQATLinear)
154+
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
155+
qat_linear.weight, n_bit, qat_linear.groupsize,
156+
)
157+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
158+
q_weight.to("cuda"), qat_linear.inner_k_tiles,
159+
)
160+
ptq_linear.weight = q_weight
161+
ptq_linear.scales_and_zeros = scales_and_zeros
162+
else:
163+
raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear))
131164

132165
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
133166
def test_qat_8da4w_linear(self):
@@ -144,7 +177,7 @@ def test_qat_8da4w_linear(self):
144177
)
145178

146179
# Force the weights to be the same
147-
self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size)
180+
self._set_ptq_weight(ptq_linear, qat_linear)
148181

149182
# Compare linear values
150183
torch.manual_seed(self.SEED)
@@ -280,7 +313,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self):
280313
loss_fn1 = torch.nn.CrossEntropyLoss()
281314
loss_fn2 = torch.nn.CrossEntropyLoss()
282315
example_inputs = nn_model.example_inputs()
283-
target = torch.randn(1, 64).float()
316+
target = torch.randn(1, 512).float()
284317
output1 = nn_model(*example_inputs)
285318
output2 = qat_model(*example_inputs)
286319
torch.testing.assert_close(output1, output2, atol=0, rtol=0)
@@ -322,6 +355,130 @@ def test_qat_generic_fake_quantize(self):
322355
torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0)
323356
torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0)
324357

358+
def _assert_close_4w(self, val, ref):
359+
# Note: for int4 weight-only quantization, we do not expect exact match
360+
# because torch._weight_int4pack_mm and torch.mm do not match exactly.
361+
# Here we use the same error bar as PyTorch core to determine closeness:
362+
# https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079
363+
mean_err = ((val - ref) / ref).mean().abs()
364+
print(mean_err)
365+
self.assertTrue(mean_err < 0.05)
366+
367+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
368+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
369+
def test_qat_4w_primitives(self):
370+
n_bit = 4
371+
group_size = 32
372+
inner_k_tiles = 8
373+
scales_precision = torch.bfloat16
374+
device = torch.device("cuda")
375+
dtype = torch.bfloat16
376+
torch.manual_seed(self.SEED)
377+
x = torch.randn(100, 256, dtype=dtype, device=device)
378+
weight = torch.randn(512, 256, dtype=dtype, device=device)
379+
380+
# PTQ
381+
(q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor(
382+
weight, n_bit, group_size, scales_precision,
383+
)
384+
q_weight = torch.ops.aten._convert_weight_to_int4pack(
385+
q_weight.to(device), inner_k_tiles,
386+
)
387+
ptq_out = torch.ops.aten._weight_int4pack_mm(
388+
x, q_weight, group_size, scales_and_zeros
389+
)
390+
391+
# QAT
392+
block_size = (1, group_size)
393+
quant_min = 0
394+
quant_max = 2 ** n_bit - 1
395+
scales, zero_points = get_groupwise_affine_qparams(
396+
weight, n_bit, group_size, scales_precision,
397+
)
398+
w_fq = fake_quantize_affine(
399+
weight,
400+
block_size,
401+
scales,
402+
zero_points,
403+
torch.int32,
404+
quant_min,
405+
quant_max,
406+
zero_point_domain = ZeroPointDomain.FLOAT,
407+
)
408+
qat_out = torch.nn.functional.linear(x, w_fq)
409+
410+
self._assert_close_4w(qat_out, ptq_out)
411+
412+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
413+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
414+
def test_qat_4w_linear(self):
415+
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
416+
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
417+
418+
group_size = 128
419+
device = torch.device("cuda")
420+
dtype = torch.bfloat16
421+
torch.manual_seed(self.SEED)
422+
qat_linear = Int4WeightOnlyQATLinear(
423+
256, 688, bias=False, groupsize=group_size, device=device,
424+
)
425+
ptq_linear = WeightOnlyInt4Linear(
426+
256, 688, bias=False, groupsize=group_size, device=device,
427+
)
428+
429+
# Force the weights to be the same
430+
self._set_ptq_weight(ptq_linear, qat_linear)
431+
432+
# Compare linear values
433+
torch.manual_seed(self.SEED)
434+
x = torch.randn(100, 256, dtype=dtype, device=device)
435+
x2 = copy.deepcopy(x)
436+
qat_out = qat_linear(x)
437+
ptq_out = ptq_linear(x2)
438+
self._assert_close_4w(qat_out, ptq_out)
439+
440+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
441+
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
442+
def test_qat_4w_quantizer(self):
443+
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
444+
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
445+
446+
group_size = 32
447+
inner_k_tiles = 8
448+
device = torch.device("cuda")
449+
dtype = torch.bfloat16
450+
torch.manual_seed(self.SEED)
451+
m = M().to(device).to(dtype)
452+
m2 = copy.deepcopy(m)
453+
qat_quantizer = Int4WeightOnlyQATQuantizer(
454+
groupsize=group_size, inner_k_tiles=inner_k_tiles,
455+
)
456+
ptq_quantizer = Int4WeightOnlyQuantizer(
457+
groupsize=group_size, inner_k_tiles=inner_k_tiles,
458+
)
459+
qat_model = qat_quantizer.prepare(m)
460+
ptq_model = ptq_quantizer.quantize(m2)
461+
462+
# Compare model values
463+
torch.manual_seed(self.SEED)
464+
x = [i.to(device).to(dtype) for i in m.example_inputs()]
465+
x2 = copy.deepcopy(x)
466+
qat_out = qat_model(*x)
467+
ptq_out = ptq_model(*x2)
468+
self._assert_close_4w(qat_out, ptq_out)
469+
470+
# Convert QAT model and compare model values
471+
converted_model = qat_quantizer.convert(qat_model)
472+
converted_out = converted_model(*x)
473+
torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0)
474+
475+
# Compare converted state dict
476+
ptq_state_dict = ptq_model.state_dict()
477+
converted_state_dict = converted_model.state_dict()
478+
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
479+
for k in ptq_state_dict.keys():
480+
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)
481+
325482

326483
if __name__ == "__main__":
327484
unittest.main()

0 commit comments

Comments
 (0)