From ce78e79c0aacdd68ffa1cfd69590d2c3f8a21b8e Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Mon, 6 May 2024 12:32:31 -0400 Subject: [PATCH] Copy weights and preserve device for 8da4w QAT linear (#211) * Copy weights and preserve device for 8da4w QAT linear Summary: This fixes two correctness bugs. First, we never copied over the weights from the existing linear, so we would start from random weights even when loading from checkpoints. Second, we never preserved the device of the original linear. This is important for settings like FSDP, where we expect non-zero ranks to have their parameters on the meta device in order to initialize these parameters correctly. Test Plan: python test/quantization/test_qat.py -k test_qat_8da4w_quantizer python test/quantization/test_qat.py -k test_qat_8da4w_quantizer_meta_weights Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar * Update test_qat.py --- test/quantization/test_qat.py | 23 ++++++++++++----------- torchao/quantization/GPTQ.py | 27 ++++++++++++++++----------- torchao/quantization/prototype/qat.py | 4 +++- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 3a12d9b636..a0587d3ff0 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -169,17 +169,6 @@ def test_qat_8da4w_quantizer(self): qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) - # Force the weights to be the same - self._set_ptq_weight( - ptq_model.linear1, qat_model.linear1.weight, group_size, - ) - self._set_ptq_weight( - ptq_model.sub.linear, qat_model.sub.linear.weight, group_size, - ) - self._set_ptq_weight( - ptq_model.linear2, qat_model.linear2.weight, group_size, - ) - # Compare model values torch.manual_seed(self.SEED) x = m.example_inputs() @@ -200,6 +189,18 @@ def test_qat_8da4w_quantizer(self): for k in ptq_state_dict.keys(): torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") + def test_qat_8da4w_quantizer_meta_weights(self): + from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer + + with torch.device("meta"): + m = M() + self.assertTrue(all(v.is_meta for v in m.state_dict().values())) + group_size = 16 + qat_quantizer = Int8DynActInt4WeightQATQuantizer(groupsize=group_size) + qat_model = qat_quantizer.prepare(m) + self.assertTrue(all(v.is_meta for v in qat_model.state_dict().values())) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "skipping when torch verion is 2.3 or lower") def test_qat_8da4w_quantizer_disable_fake_quant(self): """ diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 3e007ec75b..e7176b4fd2 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -1127,22 +1127,26 @@ def _replace_linear_8da4w( precision: torch.dtype, scales_precision: torch.dtype, linear_class: Type[torch.nn.Module], + copy_weights: bool = False, ): for name, child in module.named_children(): if isinstance(child, nn.Linear): if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed: - setattr( - module, - name, - linear_class( - child.in_features, - child.out_features, - bias=False, - groupsize=groupsize, - precision=precision, - scales_precision=scales_precision, - ), + new_linear = linear_class( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=groupsize, + precision=precision, + scales_precision=scales_precision, ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + setattr(module, name, new_linear) else: _replace_linear_8da4w( child, @@ -1151,6 +1155,7 @@ def _replace_linear_8da4w( precision, scales_precision, linear_class, + copy_weights, ) def replace_linear_8da4w( diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 7901fa8b5b..d15e841d74 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -54,6 +54,7 @@ def prepare( self.precision, self.scales_precision, Int8DynActInt4WeightQATLinear, + copy_weights = True, ) return model @@ -111,6 +112,7 @@ def __init__( in_features: int, out_features: int, bias: bool = False, + device: torch.device = None, groupsize: int = 256, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, @@ -119,7 +121,7 @@ def __init__( in_features, out_features, bias, - device=None, + device=device, dtype=precision, ) assert (