Skip to content

Support for 4D attention mask for T5 #40743

@Aethor

Description

@Aethor

Feature request

Currently, T5 cannot take 4D attention masks (batch_size, num_heads, seq_len, seq_len) as inputs. Passing a 4D attention_mask and a 4D decoder_attention_mask like so leads to a shape-related exception :

import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration

tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

input_ids = tokenizer("Where is", return_tensors="pt").input_ids
decoder_input_ids = tokenizer("<pad>", return_tensors="pt").input_ids

batch_size, seq_len = input_ids.shape
tgt_len = decoder_input_ids.shape[1]
num_heads = model.config.num_heads

attention_mask = torch.ones(batch_size, num_heads, seq_len, seq_len)
decoder_attention_mask = torch.ones(batch_size, num_heads, tgt_len, tgt_len).tril(0)

model(
    input_ids,
    decoder_input_ids=decoder_input_ids,
    attention_mask=attention_mask,
    decoder_attention_mask=decoder_attention_mask,
)

One of the problems in the current code is in the handling of the cross-attention mask. Currently, it is created using the 1D encoder attention mask when supplied. However, in the case of a 4D mask, it seems unclear how to correctly use the encoder mask: therefore, the best solution might be to introduce a new 4D mask argument cross_attention_mask of shape (batch_size, num_heads, tgt_len, seq_len)`. This lets the user controls all attention masks if necessary.

Motivation

4D masks are useful for many purposes, as outlined by #27539 and this blog post, but not all models support them.

Your contribution

I propose to fix the code to handle 4D attention masks, and to add a new cross_attention_mask argument to add the possibility to control the cross attention mask manually. I wrote a version of that code in this fork.

I'm happy to create a PR with my code, but:

  1. This is my first transformers contribution, I need help with some things such as handling the "Copy" code duplication mechanism of transformers. Should other similar models with copied functions from T5 be changed as well?
  2. Although I wrote a first test with trivial masks, I am not entirely sure how to test this
  3. I want to be sure that adding the new cross_attention mask parameter is the right way to do this and will be approved

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions