diff --git a/tests/transformers/tests/test_modeling_common.py b/tests/transformers/tests/test_modeling_common.py index d33cf1e58d..c2a818f257 100755 --- a/tests/transformers/tests/test_modeling_common.py +++ b/tests/transformers/tests/test_modeling_common.py @@ -83,6 +83,7 @@ if is_torch_available(): import torch + from safetensors.torch import save_file as safe_save_file from torch import nn from transformers import MODEL_MAPPING, AdaptiveEmbedding from transformers.pytorch_utils import id_tensor_storage @@ -408,7 +409,7 @@ class CopyClass(base_class): # check that certain keys didn't get saved with the model with tempfile.TemporaryDirectory() as tmpdirname: - model.config.save_pretrained(tmpdirname) + model.save_pretrained(tmpdirname) torch.save(state_dict, os.path.join(tmpdirname, "pytorch_model.bin")) model_fast_init = base_class_copy.from_pretrained(tmpdirname) @@ -1661,8 +1662,8 @@ def test_model_weights_reload_no_missing_tied_weights(self): # We are nuking ALL weights on file, so every parameter should # yell on load. We're going to detect if we yell too much, or too little. - with open(os.path.join(tmp_dir, "pytorch_model.bin"), "wb") as f: - torch.save({}, f) + placeholder_dict = {"tensor": torch.tensor([1, 2])} + safe_save_file(placeholder_dict, os.path.join(tmp_dir, "model.safetensors"), metadata={"format": "pt"}) model_reloaded, infos = model_class.from_pretrained(tmp_dir, output_loading_info=True) prefix = f"{model_reloaded.base_model_prefix}."