diff --git a/tests/models/t5/test_modeling_tf_t5.py b/tests/models/t5/test_modeling_tf_t5.py index 91bc63feda1a..5ad746e34fc8 100644 --- a/tests/models/t5/test_modeling_tf_t5.py +++ b/tests/models/t5/test_modeling_tf_t5.py @@ -295,6 +295,13 @@ def test_t5_decoder_model_past_with_attn_mask(self): def test_t5_decoder_model_past_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() + + # `create_and_check_t5_decoder_model_past_large_inputs` has special inputs: + # (config, input_ids, decoder_input_ids, attention_mask) + # and we have to prepare it correctly here. + config, input_ids, input_mask, token_labels = config_and_inputs + config_and_inputs = (config, input_ids, None, input_mask) + self.model_tester.create_and_check_t5_decoder_model_past_large_inputs(*config_and_inputs) def test_t5_model_xla_generate_fast(self):