Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -569,13 +569,12 @@ def call(
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}

# Add arguments to encoder from `kwargs_encoder`
for k, v in kwargs_encoder.items():
encoder_processing_inputs[k] = v
Comment on lines 575 to 577
Copy link
Contributor Author

@gante gante Mar 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see here that all contents in kwargs_encoder are dumped into encoder_processing_inputs, so they were effectively being passed twice.

input_processing expects kwargs_call to be empty, except under very special circumstances (when using deprecated arguments)

kwargs_encoder = {}

encoder_inputs = input_processing(**encoder_processing_inputs)

Expand Down Expand Up @@ -622,13 +621,12 @@ def call(
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}

# Add arguments to decoder from `kwargs_decoder`
for k, v in kwargs_decoder.items():
decoder_processing_inputs[k] = v
kwargs_decoder = {}

decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,12 +593,11 @@ def call(
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}

# Add arguments to encoder from `kwargs_encoder`
encoder_processing_inputs.update(kwargs_encoder)
kwargs_encoder = {}

encoder_inputs = input_processing(**encoder_processing_inputs)

Expand Down Expand Up @@ -654,12 +653,11 @@ def call(
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}

# Add arguments to decoder from `kwargs_decoder`
decoder_processing_inputs.update(kwargs_decoder)
kwargs_decoder = {}

decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
Expand Down
8 changes: 8 additions & 0 deletions tests/encoder_decoder/test_modeling_tf_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def check_encoder_decoder_model_from_pretrained_configs(
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)

self.assertEqual(
Expand Down Expand Up @@ -122,6 +123,7 @@ def check_encoder_decoder_model(
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
Expand All @@ -137,6 +139,7 @@ def check_encoder_decoder_model(
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)

self.assertEqual(
Expand Down Expand Up @@ -167,6 +170,7 @@ def check_encoder_decoder_model_from_pretrained(
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)

self.assertEqual(
Expand Down Expand Up @@ -195,6 +199,7 @@ def check_save_and_load(
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
Expand All @@ -208,6 +213,7 @@ def check_save_and_load(
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
Expand Down Expand Up @@ -235,6 +241,7 @@ def check_encoder_decoder_model_labels(
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)

# Make sure `loss` exist
Expand Down Expand Up @@ -269,6 +276,7 @@ def check_encoder_decoder_model_output_attentions(
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)

encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def check_encoder_decoder_model_from_pretrained_configs(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)

self.assertEqual(
Expand Down Expand Up @@ -124,6 +125,7 @@ def check_encoder_decoder_model(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
Expand All @@ -137,6 +139,7 @@ def check_encoder_decoder_model(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)

self.assertEqual(
Expand Down Expand Up @@ -164,6 +167,7 @@ def check_encoder_decoder_model_from_pretrained(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)

self.assertEqual(
Expand All @@ -189,6 +193,7 @@ def check_save_and_load(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
Expand All @@ -201,6 +206,7 @@ def check_save_and_load(
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
Expand All @@ -226,6 +232,7 @@ def check_encoder_decoder_model_labels(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)

# Make sure `loss` exist
Expand Down Expand Up @@ -257,6 +264,7 @@ def check_encoder_decoder_model_output_attentions(
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)

encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
Expand Down