Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
96f0100
Add new dynamic cache
Cyrilvallez Sep 19, 2024
568f807
Add cache by default in generate for models supporting it
Cyrilvallez Sep 19, 2024
be62f53
Add to __init__ and correct typo
Cyrilvallez Sep 19, 2024
52b920d
Correct output if prefill larger than sliding window + compatibility
Cyrilvallez Sep 19, 2024
0c0836e
Add legacy format handling
Cyrilvallez Sep 19, 2024
b087839
style
Cyrilvallez Sep 20, 2024
91a1fee
add docs
Cyrilvallez Sep 20, 2024
a662de2
fix import
Cyrilvallez Sep 20, 2024
ce203ea
Update dummy_pt_objects.py
Cyrilvallez Sep 20, 2024
1719984
Update test
Cyrilvallez Sep 20, 2024
05d1053
style
Cyrilvallez Sep 20, 2024
cb2de20
update cache conversion in test
Cyrilvallez Sep 20, 2024
7dfc86d
style
Cyrilvallez Sep 23, 2024
cc13c4c
Allow the cache to support new states of more than 1 token, even afte…
Cyrilvallez Sep 24, 2024
1887bda
Update cache_utils.py
Cyrilvallez Sep 24, 2024
ad86620
maybe change test
Cyrilvallez Sep 24, 2024
a6f0d8d
revert tests diffs
Cyrilvallez Oct 2, 2024
5394c77
define get_seen_tokens
Cyrilvallez Oct 2, 2024
25b8f80
Modify all current .get_seq_length names
Cyrilvallez Oct 2, 2024
daae19a
style
Cyrilvallez Oct 2, 2024
6b7cb5a
trigger CIs
Cyrilvallez Oct 2, 2024
9cc7077
Add tests
Cyrilvallez Oct 2, 2024
e217859
Update test_utils.py
Cyrilvallez Oct 2, 2024
cd7002e
Update test_utils.py
Cyrilvallez Oct 2, 2024
86bbd86
Update test_utils.py
Cyrilvallez Oct 2, 2024
851c8c8
Update causal mask generation in case of DynamicSlidingCache (only Mi…
Cyrilvallez Oct 3, 2024
a95ca15
Improve tests
Cyrilvallez Oct 3, 2024
94861ea
improve cache
Cyrilvallez Oct 8, 2024
82aecd8
add exceptions
Cyrilvallez Oct 8, 2024
8f48003
Update utils.py
Cyrilvallez Oct 8, 2024
c0c07f9
Update test_utils.py
Cyrilvallez Oct 8, 2024
c3fe380
Update test_utils.py
Cyrilvallez Oct 8, 2024
12a7576
Update test_utils.py
Cyrilvallez Oct 8, 2024
22abbb6
Update test_utils.py
Cyrilvallez Oct 8, 2024
d82001e
Update test_utils.py
Cyrilvallez Oct 9, 2024
005a187
Update 4d mask creation in Mistral
Cyrilvallez Oct 10, 2024
d240402
fix missed conflict
Cyrilvallez Oct 10, 2024
a7ae24d
Apply to other models
Cyrilvallez Oct 10, 2024
e0de263
Add required arg in prepare_inoput
Cyrilvallez Oct 10, 2024
8bc872a
Update test_utils.py
Cyrilvallez Oct 10, 2024
5ae270b
Update test_utils.py
Cyrilvallez Oct 10, 2024
5e248fa
Fix kv_seq_length and rotary_seq_length
Cyrilvallez Oct 10, 2024
4c04f89
up
Cyrilvallez Oct 10, 2024
cba5ae4
up
Cyrilvallez Oct 10, 2024
d200e77
up
Cyrilvallez Oct 10, 2024
2e40fd3
up
Cyrilvallez Oct 10, 2024
50731b5
CIs
Cyrilvallez Oct 11, 2024
bbd6069
improve sdpa is_causal escape
Cyrilvallez Oct 11, 2024
6ae7ec0
make fix-copies
Cyrilvallez Oct 23, 2024
efc1131
add check for models with sliding window
Cyrilvallez Oct 23, 2024
6478126
Update modeling_git.py
Cyrilvallez Oct 24, 2024
2055eda
style
Cyrilvallez Oct 24, 2024
35dd895
Update modeling_mimi.py
Cyrilvallez Oct 24, 2024
fb231cf
Update utils.py
Cyrilvallez Oct 24, 2024
8affc79
replace get_seq_length
Cyrilvallez Oct 24, 2024
bc5b036
Update test_utils.py
Cyrilvallez Oct 24, 2024
09bab35
CIs
Cyrilvallez Oct 24, 2024
6c80731
CIs
Cyrilvallez Oct 24, 2024
5efa057
Update modeling_longt5.py
Cyrilvallez Nov 1, 2024
0f83f21
Update skip test for moshi
Cyrilvallez Nov 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,22 +365,29 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] DynamicCache
- update
- get_seq_length
- get_past_seen_tokens
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] DynamicSlidingWindowCache
- update
- get_past_seen_tokens
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length
- get_past_seen_tokens

[[autodoc]] QuantoQuantizedCache

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
- get_past_seen_tokens
- reorder_cache

[[autodoc]] OffloadedCache
Expand All @@ -390,25 +397,25 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] StaticCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] HybridCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] SlidingWindowCache
- update
- reset

[[autodoc]] EncoderDecoderCache
- get_seq_length
- get_past_seen_tokens
- to_legacy_cache
- from_legacy_cache
- reset
Expand Down
4 changes: 2 additions & 2 deletions examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ def forward(
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -946,7 +946,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
4 changes: 2 additions & 2 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ def forward(
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -821,7 +821,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,6 +1275,7 @@
"Cache",
"CacheConfig",
"DynamicCache",
"DynamicSlidingWindowCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"HybridCache",
Expand Down Expand Up @@ -6194,6 +6195,7 @@
Cache,
CacheConfig,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
Expand Down
133 changes: 133 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as
`get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed.
"""
return self.get_seq_length(layer_idx)

# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
Expand Down Expand Up @@ -545,6 +551,133 @@ def batch_select_indices(self, indices: torch.Tensor):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


# TODO: (cyril) Make this the default for models with sliding window once `generate` no longer returns Cache as tuples
class DynamicSlidingWindowCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window.
This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used).

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window-1, head_dim]` if seq_len >= sliding_window-1.

Note: Since we only keep maximum `sliding_window-1` tokens in the cache, once this value is reached the cache can no
longer be roll-backed to previous states without losing information. For this reason, it should not be used with assisted decoding
(or contrastive search when using `low_memory=True`).

Example:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

>>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
DynamicSlidingWindowCache()
```
"""

def __init__(self, sliding_window: int) -> None:
super().__init__()
self.sliding_window = sliding_window
# We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update`
self._seen_tokens = []

def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int:
"""This needs to be overriden because the number of processed tokens may be larger than the cache length."""
if len(self._seen_tokens) <= layer_idx:
return 0
else:
return self._seen_tokens[layer_idx]

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous
tokens according to the sliding window if needed.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
if len(self.key_cache) <= layer_idx:
# Update the number of seen tokens
self._seen_tokens.append(key_states.shape[-2])
# Add only up to sliding window size if larger
self.key_cache.append(key_states[..., -self.sliding_window + 1 :, :])
self.value_cache.append(value_states[..., -self.sliding_window + 1 :, :])
# We should return full states during prefill even though we only save up to sliding window-1
return key_states, value_states
else:
self._seen_tokens[layer_idx] += key_states.shape[-2]
# We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep
# the last `sliding_window-1` states in the cache for next forward
full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window + 1 :, :]
self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window + 1 :, :]
return full_key_states, full_value_states

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicSlidingWindowCache(self.sliding_window)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls(splits[0].sliding_window)
for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
cache.update(layer_keys, layer_values, idx)

# We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window)
cache._seen_tokens = splits[0]._seen_tokens
return cache

def crop(self, max_length: int):
if self.get_past_seen_tokens() >= self.sliding_window - 1:
raise RuntimeError(
"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states."
)
else:
super().crop(max_length)

from_legacy_cache = None
to_legacy_cache = None


class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Expand Down
22 changes: 15 additions & 7 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..cache_utils import (
Cache,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
OffloadedCache,
QuantizedCacheConfig,
Expand Down Expand Up @@ -455,6 +456,7 @@ def prepare_inputs_for_generation(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
initial_mask_position=0,
dtype=self.dtype,
device=device,
cache_position=cache_position,
Expand Down Expand Up @@ -1557,8 +1559,8 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
past_length = 0
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
elif hasattr(cache, "get_past_seen_tokens") and cache.get_past_seen_tokens() is not None:
past_length = cache.get_past_seen_tokens()

# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
Expand Down Expand Up @@ -2158,6 +2160,8 @@ def generate(
raise ValueError(
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
)
if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache):
raise ValueError("DynamicSlidingWindowCache cannot be used in assisted generation.")

# 11. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
Expand Down Expand Up @@ -2207,6 +2211,12 @@ def generate(
raise ValueError(
f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
)
if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(
generation_config, "low_memory", False
):
raise ValueError(
"DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`."
)

result = self._contrastive_search(
input_ids,
Expand Down Expand Up @@ -2817,7 +2827,7 @@ def _contrastive_search(
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
and model_kwargs["past_key_values"].get_seq_length() == 0
and model_kwargs["past_key_values"].get_past_seen_tokens() == 0
):
# prepare inputs
model_kwargs["use_cache"] = True
Expand Down Expand Up @@ -4655,10 +4665,8 @@ def _concat(data):
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)):
return data[0].__class__.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def _ignore_causal_mask_sdpa(
if (
(is_training or not is_tracing)
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
and (sliding_window is None or key_value_length <= sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
elif sliding_window is None or key_value_length <= sliding_window:
if len(attention_mask.shape) == 4:
return False
elif not is_tracing and torch.all(attention_mask == 1):
elif not is_tracing and torch.all(attention_mask[:, -key_value_length:] == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def forward(
)

batch_size, seq_length, _ = inputs_embeds.shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
seq_length_with_past = seq_length + past_length
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
Expand Down Expand Up @@ -747,7 +747,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
Loading