Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 55 additions & 51 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ def compute_loss(self, labels, logits):
return loss_fn(next_sentence_label, next_sentence_reduced_logits)


def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
def load_tf_weights(model, resolved_archive_file):
"""
Detect missing and unexpected layers.
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.

Args:
model (:obj:`tf.keras.models.Model`):
Expand All @@ -252,62 +252,60 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file):
missing_layers = []
unexpected_layers = []

# Read the H5 file
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
model_layer_names = set(layer.name for layer in model.layers)
missing_layers = list(model_layer_names - saved_layer_names)
unexpected_layers = list(saved_layer_names - model_layer_names)

for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
saved_weight_names_set = set(
"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names
)
Comment on lines -265 to -267
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line and the line defining symbolic_weights_names were removed. Instead of being filled by

"/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names

they will now be filled by

"/".join(weight_name.split("/")[1:]) for weight_name in saved_weight_names

Why the change from 2: to 1:?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a deeper look and understand the change here. I guess it's because the second index is the the prefix, and since in TF the main layer is named after the prefix, it will remain the same across base models and models with heads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really sorry @LysandreJik I completely forgot to answer :( And yes this is exactly for that :)

symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
symbolic_weights_names = set(
"/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights
)
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
# Retrieve the name of each layer from the H5 file
saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))

return missing_layers, unexpected_layers


def load_tf_weights(model, resolved_archive_file):
"""
Load the TF weights from a H5 file.
# Find the missing layers from the high level list of layers
missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name)

Args:
model (:obj:`tf.keras.models.Model`):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
"""
with h5py.File(resolved_archive_file, "r") as f:
saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names"))
# Find the unexpected layers from the high level list of layers
unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers]))
saved_weight_names_set = set()
symbolic_weights_names = set()
weight_value_tuples = []

# Compute missing and unexpected sub layers
# Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...]
for layer in model.layers:
if layer.name in saved_layer_names:
g = f[layer.name]
saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names")
# if layer_name from the H5 file belongs to the layers from the instantiated model
if layer.name in saved_h5_model_layers_name:
# Get the H5 layer object from its name
h5_layer_object = f[layer.name]
# Get all the weights as a list from the layer object
symbolic_weights = layer.trainable_weights + layer.non_trainable_weights
saved_weight_names_values = {}
saved_weights = {}

for weight_name in saved_weight_names:
# Create a dict from the H5 saved model that looks like {"weight_name": weight_value}
# And a set with only the names
for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"):
# TF names always start with the model name so we ignore it
name = "/".join(weight_name.split("/")[1:])
saved_weight_names_values[name] = np.asarray(g[weight_name])
saved_weights[name] = np.asarray(h5_layer_object[weight_name])

# Add the updated name to the final list for computing missing/unexpected values
saved_weight_names_set.add(name)

# Loop over each weights from the instantiated model and compare with the weights from the H5 file
for symbolic_weight in symbolic_weights:
splited_layers = symbolic_weight.name.split("/")[1:]
symbolic_weight_name = "/".join(splited_layers)
# TF names always start with the model name so we ignore it
symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:])

# here we check if the current weight is among the weights from the H5 file
# If yes, get the weight_value of the corresponding weight from the H5 file
# If not, make the value to None
saved_weight_value = saved_weights.get(symbolic_weight_name, None)

if symbolic_weight_name in saved_weight_names_values:
saved_weight_value = saved_weight_names_values[symbolic_weight_name]
# Add the updated name to the final list for computing missing/unexpected values
symbolic_weights_names.add(symbolic_weight_name)

# If the current weight is found
if saved_weight_value is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why adding this test and the line before saved_weight_value = None? There is no other place saved_weight_value is used so this just adds extra complexity in the code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree here I think we can just put all the code below if saved_weight_value is not None: directly under if symbolic_weight_name in saved_weight_names_values: no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we could do saved_weight_value = saved_weight_value.get(symbolic_weight_name, None)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this?

# Check if the shape of the current weight and the one from the H5 file are different
if K.int_shape(symbolic_weight) != saved_weight_value.shape:
# If yes we reshape the weight from the H5 file accordingly to the current weight
# If the two shapes are not compatible we raise an issue
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
Expand All @@ -316,10 +314,18 @@ def load_tf_weights(model, resolved_archive_file):
else:
array = saved_weight_value

# We create the tuple that will be loaded and add it to the final list
weight_value_tuples.append((symbolic_weight, array))

# Load all the weights
K.batch_set_value(weight_value_tuples)

# Compute the missing and unexpected layers
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))

return missing_layers, unexpected_layers


class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""
Expand Down Expand Up @@ -728,7 +734,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
load_tf_weights(model, resolved_archive_file)
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
Expand All @@ -737,8 +743,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):

model(model.dummy_inputs, training=False) # Make sure restore ops are run

missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file)

if cls.authorized_missing_keys is not None:
for pat in cls.authorized_missing_keys:
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
Expand Down Expand Up @@ -1034,18 +1038,18 @@ def call(self, inputs, cls_index=None, training=False):
return output


def shape_list(x: tf.Tensor) -> List[int]:
def shape_list(tensor: tf.Tensor) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.

Args:
x (:obj:`tf.Tensor`): The tensor we want the shape of.
tensor (:obj:`tf.Tensor`): The tensor we want the shape of.

Returns:
:obj:`List[int]`: The shape of the tensor as a list.
"""
static = x.shape.as_list()
dynamic = tf.shape(x)
static = tensor.shape.as_list()
dynamic = tf.shape(tensor)
return [dynamic[i] if s is None else s for i, s in enumerate(static)]


Expand Down