Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from ...modeling_tf_outputs import TFBaseModelOutputWithPast
from ...modeling_tf_outputs import TFBaseModelOutputWithPastAndCrossAttentions

# Public API
from ...modeling_tf_utils import (
Expand Down Expand Up @@ -1220,7 +1220,7 @@ def call(
encoder_layer_head_mask: Optional[tf.Tensor] = None,
past_key_value: Optional[Tuple[tf.Tensor]] = None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
"""
Args:
hidden_states (`tf.Tensor`): input to the layer of shape *(seq_len, batch, embed_dim)*
Expand Down Expand Up @@ -1254,12 +1254,13 @@ def call(

# Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states

# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, _, cross_attn_present_key_value = self.encoder_attn(
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
Expand All @@ -1285,6 +1286,7 @@ def call(
return (
hidden_states,
self_attn_weights,
cross_attn_weights,
present_key_value,
)

Expand Down Expand Up @@ -1808,6 +1810,14 @@ def call(
# unpad `hidden_states` because the calling function is expecting a length == input_ids.size(1)
hidden_states = self.compute_hidden_states(hidden_states, padding_len)

# undo padding
if inputs["output_attentions"]:
all_attentions = (
tuple([state[:, :, :-padding_len, :] for state in all_attentions])
if padding_len > 0
else all_attentions
)

if inputs["output_hidden_states"]:
encoder_states = encoder_states + (hidden_states,)

Expand Down Expand Up @@ -2038,6 +2048,7 @@ def call(
# decoder layers
all_hidden_states = ()
all_self_attns = ()
all_cross_attentions = ()
present_key_values = ()

# check if head_mask has a correct number of layers specified if desired
Expand All @@ -2059,7 +2070,7 @@ def call(

past_key_value = inputs["past_key_values"][idx] if inputs["past_key_values"] is not None else None

hidden_states, layer_self_attn, present_key_value = decoder_layer(
hidden_states, layer_self_attn, layer_cross_attn, present_key_value = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
encoder_hidden_states=inputs["encoder_hidden_states"],
Expand All @@ -2076,24 +2087,31 @@ def call(

if inputs["output_attentions"]:
all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)

if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
else:
all_hidden_states = None

all_self_attns = list(all_self_attns) if inputs["output_attentions"] else None
all_self_attns = all_self_attns if inputs["output_attentions"] else None
all_cross_attentions = all_cross_attentions if inputs["output_attentions"] else None

present_key_values = (encoder_hidden_states, present_key_values) if inputs["use_cache"] else None

if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns
return tuple(
v
for v in [hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None
)
else:
return TFBaseModelOutputWithPast(
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -2223,6 +2241,7 @@ def call(
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
Expand Down Expand Up @@ -2475,6 +2494,7 @@ def call(
past_key_values=outputs.past_key_values, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
Expand Down
14 changes: 11 additions & 3 deletions src/transformers/models/t5/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPast,
TFBaseModelOutputWithPastAndCrossAttentions,
TFSeq2SeqLMOutput,
TFSeq2SeqModelOutput,
)
Expand Down Expand Up @@ -771,6 +771,7 @@ def call(
present_key_value_states = () if inputs["use_cache"] and self.is_decoder else None
all_hidden_states = () if inputs["output_hidden_states"] else None
all_attentions = () if inputs["output_attentions"] else None
all_cross_attentions = () if (inputs["output_attentions"] and self.is_decoder) else None
position_bias = None
encoder_decoder_position_bias = None

Expand Down Expand Up @@ -814,6 +815,8 @@ def call(

if inputs["output_attentions"]:
all_attentions = all_attentions + (layer_outputs[3],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)

hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=inputs["training"])
Expand All @@ -831,14 +834,17 @@ def call(
outputs = outputs + (all_hidden_states,)
if inputs["output_attentions"]:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
if self.is_decoder:
outputs + (all_cross_attentions,)
return outputs # last-layer hidden state, (past_key_values), (all hidden states), (all attentions), (all_cross_attentions)

if self.is_decoder:
return TFBaseModelOutputWithPast(
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=present_key_value_states,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)
else:
return TFBaseModelOutput(
Expand Down Expand Up @@ -1264,6 +1270,7 @@ def call(
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
Expand Down Expand Up @@ -1508,6 +1515,7 @@ def call(
past_key_values=past,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=inputs["encoder_outputs"].last_hidden_state,
encoder_hidden_states=inputs["encoder_outputs"].hidden_states,
encoder_attentions=inputs["encoder_outputs"].attentions,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def check_encoder_attentions_output(outputs):
self.assertEqual(len(global_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, seq_length],
[self.model_tester.num_attention_heads, seq_length, seq_length],
)
self.assertListEqual(
list(global_attentions[0].shape[-3:]),
Expand Down