From 9b52d1ddf2c6ed3a1c78743941582e96e79d962a Mon Sep 17 00:00:00 2001 From: Sea-Snell Date: Sun, 17 Jul 2022 22:53:02 -0400 Subject: [PATCH 1/3] load from sharded pytorch checkpoint for flax model --- src/transformers/modeling_flax_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 77eaa900de62..82e05e5ef324 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -639,6 +639,10 @@ def from_pretrained( if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) + is_sharded = True elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): # Load from a Flax checkpoint archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) From 06fc890200f57ad3adbaf3ec7d83435e0992419b Mon Sep 17 00:00:00 2001 From: Sea-Snell Date: Thu, 21 Jul 2022 10:39:46 -0400 Subject: [PATCH 2/3] added import from utils --- src/transformers/modeling_flax_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 82e05e5ef324..b915ecc63dee 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -43,6 +43,7 @@ FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, EntryNotFoundError, PushToHubMixin, RepositoryNotFoundError, From ce558a89f878180c97874d8b4e0b0063c0af7071 Mon Sep 17 00:00:00 2001 From: Sea-Snell Date: Thu, 21 Jul 2022 09:42:40 -0700 Subject: [PATCH 3/3] ran make style --- src/transformers/modeling_flax_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index b915ecc63dee..01a836191df2 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -42,8 +42,8 @@ FLAX_WEIGHTS_INDEX_NAME, FLAX_WEIGHTS_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, + WEIGHTS_INDEX_NAME, WEIGHTS_NAME, - WEIGHTS_INDEX_NAME, EntryNotFoundError, PushToHubMixin, RepositoryNotFoundError,