-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Support sharded safetensors in TF #29350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sharded safetensors in TF #29350
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This should be ready for review! cc @ArthurZucker @a8nova |
LysandreJik
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! When would you have in mind an eventual switch to safetensors serialization by default?
| if tf_model._keys_to_ignore_on_load_unexpected is not None: | ||
| for pat in tf_model._keys_to_ignore_on_load_unexpected: | ||
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
| if not skip_logger_warnings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, unfortunately! The reason is that this function is used both to load shards and non-sharded checkpoints. When it's loading a non-sharded checkpoint, we want to log missing keys immediately. When it's loading a shard, there will always be lots of "missing" keys, but we don't want to log those - instead, we only want to log keys that are missing from every shard, which we will only know after all shards have been loaded. This is handled in the sharded loading function.
| ): | ||
| all_loading_infos = [] | ||
| for shard in safetensors_shards: | ||
| with safe_open(shard, framework="tf") as safetensors_archive: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this load from the PT framework if we're "loading pytorch shards in tensorflow models"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safe_open(framework="tf") just loads the tensors as tf.Tensor instead of torch.Tensor - the actual value of the tensor is unchanged. However, we still need to handle weight renaming + transposes, so we still need a pt-to-tf function.
| for p1, p2 in zip(model.weights, ref_model.weights): | ||
| assert np.allclose(p1.numpy(), p2.numpy()) | ||
|
|
||
| @require_safetensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
safetensors is now a base dependency so maybe we should eventually just remove all of these
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense - do you want me to just do it in this PR?
|
@LysandreJik I think now that we have proper support we can switch to safetensors by default immediately, either in this PR or in a follow-up. |
|
I'd wait for a few weeks just to ensure we don't have reports of failure and switch for the next version. WDYT? |
|
Sounds good to me! |
46e5c56 to
efdb604
Compare
|
@LysandreJik is there anything else to be resolved before I merge this? (Except for the failing test, but that's not specific to this PR) |
a1001c8 to
a9f240e
Compare
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this and enabling sharded support!
Just a few small comments / questions
| # This should not raise even if there are two types of sharded weights | ||
| # This should discard the safetensors weights in favor of the .h5 sharded weights | ||
| TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we still want this test to make sure things are backwards compatible for now - I can load sharded h5 files even if safetensor weights are available?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-added!
tests/test_modeling_tf_utils.py
Outdated
| # Note: pickle adds some junk so the weight of the file can end up being slightly bigger than | ||
| # the size asked for (since we count parameters) | ||
| if size >= max_size_int + 50000: | ||
| with h5py.File(shard_file, "r") as state_file: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this represent here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch - that was copied from the h5 test, and wouldn't work for safetensors - we just got lucky that it wasn't called in the tests anyway. I removed it!
| mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) | ||
|
|
||
| if not skip_logger_warnings: | ||
| if len(unexpected_keys) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT, all these checks are the same as the ones above. Can we abstract these out to e.g. `validate_keys(unmatched_keys, missing_keys, mismatched_keys) and call that in both of the functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
| unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) | ||
| mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) | ||
|
|
||
| if not skip_logger_warnings: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that we want to have skip_logger_warning=True when calling load_pytorch_state_dict_in_tf2_model here, but I don't see why we're enabling skipping here? Silencing the logging warnings should really be a hidden functionality (maybe with a param _skip_logger_warnings and not something people calling either function use
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, actually - this function is only called once and warnings are always emitted, so the argument isn't needed at all. I removed it!
55519b6 to
e7a2c24
Compare
|
All comments addressed @amyeroberts! I think we should be ready now. |
amyeroberts
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this support!
| return tf_model | ||
|
|
||
|
|
||
| def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - definition should go above the lines of code where it's used i.e. before load_pytorch_state_dict_in_tf2_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
|
|
||
|
|
||
| def tf_shard_checkpoint(weights, max_shard_size="10GB"): | ||
| def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have the default shard size match the one in load_tf_weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got here by following the torch code, which also has the same issue! Specifically, save_pretrained() has a default size of 5GB, but the actual checkpoint sharding methods have a default size of 10GB. In general, though, the value passed from save_pretrained() will override those values.
It's a very minor detail either way, since I think both 5GB and 10GB shards work fine! We could consider standardizing everything at some point, but I don't think it's a high priority.
| with safe_open(resolved_archive_file, framework="tf") as f: | ||
| safetensors_metadata = f.metadata() | ||
| if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]: | ||
| if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we might want a constant e.g. SUPPORTED_SAFE_FORMATS = ["pt", "tf", "flax", "mlx"] so we don't have to update this in several locations here and for PT (for another PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed!
Co-authored-by: amyeroberts <[email protected]>
|
This looks ready to go, cc @a8nova! We have a branch cut + release planned on Monday, though, and since this touches a lot of core code I don't want to merge it right before a release. Instead, I suggest merging it right after the branch cut, and then we can finalize and merge the PRs that are blocked by it: TF-IDEFICS, TF-Gemma and possibly the Mistral/Mixtral PRs if @ariG23498 can have one of them ready by then (no stress, obviously!) Then we could launch all the new Keras models together in the following release and do a section in the release notes about them, crediting @a8nova and @ariG23498? |
|
cc @a8nova and @ariG23498, this has now been merged. If you rebase your PRs, that should resolve any issues with sharded safetensors loading! |
Right now our TF safetensors loading doesn't support sharded checkpoints, which is a problem as more and more big models move to safetensors weights only! This is currently blocking @a8nova's PR at #26870.
As sharded safetensors saving for TF was also missing, this PR adds that as well, and expands the tests to cover both.
TODO: