- 
                Notifications
    You must be signed in to change notification settings 
- Fork 31k
Description
Feature request
Currently, T5 cannot take 4D attention masks (batch_size, num_heads, seq_len, seq_len) as inputs. Passing a 4D attention_mask and a 4D decoder_attention_mask like so leads to a shape-related exception :
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
input_ids = tokenizer("Where is", return_tensors="pt").input_ids
decoder_input_ids = tokenizer("<pad>", return_tensors="pt").input_ids
batch_size, seq_len = input_ids.shape
tgt_len = decoder_input_ids.shape[1]
num_heads = model.config.num_heads
attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len)
decoder_attention_mask = torch.ones(batch_size, num_heads, tgt_len, tgt_len).tril(0)
model(
    input_ids,
    decoder_input_ids=decoder_input_ids,
    attention_mask=attention_mask,
    decoder_attention_mask=decoder_attention_mask,
)One of the problems in the current code is in the handling of the cross-attention mask. Currently, it is created using the 1D encoder attention mask when supplied. However, in the case of a 4D mask, it seems unclear how to correctly use the encoder mask: therefore, the best solution might be to introduce a new 4D mask argument cross_attention_mask of shape (batch_size, num_heads, tgt_len, seq_len)`. This lets the user controls all attention masks if necessary.
Motivation
4D masks are useful for many purposes, as outlined by #27539 and this blog post, but not all models support them.
Your contribution
I propose to fix the code to handle 4D attention masks, and to add a new cross_attention_mask argument to add the possibility to control the cross attention mask manually. I wrote a version of that code in this fork.
I'm happy to create a PR with my code, but:
- This is my first transformers contribution, I need help with some things such as handling the "Copy" code duplication mechanism of transformers. Should other similar models with copied functions from T5 be changed as well?
- Although I wrote a first test with trivial masks, I am not entirely sure how to test this
- I want to be sure that adding the new cross_attentionmask parameter is the right way to do this and will be approved