Skip to content

Conversation

@ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Dec 1, 2022

What does this PR do?

Allows loading sharded checkpoints in TF models. Should fix #19965

  • from_pt=True
  • from_flax=True

cc @sgugger just FYI

@ArthurZucker
Copy link
Collaborator Author

ArthurZucker commented Dec 1, 2022

Works great for sharded pytorch since a utility was already implemented. Though we are not gonna push for Flax, would still help to have the support already!

from transformers import TFT5ForConditionalGeneration
MODEL_NAME = "google/flan-t5-xl"
m = TFT5ForConditionalGeneration.from_pretrained(MODEL_NAME, from_pt=True)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 1, 2022

The documentation is not available anymore as the PR was closed or merged.

@ArthurZucker ArthurZucker marked this pull request as ready for review December 2, 2022 15:31
@ArthurZucker
Copy link
Collaborator Author

Just need to remove the # TODOs

@ArthurZucker ArthurZucker requested a review from sgugger December 2, 2022 17:10
Comment on lines 2573 to 2578
elif os.path.isfile(pretrained_model_name_or_path):
archive_file = pretrained_model_name_or_path
is_local = True
elif os.path.isfile(pretrained_model_name_or_path + ".index"):
archive_file = pretrained_model_name_or_path + ".index"
is_local = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code shouldn't be removed, to preserve compatibility with PreTrainedModel.from_pretrained(path_to_a_model_path, config = config)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ArthurZucker ArthurZucker merged commit 84c9bf7 into huggingface:main Dec 5, 2022
mpierrau pushed a commit to mpierrau/transformers that referenced this pull request Dec 15, 2022
* add support for `from_pt`

* add tf_flax utility file

* Update src/transformers/modeling_tf_flax_utils.py

Co-authored-by: Sylvain Gugger <[email protected]>

* remove flax related modifications

* add test

* remove FLAX related commits

* fixup

* remove safetensor todos

* revert deletion

Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Cannot load TensorFlow model from PyTorch weights split to multiple files

3 participants