diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index c2be91c7a000..1c59493e1bf7 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -569,13 +569,12 @@ def call( "output_hidden_states": output_hidden_states, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_encoder, + "kwargs_call": {}, } # Add arguments to encoder from `kwargs_encoder` for k, v in kwargs_encoder.items(): encoder_processing_inputs[k] = v - kwargs_encoder = {} encoder_inputs = input_processing(**encoder_processing_inputs) @@ -622,13 +621,12 @@ def call( "past_key_values": past_key_values, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_decoder, + "kwargs_call": {}, } # Add arguments to decoder from `kwargs_decoder` for k, v in kwargs_decoder.items(): decoder_processing_inputs[k] = v - kwargs_decoder = {} decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 965fc51d783b..eeaca58c5a01 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -593,12 +593,11 @@ def call( "output_hidden_states": output_hidden_states, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_encoder, + "kwargs_call": {}, } # Add arguments to encoder from `kwargs_encoder` encoder_processing_inputs.update(kwargs_encoder) - kwargs_encoder = {} encoder_inputs = input_processing(**encoder_processing_inputs) @@ -654,12 +653,11 @@ def call( "past_key_values": past_key_values, "return_dict": return_dict, "training": training, - "kwargs_call": kwargs_decoder, + "kwargs_call": {}, } # Add arguments to decoder from `kwargs_decoder` decoder_processing_inputs.update(kwargs_decoder) - kwargs_decoder = {} decoder_inputs = input_processing(**decoder_processing_inputs) decoder_outputs = self.decoder(**decoder_inputs) diff --git a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py index edcc881f564a..de903c40c26e 100644 --- a/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py +++ b/tests/encoder_decoder/test_modeling_tf_encoder_decoder.py @@ -91,6 +91,7 @@ def check_encoder_decoder_model_from_pretrained_configs( decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -122,6 +123,7 @@ def check_encoder_decoder_model( decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) @@ -137,6 +139,7 @@ def check_encoder_decoder_model( decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -167,6 +170,7 @@ def check_encoder_decoder_model_from_pretrained( attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, return_dict=True, + kwargs=kwargs, ) self.assertEqual( @@ -195,6 +199,7 @@ def check_save_and_load( decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 @@ -208,6 +213,7 @@ def check_save_and_load( decoder_input_ids=decoder_input_ids, attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 @@ -235,6 +241,7 @@ def check_encoder_decoder_model_labels( attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, labels=labels, + kwargs=kwargs, ) # Make sure `loss` exist @@ -269,6 +276,7 @@ def check_encoder_decoder_model_output_attentions( attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, output_attentions=True, + kwargs=kwargs, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"] diff --git a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py index a0fcbfaea325..f3a062744f5c 100644 --- a/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py +++ b/tests/vision_encoder_decoder/test_modeling_tf_vision_encoder_decoder.py @@ -96,6 +96,7 @@ def check_encoder_decoder_model_from_pretrained_configs( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -124,6 +125,7 @@ def check_encoder_decoder_model( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) @@ -137,6 +139,7 @@ def check_encoder_decoder_model( encoder_outputs=encoder_outputs, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) self.assertEqual( @@ -164,6 +167,7 @@ def check_encoder_decoder_model_from_pretrained( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, return_dict=True, + kwargs=kwargs, ) self.assertEqual( @@ -189,6 +193,7 @@ def check_save_and_load( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_2 = np.array(outputs[0]) out_2[np.isnan(out_2)] = 0 @@ -201,6 +206,7 @@ def check_save_and_load( pixel_values=pixel_values, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, + kwargs=kwargs, ) out_1 = np.array(after_outputs[0]) out_1[np.isnan(out_1)] = 0 @@ -226,6 +232,7 @@ def check_encoder_decoder_model_labels( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, labels=labels, + kwargs=kwargs, ) # Make sure `loss` exist @@ -257,6 +264,7 @@ def check_encoder_decoder_model_output_attentions( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, output_attentions=True, + kwargs=kwargs, ) encoder_attentions = outputs_encoder_decoder["encoder_attentions"]