From 0a0d875711b5f50136e97eb8437a0e4162094b70 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:03:26 +0200 Subject: [PATCH 01/12] initial commit --- .../modeling_flax_pytorch_utils.py | 72 ++++++++++++++++--- src/transformers/modeling_flax_utils.py | 17 ++++- src/transformers/utils/hub.py | 5 ++ 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index a91d41b9d6d9..7120fad64103 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,7 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import logging +from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging +from .utils.hub import get_checkpoint_shard_files logger = logging.get_logger(__name__) @@ -38,7 +39,9 @@ ##################### -def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False): +def load_pytorch_checkpoint_in_flax_state_dict( + flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False +): """Load pytorch checkpoints in a flax model""" try: import torch # noqa: F401 @@ -50,14 +53,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa ) raise - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") + if not is_sharded: + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info(f"Loading PyTorch weights from {pt_path}") - pt_state_dict = torch.load(pt_path, map_location="cpu") - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") - - flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + pt_state_dict = torch.load(pt_path, map_location="cpu") + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + else: + # model is sharded and pytorch_checkpoint_path already contains the list of shard files + flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) return flax_state_dict @@ -156,6 +162,56 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) +def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): + import torch + + # Load the index + flax_state_dict = {} + for shard_file in shard_filenames: + # load using msgpack utils + pt_state_dict = torch.load(shard_file) + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + model_prefix = flax_model.base_model_prefix + random_flax_state_dict = flatten_dict(flax_model.params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + + pt_tuple_key = tuple(pt_key.split(".")) + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + return unflatten_dict(flax_state_dict) + + ##################### # Flax => PyTorch # ##################### diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 77eaa900de62..2aa69691abef 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -42,6 +42,7 @@ FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, EntryNotFoundError, PushToHubMixin, @@ -58,7 +59,7 @@ logging, replace_return_docstrings, ) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files, hf_url_exists logger = logging.get_logger(__name__) @@ -639,6 +640,10 @@ def from_pretrained( if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) @@ -660,6 +665,14 @@ def from_pretrained( ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + # check if an index file exists + elif hf_url_exists(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME): + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + is_sharded = True else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( @@ -780,7 +793,7 @@ def from_pretrained( model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: - state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) else: if is_sharded: diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 4e46298e28a9..91d7ea6f65b9 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -118,6 +118,11 @@ def is_remote_url(url_or_filename): return parsed.scheme in ("http", "https") +def hf_url_exists(path, file): + r = requests.head(os.path.join("https://huggingface.co/", path, "raw/main/", file)) + return r.status_code == requests.codes.ok + + def hf_bucket_url( model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None ) -> str: From 4dcc32ec0e0b9a20dcc72582225bf6b9f0426eb1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:13:10 +0200 Subject: [PATCH 02/12] add small test --- tests/test_modeling_flax_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f90615efea36..b4506c4438a9 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,12 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + def test_from_sharded_pt(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): + assert np.allclose(np.array(p1), np.array(p2)) + def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 556fa9a3921cc0a4d3e14072b63f75bc94a60393 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:38:47 +0200 Subject: [PATCH 03/12] add cross pt tf flag to test --- tests/test_modeling_flax_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b4506c4438a9..f33e0fb5d9e7 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,7 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") From 2e8b241c5eb3eaea711de1e1084164fc8974b71d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:12 +0200 Subject: [PATCH 04/12] fix quality --- src/transformers/modeling_flax_pytorch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 7120fad64103..dadc98fc9cda 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,8 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging -from .utils.hub import get_checkpoint_shard_files +from .utils import logging + logger = logging.get_logger(__name__) From cfda646e93059b0f791a9835009475f9ed7956c3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:47 +0200 Subject: [PATCH 05/12] style --- src/transformers/modeling_flax_pytorch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index dadc98fc9cda..249f754db726 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -30,7 +30,6 @@ from .utils import logging - logger = logging.get_logger(__name__) From 4e49a48d225e5d0010c9a074cdbc1e93fffb93da Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 22:31:44 +0200 Subject: [PATCH 06/12] update test with new repo --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f33e0fb5d9e7..8e9b27c51530 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 74a451b8b333989c797198d72a98f098ccdd7e47 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:44:44 +0200 Subject: [PATCH 07/12] fix failing test --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 8e9b27c51530..3f0a8c6c8079 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 708bcb6f4de4eb1af405f08cf043c12e1cb9788d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:45:31 +0200 Subject: [PATCH 08/12] update --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 3f0a8c6c8079..5bc7d97d46bc 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From c49bb18d588e21c4e18e68030396e3e0cb384f78 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:15 +0200 Subject: [PATCH 09/12] fix wrong param ordering --- tests/test_modeling_flax_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 5e9c061bc5b7..49505856954c 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,8 +1103,9 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): - assert np.allclose(np.array(p1), np.array(p2)) + for key,ref_val in flatten_dict(ref_model.params).items(): + val = flatten_dict(model.params)[key] + assert np.allclose(np.array(val), np.array(ref_val)) def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From e20f971f6c83a770dea9ff17ec597587891e8690 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:37 +0200 Subject: [PATCH 10/12] style --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 49505856954c..837f874889ae 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,7 +1103,7 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for key,ref_val in flatten_dict(ref_model.params).items(): + for key, ref_val in flatten_dict(ref_model.params).items(): val = flatten_dict(model.params)[key] assert np.allclose(np.array(val), np.array(ref_val)) From 49b2787225a28a7437f8c167151ad4cd15357da2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 20 Oct 2022 09:15:12 +0000 Subject: [PATCH 11/12] clean --- src/transformers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 4929a90cf877..c53539934e0f 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -52,7 +52,7 @@ logging, replace_return_docstrings, ) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files, hf_url_exists +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files logger = logging.get_logger(__name__) From 8baabeb518c50c0d92c140c121839ca9e551d619 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 11 Aug 2022 14:17:12 +0200 Subject: [PATCH 12/12] add pt_fo_flax cmd --- src/transformers/commands/pt_to_flax.py | 364 ++++++++++++++++++++++++ 1 file changed, 364 insertions(+) create mode 100644 src/transformers/commands/pt_to_flax.py diff --git a/src/transformers/commands/pt_to_flax.py b/src/transformers/commands/pt_to_flax.py new file mode 100644 index 000000000000..c6d8792d87d1 --- /dev/null +++ b/src/transformers/commands/pt_to_flax.py @@ -0,0 +1,364 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +from argparse import ArgumentParser, Namespace +from glob import glob +from importlib import import_module + +import numpy as np +from datasets import load_dataset +from packaging import version + +import huggingface_hub + +from .. import ( + FEATURE_EXTRACTOR_MAPPING, + PROCESSOR_MAPPING, + TOKENIZER_MAPPING, + AutoConfig, + AutoFeatureExtractor, + AutoProcessor, + AutoTokenizer, + is_flax_available, + is_torch_available, +) +from ..utils import FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, logging +from . import BaseTransformersCLICommand + + +if is_flax_available(): + import jax.numpy as jnp +if is_torch_available(): + import torch + + +MAX_ERROR = 5e-5 # larger error tolerance than in our internal tests, to avoid flaky user-facing errors + + +def convert_command_factory(args: Namespace): + """ + Factory function used to convert a model PyTorch checkpoint in a Flax checkpoint. + + Returns: ServeCommand + """ + return PTtoFXCommand( + args.model_name, + args.local_dir, + args.max_hidden_error, + args.new_weights, + args.no_pr, + args.push, + args.extra_commit_description, + ) + + +class PTtoFXCommand(BaseTransformersCLICommand): + @staticmethod + def register_subcommand(parser: ArgumentParser): + """ + Register this command to argparse so it's available for the transformer-cli + + Args: + parser: Root parser to register command-specific arguments + """ + train_parser = parser.add_parser( + "pt-to-fx", + help=( + "CLI tool to run convert a transformers model from a PyTorch checkpoint to a Flax checkpoint." + " Can also be used to validate existing weights without opening PRs, with --no-pr." + ), + ) + train_parser.add_argument( + "--model-name", + type=str, + required=True, + help="The model name, including owner/organization, as seen on the hub.", + ) + train_parser.add_argument( + "--local-dir", + type=str, + default="", + help="Optional local directory of the model repository. Defaults to /tmp/{model_name}", + ) + train_parser.add_argument( + "--max-hidden-error", + type=float, + default=MAX_ERROR, + help=( + f"Maximum error tolerance for hidden layer outputs. Defaults to {MAX_ERROR}. If you suspect the hidden" + " layers outputs will be used for downstream applications, avoid increasing this tolerance." + ), + ) + train_parser.add_argument( + "--new-weights", + action="store_true", + help="Optional flag to create new TensorFlow weights, even if they already exist.", + ) + train_parser.add_argument( + "--no-pr", action="store_true", help="Optional flag to NOT open a PR with converted weights." + ) + train_parser.add_argument( + "--push", + action="store_true", + help="Optional flag to push the weights directly to `main` (requires permissions)", + ) + train_parser.add_argument( + "--extra-commit-description", + type=str, + default="", + help="Optional additional commit description to use when opening a PR (e.g. to tag the owner).", + ) + train_parser.set_defaults(func=convert_command_factory) + + @staticmethod + def find_pt_fx_differences(pt_outputs, fx_outputs): + """ + Compares the Flax and PyTorch outputs, returning a dictionary with all tensor differences. + """ + # 1. All output attributes must be the same + pt_out_attrs = set(pt_outputs.keys()) + fx_out_attrs = set(fx_outputs.keys()) + if pt_out_attrs != fx_out_attrs: + raise ValueError( + f"The model outputs have different attributes, aborting. (Pytorch: {pt_out_attrs}, Flax:" + f" {fx_out_attrs})" + ) + + # 2. For each output attribute, computes the difference + def _find_pt_fx_differences(pt_out, fx_out, differences, attr_name=""): + + # If the current attribute is a tensor, it is a leaf and we make the comparison. Otherwise, we will dig in + # recursivelly, keeping the name of the attribute. + if isinstance(pt_out, torch.Tensor): + tensor_difference = np.max(np.abs(pt_out.numpy() - fx_out.numpy())) + differences[attr_name] = tensor_difference + else: + root_name = attr_name + for i, pt_item in enumerate(pt_out): + # If it is a named attribute, we keep the name. Otherwise, just its index. + if isinstance(pt_item, str): + branch_name = root_name + pt_item + fx_item = fx_out[pt_item] + pt_item = pt_out[pt_item] + else: + branch_name = root_name + f"[{i}]" + fx_item = fx_out[i] + differences = _find_pt_fx_differences(pt_item, fx_item, differences, branch_name) + + return differences + + return _find_pt_fx_differences(pt_outputs, fx_outputs, {}) + + def __init__( + self, + model_name: str, + local_dir: str, + max_hidden_error: float, + new_weights: bool, + no_pr: bool, + push: bool, + extra_commit_description: str, + *args + ): + self._logger = logging.get_logger("transformers-cli/pt_to_fx") + self._model_name = model_name + self._local_dir = local_dir if local_dir else os.path.join("/tmp", model_name) + self._max_hidden_error = max_hidden_error + self._new_weights = new_weights + self._no_pr = no_pr + self._push = push + self._extra_commit_description = extra_commit_description + + def get_inputs(self, pt_model, config): + """ + Returns the right inputs for the model, based on its signature. + """ + + def _get_audio_input(): + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + speech_samples = ds.sort("id").select(range(2))[:2]["audio"] + raw_samples = [x["array"] for x in speech_samples] + return raw_samples + + model_forward_signature = set(inspect.signature(pt_model.forward).parameters.keys()) + processor_inputs = {} + if "input_ids" in model_forward_signature: + processor_inputs.update( + { + "text": ["Hi there!", "I am a batch with more than one row and different input lengths."], + "padding": True, + "truncation": True, + } + ) + if "pixel_values" in model_forward_signature: + sample_images = load_dataset("cifar10", "plain_text", split="test")[:2]["img"] + processor_inputs.update({"images": sample_images}) + if "input_features" in model_forward_signature: + processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) + if "input_values" in model_forward_signature: # Wav2Vec2 audio input + processor_inputs.update({"raw_speech": _get_audio_input(), "padding": True}) + + model_config_class = type(pt_model.config) + if model_config_class in PROCESSOR_MAPPING: + processor = AutoProcessor.from_pretrained(self._local_dir) + if model_config_class in TOKENIZER_MAPPING and processor.tokenizer.pad_token is None: + processor.tokenizer.pad_token = processor.tokenizer.eos_token + elif model_config_class in FEATURE_EXTRACTOR_MAPPING: + processor = AutoFeatureExtractor.from_pretrained(self._local_dir) + elif model_config_class in TOKENIZER_MAPPING: + processor = AutoTokenizer.from_pretrained(self._local_dir) + if processor.pad_token is None: + processor.pad_token = processor.eos_token + else: + raise ValueError(f"Unknown data processing type (model config type: {model_config_class})") + + pt_input = processor(**processor_inputs, return_tensors="pt") + fx_input = processor(**processor_inputs, return_tensors="np") + + # Extra input requirements, in addition to the input modality + if config.is_encoder_decoder or (hasattr(pt_model, "encoder") and hasattr(pt_model, "decoder")): + decoder_input_ids = np.asarray([[1], [1]], dtype=int) * (pt_model.config.decoder_start_token_id or 0) + pt_input.update({"decoder_input_ids": torch.tensor(decoder_input_ids)}) + fx_input.update({"decoder_input_ids": jnp.array(decoder_input_ids)}) + + return pt_input, fx_input + + def run(self): + if version.parse(huggingface_hub.__version__) < version.parse("0.8.1"): + raise ImportError( + "The huggingface_hub version must be >= 0.8.1 to use this command. Please update your huggingface_hub" + " installation." + ) + else: + from huggingface_hub import Repository, create_commit + from huggingface_hub._commit_api import CommitOperationAdd + + # Fetch remote data + repo = Repository(local_dir=self._local_dir, clone_from=self._model_name) + + # Load config and get the appropriate architecture -- the latter is needed to convert the head's weights + config = AutoConfig.from_pretrained(self._local_dir) + architectures = config.architectures + if architectures is None: # No architecture defined -- use auto classes + pt_class = getattr(import_module("transformers"), "AutoModel") + fx_class = getattr(import_module("transformers"), "FlaxAutoModel") + self._logger.warn("No detected architecture, using AutoModel/FlaxAutoModel") + else: # Architecture defined -- use it + if len(architectures) > 1: + raise ValueError(f"More than one architecture was found, aborting. (architectures = {architectures})") + self._logger.warn(f"Detected architecture: {architectures[0]}") + pt_class = getattr(import_module("transformers"), architectures[0]) + try: + fx_class = getattr(import_module("transformers"), "Flax" + architectures[0]) + except AttributeError: + raise AttributeError(f"The Flax equivalent of {architectures[0]} doesn't exist in transformers.") + + # Load models and acquire a basic input compatible with the model. + pt_model = pt_class.from_pretrained(self._local_dir) + pt_model.eval() + + fx_from_pt_model = fx_class.from_pretrained( + self._local_dir, from_pt=True + ) # now also works for sharded checkpoints + pt_input, fx_input = self.get_inputs(pt_model, config) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_input, output_hidden_states=True) + del pt_model # will no longer be used, and may have a large memory footprint + + fx_from_pt_model = fx_class.from_pretrained(self._local_dir, from_pt=True) + fx_from_pt_outputs = fx_from_pt_model(**fx_input, output_hidden_states=True) + + # Confirms that cross loading PT weights into FX worked. + crossload_differences = self.find_pt_fx_differences(pt_outputs, fx_from_pt_outputs) + output_differences = {k: v for k, v in crossload_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in crossload_differences.items() if "hidden" in k} + max_crossload_output_diff = max(output_differences.values()) + max_crossload_hidden_diff = max(hidden_differences.values()) + if max_crossload_output_diff > MAX_ERROR or max_crossload_hidden_diff > self._max_hidden_error: + raise ValueError( + "The cross-loaded TensorFlow model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) + ) + + # Save the weights in a FX format (if needed) and confirms that the results are still good + fx_weights_path = os.path.join(self._local_dir, FLAX_WEIGHTS_NAME) + fx_weights_index_path = os.path.join(self._local_dir, FLAX_WEIGHTS_INDEX_NAME) + if (not os.path.exists(fx_weights_path) and not os.path.exists(fx_weights_index_path)) or self._new_weights: + fx_from_pt_model.save_pretrained(self._local_dir) + del fx_from_pt_model # will no longer be used, and may have a large memory footprint + + fx_model = fx_class.from_pretrained(self._local_dir) + fx_outputs = fx_model(**fx_input, output_hidden_states=True) + + conversion_differences = self.find_pt_fx_differences(pt_outputs, fx_outputs) + output_differences = {k: v for k, v in conversion_differences.items() if "hidden" not in k} + hidden_differences = {k: v for k, v in conversion_differences.items() if "hidden" in k} + max_conversion_output_diff = max(output_differences.values()) + max_conversion_hidden_diff = max(hidden_differences.values()) + if max_conversion_output_diff > MAX_ERROR or max_conversion_hidden_diff > self._max_hidden_error: + raise ValueError( + "The converted Flax model has different outputs, something went wrong!\n" + + f"\nList of maximum output differences above the threshold ({MAX_ERROR}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in output_differences.items() if v > MAX_ERROR]) + + f"\n\nList of maximum hidden layer differences above the threshold ({self._max_hidden_error}):\n" + + "\n".join([f"{k}: {v:.3e}" for k, v in hidden_differences.items() if v > self._max_hidden_error]) + ) + + commit_message = "Update FX weights" if self._new_weights else "Add FX weights" + if self._push: + repo.git_add(auto_lfs_track=True) + repo.git_commit(commit_message) + repo.git_push(blocking=True) # this prints a progress bar with the upload + self._logger.warn(f"FX weights pushed into {self._model_name}") + elif not self._no_pr: + self._logger.warn("Uploading the weights into a new PR...") + commit_descrition = ( + "Model converted by the [`transformers`' `pt_to_fx`" + " CLI](https://github.com/huggingface/transformers/blob/main/src/transformers/commands/pt_to_fx.py). " + "All converted model outputs and hidden layers were validated against its Pytorch counterpart.\n\n" + f"Maximum crossload output difference={max_crossload_output_diff:.3e}; " + f"Maximum crossload hidden layer difference={max_crossload_hidden_diff:.3e};\n" + f"Maximum conversion output difference={max_conversion_output_diff:.3e}; " + f"Maximum conversion hidden layer difference={max_conversion_hidden_diff:.3e};\n" + ) + if self._extra_commit_description: + commit_descrition += "\n\n" + self._extra_commit_description + + # sharded model -> adds all related files (index and .h5 shards) + if os.path.exists(fx_weights_index_path): + operations = [ + CommitOperationAdd(path_in_repo=FLAX_WEIGHTS_INDEX_NAME, path_or_fileobj=fx_weights_index_path) + ] + for shard_path in glob.glob(self._local_dir + "/flax_model-*.msgpack"): + operations += [ + CommitOperationAdd(path_in_repo=os.path.basename(shard_path), path_or_fileobj=shard_path) + ] + else: + operations = [CommitOperationAdd(path_in_repo=FLAX_WEIGHTS_NAME, path_or_fileobj=fx_weights_path)] + + hub_pr_url = create_commit( + repo_id=self._model_name, + operations=operations, + commit_message=commit_message, + commit_description=commit_descrition, + repo_type="model", + create_pr=True, + ) + self._logger.warn(f"PR open in {hub_pr_url}")