diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index f273148ac914..f11ef017003a 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -2338,6 +2338,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None @@ -2347,6 +2348,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, @@ -2494,7 +2496,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs - cross_attentions=outputs.cross_attentions, + cross_attentions=outputs.cross_attentions, # index 4 of d outputs encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out @@ -2505,6 +2507,7 @@ def serving_output(self, output): pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None enc_g_attns = tf.convert_to_tensor(output.encoder_global_attentions) if self.config.output_attentions else None @@ -2514,6 +2517,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 372f6cf132d7..af09161a7a91 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1280,6 +1280,7 @@ def serving_output(self, output): pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1289,6 +1290,7 @@ def serving_output(self, output): decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, + cross_attentions=cross_attns, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, ) @@ -1525,6 +1527,7 @@ def serving_output(self, output): pkv = tf.convert_to_tensor(output.past_key_values[1:]) if self.config.use_cache else None dec_hs = tf.convert_to_tensor(output.decoder_hidden_states) if self.config.output_hidden_states else None dec_attns = tf.convert_to_tensor(output.decoder_attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if self.config.output_attentions else None enc_hs = tf.convert_to_tensor(output.encoder_hidden_states) if self.config.output_hidden_states else None enc_attns = tf.convert_to_tensor(output.encoder_attentions) if self.config.output_attentions else None @@ -1533,6 +1536,7 @@ def serving_output(self, output): past_key_values=pkv, decoder_hidden_states=dec_hs, decoder_attentions=dec_attns, + cross_attentions=cross_attns, encoder_last_hidden_state=output.encoder_last_hidden_state, encoder_hidden_states=enc_hs, encoder_attentions=enc_attns, diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index a493ee1ebf83..e7469fd89611 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -1514,7 +1514,12 @@ def serving_output(self, output): hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None - return TFBaseModelOutput(last_hidden_state=output.last_hidden_state, hidden_states=hs, attentions=attns) + return TFWav2Vec2BaseModelOutput( + last_hidden_state=output.last_hidden_state, + extract_features=output.extract_features, + hidden_states=hs, + attentions=attns, + ) @add_start_docstrings(