diff --git a/docs/source/main_classes/output.rst b/docs/source/main_classes/output.rst index f1e8e01b0da2..5ccd29209094 100644 --- a/docs/source/main_classes/output.rst +++ b/docs/source/main_classes/output.rst @@ -65,12 +65,34 @@ BaseModelOutputWithPooling :members: +BaseModelOutputWithCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions + :members: + + +BaseModelOutputWithPoolingAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions + :members: + + BaseModelOutputWithPast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast :members: + +BaseModelOutputWithPastAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions + :members: + + Seq2SeqModelOutput ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -85,6 +107,20 @@ CausalLMOutput :members: +CausalLMOutputWithCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions + :members: + + +CausalLMOutputWithPastAndCrossAttentions +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions + :members: + + CausalLMOutputWithPast ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/src/transformers/modeling_bart.py b/src/transformers/modeling_bart.py index 475ba0f386bb..b43d6b4d8aa8 100644 --- a/src/transformers/modeling_bart.py +++ b/src/transformers/modeling_bart.py @@ -35,7 +35,7 @@ ) from .modeling_outputs import ( BaseModelOutput, - BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, Seq2SeqQuestionAnsweringModelOutput, @@ -451,11 +451,12 @@ def forward( assert self.encoder_attn.cache_key != self.self_attn.cache_key if self.normalize_before: x = self.encoder_attn_layer_norm(x) - x, _ = self.encoder_attn( + x, cross_attn_weights = self.encoder_attn( query=x, key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state + output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -477,7 +478,8 @@ def forward( x, self_attn_weights, layer_state, - ) # just self_attn weights for now, following t5, layer_state = cache for decoding + cross_attn_weights, + ) # layer_state = cache for decoding class BartDecoder(nn.Module): @@ -590,6 +592,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + all_cross_attentions = () if output_attentions else None next_decoder_cache: List[Dict] = [] for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -601,7 +604,7 @@ def forward( layer_state = past_key_values[idx] if past_key_values is not None else None - x, layer_self_attn, layer_past = decoder_layer( + x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( x, encoder_hidden_states, encoder_attn_mask=encoder_padding_mask, @@ -616,6 +619,7 @@ def forward( if output_attentions: all_self_attns += (layer_self_attn,) + all_cross_attentions += (layer_cross_attn,) if self.layer_norm: # if config.add_final_layer_norm (mBART) x = self.layer_norm(x) @@ -628,9 +632,15 @@ def forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns + return tuple( + v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=x, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, ) @@ -934,6 +944,7 @@ def forward( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, @@ -1078,6 +1089,7 @@ def forward( past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, @@ -1207,6 +1219,7 @@ def forward( past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, @@ -1317,6 +1330,7 @@ def forward( past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 587e33695b6a..35b03b73e47d 100755 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -37,9 +37,9 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - CausalLMOutput, + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, NextSentencePredictorOutput, @@ -449,7 +449,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -483,15 +484,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -752,7 +762,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="bert-base-uncased", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -843,11 +853,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -984,7 +995,7 @@ def get_output_embeddings(self): return self.cls.predictions.decoder @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -1063,11 +1074,12 @@ def forward( output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/src/transformers/modeling_bert_generation.py b/src/transformers/modeling_bert_generation.py index f201c1bd8556..8366f182bd74 100755 --- a/src/transformers/modeling_bert_generation.py +++ b/src/transformers/modeling_bert_generation.py @@ -28,7 +28,7 @@ replace_return_docstrings, ) from .modeling_bert import BertEncoder -from .modeling_outputs import BaseModelOutput, CausalLMOutput +from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions from .modeling_utils import PreTrainedModel from .utils import logging @@ -297,7 +297,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder", - output_type=BaseModelOutput, + output_type=BaseModelOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -381,10 +381,11 @@ def forward( if not return_dict: return (sequence_output,) + encoder_outputs[1:] - return BaseModelOutput( + return BaseModelOutputWithCrossAttentions( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -422,7 +423,7 @@ def get_output_embeddings(self): return self.lm_head.decoder @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -499,11 +500,12 @@ def forward( output = (prediction_scores,) + outputs[1:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/src/transformers/modeling_electra.py b/src/transformers/modeling_electra.py index 3d5a161460b0..e244ac1c55e3 100644 --- a/src/transformers/modeling_electra.py +++ b/src/transformers/modeling_electra.py @@ -34,7 +34,7 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, + BaseModelOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -445,7 +445,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -479,15 +480,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -697,7 +707,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", - output_type=BaseModelOutput, + output_type=BaseModelOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( diff --git a/src/transformers/modeling_encoder_decoder.py b/src/transformers/modeling_encoder_decoder.py index 50381ed7c6a7..5080b1cea59d 100644 --- a/src/transformers/modeling_encoder_decoder.py +++ b/src/transformers/modeling_encoder_decoder.py @@ -426,6 +426,7 @@ def forward( past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, diff --git a/src/transformers/modeling_fsmt.py b/src/transformers/modeling_fsmt.py index ba7f18cbf387..fba900b1370c 100644 --- a/src/transformers/modeling_fsmt.py +++ b/src/transformers/modeling_fsmt.py @@ -46,7 +46,12 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput +from .modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) from .modeling_utils import PreTrainedModel from .utils import logging @@ -543,11 +548,12 @@ def forward( # Cross attention residual = x assert self.encoder_attn.cache_key != self.self_attn.cache_key - x, _ = self.encoder_attn( + x, cross_attn_weights = self.encoder_attn( query=x, key=encoder_hidden_states, key_padding_mask=encoder_attn_mask, layer_state=layer_state, # mutates layer state + output_attentions=output_attentions, ) x = F.dropout(x, p=self.dropout, training=self.training) x = residual + x @@ -565,7 +571,8 @@ def forward( x, self_attn_weights, layer_state, - ) # just self_attn weights for now, following t5, layer_state = cache for decoding + cross_attn_weights, + ) # layer_state = cache for decoding class FSMTDecoder(nn.Module): @@ -669,6 +676,7 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None + all_cross_attns = () if output_attentions else None next_decoder_cache = [] for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -680,7 +688,7 @@ def forward( layer_state = past_key_values[idx] if past_key_values is not None else None - x, layer_self_attn, layer_past = decoder_layer( + x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer( x, encoder_hidden_states, encoder_attn_mask=encoder_padding_mask, @@ -695,6 +703,7 @@ def forward( if output_attentions: all_self_attns += (layer_self_attn,) + all_cross_attns += (layer_cross_attn,) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) if output_hidden_states: @@ -707,9 +716,15 @@ def forward( next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns + return tuple( + v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=x, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attns, ) @@ -903,7 +918,7 @@ def __init__(self, config: FSMTConfig): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="facebook/wmt19-ru-en", - output_type=BaseModelOutputWithPast, + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -989,6 +1004,7 @@ def forward( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, @@ -1101,6 +1117,7 @@ def forward( past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 030ac24edbae..838ea248f172 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -33,7 +33,11 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from .modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithPastAndCrossAttentions, + SequenceClassifierOutputWithPast, +) from .modeling_utils import ( Conv1D, PreTrainedModel, @@ -311,14 +315,14 @@ def forward( attn_output = cross_attn_outputs[0] # residual connection hidden_states = hidden_states + attn_output - outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) # residual connection hidden_states = hidden_states + feed_forward_hidden_states outputs = [hidden_states] + outputs - return outputs # hidden_states, present, (cross_attentions, attentions) + return outputs # hidden_states, present, (attentions, cross_attentions) class GPT2PreTrainedModel(PreTrainedModel): @@ -506,7 +510,7 @@ def _prune_heads(self, heads_to_prune): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2", - output_type=BaseModelOutputWithPast, + output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -618,7 +622,8 @@ def forward( output_shape = input_shape + (hidden_states.size(-1),) presents = () if use_cache else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -659,7 +664,9 @@ def custom_forward(*inputs): presents = presents + (present,) if output_attentions: - all_attentions = all_attentions + (outputs[2],) + all_self_attentions = all_self_attentions + (outputs[2],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3],) hidden_states = self.ln_f(hidden_states) @@ -669,13 +676,14 @@ def custom_forward(*inputs): all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutputWithPast( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, - attentions=all_attentions, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -727,7 +735,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="gpt2", - output_type=CausalLMOutputWithPast, + output_type=CausalLMOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -795,12 +803,13 @@ def forward( output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithPast( + return CausalLMOutputWithPastAndCrossAttentions( loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, ) diff --git a/src/transformers/modeling_layoutlm.py b/src/transformers/modeling_layoutlm.py index 29ff2ce77c0e..24126c0c0074 100644 --- a/src/transformers/modeling_layoutlm.py +++ b/src/transformers/modeling_layoutlm.py @@ -24,7 +24,12 @@ from .activations import ACT2FN from .configuration_layoutlm import LayoutLMConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward -from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput +from .modeling_outputs import ( + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + MaskedLMOutput, + TokenClassifierOutput, +) from .modeling_utils import ( PreTrainedModel, apply_chunking_to_forward, @@ -374,7 +379,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -408,15 +414,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -611,7 +626,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="layoutlm-base-uncased", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) def forward( @@ -716,11 +731,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) diff --git a/src/transformers/modeling_outputs.py b/src/transformers/modeling_outputs.py index 26a41fc9d1a5..1519ac9ae810 100644 --- a/src/transformers/modeling_outputs.py +++ b/src/transformers/modeling_outputs.py @@ -99,6 +99,120 @@ class BaseModelOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class BaseModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) further processed by a + Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence + prediction (classification) objective during pretraining. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + + If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, + 1, hidden_size)` is output. + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, + batch_size, num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + """ + + last_hidden_state: torch.FloatTensor + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class Seq2SeqModelOutput(ModelOutput): """ @@ -128,6 +242,12 @@ class Seq2SeqModelOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -147,6 +267,7 @@ class Seq2SeqModelOutput(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -217,6 +338,85 @@ class CausalLMOutputWithPast(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class CausalLMOutputWithPastAndCrossAttentions(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, + batch_size, num_heads, sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + :obj:`past_key_values` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Cross attentions weights after the attention softmax, used to compute the weighted average in the + cross-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + + @dataclass class SequenceClassifierOutputWithPast(ModelOutput): """ @@ -309,6 +509,12 @@ class Seq2SeqLMOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -329,6 +535,7 @@ class Seq2SeqLMOutput(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -420,6 +627,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -440,6 +653,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -566,6 +780,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, + sequence_length, sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): @@ -587,6 +807,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): past_key_values: Optional[List[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None diff --git a/src/transformers/modeling_prophetnet.py b/src/transformers/modeling_prophetnet.py index 0a0c6b1be2fc..96508c667da5 100644 --- a/src/transformers/modeling_prophetnet.py +++ b/src/transformers/modeling_prophetnet.py @@ -16,6 +16,7 @@ import copy import math +import warnings from dataclasses import dataclass from typing import Dict, Optional, Tuple @@ -261,7 +262,7 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput): Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. - decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads, encoder_sequence_length, decoder_sequence_length)`. @@ -288,11 +289,19 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput): decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None - decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.", + FutureWarning, + ) + return self.cross_attentions + @dataclass class ProphetNetSeq2SeqModelOutput(ModelOutput): @@ -337,7 +346,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput): Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the weighted average in the - decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads, encoder_sequence_length, decoder_sequence_length)`. @@ -365,11 +374,19 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput): decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None - decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None + @property + def decoder_cross_attentions(self): + warnings.warn( + "`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.", + FutureWarning, + ) + return self.cross_attentions + @dataclass class ProphetNetDecoderModelOutput(ModelOutput): @@ -1651,7 +1668,7 @@ def forward( decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram, decoder_attentions=decoder_outputs.attentions, decoder_ngram_attentions=decoder_outputs.ngram_attentions, - decoder_cross_attentions=decoder_outputs.cross_attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, @@ -1766,7 +1783,7 @@ def forward( decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states, decoder_attentions=outputs.decoder_attentions, decoder_ngram_attentions=outputs.decoder_ngram_attentions, - decoder_cross_attentions=outputs.decoder_cross_attentions, + cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, @@ -1986,6 +2003,7 @@ def forward( hidden_states_ngram=outputs.hidden_states_ngram, attentions=outputs.attentions, ngram_attentions=outputs.ngram_attentions, + cross_attentions=outputs.cross_attentions, ) def _compute_loss(self, logits, labels): diff --git a/src/transformers/modeling_roberta.py b/src/transformers/modeling_roberta.py index 1f676b9fefca..3bb3a79a2326 100644 --- a/src/transformers/modeling_roberta.py +++ b/src/transformers/modeling_roberta.py @@ -31,9 +31,9 @@ replace_return_docstrings, ) from .modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, - CausalLMOutput, + BaseModelOutputWithCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, MaskedLMOutput, MultipleChoiceModelOutput, QuestionAnsweringModelOutput, @@ -393,7 +393,8 @@ def forward( return_dict=False, ): all_hidden_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -427,15 +428,24 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, ) @@ -599,7 +609,7 @@ class PreTrainedModel @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="roberta-base", - output_type=BaseModelOutputWithPooling, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) # Copied from transformers.modeling_bert.BertModel.forward @@ -689,11 +699,12 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, ) @@ -719,7 +730,7 @@ def get_output_embeddings(self): return self.lm_head.decoder @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) - @replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids=None, @@ -799,11 +810,12 @@ def forward( output = (prediction_scores,) + outputs[2:] return ((lm_loss,) + output) if lm_loss is not None else output - return CausalLMOutput( + return CausalLMOutputWithCrossAttentions( loss=lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + cross_attentions=outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index d31524b31a6f..51ea2560b393 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -33,7 +33,12 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput +from .modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .utils import logging @@ -503,6 +508,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + return_dict=False, ): if past_key_value is not None: @@ -533,7 +539,8 @@ def forward( hidden_states, present_key_value_state = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights - if self.is_decoder and encoder_hidden_states is not None: + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: # the actual query length is unknown for cross attention # if using past key value states. Need to inject it here if present_key_value_state is not None: @@ -564,7 +571,6 @@ def forward( hidden_states = self.layer[-1](hidden_states) outputs = (hidden_states,) - # Add attentions if we output them outputs = outputs + (present_key_value_state,) + attention_outputs return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) @@ -743,6 +749,7 @@ def forward( present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None position_bias = None encoder_decoder_position_bias = None @@ -779,7 +786,9 @@ def forward( present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: - all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now + all_attentions = all_attentions + (layer_outputs[2],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) @@ -791,14 +800,21 @@ def forward( if not return_dict: return tuple( v - for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions] + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] if v is not None ) - return BaseModelOutputWithPast( + return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, + cross_attentions=all_cross_attentions, ) @@ -1038,6 +1054,7 @@ def forward( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, @@ -1227,6 +1244,7 @@ def forward( past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 597be84ede72..53f45a26f388 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -253,9 +253,7 @@ def test_attention_outputs(self): out_len = len(outputs) if self.is_encoder_decoder: - correct_outlen = ( - self.model_tester.base_model_out_len if hasattr(self.model_tester, "base_model_out_len") else 4 - ) + correct_outlen = 5 # loss is at first position if "labels" in inputs_dict: @@ -266,6 +264,7 @@ def test_attention_outputs(self): self.assertEqual(out_len, correct_outlen) + # decoder attentions decoder_attentions = outputs.decoder_attentions self.assertIsInstance(decoder_attentions, (list, tuple)) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) @@ -274,6 +273,19 @@ def test_attention_outputs(self): [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], ) + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + decoder_seq_length, + encoder_key_length, + ], + ) + # Check attention is always last and order is fine inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = True diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 118e40761a71..0dbab4634d04 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -292,6 +292,62 @@ def check_encoder_decoder_model_labels( outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) ) + def check_encoder_decoder_model_output_attentions( + self, + config, + input_ids, + attention_mask, + encoder_hidden_states, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + labels, + **kwargs + ): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + input_ids=input_ids, + decoder_input_ids=decoder_input_ids, + attention_mask=attention_mask, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + return_dict=True, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + self.assertListEqual( + list(encoder_attentions[0].shape[-3:]), + [config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]], + ) + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]], + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = input_ids.shape[-1] * ( + 1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0) + ) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]], + ) + def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) @@ -413,6 +469,10 @@ def test_encoder_decoder_model_labels(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_labels(**input_ids_dict) + def test_encoder_decoder_model_output_attentions(self): + input_ids_dict = self.prepare_config_and_inputs() + self.check_encoder_decoder_model_output_attentions(**input_ids_dict) + def test_encoder_decoder_model_generate(self): input_ids_dict = self.prepare_config_and_inputs() self.check_encoder_decoder_model_generate(**input_ids_dict) diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index f9a1cce4a817..0a967ce287a7 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -916,6 +916,116 @@ def test_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) + # methods overwrite method in `test_modeling_common.py` + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + chunk_length = getattr(self.model_tester, "chunk_length", None) + if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): + encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + if chunk_length is not None: + self.assertListEqual( + list(attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + correct_outlen = 7 + + # loss is at first position + if "labels" in inputs_dict: + correct_outlen += 1 # loss is added to beginning + + self.assertEqual(out_len, correct_outlen) + + # decoder attentions + decoder_attentions = outputs.decoder_attentions + self.assertIsInstance(decoder_attentions, (list, tuple)) + self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(decoder_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], + ) + + # cross attentions + cross_attentions = outputs.cross_attentions + self.assertIsInstance(cross_attentions, (list, tuple)) + self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(cross_attentions[0].shape[-3:]), + [ + self.model_tester.num_attention_heads, + (self.model_tester.ngram + 1) * decoder_seq_length, + encoder_key_length, + ], + ) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + if hasattr(self.model_tester, "num_hidden_states_types"): + added_hidden_states = self.model_tester.num_hidden_states_types + elif self.is_encoder_decoder: + added_hidden_states = 2 + else: + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + if chunk_length is not None: + self.assertListEqual( + list(self_attentions[0].shape[-4:]), + [self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length], + ) + else: + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + @require_torch class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):