Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions docs/source/en/kv_cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ To enable quantization of the key-value cache, one needs to indicate `cache_impl
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`~QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`~QuantizedCacheConfig`], the default is `quanto`.

It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length.
It is recommended to set `axis-key/axis-value` parameters in the cache config to `0` if you're using the `quanto` backend and to `1` if you're using the `HQQ` backend. For other config values, please use the defaults unless you're running out of memory. In that case, you may consider decreasing the residual length.

<Tip warning={true}>

Expand Down Expand Up @@ -308,7 +308,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac

### Encoder-Decoder Cache

The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon.
The [`~EncoderDecoderCache`] is a wrapper designed to handle the caching needs of encoder-decoder models. This cache type is specifically built to manage both self-attention and cross-attention caches, ensuring storage and retrieval of past key/values required for these complex models. Cool thing about Encoder-Decoder Cache is that you can set different cache types for the encoder and for the decoder, depending on your use case. Currently this cache is only supported in [Whisper](./model_doc/whisper) models but we will be adding more models soon.

In terms of usage, there is nothing special to be done and calling `generate()` or `forward()` will handle everything for you.

Expand Down Expand Up @@ -379,7 +379,7 @@ Sometimes you would want to first fill-in cache object with key/values for certa
>>> model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
>>> tokenizer = AutoTokenizer.from_pretrained(model_id)

>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
>>> # Init StaticCache with big enough max-length (1024 tokens for the below example)
>>> # You can also init a DynamicCache, if that suits you better
>>> prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)

Expand All @@ -394,10 +394,35 @@ Sometimes you would want to first fill-in cache object with key/values for certa
>>> for prompt in prompts:
... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
... past_key_values = copy.deepcopy(prompt_cache)
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
... outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
... response = tokenizer.batch_decode(outputs)[0]
... responses.append(response)

>>> print(responses)
['<s> You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTitle: The Ultimate Guide to Travelling: Tips, Tricks, and', '<s> You are a helpful assistant. What is the capital of France?\n\nYes, the capital of France is Paris.</s>']
```


## Legacy cache format

Prior to the introduction of the `Cache` object, the cache of LLMs used to be a tuple of tuples of tensors. The legacy
format has a dynamic size, growing as we generate text -- very similar to `DynamicCache`. If your project depend on
this legacy format, you can seamlessly convert it to a `DynamicCache` and back.

```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
>>> inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

>>> # `return_dict_in_generate=True` is required to return the cache. `return_legacy_cache` forces the returned cache
>>> # to be of the legacy type
>>> generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)

>>> # We can convert a legacy cache to a DynamicCache -- and the other way around. This is helpful if you have custom
>>> # logic to manipulate a cache in a specific format.
>>> cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
>>> legacy_format_cache = cache.to_legacy_cache()
```
26 changes: 15 additions & 11 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,14 +687,18 @@ def forward(
inputs_embeds = self.word_embeddings(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: not self.training was removed.

If we are training and we pass past_key_values as tuple of tuples, we definitely want to see the warning -- the code will break in the near future

return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

batch_size, seq_length, _ = inputs_embeds.shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down Expand Up @@ -765,9 +769,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
Comment on lines +772 to +774
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama

(on some models, this pattern was slightly different)


if not return_dict:
return tuple(
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

use_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

seq_length = inputs_embeds.shape[1]
Expand Down Expand Up @@ -608,9 +612,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
19 changes: 11 additions & 8 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
19 changes: 11 additions & 8 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,16 +1059,19 @@ def forward(

inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
24 changes: 14 additions & 10 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,17 +1031,21 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

# Compute alibi tensor: check build_alibi_tensor documentation
use_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

# Compute alibi tensor: check build_alibi_tensor documentation
alibi = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length, _ = inputs_embeds.shape
Expand Down Expand Up @@ -1126,9 +1130,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False # noqa: F841
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
28 changes: 13 additions & 15 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False # noqa: F841
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand All @@ -856,15 +863,6 @@ def forward(
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down
Loading