diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3a6e099576..a7db62a485 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -198,6 +198,7 @@ def test_qat_8da4w_linear(self): ptq_out = ptq_linear(x2) torch.testing.assert_close(ptq_out, qat_out, atol=0, rtol=0) + # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -225,14 +226,6 @@ def test_qat_8da4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(ptq_out, converted_out, atol=0, rtol=0) - # TODO: enable this after supporting aten.eq.default in both subclasses - # Compare converted state dict - # ptq_state_dict = ptq_model.state_dict() - # converted_state_dict = converted_model.state_dict() - # self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - # for k in ptq_state_dict.keys(): - # torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_meta_weights(self): from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer @@ -245,6 +238,20 @@ def test_qat_8da4w_quantizer_meta_weights(self): qat_model = qat_quantizer.prepare(m) self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) + def _copy_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + nn_linear.weight = torch.nn.Parameter(subclass_linear.weight.original_tensor) + + def _assert_matches_subclass_weights( + self, + nn_linear: torch.nn.Linear, + subclass_linear: AffineFakeQuantizedTensor, + ): + torch.testing.assert_close(nn_linear.weight, subclass_linear.weight.original_tensor, atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ @@ -279,9 +286,9 @@ def assert_fake_quant_enabled(m: torch.nn.Linear, enabled: bool): assert_fake_quant_enabled(qat_model.sub.linear, enabled=False) # Disabled fake quant is just a normal linear - m2.linear1.weight = torch.nn.Parameter(qat_model.linear1.weight.original_tensor) - m2.linear2.weight = torch.nn.Parameter(qat_model.linear2.weight.original_tensor) - m2.sub.linear.weight = torch.nn.Parameter(qat_model.sub.linear.weight.original_tensor) + self._copy_subclass_weights(m2.linear1, qat_model.linear1) + self._copy_subclass_weights(m2.linear2, qat_model.linear2) + self._copy_subclass_weights(m2.sub.linear, qat_model.sub.linear) torch.manual_seed(self.SEED) x = m.example_inputs() x2 = copy.deepcopy(x) @@ -318,10 +325,6 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): disable_8da4w_fake_quant, ) - def get_qat_weight(qat_linear: torch.nn.Linear): - assert isinstance(qat_linear.weight, AffineFakeQuantizedTensor) - return qat_linear.weight.original_tensor - group_size = 16 torch.manual_seed(self.SEED) m = M() @@ -329,9 +332,9 @@ def get_qat_weight(qat_linear: torch.nn.Linear): quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) qat_model = quantizer.prepare(m) qat_model.apply(disable_8da4w_fake_quant) - nn_model.linear1.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear1)) - nn_model.linear2.weight = torch.nn.Parameter(get_qat_weight(qat_model.linear2)) - nn_model.sub.linear.weight = torch.nn.Parameter(get_qat_weight(qat_model.sub.linear)) + self._copy_subclass_weights(nn_model.linear1, qat_model.linear1) + self._copy_subclass_weights(nn_model.linear2, qat_model.linear2) + self._copy_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) # Simulate training for both models optimizer1 = torch.optim.SGD(nn_model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5) @@ -353,9 +356,9 @@ def get_qat_weight(qat_linear: torch.nn.Linear): optimizer2.step() # After 1 training step, weights should match exactly - torch.testing.assert_close(nn_model.linear1.weight, get_qat_weight(qat_model.linear1), atol=0, rtol=0) - torch.testing.assert_close(nn_model.linear2.weight, get_qat_weight(qat_model.linear2), atol=0, rtol=0) - torch.testing.assert_close(nn_model.sub.linear.weight, get_qat_weight(qat_model.sub.linear), atol=0, rtol=0) + self._assert_matches_subclass_weights(nn_model.linear1, qat_model.linear1) + self._assert_matches_subclass_weights(nn_model.linear2, qat_model.linear2) + self._assert_matches_subclass_weights(nn_model.sub.linear, qat_model.sub.linear) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_generic_fake_quantize(self): @@ -396,6 +399,7 @@ def _assert_close_4w(self, val, ref): print(mean_err) self.assertTrue(mean_err < 0.05) + # TODO: compare against quantize_ API instead @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") def test_qat_4w_quantizer(self): @@ -429,14 +433,6 @@ def test_qat_4w_quantizer(self): converted_out = converted_model(*x) torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) - # TODO: enable this after supporting aten.eq.default in both subclasses - # Compare converted state dict - # ptq_state_dict = ptq_model.state_dict() - # converted_state_dict = converted_model.state_dict() - # self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) - # for k in ptq_state_dict.keys(): - # torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) - if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py index 0d7b7e4fa1..f15cbf5c59 100644 --- a/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py +++ b/torchao/quantization/prototype/qat/affine_fake_quantized_tensor.py @@ -14,7 +14,10 @@ _dispatch__torch_function__, _dispatch__torch_dispatch__, ) -from .utils import _GenericFakeQuantize +from .utils import ( + _GenericFakeQuantize, + _UnwrapAffineFakeQuantizedTensor, +) aten = torch.ops.aten @@ -129,7 +132,7 @@ def get_value(self) -> torch.Tensor: if self.fake_quant_enabled: return self.apply_fake_quant_fn(self) else: - return self.original_tensor + return _UnwrapAffineFakeQuantizedTensor.apply(self) def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) diff --git a/torchao/quantization/prototype/qat/api.py b/torchao/quantization/prototype/qat/api.py index 99b0ddbffd..d9414b610c 100644 --- a/torchao/quantization/prototype/qat/api.py +++ b/torchao/quantization/prototype/qat/api.py @@ -145,6 +145,7 @@ def convert( quantize_(model, quantize_fn) return model +# TODO: deprecate class Int8DynActInt4WeightQATLinear(torch.nn.Linear): """ This module implements a linear layer with int8 dynamic per token fake @@ -322,7 +323,7 @@ def convert( quantize_(model, quantize_fn) return model - +# TODO: deprecate class Int4WeightOnlyQATLinear(torch.nn.Linear): """ This module implements a linear layer with int4 fake quantized grouped diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index f58b4d69c7..90f912a07e 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -73,6 +73,31 @@ def backward(ctx, gy): (mask,) = ctx.saved_tensors return gy * mask, None, None, None, None, None, None + +class _UnwrapAffineFakeQuantizedTensor(torch.autograd.Function): + """ + Helper autograd function to unwrap `AffineFakeQuantizedTensor` while ensuring + gradients are still passed to the tensor subclass. This is used in place of + `_GenericFakeQuantize` when fake quant is disabled. + """ + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + input: torch.Tensor, + ) -> torch.Tensor: + # avoid circular dependencies + from torchao.quantization.prototype.qat.affine_fake_quantized_tensor import ( + AffineFakeQuantizedTensor, + ) + assert isinstance(input, AffineFakeQuantizedTensor) + return input.original_tensor + + @staticmethod + def backward(ctx, gy): + return gy, + + def _fake_quantize_per_channel_group( input: torch.Tensor, scales: torch.Tensor, @@ -206,8 +231,9 @@ def _enable_fake_quant(mod: torch.nn.Module, enable: bool): if hasattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK): (prehook, handle) = getattr(mod, _QAT_LINEAR_SUBCLASS_INPUT_PREHOOK) if enable: - handle = mod.register_forward_pre_hook(prehook) + if handle is None: + handle = mod.register_forward_pre_hook(prehook) + _forward_pre_hook_handler(mod, prehook, handle) else: handle.remove() - handle = None - _forward_pre_hook_handler(mod, prehook, handle) + _forward_pre_hook_handler(mod, prehook, None) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index deac2b246e..5de7083c5c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -267,7 +267,7 @@ def _get_linear_subclass_inserter( requires_grad: bool = False, ) -> Callable: """ - Return a functinon that inserts wraps the weight and/or input activation of a + Return a function that inserts wraps the weight and/or input activation of a linear module in tensor subclasses. Args: