diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aa5592c7d9d9..57294682a8b9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1538,6 +1538,14 @@ def save_pretrained( kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ + # Checks if the model has been loaded in 8-bit + if getattr(self, "is_loaded_in_8bit", False): + warnings.warn( + "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" + " behaviors. ", + UserWarning, + ) + if "save_config" in kwargs: warnings.warn( "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." @@ -2340,6 +2348,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P load_in_8bit=load_in_8bit, ) + cls.is_loaded_in_8bit = load_in_8bit + # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/mixed_int8/test_mixed_int8.py index 2911d6774880..a459ffa84d0e 100644 --- a/tests/mixed_int8/test_mixed_int8.py +++ b/tests/mixed_int8/test_mixed_int8.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import gc +import tempfile import unittest from transformers import ( @@ -107,6 +108,13 @@ def test_generate_quality(self): self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) + def test_warns_save_pretrained(self): + r""" + Test whether trying to save a model after converting it in 8-bit will throw a warning. + """ + with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname: + self.model_8bit.save_pretrained(tmpdirname) + class MixedInt8ModelClassesTest(BaseMixedInt8Test): def setUp(self):