diff --git a/src/transformers/commands/pt_to_tf.py b/src/transformers/commands/pt_to_tf.py index 480e07ad5f09..3c4857139eef 100644 --- a/src/transformers/commands/pt_to_tf.py +++ b/src/transformers/commands/pt_to_tf.py @@ -34,7 +34,7 @@ is_tf_available, is_torch_available, ) -from ..utils import logging +from ..utils import TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging from . import BaseTransformersCLICommand @@ -48,7 +48,6 @@ MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors -TF_WEIGHTS_NAME = "tf_model.h5" def convert_command_factory(args: Namespace): @@ -58,7 +57,13 @@ def convert_command_factory(args: Namespace): Returns: ServeCommand """ return PTtoTFCommand( - args.model_name, args.local_dir, args.new_weights, args.no_pr, args.push, args.extra_commit_description + args.model_name, + args.local_dir, + args.max_hidden_error, + args.new_weights, + args.no_pr, + args.push, + args.extra_commit_description, ) @@ -90,6 +95,15 @@ def register_subcommand(parser: ArgumentParser): default="", help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", ) + train_parser.add_argument( + "--max-hidden-error", + type=float, + default=MAX_ERROR, + help=( + f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden" + " layers outputs will be used for downstream applications, avoid increasing this tolerance." + ), + ) train_parser.add_argument( "--new-weights", action="store_true", @@ -112,14 +126,10 @@ def register_subcommand(parser: ArgumentParser): train_parser.set_defaults(func=convert_command_factory) @staticmethod - def find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input): + def find_pt_tf_differences(pt_outputs, tf_outputs): """ - Compares the TensorFlow and PyTorch models, given their inputs, returning a dictionary with all tensor - differences. + Compares the TensorFlow and PyTorch outputs, returning a dictionary with all tensor differences. """ - pt_outputs = pt_model(**pt_input, output_hidden_states=True) - tf_outputs = tf_model(**tf_input, output_hidden_states=True) - # 1. All output attributes must be the same pt_out_attrs = set(pt_outputs.keys()) tf_out_attrs = set(tf_outputs.keys()) @@ -158,6 +168,7 @@ def __init__( self, model_name: str, local_dir: str, + max_hidden_error: float, new_weights: bool, no_pr: bool, push: bool, @@ -167,6 +178,7 @@ def __init__( self._logger = logging.get_logger("transformers-cli/pt_to_tf") self._model_name = model_name self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) + self._max_hidden_error = max_hidden_error self._new_weights = new_weights self._no_pr = no_pr self._push = push @@ -260,34 +272,49 @@ def run(self): pt_model = pt_class.from_pretrained(self._local_dir) tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) pt_input, tf_input = self.get_inputs(pt_model, config) + pt_outputs = pt_model(**pt_input, output_hidden_states=True) + del pt_model # will no longer be used, and may have a large memory footprint + + tf_from_pt_model = tf_class.from_pretrained(self._local_dir, from_pt=True) + tf_from_pt_outputs = tf_from_pt_model(**tf_input, output_hidden_states=True) # Confirms that cross loading PT weights into TF worked. - crossload_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_from_pt_model, tf_input) - max_crossload_diff = max(crossload_differences.values()) - if max_crossload_diff > MAX_ERROR: + crossload_differences = self.find_pt_tf_differences(pt_outputs, tf_from_pt_outputs) + output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k} + max_crossload_output_diff = max(output_differences.values()) + max_crossload_hidden_diff = max(hidden_differences.values()) + if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error: raise ValueError( - "The cross-loaded TensorFlow model has different outputs, something went wrong! Exaustive list of" - f" maximum tensor differences above the error threshold ({MAX_ERROR}):\n" - + "\n".join( - [f"{key}: {value:.3e}" for key, value in crossload_differences.items() if value > MAX_ERROR] - ) + "The cross-loaded TensorFlow model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) ) # Save the weights in a TF format (if needed) and confirms that the results are still good - tf_weights_path = os.path.join(self._local_dir, TF_WEIGHTS_NAME) - if not os.path.exists(tf_weights_path) or self._new_weights: - tf_from_pt_model.save_weights(tf_weights_path) + tf_weights_path = os.path.join(self._local_dir, TF2_WEIGHTS_NAME) + tf_weights_index_path = os.path.join(self._local_dir, TF2_WEIGHTS_INDEX_NAME) + if (not os.path.exists(tf_weights_path) and not os.path.exists(tf_weights_index_path)) or self._new_weights: + tf_from_pt_model.save_pretrained(self._local_dir) del tf_from_pt_model # will no longer be used, and may have a large memory footprint + tf_model = tf_class.from_pretrained(self._local_dir) - conversion_differences = self.find_pt_tf_differences(pt_model, pt_input, tf_model, tf_input) - max_conversion_diff = max(conversion_differences.values()) - if max_conversion_diff > MAX_ERROR: + tf_outputs = tf_model(**tf_input, output_hidden_states=True) + + conversion_differences = self.find_pt_tf_differences(pt_outputs, tf_outputs) + output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k} + max_conversion_output_diff = max(output_differences.values()) + max_conversion_hidden_diff = max(hidden_differences.values()) + if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error: raise ValueError( - "The converted TensorFlow model has different outputs, something went wrong! Exaustive list of maximum" - f" tensor differences above the error threshold ({MAX_ERROR}):\n" - + "\n".join( - [f"{key}: {value:.3e}" for key, value in conversion_differences.items() if value > MAX_ERROR] - ) + "The converted TensorFlow model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) ) commit_message = "Update TF weights" if self._new_weights else "Add TF weights" @@ -300,16 +327,31 @@ def run(self): self._logger.warn("Uploading the weights into a new PR...") commit_descrition = ( "Model converted by the [`transformers`' `pt_to_tf`" - " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py)." - "\n\nAll converted model outputs and hidden layers were validated against its Pytorch counterpart." - f" Maximum crossload output difference={max_crossload_diff:.3e}; Maximum converted output" - f" difference={max_conversion_diff:.3e}." + " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_tf.py). " + "All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n" + f"Maximum crossload output difference={max_crossload_output_diff:.3e}; " + f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n" + f"Maximum conversion output difference={max_conversion_output_diff:.3e}; " + f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n" ) if self._extra_commit_description: commit_descrition += "\n\n" + self._extra_commit_description + + # sharded model -> adds all related files (index and .h5 shards) + if os.path.exists(tf_weights_index_path): + operations = [ + CommitOperationAdd(path_in_repo=TF2_WEIGHTS_INDEX_NAME, path_or_fileobj=tf_weights_index_path) + ] + for shard_path in tf.io.gfile.glob(self._local_dir + "/tf_model-*.h5"): + operations += [ + CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path) + ] + else: + operations = [CommitOperationAdd(path_in_repo=TF2_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)] + hub_pr_url = create_commit( repo_id=self._model_name, - operations=[CommitOperationAdd(path_in_repo=TF_WEIGHTS_NAME, path_or_fileobj=tf_weights_path)], + operations=operations, commit_message=commit_message, commit_description=commit_descrition, repo_type="model", diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index b5f51a45dc18..73d6a7613fda 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -117,10 +117,17 @@ def load_pytorch_checkpoint_in_tf2_model( ) raise - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") + # Treats a single file as a collection of shards with 1 shard. + if isinstance(pytorch_checkpoint_path, str): + pytorch_checkpoint_path = [pytorch_checkpoint_path] + + # Loads all shards into a single state dictionary + pt_state_dict = {} + for path in pytorch_checkpoint_path: + pt_path = os.path.abspath(path) + logger.info(f"Loading PyTorch weights from {pt_path}") + pt_state_dict.update(torch.load(pt_path, map_location="cpu")) - pt_state_dict = torch.load(pt_path, map_location="cpu") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters") return load_pytorch_weights_in_tf2_model( diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index adfec47e8eb6..f678a4888177 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -50,6 +50,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, EntryNotFoundError, ModelOutput, @@ -2157,11 +2158,15 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint in priority if from_pt archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME)): - # Load from a sharded PyTorch checkpoint + # Load from a sharded TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_INDEX_NAME) is_sharded = True # At this stage we don't have a weight file so we will raise an error. diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 545dae8fbff8..27e9a7823ac4 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -27,7 +27,7 @@ from datasets import Dataset -from huggingface_hub import HfFolder, delete_repo, set_access_token +from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token from requests.exceptions import HTTPError from transformers import is_tf_available, is_torch_available from transformers.configuration_utils import PretrainedConfig @@ -1966,6 +1966,16 @@ def test_checkpoint_sharding_from_hub(self): for p1, p2 in zip(model.weights, ref_model.weights): assert np.allclose(p1.numpy(), p2.numpy()) + @is_pt_tf_cross_test + def test_checkpoint_sharding_local_from_pt(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded") + model = TFBertModel.from_pretrained(tmp_dir, from_pt=True) + # the model above is the same as the model below, just a sharded pytorch version. + ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert") + for p1, p2 in zip(model.weights, ref_model.weights): + assert np.allclose(p1.numpy(), p2.numpy()) + def test_shard_checkpoint(self): # This is the model we will use, total size 340,000 bytes. model = tf.keras.Sequential(