diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index 938b1df05e18..508e020e5321 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -140,7 +140,7 @@ def __init__( # The model code was relying on saved configs where `tie_word_embeddings` is # set to `False` in 1.1v and using it as indicator of whether to scale or not # But in fact we tie weights always and force it to be `True` - self.scale_decoder_outputs = kwargs.get("tie_word_embeddings") is not False + self.scale_decoder_outputs = tie_word_embeddings is True self.tie_word_embeddings = True super().__init__(is_encoder_decoder=is_encoder_decoder, **kwargs) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index a8dbf5d8b031..1e5b1b995396 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -46,7 +46,6 @@ ) -# Copied from tests.models.t5.test_modeling_t5.T5ModelTester with T5->MT5 class MT5ModelTester: def __init__( self, @@ -421,20 +420,6 @@ def create_and_check_model_fp16_forward( output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] self.parent.assertFalse(torch.isnan(output).any().item()) - def check_resize_embeddings_t5_v1_1( - self, - config, - ): - prev_vocab_size = config.vocab_size - - config.tie_word_embeddings = False - model = MT5ForConditionalGeneration(config=config).to(torch_device).eval() - model.resize_token_embeddings(prev_vocab_size - 10) - - self.parent.assertEqual(model.get_input_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.get_output_embeddings().weight.shape[0], prev_vocab_size - 10) - self.parent.assertEqual(model.config.vocab_size, prev_vocab_size - 10) - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -456,7 +441,6 @@ def prepare_config_and_inputs_for_common(self): @require_torch -# Copied from tests.models.t5.test_modeling_t5.T5ModelTest with T5->MT5, google-t5/t5-small->google/mt5-small class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = ( (MT5Model, MT5ForConditionalGeneration, MT5ForSequenceClassification, MT5ForQuestionAnswering) @@ -549,16 +533,10 @@ def test_shift_right(self): def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() + if config_and_inputs[0].__class__.__name__ == "T" + "5Config": + self.assertTrue(config_and_inputs[0].scale_decoder_outputs) self.model_tester.create_and_check_model(*config_and_inputs) - def test_model_v1_1(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - # check that gated gelu feed forward and different word embeddings work - config = config_and_inputs[0] - config.tie_word_embeddings = False - config.feed_forward_proj = "gated-gelu" - self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) - # MT5ForSequenceClassification does not support inputs_embeds def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -652,10 +630,6 @@ def test_model_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) - def test_v1_1_resize_embeddings(self): - config = self.model_tester.prepare_config_and_inputs()[0] - self.model_tester.check_resize_embeddings_t5_v1_1(config) - @slow def test_model_from_pretrained(self): model_name = "google/mt5-small" diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index 0b9be46ec1be..a29845301e02 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -166,6 +166,28 @@ def get_config(self): decoder_start_token_id=self.decoder_start_token_id, ) + def get_config_v1_1(self): + return T5Config( + vocab_size=self.vocab_size, + d_model=self.hidden_size, + d_ff=self.d_ff, + d_kv=self.hidden_size // self.num_attention_heads, + num_layers=self.num_hidden_layers, + num_decoder_layers=self.decoder_layers, + num_heads=self.num_attention_heads, + relative_attention_num_buckets=self.relative_attention_num_buckets, + dropout_rate=self.dropout_rate, + initializer_factor=self.initializer_factor, + eos_token_id=self.eos_token_id, + bos_token_id=self.pad_token_id, + pad_token_id=self.pad_token_id, + decoder_start_token_id=self.decoder_start_token_id, + # V1.1 related params: uses gated-gelu and `tie_word_embeddings=False` as an + # indicator to not scale decoder outputs + feed_forward_proj="gated-gelu", + tie_word_embeddings=False, + ) + def check_prepare_lm_labels_via_shift_left( self, config, @@ -436,6 +458,8 @@ def check_resize_embeddings_t5_v1_1( ): prev_vocab_size = config.vocab_size + # V1.1 related params: uses gated-gelu and `tie_word_embeddings=False` + config.feed_forward_proj = "gated-gelu" config.tie_word_embeddings = False model = T5ForConditionalGeneration(config=config).to(torch_device).eval() model.resize_token_embeddings(prev_vocab_size - 10) @@ -557,15 +581,17 @@ def test_shift_right(self): def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.assertTrue(config_and_inputs[0].scale_decoder_outputs) self.model_tester.create_and_check_model(*config_and_inputs) def test_model_v1_1(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() - # check that gated gelu feed forward and different word embeddings work - config = config_and_inputs[0] - config.tie_word_embeddings = False - config.feed_forward_proj = "gated-gelu" - self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + config_v1 = self.model_tester.get_config_v1_1() + config_and_inputs = list(config_and_inputs) + config_and_inputs[0] = config_v1 + + self.assertFalse(config_and_inputs[0].scale_decoder_outputs) + self.model_tester.create_and_check_model(*config_and_inputs) # T5ForSequenceClassification does not support inputs_embeds def test_inputs_embeds(self):