Skip to content

Commit

Permalink
Add test for quantizing models with hierarchies in qat 8da4w (#157)
Browse files Browse the repository at this point in the history
Summary:
att

Test Plan:
python test/quantization/test_qat.py

Reviewers:

Subscribers:

Tasks:

Tags:

Co-authored-by: Jerry Zhang <[email protected]>
  • Loading branch information
andrewor14 and jerryzh168 authored Apr 22, 2024
1 parent 3124382 commit 2003325
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 10 deletions.
16 changes: 16 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,30 @@


# TODO: put this in a common test utils file
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 32).to(torch.float),)

def forward(self, x):
return self.linear(x)

class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
self.sub = Sub()
self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float)

def example_inputs(self):
return (torch.randn(1, 64).to(torch.float),)

def forward(self, x):
x = self.linear1(x)
x = self.sub(x)
x = self.linear2(x)
return x

Expand Down Expand Up @@ -160,6 +173,9 @@ def test_qat_8da4w_quantizer(self):
self._set_ptq_weight(
ptq_model.linear1, qat_model.linear1.weight, group_size,
)
self._set_ptq_weight(
ptq_model.sub.linear, qat_model.sub.linear.weight, group_size,
)
self._set_ptq_weight(
ptq_model.linear2, qat_model.linear2.weight, group_size,
)
Expand Down
20 changes: 10 additions & 10 deletions torchao/quantization/prototype/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
dynamic per token fake quantized activations and int4 fake quantized
grouped per channel weights.
"""

def __init__(
self,
groupsize: int = 256,
Expand All @@ -37,7 +37,7 @@ def __init__(
self.padding_allowed: bool = padding_allowed
self.precision: torch.dtype = precision
self.scales_precision: torch.dtype = scales_precision

def prepare(
self,
model: torch.nn.Module,
Expand All @@ -53,7 +53,7 @@ def prepare(
Int8DynActInt4WeightQATLinear,
)
return model

def convert(
self,
model: torch.nn.Module,
Expand All @@ -62,19 +62,19 @@ def convert(
) -> torch.nn.Module:
# TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear
pass


class Int8DynActInt4WeightQATLinear(torch.nn.Linear):
"""
This module implements a linear layer with int8 dynamic per token fake
quantized activations with int4 fake quantized grouped per channel weights.
args:
groupsize: the number of elements in each quantized group for weights
precision: precision of weights
scales_precision: precision of per group scales and zero points
"""

def __init__(
self,
in_features: int,
Expand All @@ -97,7 +97,7 @@ def __init__(
assert not bias, "require bias=False"
self.groupsize = groupsize
self.scales_precision = scales_precision

def forward(self, x: torch.Tensor) -> torch.Tensor:
# activations: int8 dynamic asymmetric quant
(act_qmin, act_qmax) = self._get_qmin_qmax(8)
Expand All @@ -107,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
x_fq = fake_quantize_per_token(
x, act_scales, act_zp, act_qmin, act_qmax,
)

# weights: int4 grouped per channel symmetric quant
(weight_qmin, weight_qmax) = self._get_qmin_qmax(4)
(weight_scales, weight_zp) = get_group_qparams_symmetric(
Expand All @@ -122,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.groupsize,
)
return torch.nn.functional.linear(x_fq, w_fq)

def _get_qmin_qmax(self, n_bit: int):
qmin = -(2 ** (n_bit - 1))
qmax = 2 ** (n_bit - 1) - 1
Expand Down

0 comments on commit 2003325

Please sign in to comment.