-
Notifications
You must be signed in to change notification settings - Fork 31.7k
New TF loading weights #8490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New TF loading weights #8490
Changes from all commits
c2010c0
3d842c3
b324000
8a89d8c
e327cfe
afa0411
49669ed
6567037
08f998c
ac6785f
315249e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -236,9 +236,9 @@ def compute_loss(self, labels, logits): | |
| return loss_fn(next_sentence_label, next_sentence_reduced_logits) | ||
|
|
||
|
|
||
| def detect_tf_missing_unexpected_layers(model, resolved_archive_file): | ||
| def load_tf_weights(model, resolved_archive_file): | ||
| """ | ||
| Detect missing and unexpected layers. | ||
| Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes. | ||
|
|
||
| Args: | ||
| model (:obj:`tf.keras.models.Model`): | ||
|
|
@@ -252,62 +252,60 @@ def detect_tf_missing_unexpected_layers(model, resolved_archive_file): | |
| missing_layers = [] | ||
| unexpected_layers = [] | ||
|
|
||
| # Read the H5 file | ||
| with h5py.File(resolved_archive_file, "r") as f: | ||
| saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
| model_layer_names = set(layer.name for layer in model.layers) | ||
| missing_layers = list(model_layer_names - saved_layer_names) | ||
| unexpected_layers = list(saved_layer_names - model_layer_names) | ||
|
|
||
| for layer in model.layers: | ||
| if layer.name in saved_layer_names: | ||
| g = f[layer.name] | ||
| saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") | ||
| saved_weight_names_set = set( | ||
| "/".join(weight_name.split("/")[2:]) for weight_name in saved_weight_names | ||
| ) | ||
| symbolic_weights = layer.trainable_weights + layer.non_trainable_weights | ||
| symbolic_weights_names = set( | ||
| "/".join(symbolic_weight.name.split("/")[2:]) for symbolic_weight in symbolic_weights | ||
| ) | ||
| missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) | ||
| unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) | ||
| # Retrieve the name of each layer from the H5 file | ||
| saved_h5_model_layers_name = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
|
|
||
| return missing_layers, unexpected_layers | ||
|
|
||
|
|
||
| def load_tf_weights(model, resolved_archive_file): | ||
| """ | ||
| Load the TF weights from a H5 file. | ||
| # Find the missing layers from the high level list of layers | ||
| missing_layers = list(set([layer.name for layer in model.layers]) - saved_h5_model_layers_name) | ||
|
|
||
| Args: | ||
| model (:obj:`tf.keras.models.Model`): | ||
| The model to load the weights into. | ||
| resolved_archive_file (:obj:`str`): | ||
| The location of the H5 file. | ||
| """ | ||
| with h5py.File(resolved_archive_file, "r") as f: | ||
| saved_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) | ||
| # Find the unexpected layers from the high level list of layers | ||
| unexpected_layers = list(saved_h5_model_layers_name - set([layer.name for layer in model.layers])) | ||
| saved_weight_names_set = set() | ||
| symbolic_weights_names = set() | ||
| weight_value_tuples = [] | ||
|
|
||
| # Compute missing and unexpected sub layers | ||
| # Store the weights in list of tuples that looks like [(weight_object, value_of_weight),...] | ||
| for layer in model.layers: | ||
| if layer.name in saved_layer_names: | ||
| g = f[layer.name] | ||
| saved_weight_names = hdf5_format.load_attributes_from_hdf5_group(g, "weight_names") | ||
| # if layer_name from the H5 file belongs to the layers from the instantiated model | ||
| if layer.name in saved_h5_model_layers_name: | ||
| # Get the H5 layer object from its name | ||
| h5_layer_object = f[layer.name] | ||
| # Get all the weights as a list from the layer object | ||
| symbolic_weights = layer.trainable_weights + layer.non_trainable_weights | ||
| saved_weight_names_values = {} | ||
| saved_weights = {} | ||
|
|
||
| for weight_name in saved_weight_names: | ||
| # Create a dict from the H5 saved model that looks like {"weight_name": weight_value} | ||
| # And a set with only the names | ||
| for weight_name in hdf5_format.load_attributes_from_hdf5_group(h5_layer_object, "weight_names"): | ||
| # TF names always start with the model name so we ignore it | ||
| name = "/".join(weight_name.split("/")[1:]) | ||
| saved_weight_names_values[name] = np.asarray(g[weight_name]) | ||
| saved_weights[name] = np.asarray(h5_layer_object[weight_name]) | ||
|
|
||
| # Add the updated name to the final list for computing missing/unexpected values | ||
| saved_weight_names_set.add(name) | ||
|
|
||
| # Loop over each weights from the instantiated model and compare with the weights from the H5 file | ||
| for symbolic_weight in symbolic_weights: | ||
| splited_layers = symbolic_weight.name.split("/")[1:] | ||
| symbolic_weight_name = "/".join(splited_layers) | ||
| # TF names always start with the model name so we ignore it | ||
| symbolic_weight_name = "/".join(symbolic_weight.name.split("/")[1:]) | ||
|
|
||
| # here we check if the current weight is among the weights from the H5 file | ||
| # If yes, get the weight_value of the corresponding weight from the H5 file | ||
| # If not, make the value to None | ||
| saved_weight_value = saved_weights.get(symbolic_weight_name, None) | ||
|
|
||
| if symbolic_weight_name in saved_weight_names_values: | ||
| saved_weight_value = saved_weight_names_values[symbolic_weight_name] | ||
| # Add the updated name to the final list for computing missing/unexpected values | ||
| symbolic_weights_names.add(symbolic_weight_name) | ||
|
|
||
| # If the current weight is found | ||
| if saved_weight_value is not None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why adding this test and the line before
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree here I think we can just put all the code below
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively we could do
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like this? |
||
| # Check if the shape of the current weight and the one from the H5 file are different | ||
| if K.int_shape(symbolic_weight) != saved_weight_value.shape: | ||
| # If yes we reshape the weight from the H5 file accordingly to the current weight | ||
| # If the two shapes are not compatible we raise an issue | ||
| try: | ||
| array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight)) | ||
| except AssertionError as e: | ||
|
|
@@ -316,10 +314,18 @@ def load_tf_weights(model, resolved_archive_file): | |
| else: | ||
| array = saved_weight_value | ||
|
|
||
| # We create the tuple that will be loaded and add it to the final list | ||
| weight_value_tuples.append((symbolic_weight, array)) | ||
|
|
||
| # Load all the weights | ||
| K.batch_set_value(weight_value_tuples) | ||
|
|
||
| # Compute the missing and unexpected layers | ||
| missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set)) | ||
| unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names)) | ||
|
|
||
| return missing_layers, unexpected_layers | ||
|
|
||
|
|
||
| class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): | ||
| r""" | ||
|
|
@@ -728,7 +734,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| # 'by_name' allow us to do transfer learning by skipping/adding layers | ||
| # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 | ||
| try: | ||
| load_tf_weights(model, resolved_archive_file) | ||
| missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file) | ||
| except OSError: | ||
| raise OSError( | ||
| "Unable to load weights from h5 file. " | ||
|
|
@@ -737,8 +743,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
|
|
||
| model(model.dummy_inputs, training=False) # Make sure restore ops are run | ||
|
|
||
| missing_keys, unexpected_keys = detect_tf_missing_unexpected_layers(model, resolved_archive_file) | ||
|
|
||
| if cls.authorized_missing_keys is not None: | ||
| for pat in cls.authorized_missing_keys: | ||
| missing_keys = [k for k in missing_keys if re.search(pat, k) is None] | ||
|
|
@@ -1034,18 +1038,18 @@ def call(self, inputs, cls_index=None, training=False): | |
| return output | ||
|
|
||
|
|
||
| def shape_list(x: tf.Tensor) -> List[int]: | ||
| def shape_list(tensor: tf.Tensor) -> List[int]: | ||
| """ | ||
| Deal with dynamic shape in tensorflow cleanly. | ||
|
|
||
| Args: | ||
| x (:obj:`tf.Tensor`): The tensor we want the shape of. | ||
| tensor (:obj:`tf.Tensor`): The tensor we want the shape of. | ||
|
|
||
| Returns: | ||
| :obj:`List[int]`: The shape of the tensor as a list. | ||
| """ | ||
| static = x.shape.as_list() | ||
| dynamic = tf.shape(x) | ||
| static = tensor.shape.as_list() | ||
| dynamic = tf.shape(tensor) | ||
| return [dynamic[i] if s is None else s for i, s in enumerate(static)] | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line and the line defining
symbolic_weights_nameswere removed. Instead of being filled bythey will now be filled by
Why the change from
2:to1:?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took a deeper look and understand the change here. I guess it's because the second index is the the prefix, and since in TF the main layer is named after the prefix, it will remain the same across base models and models with heads.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm really sorry @LysandreJik I completely forgot to answer :( And yes this is exactly for that :)