From 3a816e76b0588e4a7bfb7439d65f0b121c22c009 Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 3 Apr 2023 17:01:29 +0100 Subject: [PATCH 1/8] Fix inverted conditional in TF common test! --- tests/test_modeling_tf_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index dc3bd8d6b887..207f7295255e 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -699,7 +699,7 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): # For some models (e.g. base models), there is no label returned. # Set the input dict to `None` to avoid check outputs twice for the same input dicts. - if set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()): + if not set(tf_inputs_dict_with_labels.keys()).symmetric_difference(tf_inputs_dict.keys()): tf_inputs_dict_with_labels = None # Check we can load pt model in tf and vice-versa with model => model functions From ee9b8c5f2c3e856bd4740eff41a3be8352a2eb7b Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 14:05:31 +0100 Subject: [PATCH 2/8] Make the same change in the PT tests file --- tests/test_modeling_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f4270363be68..113ee8f1c182 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2030,7 +2030,7 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): # For some models (e.g. base models), there is no label returned. # Set the input dict to `None` to avoid check outputs twice for the same input dicts. - if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()): + if not set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()): pt_inputs_dict_with_labels = None # Check we can load pt model in tf and vice-versa with model => model functions From 1ce4421f14fee4f23e241b2946df72bedc4c0169 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 16:31:19 +0100 Subject: [PATCH 3/8] Make sure hidden states for GPT2 have the same output shape in PT/TF --- src/transformers/models/gpt2/modeling_tf_gpt2.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index a80b2d4d33d6..d168ec3593df 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -1051,6 +1051,12 @@ def call( ) hidden_states = transformer_outputs[0] hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) + if output_hidden_states: + # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the + # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) + all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,) + else: + all_hidden_states = None lm_logits = self.transformer.wte(hidden_states, mode="linear") mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = tf.squeeze(mc_logits, axis=-1) @@ -1062,7 +1068,7 @@ def call( logits=lm_logits, mc_logits=mc_logits, past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, + hidden_states=all_hidden_states, attentions=transformer_outputs.attentions, ) From 12cbb740ad2a089347130c67f0e7e35bd800040d Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 16:43:36 +0100 Subject: [PATCH 4/8] Minor fix to PT implementation of token classification loss --- src/transformers/models/esm/modeling_esm.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index e53f87f6ce2c..9505a15c3a05 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -1228,16 +1228,7 @@ def forward( loss = None if labels is not None: loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels) - active_labels = torch.where( - active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) - ) - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + outputs[2:] From b9d40b5f63243fc7e04b918d03bb91f97334e7da Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 17:38:59 +0100 Subject: [PATCH 5/8] Skip loss equivalence test for TFHubert because it keeps overflowing to inf --- .../models/hubert/test_modeling_tf_hubert.py | 113 +++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index d5164b6069e5..ad6aed4f7496 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -18,12 +18,14 @@ import inspect import math import unittest +import os +import tempfile import numpy as np import pytest from transformers import is_tf_available -from transformers.testing_utils import require_soundfile, require_tf, slow +from transformers.testing_utils import require_soundfile, require_tf, slow, is_pt_tf_cross_test from ...test_configuration_common import ConfigTester from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor @@ -333,6 +335,60 @@ def test_keras_fit(self): # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC pass + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + # We override the base test here to skip loss calculation for Hubert models because the loss is massive with + # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT + import transformers + import torch + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency + # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. + # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it. + self._make_attention_mask_non_null(inputs_dict) + + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model( + tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys + ) + pt_model = transformers.load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( + tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) @require_tf class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): @@ -458,6 +514,61 @@ def test_keras_fit(self): # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC pass + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + # We override the base test here to skip loss calculation for Hubert models because the loss is massive with + # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT + import transformers + import torch + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency + # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. + # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it. + self._make_attention_mask_non_null(inputs_dict) + + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model( + tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys + ) + pt_model = transformers.load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( + tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + @require_tf class TFHubertUtilsTest(unittest.TestCase): From 28cc6502c5bb3ea5a52ea7018116f0ed91bfb758 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 18:05:57 +0100 Subject: [PATCH 6/8] Compute LM loss for TF the (weird) way it's computed in PT --- src/transformers/models/xglm/modeling_tf_xglm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py index 1dac55651563..c17770860ce6 100644 --- a/src/transformers/models/xglm/modeling_tf_xglm.py +++ b/src/transformers/models/xglm/modeling_tf_xglm.py @@ -953,9 +953,8 @@ def call( loss = None if labels is not None: # shift labels to the left and cut last logit token - shifted_logits = lm_logits[:, :-1] - labels = labels[:, 1:] - loss = self.hf_compute_loss(labels, shifted_logits) + labels = tf.concat([labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))], axis=-1) + loss = self.hf_compute_loss(labels, lm_logits) if not return_dict: output = (lm_logits,) + outputs[1:] From 65f48a7ef61116864a90f5c411c40fd631a3e582 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 18:20:48 +0100 Subject: [PATCH 7/8] Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert --- .../models/xglm/modeling_tf_xglm.py | 5 +- .../models/hubert/test_modeling_tf_hubert.py | 11 +- .../wav2vec2/test_modeling_tf_wav2vec2.py | 115 ++++++++++++++++++ 3 files changed, 126 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/xglm/modeling_tf_xglm.py b/src/transformers/models/xglm/modeling_tf_xglm.py index c17770860ce6..c07bafe240c6 100644 --- a/src/transformers/models/xglm/modeling_tf_xglm.py +++ b/src/transformers/models/xglm/modeling_tf_xglm.py @@ -953,7 +953,10 @@ def call( loss = None if labels is not None: # shift labels to the left and cut last logit token - labels = tf.concat([labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))], axis=-1) + labels = tf.concat( + [labels[:, 1:], tf.fill((labels.shape[0], 1), tf.cast(self.config.pad_token_id, labels.dtype))], + axis=-1, + ) loss = self.hf_compute_loss(labels, lm_logits) if not return_dict: diff --git a/tests/models/hubert/test_modeling_tf_hubert.py b/tests/models/hubert/test_modeling_tf_hubert.py index ad6aed4f7496..a48ed0634e84 100644 --- a/tests/models/hubert/test_modeling_tf_hubert.py +++ b/tests/models/hubert/test_modeling_tf_hubert.py @@ -17,15 +17,15 @@ import copy import inspect import math -import unittest import os import tempfile +import unittest import numpy as np import pytest from transformers import is_tf_available -from transformers.testing_utils import require_soundfile, require_tf, slow, is_pt_tf_cross_test +from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow from ...test_configuration_common import ConfigTester from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor @@ -339,9 +339,10 @@ def test_keras_fit(self): def test_pt_tf_model_equivalence(self, allow_missing_keys=False): # We override the base test here to skip loss calculation for Hubert models because the loss is massive with # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT - import transformers import torch + import transformers + for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -390,6 +391,7 @@ def test_pt_tf_model_equivalence(self, allow_missing_keys=False): # Original test: check without `labels` self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + @require_tf class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase): all_model_classes = (TFHubertModel, TFHubertForCTC) if is_tf_available() else () @@ -518,9 +520,10 @@ def test_keras_fit(self): def test_pt_tf_model_equivalence(self, allow_missing_keys=False): # We override the base test here to skip loss calculation for Hubert models because the loss is massive with # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT - import transformers import torch + import transformers + for model_class in self.all_model_classes: config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py index 3bb3d36cbfb2..87a174b5b251 100644 --- a/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_tf_wav2vec2.py @@ -19,6 +19,8 @@ import inspect import math import multiprocessing +import os +import tempfile import traceback import unittest @@ -31,6 +33,7 @@ from transformers.testing_utils import ( CaptureLogger, is_flaky, + is_pt_tf_cross_test, require_librosa, require_pyctcdecode, require_tf, @@ -397,6 +400,62 @@ def test_keras_fit(self): # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC pass + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + # We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with + # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT + import torch + + import transformers + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency + # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. + # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it. + self._make_attention_mask_non_null(inputs_dict) + + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model( + tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys + ) + pt_model = transformers.load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( + tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + @require_tf class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase): @@ -524,6 +583,62 @@ def test_keras_fit(self): # TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC pass + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self, allow_missing_keys=False): + # We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with + # the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT + import torch + + import transformers + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + # Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency + # of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`. + # TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it. + self._make_attention_mask_non_null(inputs_dict) + + pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning + pt_model_class = getattr(transformers, pt_model_class_name) + + tf_model = model_class(config) + pt_model = pt_model_class(config) + + tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + + # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model( + tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys + ) + pt_model = transformers.load_tf2_model_in_pytorch_model( + pt_model, tf_model, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model( + tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model( + pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys + ) + + # Original test: check without `labels` + self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict) + @require_tf class TFWav2Vec2UtilsTest(unittest.TestCase): From 2120e9dd9f47851c435e56d87aa127aa5e329ee2 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 18:43:28 +0100 Subject: [PATCH 8/8] Fix - don't try to access the hidden states property when output is a tuple --- src/transformers/models/gpt2/modeling_tf_gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/gpt2/modeling_tf_gpt2.py b/src/transformers/models/gpt2/modeling_tf_gpt2.py index d168ec3593df..a84fdbd80664 100644 --- a/src/transformers/models/gpt2/modeling_tf_gpt2.py +++ b/src/transformers/models/gpt2/modeling_tf_gpt2.py @@ -1051,7 +1051,7 @@ def call( ) hidden_states = transformer_outputs[0] hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) - if output_hidden_states: + if return_dict and output_hidden_states: # We do this to match the slightly odd PT behaviour - the final hidden state is reshaped to rank 4 when the # input is rank 3, but all other hidden states remain at rank-3 (with the first 2 dims merged) all_hidden_states = transformer_outputs.hidden_states[:-1] + (hidden_states,)