-
Notifications
You must be signed in to change notification settings - Fork 32k
Support sharded safetensors in TF #29350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
080c24a
aaa0212
43a1b2c
36014c7
ab90852
7eae1b4
0a9cf9a
73d3d75
b09d757
8329061
ffc0e3a
03efce9
f415eb8
e5aace5
e0faed0
a7dbf06
e7a2c24
7ef18f9
d75d69a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,10 +21,24 @@ | |
|
|
||
| import numpy | ||
|
|
||
| from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size | ||
| from .utils import ( | ||
| ExplicitEnum, | ||
| expand_dims, | ||
| is_numpy_array, | ||
| is_safetensors_available, | ||
| is_torch_tensor, | ||
| logging, | ||
| reshape, | ||
| squeeze, | ||
| tensor_size, | ||
| ) | ||
| from .utils import transpose as transpose_func | ||
|
|
||
|
|
||
| if is_safetensors_available(): | ||
| from safetensors import safe_open | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model( | |
| ) | ||
|
|
||
|
|
||
| def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name): | ||
| if len(unexpected_keys) > 0: | ||
| logger.warning( | ||
| "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" | ||
| f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing" | ||
| f" {class_name} from a PyTorch model trained on another task or with another architecture" | ||
| " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" | ||
| f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect" | ||
| " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" | ||
| " BertForSequenceClassification model)." | ||
| ) | ||
| else: | ||
| logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n") | ||
| if len(missing_keys) > 0: | ||
| logger.warning( | ||
| f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the" | ||
| f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" | ||
| " down-stream task to be able to use it for predictions and inference." | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| f"All the weights of {class_name} were initialized from the PyTorch model.\n" | ||
| "If your task is similar to the task the model of the checkpoint was trained on, " | ||
| f"you can already use {class_name} for predictions without further training." | ||
| ) | ||
|
|
||
| if len(mismatched_keys) > 0: | ||
| mismatched_warning = "\n".join( | ||
| [ | ||
| f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | ||
| for key, shape1, shape2 in mismatched_keys | ||
| ] | ||
| ) | ||
| logger.warning( | ||
| f"Some weights of {class_name} were not initialized from the model checkpoint" | ||
| f" are newly initialized because the shapes did not" | ||
| f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | ||
| " to use it for predictions and inference." | ||
| ) | ||
|
|
||
|
|
||
| def load_pytorch_state_dict_in_tf2_model( | ||
| tf_model, | ||
| pt_state_dict, | ||
|
|
@@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model( | |
| _prefix=None, | ||
| tf_to_pt_weight_rename=None, | ||
| ignore_mismatched_sizes=False, | ||
| skip_logger_warnings=False, | ||
| ): | ||
| """Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading | ||
| safetensors archive created with the safe_open() function.""" | ||
|
|
@@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model( | |
| if tf_model._keys_to_ignore_on_load_unexpected is not None: | ||
| for pat in tf_model._keys_to_ignore_on_load_unexpected: | ||
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | ||
| if not skip_logger_warnings: | ||
| _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) | ||
|
|
||
| if len(unexpected_keys) > 0: | ||
| logger.warning( | ||
| "Some weights of the PyTorch model were not used when initializing the TF 2.0 model" | ||
| f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing" | ||
| f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture" | ||
| " (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS" | ||
| f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect" | ||
| " to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a" | ||
| " BertForSequenceClassification model)." | ||
| ) | ||
| else: | ||
| logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n") | ||
| if len(missing_keys) > 0: | ||
| logger.warning( | ||
| f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the" | ||
| f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a" | ||
| " down-stream task to be able to use it for predictions and inference." | ||
| ) | ||
| else: | ||
| logger.warning( | ||
| f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n" | ||
| "If your task is similar to the task the model of the checkpoint was trained on, " | ||
| f"you can already use {tf_model.__class__.__name__} for predictions without further training." | ||
| ) | ||
| if output_loading_info: | ||
| loading_info = { | ||
| "missing_keys": missing_keys, | ||
| "unexpected_keys": unexpected_keys, | ||
| "mismatched_keys": mismatched_keys, | ||
| } | ||
| return tf_model, loading_info | ||
|
|
||
| if len(mismatched_keys) > 0: | ||
| mismatched_warning = "\n".join( | ||
| [ | ||
| f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" | ||
| for key, shape1, shape2 in mismatched_keys | ||
| ] | ||
| ) | ||
| logger.warning( | ||
| f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint" | ||
| f" are newly initialized because the shapes did not" | ||
| f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" | ||
| " to use it for predictions and inference." | ||
| ) | ||
| return tf_model | ||
|
|
||
|
|
||
| def load_sharded_pytorch_safetensors_in_tf2_model( | ||
| tf_model, | ||
| safetensors_shards, | ||
| tf_inputs=None, | ||
| allow_missing_keys=False, | ||
| output_loading_info=False, | ||
| _prefix=None, | ||
| tf_to_pt_weight_rename=None, | ||
| ignore_mismatched_sizes=False, | ||
| ): | ||
| all_loading_infos = [] | ||
| for shard in safetensors_shards: | ||
| with safe_open(shard, framework="tf") as safetensors_archive: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this load from the PT framework if we're "loading pytorch shards in tensorflow models"?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| tf_model, loading_info = load_pytorch_state_dict_in_tf2_model( | ||
| tf_model, | ||
| safetensors_archive, | ||
| tf_inputs=tf_inputs, | ||
| allow_missing_keys=allow_missing_keys, | ||
| output_loading_info=True, | ||
| _prefix=_prefix, | ||
| tf_to_pt_weight_rename=tf_to_pt_weight_rename, | ||
| ignore_mismatched_sizes=ignore_mismatched_sizes, | ||
| skip_logger_warnings=True, # We will emit merged warnings at the end | ||
| ) | ||
| all_loading_infos.append(loading_info) | ||
| # Now we just need to merge the loading info | ||
| # Keys are missing only if they're missing in *every* shard | ||
| missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos])) | ||
| # Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard | ||
| unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], []) | ||
| mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], []) | ||
|
|
||
| _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__) | ||
|
|
||
| if output_loading_info: | ||
| loading_info = { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, unfortunately! The reason is that this function is used both to load shards and non-sharded checkpoints. When it's loading a non-sharded checkpoint, we want to log missing keys immediately. When it's loading a shard, there will always be lots of "missing" keys, but we don't want to log those - instead, we only want to log keys that are missing from every shard, which we will only know after all shards have been loaded. This is handled in the sharded loading function.