diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 4458b9c532e6..07d8e812f257 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -647,19 +647,17 @@ def call( # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - past = (encoder_outputs[0], past_key_values) if past_key_values else None - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = tuple([x for x in output if x is not None]) return output return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past, + past_key_values=past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 0f63e343165d..1d63640af039 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -678,19 +678,17 @@ def call( # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - past = (encoder_outputs[0], past_key_values) if past_key_values else None - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = tuple([x for x in output if x is not None]) return output return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past, + past_key_values=past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions,