Skip to content

Allow static cache to be larger than sequence length / batch size for encoder-decoder models #35444

@cptspacemanspiff

Description

@cptspacemanspiff

Feature request

In encoder decoder models using an encoder-decoder cache object when using a static cache:

  1. the cross-attention cache size must equal the encoder sequence length.
  2. batch size for both self-attention and cross-attention caches must be the same as the generating batch size.

Motivation

I have been working on executorch export for encoder-decoder models. as part of that I have been digging into the implementation of the encoder-decoder cache and static cache.

How I would expect static caches to work is that when you initialize the cache, then as long as your generation (batch size, encoder sequence length, decoder sequence length) is less than the associated cache values, it should work.

Currently however:

  1. The cross attention cache must be exactly the size as the encoder sequence length.
  2. The batch size that the cache is initialized with must be exactly the batch size that the cache is run with.

Your contribution

As I was digging through this, I updated the T5 attention and the static cache implementation in an attempt to handle both these cases.

#35445

That being said, I am just starting to learn transformers (both the hf library and in general), and have no real idea what I am doing.

Here is the code I have been using to generate the issue:

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)
from transformers.cache_utils import (
    StaticCache,
    EncoderDecoderCache,
)

model_name = "google-t5/t5-small"

dtype = torch.float16

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    torch_dtype=dtype,
)


encoder_cache = StaticCache(
    model.config, max_cache_len=170, max_batch_size=4, dtype=dtype
)
decoder_cache = StaticCache(
    model.config, max_cache_len=200, max_batch_size=4, dtype=dtype
)
cache = EncoderDecoderCache(decoder_cache, encoder_cache)

strings_1 = [
    "When the night has come and the land is dark, and the moon is the only light we will see.",
    "Abba is the best",
    # "No lindy is the best",
    # "No Elton john is the absolute best.",
]
input_ids = tokenizer(strings_1, return_tensors="pt", padding=True)
tokens = model.generate(**input_ids, past_key_values=cache)
text_translated = [tokenizer.decode(t, skip_special_tokens=False) for t in tokens]
print(text_translated)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions