Skip to content

Conversation

@Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Feb 28, 2024

Right now our TF safetensors loading doesn't support sharded checkpoints, which is a problem as more and more big models move to safetensors weights only! This is currently blocking @a8nova's PR at #26870.

As sharded safetensors saving for TF was also missing, this PR adds that as well, and expands the tests to cover both.

TODO:

  • Check we prioritize safetensors vs. TF weights correctly when both are present (and make sure behaviour matches for sharded vs. unsharded)
  • Add tests for sharded safetensors loading from both PT and TF format
  • Add test for saving sharded safetensors

@Rocketknight1 Rocketknight1 mentioned this pull request Feb 28, 2024
5 tasks
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1 Rocketknight1 marked this pull request as ready for review March 4, 2024 18:02
@Rocketknight1
Copy link
Member Author

This should be ready for review! cc @ArthurZucker @a8nova

@ArthurZucker ArthurZucker self-requested a review March 6, 2024 03:37
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! When would you have in mind an eventual switch to safetensors serialization by default?

if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if not skip_logger_warnings:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, unfortunately! The reason is that this function is used both to load shards and non-sharded checkpoints. When it's loading a non-sharded checkpoint, we want to log missing keys immediately. When it's loading a shard, there will always be lots of "missing" keys, but we don't want to log those - instead, we only want to log keys that are missing from every shard, which we will only know after all shards have been loaded. This is handled in the sharded loading function.

):
all_loading_infos = []
for shard in safetensors_shards:
with safe_open(shard, framework="tf") as safetensors_archive:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this load from the PT framework if we're "loading pytorch shards in tensorflow models"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_open(framework="tf") just loads the tensors as tf.Tensor instead of torch.Tensor - the actual value of the tensor is unchanged. However, we still need to handle weight renaming + transposes, so we still need a pt-to-tf function.

for p1, p2 in zip(model.weights, ref_model.weights):
assert np.allclose(p1.numpy(), p2.numpy())

@require_safetensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safetensors is now a base dependency so maybe we should eventually just remove all of these

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense - do you want me to just do it in this PR?

@Rocketknight1
Copy link
Member Author

@LysandreJik I think now that we have proper support we can switch to safetensors by default immediately, either in this PR or in a follow-up.

@LysandreJik
Copy link
Member

I'd wait for a few weeks just to ensure we don't have reports of failure and switch for the next version. WDYT?

@Rocketknight1
Copy link
Member Author

Sounds good to me!

@Rocketknight1 Rocketknight1 force-pushed the supported_sharded_safetensors_loading_in_tf branch from 46e5c56 to efdb604 Compare March 7, 2024 14:33
@Rocketknight1
Copy link
Member Author

@LysandreJik is there anything else to be resolved before I merge this? (Except for the failing test, but that's not specific to this PR)

@Rocketknight1 Rocketknight1 force-pushed the supported_sharded_safetensors_loading_in_tf branch 3 times, most recently from a1001c8 to a9f240e Compare March 12, 2024 16:52
Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this and enabling sharded support!

Just a few small comments / questions

Comment on lines -534 to -517
# This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the .h5 sharded weights
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we still want this test to make sure things are backwards compatible for now - I can load sharded h5 files even if safetensor weights are available?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-added!

# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
# the size asked for (since we count parameters)
if size >= max_size_int + 50000:
with h5py.File(shard_file, "r") as state_file:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this represent here?

Copy link
Member Author

@Rocketknight1 Rocketknight1 Mar 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch - that was copied from the h5 test, and wouldn't work for safetensors - we just got lucky that it wasn't called in the tests anyway. I removed it!

mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])

if not skip_logger_warnings:
if len(unexpected_keys) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, all these checks are the same as the ones above. Can we abstract these out to e.g. `validate_keys(unmatched_keys, missing_keys, mismatched_keys) and call that in both of the functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])

if not skip_logger_warnings:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that we want to have skip_logger_warning=True when calling load_pytorch_state_dict_in_tf2_model here, but I don't see why we're enabling skipping here? Silencing the logging warnings should really be a hidden functionality (maybe with a param _skip_logger_warnings and not something people calling either function use

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, actually - this function is only called once and warnings are always emitted, so the argument isn't needed at all. I removed it!

@Rocketknight1 Rocketknight1 force-pushed the supported_sharded_safetensors_loading_in_tf branch from 55519b6 to e7a2c24 Compare March 15, 2024 14:49
@Rocketknight1
Copy link
Member Author

All comments addressed @amyeroberts! I think we should be ready now.

Copy link
Contributor

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this support!

return tf_model


def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - definition should go above the lines of code where it's used i.e. before load_pytorch_state_dict_in_tf2_model

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!



def tf_shard_checkpoint(weights, max_shard_size="10GB"):
def tf_shard_checkpoint(weights, max_shard_size="10GB", weights_name: str = TF2_WEIGHTS_NAME):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have the default shard size match the one in load_tf_weights?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got here by following the torch code, which also has the same issue! Specifically, save_pretrained() has a default size of 5GB, but the actual checkpoint sharding methods have a default size of 10GB. In general, though, the value passed from save_pretrained() will override those values.

It's a very minor detail either way, since I think both 5GB and 10GB shards work fine! We could consider standardizing everything at some point, but I don't think it's a high priority.

with safe_open(resolved_archive_file, framework="tf") as f:
safetensors_metadata = f.metadata()
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax"]:
if safetensors_metadata is None or safetensors_metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we might want a constant e.g. SUPPORTED_SAFE_FORMATS = ["pt", "tf", "flax", "mlx"] so we don't have to update this in several locations here and for PT (for another PR)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

@Rocketknight1
Copy link
Member Author

This looks ready to go, cc @a8nova! We have a branch cut + release planned on Monday, though, and since this touches a lot of core code I don't want to merge it right before a release. Instead, I suggest merging it right after the branch cut, and then we can finalize and merge the PRs that are blocked by it: TF-IDEFICS, TF-Gemma and possibly the Mistral/Mixtral PRs if @ariG23498 can have one of them ready by then (no stress, obviously!)

Then we could launch all the new Keras models together in the following release and do a section in the release notes about them, crediting @a8nova and @ariG23498?

@Rocketknight1 Rocketknight1 merged commit 11ef35e into main Mar 20, 2024
@Rocketknight1 Rocketknight1 deleted the supported_sharded_safetensors_loading_in_tf branch March 20, 2024 14:22
@Rocketknight1
Copy link
Member Author

cc @a8nova and @ariG23498, this has now been merged. If you rebase your PRs, that should resolve any issues with sharded safetensors loading!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants