diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 647629cfd8..64518c0599 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -22,10 +22,22 @@ # TODO: put this in a common test utils file +class Sub(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float) + + def example_inputs(self): + return (torch.randn(1, 32).to(torch.float),) + + def forward(self, x): + return self.linear(x) + class M(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.sub = Sub() self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) def example_inputs(self): @@ -33,6 +45,7 @@ def example_inputs(self): def forward(self, x): x = self.linear1(x) + x = self.sub(x) x = self.linear2(x) return x @@ -160,6 +173,9 @@ def test_qat_8da4w_quantizer(self): self._set_ptq_weight( ptq_model.linear1, qat_model.linear1.weight, group_size, ) + self._set_ptq_weight( + ptq_model.sub.linear, qat_model.sub.linear.weight, group_size, + ) self._set_ptq_weight( ptq_model.linear2, qat_model.linear2.weight, group_size, ) diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 87f28ec96a..621f4bf80f 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -24,7 +24,7 @@ class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): dynamic per token fake quantized activations and int4 fake quantized grouped per channel weights. """ - + def __init__( self, groupsize: int = 256, @@ -37,7 +37,7 @@ def __init__( self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision self.scales_precision: torch.dtype = scales_precision - + def prepare( self, model: torch.nn.Module, @@ -53,7 +53,7 @@ def prepare( Int8DynActInt4WeightQATLinear, ) return model - + def convert( self, model: torch.nn.Module, @@ -62,19 +62,19 @@ def convert( ) -> torch.nn.Module: # TODO: replace Int8DynActInt4WeightQATLinear -> Int8DynActInt4WeightLinear pass - - + + class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. - + args: groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points """ - + def __init__( self, in_features: int, @@ -97,7 +97,7 @@ def __init__( assert not bias, "require bias=False" self.groupsize = groupsize self.scales_precision = scales_precision - + def forward(self, x: torch.Tensor) -> torch.Tensor: # activations: int8 dynamic asymmetric quant (act_qmin, act_qmax) = self._get_qmin_qmax(8) @@ -107,7 +107,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x_fq = fake_quantize_per_token( x, act_scales, act_zp, act_qmin, act_qmax, ) - + # weights: int4 grouped per channel symmetric quant (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) (weight_scales, weight_zp) = get_group_qparams_symmetric( @@ -122,7 +122,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.groupsize, ) return torch.nn.functional.linear(x_fq, w_fq) - + def _get_qmin_qmax(self, n_bit: int): qmin = -(2 ** (n_bit - 1)) qmax = 2 ** (n_bit - 1) - 1