Skip to content

Commit

Permalink
Add convert path for 8da4w QAT
Browse files Browse the repository at this point in the history
Summary: This commit implements the convert path for 8da4w QAT,
which swaps the QAT linear with the quantized linear, and
quantizing the weights the same way as the PTQ flow. The result
is a model that is identical to the one output by the PTQ flow.

Test Plan:
python test/quantization/test_qat.py -k test_qat_8da4w_quantizer

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
  • Loading branch information
andrewor14 committed Apr 24, 2024
1 parent 92e7f12 commit 36fad62
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 4 deletions.
12 changes: 12 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def test_qat_8da4w_quantizer(self):
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
41 changes: 37 additions & 4 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.GPTQ import _replace_linear_8da4w
from torchao.quantization.GPTQ import (
_replace_linear_8da4w,
Int8DynActInt4WeightLinear,
)

class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
"""
Expand Down Expand Up @@ -60,10 +63,38 @@ def convert(
*args: Any,
**kwargs: Any
) -> torch.nn.Module:
# TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear
pass

_convert_qat_linear_8da4w(model)
return model

def _convert_qat_linear_8da4w(module: torch.nn.Module):
"""
Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
"""
for name, child in module.named_children():
if isinstance(child, Int8DynActInt4WeightQATLinear):
quantized_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
groupsize=child.groupsize,
precision=child.precision,
scales_precision=child.scales_precision,
)
setattr(module, name, quantized_linear)

# Load weights and qparams into quantized linear
n_bit = 4
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize,
)
quantized_linear.weight = q_weight
quantized_linear.scales = s
quantized_linear.zeros = zp
else:
_convert_qat_linear_8da4w(child)

class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int8 dynamic per token fake
Expand Down Expand Up @@ -96,6 +127,7 @@ def __init__(
), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
assert not bias, "require bias=False"
self.groupsize = groupsize
self.precision = precision
self.scales_precision = scales_precision

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
return torch.nn.functional.linear(x_fq, w_fq)

# TODO: move this to common util
def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
Expand Down

0 comments on commit 36fad62

Please sign in to comment.