From 0a0d875711b5f50136e97eb8437a0e4162094b70 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:03:26 +0200 Subject: [PATCH 01/10] initial commit --- .../modeling_flax_pytorch_utils.py | 72 ++++++++++++++++--- src/transformers/modeling_flax_utils.py | 17 ++++- src/transformers/utils/hub.py | 5 ++ 3 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index a91d41b9d6d9..7120fad64103 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,7 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import logging +from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging +from .utils.hub import get_checkpoint_shard_files logger = logging.get_logger(__name__) @@ -38,7 +39,9 @@ ##################### -def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, allow_missing_keys=False): +def load_pytorch_checkpoint_in_flax_state_dict( + flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False +): """Load pytorch checkpoints in a flax model""" try: import torch # noqa: F401 @@ -50,14 +53,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_pa ) raise - pt_path = os.path.abspath(pytorch_checkpoint_path) - logger.info(f"Loading PyTorch weights from {pt_path}") + if not is_sharded: + pt_path = os.path.abspath(pytorch_checkpoint_path) + logger.info(f"Loading PyTorch weights from {pt_path}") - pt_state_dict = torch.load(pt_path, map_location="cpu") - logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") - - flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + pt_state_dict = torch.load(pt_path, map_location="cpu") + logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") + flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model) + else: + # model is sharded and pytorch_checkpoint_path already contains the list of shard files + flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model) return flax_state_dict @@ -156,6 +162,56 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): return unflatten_dict(flax_state_dict) +def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model): + import torch + + # Load the index + flax_state_dict = {} + for shard_file in shard_filenames: + # load using msgpack utils + pt_state_dict = torch.load(shard_file) + pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} + + model_prefix = flax_model.base_model_prefix + random_flax_state_dict = flatten_dict(flax_model.params) + + load_model_with_head_into_base_model = (model_prefix not in flax_model.params) and ( + model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + load_base_model_into_model_with_head = (model_prefix in flax_model.params) and ( + model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()]) + ) + # Need to change some parameters name to match Flax names + for pt_key, pt_tensor in pt_state_dict.items(): + + pt_tuple_key = tuple(pt_key.split(".")) + + # remove base model prefix if necessary + has_base_model_prefix = pt_tuple_key[0] == model_prefix + if load_model_with_head_into_base_model and has_base_model_prefix: + pt_tuple_key = pt_tuple_key[1:] + + # Correctly rename weight parameters + flax_key, flax_tensor = rename_key_and_reshape_tensor( + pt_tuple_key, pt_tensor, random_flax_state_dict, model_prefix + ) + # add model prefix if necessary + require_base_model_prefix = (model_prefix,) + flax_key in random_flax_state_dict + if load_base_model_into_model_with_head and require_base_model_prefix: + flax_key = (model_prefix,) + flax_key + + if flax_key in random_flax_state_dict: + if flax_tensor.shape != random_flax_state_dict[flax_key].shape: + raise ValueError( + f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape " + f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}." + ) + + # also add unexpected weight so that warning is thrown + flax_state_dict[flax_key] = jnp.asarray(flax_tensor) + return unflatten_dict(flax_state_dict) + + ##################### # Flax => PyTorch # ##################### diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 77eaa900de62..2aa69691abef 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -42,6 +42,7 @@ FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, EntryNotFoundError, PushToHubMixin, @@ -58,7 +59,7 @@ logging, replace_return_docstrings, ) -from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files, hf_url_exists logger = logging.get_logger(__name__) @@ -639,6 +640,10 @@ def from_pretrained( if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded pytorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) @@ -660,6 +665,14 @@ def from_pretrained( ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path + # check if an index file exists + elif hf_url_exists(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME): + archive_file = hf_bucket_url( + pretrained_model_name_or_path, + filename=WEIGHTS_INDEX_NAME, + revision=revision, + ) + is_sharded = True else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME archive_file = hf_bucket_url( @@ -780,7 +793,7 @@ def from_pretrained( model = cls(config, *model_args, _do_init=_do_init, **model_kwargs) if from_pt: - state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file) + state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded) else: if is_sharded: diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 4e46298e28a9..91d7ea6f65b9 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -118,6 +118,11 @@ def is_remote_url(url_or_filename): return parsed.scheme in ("http", "https") +def hf_url_exists(path, file): + r = requests.head(os.path.join("https://huggingface.co/", path, "raw/main/", file)) + return r.status_code == requests.codes.ok + + def hf_bucket_url( model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None ) -> str: From 4dcc32ec0e0b9a20dcc72582225bf6b9f0426eb1 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 15:13:10 +0200 Subject: [PATCH 02/10] add small test --- tests/test_modeling_flax_common.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f90615efea36..b4506c4438a9 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,12 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + def test_from_sharded_pt(self): + model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): + assert np.allclose(np.array(p1), np.array(p2)) + def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From 556fa9a3921cc0a4d3e14072b63f75bc94a60393 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:38:47 +0200 Subject: [PATCH 03/10] add cross pt tf flag to test --- tests/test_modeling_flax_common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index b4506c4438a9..f33e0fb5d9e7 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,7 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") From 2e8b241c5eb3eaea711de1e1084164fc8974b71d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:12 +0200 Subject: [PATCH 04/10] fix quality --- src/transformers/modeling_flax_pytorch_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index 7120fad64103..dadc98fc9cda 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -27,8 +27,8 @@ from flax.serialization import from_bytes from flax.traverse_util import flatten_dict, unflatten_dict -from .utils import FLAX_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, logging -from .utils.hub import get_checkpoint_shard_files +from .utils import logging + logger = logging.get_logger(__name__) From cfda646e93059b0f791a9835009475f9ed7956c3 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 17:41:47 +0200 Subject: [PATCH 05/10] style --- src/transformers/modeling_flax_pytorch_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/modeling_flax_pytorch_utils.py b/src/transformers/modeling_flax_pytorch_utils.py index dadc98fc9cda..249f754db726 100644 --- a/src/transformers/modeling_flax_pytorch_utils.py +++ b/src/transformers/modeling_flax_pytorch_utils.py @@ -30,7 +30,6 @@ from .utils import logging - logger = logging.get_logger(__name__) From 4e49a48d225e5d0010c9a074cdbc1e93fffb93da Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 5 Jul 2022 22:31:44 +0200 Subject: [PATCH 06/10] update test with new repo --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f33e0fb5d9e7..8e9b27c51530 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 74a451b8b333989c797198d72a98f098ccdd7e47 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:44:44 +0200 Subject: [PATCH 07/10] fix failing test --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 8e9b27c51530..3f0a8c6c8079 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only") + ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From 708bcb6f4de4eb1af405f08cf043c12e1cb9788d Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 09:45:31 +0200 Subject: [PATCH 08/10] update --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 3f0a8c6c8079..5bc7d97d46bc 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1102,7 +1102,7 @@ def test_checkpoint_sharding_local(self): @is_pt_flax_cross_test def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) - ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-fx-only") + ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): assert np.allclose(np.array(p1), np.array(p2)) From c49bb18d588e21c4e18e68030396e3e0cb384f78 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:15 +0200 Subject: [PATCH 09/10] fix wrong param ordering --- tests/test_modeling_flax_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 5e9c061bc5b7..49505856954c 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,8 +1103,9 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()): - assert np.allclose(np.array(p1), np.array(p2)) + for key,ref_val in flatten_dict(ref_model.params).items(): + val = flatten_dict(model.params)[key] + assert np.allclose(np.array(val), np.array(ref_val)) def test_gradient_checkpointing(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From e20f971f6c83a770dea9ff17ec597587891e8690 Mon Sep 17 00:00:00 2001 From: Arthur Date: Tue, 2 Aug 2022 10:07:37 +0200 Subject: [PATCH 10/10] style --- tests/test_modeling_flax_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 49505856954c..837f874889ae 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1103,7 +1103,7 @@ def test_checkpoint_sharding_local(self): def test_from_sharded_pt(self): model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True) ref_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-fx-only") - for key,ref_val in flatten_dict(ref_model.params).items(): + for key, ref_val in flatten_dict(ref_model.params).items(): val = flatten_dict(model.params)[key] assert np.allclose(np.array(val), np.array(ref_val))