diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index 38a31e0ca8bd..f273148ac914 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -29,7 +29,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_tf_outputs import TFBaseModelOutputWithPast +from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions # Public API from ...modeling_tf_utils import ( @@ -1220,7 +1220,7 @@ def call( encoder_layer_head_mask: Optional[tf.Tensor] = None, past_key_value: Optional[Tuple[tf.Tensor]] = None, training=False, - ) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]: """ Args: hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)* @@ -1254,12 +1254,13 @@ def call( # Cross-Attention Block cross_attn_present_key_value = None + cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, _, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, @@ -1285,6 +1286,7 @@ def call( return ( hidden_states, self_attn_weights, + cross_attn_weights, present_key_value, ) @@ -1808,6 +1810,14 @@ def call( # unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1) hidden_states = self.compute_hidden_states(hidden_states, padding_len) + # undo padding + if inputs["output_attentions"]: + all_attentions = ( + tuple([state[:, :, :-padding_len, :] for state in all_attentions]) + if padding_len > 0 + else all_attentions + ) + if inputs["output_hidden_states"]: encoder_states = encoder_states + (hidden_states,) @@ -2038,6 +2048,7 @@ def call( # decoder layers all_hidden_states = () all_self_attns = () + all_cross_attentions = () present_key_values = () # check if head_mask has a correct number of layers specified if desired @@ -2059,7 +2070,7 @@ def call( past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None - hidden_states, layer_self_attn, present_key_value = decoder_layer( + hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer( hidden_states, attention_mask=combined_attention_mask, encoder_hidden_states=inputs["encoder_hidden_states"], @@ -2076,24 +2087,31 @@ def call( if inputs["output_attentions"]: all_self_attns += (layer_self_attn,) + all_cross_attentions += (layer_cross_attn,) if inputs["output_hidden_states"]: all_hidden_states += (hidden_states,) else: all_hidden_states = None - all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None + all_self_attns = all_self_attns if inputs["output_attentions"] else None + all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None if not inputs["return_dict"]: - return hidden_states, present_key_values, all_hidden_states, all_self_attns + return tuple( + v + for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) else: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_values, hidden_states=all_hidden_states, attentions=all_self_attns, + cross_attentions=all_cross_attentions, ) @@ -2223,6 +2241,7 @@ def call( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -2475,6 +2494,7 @@ def call( past_key_values=outputs.past_key_values, # index 1 of d outputs decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out encoder_attentions=outputs.encoder_attentions, # 2 of e out diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index c626662fc155..372f6cf132d7 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -33,7 +33,7 @@ ) from ...modeling_tf_outputs import ( TFBaseModelOutput, - TFBaseModelOutputWithPast, + TFBaseModelOutputWithPastAndCrossAttentions, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput, ) @@ -771,6 +771,7 @@ def call( present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None all_hidden_states = () if inputs["output_hidden_states"] else None all_attentions = () if inputs["output_attentions"] else None + all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None position_bias = None encoder_decoder_position_bias = None @@ -814,6 +815,8 @@ def call( if inputs["output_attentions"]: all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states, training=inputs["training"]) @@ -831,14 +834,17 @@ def call( outputs = outputs + (all_hidden_states,) if inputs["output_attentions"]: outputs = outputs + (all_attentions,) - return outputs # last-layer hidden state, (all hidden states), (all attentions) + if self.is_decoder: + outputs + (all_cross_attentions,) + return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions) if self.is_decoder: - return TFBaseModelOutputWithPast( + return TFBaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, + cross_attentions=all_cross_attentions, ) else: return TFBaseModelOutput( @@ -1264,6 +1270,7 @@ def call( past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, @@ -1508,6 +1515,7 @@ def call( past_key_values=past, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state, encoder_hidden_states=inputs["encoder_outputs"].hidden_states, encoder_attentions=inputs["encoder_outputs"].attentions, diff --git a/tests/test_modeling_tf_led.py b/tests/test_modeling_tf_led.py index b42e8b538cdc..a77f17f87200 100644 --- a/tests/test_modeling_tf_led.py +++ b/tests/test_modeling_tf_led.py @@ -322,7 +322,7 @@ def check_encoder_attentions_output(outputs): self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers) self.assertListEqual( list(attentions[0].shape[-3:]), - [self.model_tester.num_attention_heads, encoder_seq_length, seq_length], + [self.model_tester.num_attention_heads, seq_length, seq_length], ) self.assertListEqual( list(global_attentions[0].shape[-3:]),