Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add convert path for 8da4w QAT #154

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def test_qat_8da4w_quantizer(self):
ptq_out = ptq_model(*x2)
torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0)

# Convert QAT model and compare model values
converted_model = qat_quantizer.convert(qat_model)
converted_out = converted_model(*x)
torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0)

# Compare converted state dict
ptq_state_dict = ptq_model.state_dict()
converted_state_dict = converted_model.state_dict()
self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys())
for k in ptq_state_dict.keys():
torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0)


if __name__ == "__main__":
unittest.main()
41 changes: 37 additions & 4 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


if TORCH_VERSION_AFTER_2_3:
from torchao.quantization.GPTQ import _replace_linear_8da4w
from torchao.quantization.GPTQ import (
_replace_linear_8da4w,
Int8DynActInt4WeightLinear,
)

class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
"""
Expand Down Expand Up @@ -60,10 +63,38 @@ def convert(
*args: Any,
**kwargs: Any
) -> torch.nn.Module:
# TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear
pass

_convert_qat_linear_8da4w(model)
return model

def _convert_qat_linear_8da4w(module: torch.nn.Module):
"""
Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
"""
for name, child in module.named_children():
if isinstance(child, Int8DynActInt4WeightQATLinear):
quantized_linear = Int8DynActInt4WeightLinear(
child.in_features,
child.out_features,
bias=False,
groupsize=child.groupsize,
precision=child.precision,
scales_precision=child.scales_precision,
)
setattr(module, name, quantized_linear)

# Load weights and qparams into quantized linear
n_bit = 4
(qmin, qmax) = child._get_qmin_qmax(n_bit)
(s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize)
q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group(
child.weight, s, zp, qmin, qmax, torch.int8, child.groupsize,
)
quantized_linear.weight = q_weight
quantized_linear.scales = s
quantized_linear.zeros = zp
else:
_convert_qat_linear_8da4w(child)

class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int8 dynamic per token fake
Expand Down Expand Up @@ -96,6 +127,7 @@ def __init__(
), f"require in_features:{in_features} % groupsize:{groupsize} == 0"
assert not bias, "require bias=False"
self.groupsize = groupsize
self.precision = precision
self.scales_precision = scales_precision

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -123,6 +155,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
return torch.nn.functional.linear(x_fq, w_fq)

# TODO: move this to common util
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, this probably doesn't have to live here, I'm also adding a new util for this as well in quant_primitives.py as well

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
Expand Down
Loading