Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 102 additions & 38 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@

import numpy

from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze, tensor_size
from .utils import (
ExplicitEnum,
expand_dims,
is_numpy_array,
is_safetensors_available,
is_torch_tensor,
logging,
reshape,
squeeze,
tensor_size,
)
from .utils import transpose as transpose_func


if is_safetensors_available():
from safetensors import safe_open


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -247,6 +261,47 @@ def load_pytorch_weights_in_tf2_model(
)


def _log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name):
if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {class_name}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {class_name} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {class_name} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {class_name}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {class_name} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {class_name} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {class_name} for predictions without further training."
)

if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {class_name} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)


def load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
Expand All @@ -256,6 +311,7 @@ def load_pytorch_state_dict_in_tf2_model(
_prefix=None,
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
skip_logger_warnings=False,
):
"""Load a pytorch state_dict in a TF 2.0 model. pt_state_dict can be either an actual dict or a lazy-loading
safetensors archive created with the safe_open() function."""
Expand Down Expand Up @@ -373,45 +429,53 @@ def load_pytorch_state_dict_in_tf2_model(
if tf_model._keys_to_ignore_on_load_unexpected is not None:
for pat in tf_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if not skip_logger_warnings:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, unfortunately! The reason is that this function is used both to load shards and non-sharded checkpoints. When it's loading a non-sharded checkpoint, we want to log missing keys immediately. When it's loading a shard, there will always be lots of "missing" keys, but we don't want to log those - instead, we only want to log keys that are missing from every shard, which we will only know after all shards have been loaded. This is handled in the sharded loading function.

_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)

if len(unexpected_keys) > 0:
logger.warning(
"Some weights of the PyTorch model were not used when initializing the TF 2.0 model"
f" {tf_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
f" {tf_model.__class__.__name__} from a PyTorch model trained on another task or with another architecture"
" (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n- This IS"
f" NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect"
" to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a"
" BertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the"
f" PyTorch model and are newly initialized: {missing_keys}\nYou should probably TRAIN this model on a"
" down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
}
return tf_model, loading_info

if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {tf_model.__class__.__name__} were not initialized from the model checkpoint"
f" are newly initialized because the shapes did not"
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
" to use it for predictions and inference."
)
return tf_model


def load_sharded_pytorch_safetensors_in_tf2_model(
tf_model,
safetensors_shards,
tf_inputs=None,
allow_missing_keys=False,
output_loading_info=False,
_prefix=None,
tf_to_pt_weight_rename=None,
ignore_mismatched_sizes=False,
):
all_loading_infos = []
for shard in safetensors_shards:
with safe_open(shard, framework="tf") as safetensors_archive:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this load from the PT framework if we're "loading pytorch shards in tensorflow models"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

safe_open(framework="tf") just loads the tensors as tf.Tensor instead of torch.Tensor - the actual value of the tensor is unchanged. However, we still need to handle weight renaming + transposes, so we still need a pt-to-tf function.

tf_model, loading_info = load_pytorch_state_dict_in_tf2_model(
tf_model,
safetensors_archive,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=True,
_prefix=_prefix,
tf_to_pt_weight_rename=tf_to_pt_weight_rename,
ignore_mismatched_sizes=ignore_mismatched_sizes,
skip_logger_warnings=True, # We will emit merged warnings at the end
)
all_loading_infos.append(loading_info)
# Now we just need to merge the loading info
# Keys are missing only if they're missing in *every* shard
missing_keys = sorted(set.intersection(*[set(info["missing_keys"]) for info in all_loading_infos]))
# Keys are unexpected/mismatched if they're unexpected/mismatched in *any* shard
unexpected_keys = sum([info["unexpected_keys"] for info in all_loading_infos], [])
mismatched_keys = sum([info["mismatched_keys"] for info in all_loading_infos], [])

_log_key_warnings(missing_keys, unexpected_keys, mismatched_keys, class_name=tf_model.__class__.__name__)

if output_loading_info:
loading_info = {
Expand Down
Loading