Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1538,6 +1538,14 @@ def save_pretrained(
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
# Checks if the model has been loaded in 8-bit
if getattr(self, "is_loaded_in_8bit", False):
warnings.warn(
"You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected"
" behaviors. ",
UserWarning,
)

if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
Expand Down Expand Up @@ -2340,6 +2348,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit=load_in_8bit,
)

cls.is_loaded_in_8bit = load_in_8bit

# make sure token embedding weights are still tied if needed
model.tie_weights()

Expand Down
8 changes: 8 additions & 0 deletions tests/mixed_int8/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest

from transformers import (
Expand Down Expand Up @@ -107,6 +108,13 @@ def test_generate_quality(self):

self.assertEqual(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)

def test_warns_save_pretrained(self):
r"""
Test whether trying to save a model after converting it in 8-bit will throw a warning.
"""
with self.assertWarns(UserWarning), tempfile.TemporaryDirectory() as tmpdirname:
self.model_8bit.save_pretrained(tmpdirname)


class MixedInt8ModelClassesTest(BaseMixedInt8Test):
def setUp(self):
Expand Down