diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index fa6df4e40acf..a642e883b0be 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -47,6 +47,7 @@ SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, ModelOutput, @@ -2392,7 +2393,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): save directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. - from_pt: (`bool`, *optional*, defaults to `False`): + from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch state_dict save file (see docstring of `pretrained_model_name_or_path` argument). ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): @@ -2531,7 +2532,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if pretrained_model_name_or_path is not None: pretrained_model_name_or_path = str(pretrained_model_name_or_path) is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isdir(pretrained_model_name_or_path): + if is_local: if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint in priority if from_pt archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) @@ -2559,7 +2560,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) is_sharded = True # At this stage we don't have a weight file so we will raise an error. - elif os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME): + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)) or os.path.isfile( + os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + ): raise EnvironmentError( f"Error no file named {TF2_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} " "but there is a file for PyTorch weights. Use `from_pt=True` to load this model from those " @@ -2630,6 +2633,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) if resolved_archive_file is not None: is_sharded = True + if resolved_archive_file is None and filename == WEIGHTS_NAME: + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, WEIGHTS_INDEX_NAME, **cached_file_kwargs + ) + if resolved_archive_file is not None: + is_sharded = True if resolved_archive_file is None: # Otherwise, maybe there is a PyTorch or Flax model file. We try those to give a helpful error # message. @@ -2646,8 +2656,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): ) else: raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." + f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}," + f" {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" ) except EnvironmentError: @@ -2661,7 +2671,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" " from 'https://huggingface.co/models', make sure you don't have a local directory with the" f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {TF2_WEIGHTS_NAME} or {WEIGHTS_NAME}." + f" directory containing a file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME} or {TF_WEIGHTS_NAME}" ) if is_local: logger.info(f"loading weights file {archive_file}") diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 679776dd8c90..6908c73c99ec 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -2127,6 +2127,14 @@ def test_checkpoint_sharding_local_from_pt(self): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + @is_pt_tf_cross_test + def test_checkpoint_sharding_hub_from_pt(self): + model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + # the model above is the same as the model below, just a sharded pytorch version. + ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + for p1, p2 in zip(model.weights, ref_model.weights): + assert np.allclose(p1.numpy(), p2.numpy()) + def test_shard_checkpoint(self): # This is the model we will use, total size 340,000 bytes. model = tf.keras.Sequential(