From cd3a586fb6bc8b41100b051dc9f095d3ec43dbcc Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 18 Mar 2022 17:12:46 +0000 Subject: [PATCH] Update use of past to match new generate code --- .../models/encoder_decoder/modeling_tf_encoder_decoder.py | 6 ++---- .../modeling_tf_vision_encoder_decoder.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index 4458b9c532e6..07d8e812f257 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -647,19 +647,17 @@ def call( # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - past = (encoder_outputs[0], past_key_values) if past_key_values else None - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = tuple([x for x in output if x is not None]) return output return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past, + past_key_values=past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, diff --git a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py index 0f63e343165d..1d63640af039 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py @@ -678,19 +678,17 @@ def call( # The starting index of the remaining elements in `decoder_outputs` start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) - past = (encoder_outputs[0], past_key_values) if past_key_values else None - if not decoder_inputs["return_dict"]: if not isinstance(encoder_outputs, tuple): encoder_outputs = encoder_outputs.to_tuple() - output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs + output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs output = tuple([x for x in output if x is not None]) return output return TFSeq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, - past_key_values=past, + past_key_values=past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions,