diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634..3004e89300fc 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,14 +362,21 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] DynamicCache - update - - get_seq_length + - get_past_seen_tokens + - reorder_cache + - to_legacy_cache + - from_legacy_cache + +[[autodoc]] DynamicSlidingWindowCache + - update + - get_past_seen_tokens - reorder_cache - to_legacy_cache - from_legacy_cache [[autodoc]] QuantizedCache - update - - get_seq_length + - get_past_seen_tokens [[autodoc]] QuantoQuantizedCache @@ -377,7 +384,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] SinkCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache [[autodoc]] OffloadedCache @@ -387,17 +394,17 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] OffloadedStaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] HybridCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] SlidingWindowCache @@ -405,7 +412,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens - reset [[autodoc]] EncoderDecoderCache - - get_seq_length + - get_past_seen_tokens - to_legacy_cache - from_legacy_cache - reset diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b5b1fc6aec85..420fe6d6d2c1 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -906,7 +906,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -997,7 +997,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 49cdd2741620..8b20b43c20b3 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -784,7 +784,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -879,7 +879,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d91bdb1820c2..71b14bb8051a 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -902,7 +902,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ab829c6894c0..af795e60ba47 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1268,6 +1268,7 @@ "Cache", "CacheConfig", "DynamicCache", + "DynamicSlidingWindowCache", "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", @@ -6156,6 +6157,7 @@ Cache, CacheConfig, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, HQQQuantizedCache, HybridCache, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e4a1ee26c12..ddc9db0d02d0 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -64,6 +64,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as + `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. + """ + return self.get_seq_length(layer_idx) + # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so @@ -545,6 +551,133 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +# TODO: (cyril) Make this the default for models with sliding window once `generate` no longer returns Cache as tuples +class DynamicSlidingWindowCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. + This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window-1, head_dim]` if seq_len >= sliding_window-1. + + Note: Since we only keep maximum `sliding_window-1` tokens in the cache, once this value is reached the cache can no + longer be roll-backed to previous states without losing information. For this reason, it should not be used with assisted decoding + (or contrastive search when using `low_memory=True`). + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicSlidingWindowCache() + ``` + """ + + def __init__(self, sliding_window: int) -> None: + super().__init__() + self.sliding_window = sliding_window + # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` + self._seen_tokens = [] + + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" + if len(self._seen_tokens) <= layer_idx: + return 0 + else: + return self._seen_tokens[layer_idx] + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous + tokens according to the sliding window if needed. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + if len(self.key_cache) <= layer_idx: + # Update the number of seen tokens + self._seen_tokens.append(key_states.shape[-2]) + # Add only up to sliding window size if larger + self.key_cache.append(key_states[..., -self.sliding_window + 1 :, :]) + self.value_cache.append(value_states[..., -self.sliding_window + 1 :, :]) + # We should return full states during prefill even though we only save up to sliding window-1 + return key_states, value_states + else: + self._seen_tokens[layer_idx] += key_states.shape[-2] + # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep + # the last `sliding_window-1` states in the cache for next forward + full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window + 1 :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window + 1 :, :] + return full_key_states, full_value_states + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls(splits[0].sliding_window) + for idx in range(len(splits[0])): + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + + # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) + cache._seen_tokens = splits[0]._seen_tokens + return cache + + def crop(self, max_length: int): + if self.get_past_seen_tokens() >= self.sliding_window - 1: + raise RuntimeError( + "The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states." + ) + else: + super().crop(max_length) + + from_legacy_cache = None + to_legacy_cache = None + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5da4878513eb..83a21e590df9 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -1532,8 +1533,8 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: - past_length = cache.get_seq_length() + elif hasattr(cache, "get_past_seen_tokens") and cache.get_past_seen_tokens() is not None: + past_length = cache.get_past_seen_tokens() # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. @@ -2130,6 +2131,8 @@ def generate( raise ValueError( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache): + raise ValueError("DynamicSlidingWindowCache cannot be used in assisted generation.") # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -2179,6 +2182,12 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr( + generation_config, "low_memory", False + ): + raise ValueError( + "DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`." + ) result = self._contrastive_search( input_ids, @@ -2764,7 +2773,7 @@ def _contrastive_search( # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None or ( isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) - and model_kwargs["past_key_values"].get_seq_length() == 0 + and model_kwargs["past_key_values"].get_past_seen_tokens() == 0 ): # prepare inputs model_kwargs["use_cache"] = True @@ -4166,7 +4175,7 @@ def _assisted_decoding( isinstance(past_key_values, EncoderDecoderCache) and isinstance(past_key_values.self_attention_cache, DynamicCache) ): - if past_key_values.get_seq_length() == 0: + if past_key_values.get_past_seen_tokens() == 0: start_from_empty_dynamic_cache = True this_peer_finished = False @@ -4603,10 +4612,8 @@ def _concat(data): if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): + return data[0].__class__.from_batch_splits(data, num_hidden_layers=num_hidden_layers) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4319c021cb2b..64b64e889962 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -275,13 +275,13 @@ def _ignore_causal_mask_sdpa( if ( (is_training or not is_tracing) and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) + and (sliding_window is None or key_value_length <= sliding_window) ): ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: + elif sliding_window is None or key_value_length <= sliding_window: if len(attention_mask.shape) == 4: return False - elif not is_tracing and torch.all(attention_mask == 1): + elif not is_tracing and torch.all(attention_mask[:, -key_value_length:] == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 75f8e5830f44..b0bf53901f47 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -646,7 +646,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_length if cache_position is None: cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) @@ -747,7 +747,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index fd76c0b11522..f9cf103e6f50 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1295,7 +1295,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1386,7 +1386,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 478745b2c59e..2b51bad8ab7b 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -487,7 +487,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -590,7 +590,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a5d3721f5bdb..69505b581470 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -868,7 +868,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -959,7 +959,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index ef81e43d0294..d07bc864839d 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1019,7 +1019,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1120,7 +1120,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index f48accab44bf..94c3930f2de7 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -992,7 +992,7 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation alibi = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 batch_size, seq_length, _ = inputs_embeds.shape if self.use_alibi: mask = ( @@ -1114,7 +1114,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ff206a470bc3..c7cb92fde0ff 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -780,7 +780,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -875,7 +875,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 7130a30dc9be..ce39025cfbdf 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -860,7 +860,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0b99aa59c65b..be4e5b8e0c97 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,7 +790,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c0f76dbe5bfc..7ea5b3bdd23e 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -627,7 +627,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c7f9ceafe194..168935b467ab 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1283,7 +1283,7 @@ def forward( past_key_values_length = ( past_key_values[0][0].shape[2] if not isinstance(past_key_values, Cache) - else past_key_values.get_seq_length() + else past_key_values.get_past_seen_tokens() ) # Prepare head mask if needed @@ -1611,7 +1611,7 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values.get_seq_length() + past_length = past_key_values.get_past_seen_tokens() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7bba7608e6c1..0099341d5262 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -702,7 +702,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -804,7 +804,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index f4636db0a97b..f0adac7642a7 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -904,7 +904,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -1001,7 +1001,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b618f531e52f..7215e8e05076 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -624,7 +624,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -705,7 +705,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 5c80485823c1..8145d9b250ab 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -774,7 +774,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) @@ -899,7 +899,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 0eb27d452f08..420d971b2ac4 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -794,7 +794,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -892,7 +892,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index ebdea826fa04..7ac0829509a4 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1020,7 +1020,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1125,7 +1125,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 02de8d61ae20..e757a043e68d 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1153,7 +1153,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_key_values_length if cache_position is None: @@ -1384,7 +1384,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b53d0722587d..6f3c2ffbbb5a 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1359,7 +1359,7 @@ def forward( "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" ) - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1664,7 +1664,7 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = past_key_values.get_seq_length() + past_length = past_key_values.get_past_seen_tokens() max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 757391175ea6..9c65b2e01b76 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -953,7 +953,7 @@ def forward( past_seen_tokens = 0 if use_cache: - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1252,7 +1252,7 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - past_length = past_key_values.get_seq_length() + past_length = past_key_values.get_past_seen_tokens() max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bbc70b26d1f8..ca2651871cbe 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -998,7 +998,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1101,7 +1101,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dde017bbb927..bbcbe5ef96bb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -904,7 +904,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -995,7 +995,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 514f9de706ec..6120e2576d68 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -23,7 +23,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -958,7 +958,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) @@ -1044,7 +1044,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1057,7 +1058,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1076,12 +1077,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1109,6 +1116,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1127,6 +1135,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1146,14 +1157,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1164,7 +1180,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b0ffe3e56e59..464172172ae8 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -313,7 +313,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += cache_position[0] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) @@ -775,7 +775,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -872,7 +872,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -885,7 +886,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -904,12 +905,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -936,6 +943,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -954,6 +962,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -973,14 +984,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -991,7 +1007,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1178,6 +1196,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c7fadbb8f88..05db01c83f3b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -329,15 +329,15 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -412,22 +412,15 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -558,11 +551,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -986,7 +982,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1085,7 +1081,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1098,7 +1095,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1117,12 +1114,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1150,6 +1153,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1168,6 +1172,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1187,14 +1194,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1205,7 +1217,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1413,6 +1427,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0bc77eaeec33..89b9813aa4a8 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1620,7 +1620,7 @@ def forward( hidden_states = inputs_embeds if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1726,7 +1726,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 7d0390adc3c0..54e5ff3698ab 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -872,7 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 7ab54146c974..3b1de2709f9a 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -825,7 +825,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -914,7 +914,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 8c29f89ff3e7..47167a62677b 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -973,7 +973,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1073,7 +1073,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index d75a05bda0e1..e19deac906ea 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -472,7 +472,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 7ae3469a4c93..88c2941ed503 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -648,7 +648,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -741,7 +741,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3f770c9ec00b..e68fb3d07de3 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -938,7 +938,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1032,7 +1032,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 0380c6cd49d6..3dd799f59177 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -377,16 +377,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -473,22 +470,15 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) - + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -625,11 +615,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_len = past_key_value.get_max_cache_shape() + else: + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -965,7 +958,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1052,7 +1045,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1065,7 +1059,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1084,12 +1078,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1117,6 +1117,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1135,6 +1136,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1154,14 +1158,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1172,7 +1181,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1389,6 +1400,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index d1705f04ddb7..298777ecbcca 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1220,7 +1220,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1233,7 +1234,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1252,12 +1253,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1285,6 +1292,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1303,6 +1311,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1322,14 +1333,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1340,7 +1356,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1579,6 +1597,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 50f273ba766c..505a35b8b249 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -395,7 +395,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window @@ -878,7 +879,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -971,7 +972,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -984,7 +986,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1003,12 +1005,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1036,6 +1044,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1054,6 +1063,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1073,14 +1085,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1091,7 +1108,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1279,6 +1298,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 6422baac5feb..928433c58d53 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -1266,7 +1266,7 @@ def prepare_inputs_for_generation( if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens + past_length = past_key_values.get_past_seen_tokens() else: cache_length = past_length = past_key_values[0][0].shape[2] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 2ab13b7227ad..2b5aff242375 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -483,7 +483,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window @@ -1048,7 +1049,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1152,7 +1153,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1165,7 +1167,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1184,12 +1186,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1217,6 +1225,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1235,6 +1244,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1254,14 +1266,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1272,7 +1289,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1483,6 +1502,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 283e38d3a7d5..86e6bfbb5472 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -30,7 +30,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -549,10 +549,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -809,9 +805,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -1139,7 +1132,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1234,7 +1227,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1247,7 +1241,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -1266,12 +1260,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1299,6 +1299,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1317,6 +1318,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1336,14 +1340,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1354,7 +1363,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1838,6 +1849,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index fe3ad6498172..779034f51b82 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -923,7 +923,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1016,7 +1016,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index e0fdbef1a3ba..3e6fba51861d 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -375,7 +375,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window @@ -851,7 +852,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -945,7 +946,8 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -958,7 +960,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): @@ -977,12 +979,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1010,6 +1018,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1028,6 +1037,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1047,14 +1059,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1065,7 +1082,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype @@ -1255,6 +1274,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 079965fc174a..408ff54f5c55 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1248,7 +1248,7 @@ def forward( if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_past_seen_tokens() if cache_position is None: cache_position = torch.arange( @@ -1383,7 +1383,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -1824,7 +1824,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] @@ -2105,7 +2107,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, (Cache, EncoderDecoderCache)): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 048de1cc8ae7..514277b10766 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DynamicSlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncoderDecoderCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117b..2aa8d154cd0e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -2024,6 +2030,118 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @parameterized.expand([(False,), (True,)]) + @pytest.mark.generate + def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): + """ + Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. + """ + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if not hasattr(config, "sliding_window"): + self.skipTest(reason="This model does not support sliding window.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") + + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + if left_padding: + attention_mask = torch.tensor( + [ + [0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ], + device=input_ids.device, + dtype=int, + ) + else: + attention_mask = torch.ones_like(input_ids) + + # Make sure we will go beyond the sliding window + config.sliding_window = 3 + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + "do_sample": False, + "max_new_tokens": 20, + "min_new_tokens": 20, + "use_cache": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + results_dynamic = model.generate( + input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache + ) + results_sliding_dynamic = model.generate( + input_ids, + attention_mask=attention_mask, + **all_generation_kwargs, + past_key_values=dynamic_sliding_cache, + ) + + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) + + @parameterized.expand([(3, 1), (3, 4), (14, 5)]) + @pytest.mark.generate + def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_window: int, additional_tokens: int): + """ + Tests if we can correctly continue generation with DynamicSlidingWindowCache. + - First case tests that we can continue if the cache is already full, and we add less tokens than the sliding window + - Second case tests that we can continue if the cache is already full, and we add more tokens that the sliding window + - Third case tests that we can continue if the cache is not full, and we add tokens so that the new input is bigger than the sliding window + """ + for model_class in self.all_generative_model_classes: + config, _ = self.prepare_config_and_inputs_for_generate() + if not hasattr(config, "sliding_window"): + self.skipTest(reason="This model does not support sliding window.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") + + # We need to be sure to always have shape (2, 7) for the different test assumptions to hold + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + + # Make sure we will go beyond the sliding window + config.sliding_window = sliding_window + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, + "use_cache": True, + "return_dict_in_generate": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate( + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache + ) + + results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values + results_sliding_dynamic, dynamic_sliding_cache = ( + out_sliding_dynamic.sequences, + out_sliding_dynamic.past_key_values, + ) + + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) + + bs = results_dynamic.shape[0] + added_tokens = ids_tensor((bs, additional_tokens), vocab_size=config.vocab_size) + input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) + + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate( + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache + ) + + self.assertListEqual(out_dynamic.sequences.tolist(), out_sliding_dynamic.sequences.tolist()) + def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] seq_length = main_input.shape[-1]