[Serialization] add argument to pass shared tensors names to drop when saving #2696
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Related to transformers#35080.
Currently, when saving a torch state dict with shared tensors, the key to keep is chosen alphabetically which might not be the desired behavior. This PR simply adds an optional
state_dict_keys_to_discard
argument to bothsave_torch_state_dict()
andsave_torch_model()
that allows users to specify which keys should be discarded in priority when duplicates are found.Note : The logic behind choosing which key should be discarded in priority is framework/model-specific and should be handled by the user. For example, for
transformers.PreTrainedModel
models, this is specified in the_tied_weights_keys
attribute, example usage:or with
save_torch_state_dict()