Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SAFE_WEIGHTS_NAME,
TF2_WEIGHTS_INDEX_NAME,
TF2_WEIGHTS_NAME,
TF_WEIGHTS_NAME,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
ModelOutput,
Expand Down Expand Up @@ -2392,7 +2393,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
save directory.
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
configuration JSON file named *config.json* is found in the directory.
from_pt: (`bool`, *optional*, defaults to `False`):
from_pt (`bool`, *optional*, defaults to `False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
`pretrained_model_name_or_path` argument).
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -2531,7 +2532,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if is_local:
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint in priority if from_pt
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
Expand Down Expand Up @@ -2559,7 +2560,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)
is_sharded = True
# At this stage we don't have a weight file so we will raise an error.
elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME):
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile(
os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
):
raise EnvironmentError(
f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} "
"but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those "
Expand Down Expand Up @@ -2630,6 +2633,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None and filename == WEIGHTS_NAME:
# Maybe the checkpoint is sharded, we try to grab the index name in this case.
resolved_archive_file = cached_file(
pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error
# message.
Expand All @@ -2646,8 +2656,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)
else:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME},"
f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
)

except EnvironmentError:
Expand All @@ -2661,7 +2671,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
" from 'https://huggingface.co/models', make sure you don't have a local directory with the"
f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}."
f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}"
)
if is_local:
logger.info(f"loading weights file {archive_file}")
Expand Down
8 changes: 8 additions & 0 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,6 +2127,14 @@ def test_checkpoint_sharding_local_from_pt(self):
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())

@is_pt_tf_cross_test
def test_checkpoint_sharding_hub_from_pt(self):
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
# the model above is the same as the model below, just a sharded pytorch version.
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())

def test_shard_checkpoint(self):
# This is the model we will use, total size 340,000 bytes.
model = tf.keras.Sequential(
Expand Down