diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 026a20667985..2f04df97e863 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1309,7 +1309,8 @@ def get_apply_tensor_subclass(self): return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) def __repr__(self): - return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})" + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" @dataclass diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index c7c701e49aec..c3ab06ee61ba 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -74,6 +74,13 @@ def test_post_init_check(self): with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"): _ = TorchAoConfig("int4_weight_only", group_size1=32) + def test_repr(self): + """ + Check that there is no error in the repr + """ + quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) + repr(quantization_config) + @require_torch_gpu @require_torchao