Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/transformers/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
FLAX_WEIGHTS_INDEX_NAME,
FLAX_WEIGHTS_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
EntryNotFoundError,
PushToHubMixin,
Expand Down Expand Up @@ -639,6 +640,10 @@ def from_pretrained(
if from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW this will only work if the WEIGHTS_INDEX_NAME file is locally present, and does not include the hub.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's just finalize yours. What's left to do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just fixing the tests, and making sure that the tests are actually good. Should be quiet straightforward☺️🙌

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sea-Snell we just need to fix test_from_sharded_pt which is failing because the model used for comparison are not the same! Simply using the same model (either upload a new model using the same config but shard it with save_pretrained and setting the max_shard_size to 150KB should do the trick.

# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)
is_sharded = True
Comment on lines +643 to +646
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just wondering if could add a small test?
You can use hf-internal-testing/tiny-random-bert-sharded/.
Also I opened #18026 which is really similar, which adds

    @is_pt_flax_cross_test
    def test_from_sharded_pt(self):
        model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
        ref_model = FlaxBertModel.from_pretrained("ArthurZ/tiny-random-bert-flax-only")
        for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(ref_model.params).values()):
            assert np.allclose(np.array(p1), np.array(p2))

Was not really aware that the conversion would be straight forward let me have a look

elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)):
# Load from a Flax checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)
Expand Down