diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b8d66c4247d4..472ab2ffbfb4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -46,6 +46,7 @@ from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.fsdp import is_fsdp_managed_module from ..masking_utils import create_masks_for_generate +from ..modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie @@ -57,7 +58,6 @@ is_torchdynamo_exporting, logging, ) -from ..modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( @@ -1811,7 +1811,8 @@ def _get_initial_cache_position(self, seq_length, device, model_kwargs): if model_kwargs.get("past_key_values") is not None: cache = model_kwargs["past_key_values"] past_length = 0 - if not isinstance(cache, Cache): + # Support for BC tuple cache format + if isinstance(cache, tuple): past_length = cache[0][0].shape[2] elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: past_length = cache.get_seq_length() diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index ef75f254cc20..806f1eb97a35 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -31,7 +31,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub from ...modeling_attn_mask_utils import AttentionMaskConverter @@ -85,7 +85,7 @@ class BambaFlashAttentionKwargs(TypedDict, total=False): seq_idx: torch.IntTensor -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -104,7 +104,6 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index be58fd3abd42..6ed4a79bebf2 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -42,7 +42,6 @@ segment_sum, ) -from ...cache_utils import DynamicLayer from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -114,7 +113,6 @@ class HybridMambaAttentionDynamicCache(HybridMambaAttentionDynamicCache): """ def __init__(self, config: BambaConfig, batch_size, dtype=torch.float16, device=None): - HybridMambaAttentionDynamicCache.__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 8b099342f6ee..591e41b785d4 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -62,7 +62,7 @@ logger = logging.get_logger(__name__) -class FalconHybridMambaAttentionDynamicCache(Cache): +class FalconHybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index c727d40f448b..4010501397c3 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -27,7 +27,7 @@ from transformers.activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_layers import GradientCheckpointingLayer @@ -222,7 +222,7 @@ def forward( return attn_output, attn_weights -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -241,7 +241,6 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config: GraniteMoeHybridConfig, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba conv_kernel_size = config.mamba_d_conv diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 34af2f2f2e54..b88133a1dea0 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -28,7 +28,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, DynamicLayer +from ...cache_utils import DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available @@ -189,7 +189,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class HybridMambaAttentionDynamicCache(Cache): +class HybridMambaAttentionDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). @@ -208,7 +208,6 @@ class HybridMambaAttentionDynamicCache(Cache): is_compileable = False def __init__(self, config, batch_size, dtype=torch.float16, device=None): - super().__init__(layer_classes=DynamicLayer) self.dtype = dtype self.layers_block_type = config.layers_block_type self.has_previous_state = False # only used by mamba diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 5a60fed7eb27..a39b4de4dde4 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -119,7 +119,7 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache(DynamicCache): +class Lfm2HybridConvCache: """ Attention and conv cache for Lfm2. @@ -251,6 +251,9 @@ def crop(self, max_length: int): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + def __len__(self) -> int: + return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index c3c39e46776f..6a34426792aa 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -80,7 +80,7 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) -class Lfm2HybridConvCache(DynamicCache): +class Lfm2HybridConvCache: """ Attention and conv cache for Lfm2. @@ -212,6 +212,9 @@ def crop(self, max_length: int): self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + def __len__(self) -> int: + return len(self.key_cache) + def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: return self.key_cache[layer_idx], self.value_cache[layer_idx] diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 16290ea4e1b7..f67868134d0f 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -93,7 +93,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class ZambaHybridDynamicCache(Cache): +class ZambaHybridDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index de6c9c9b96df..e33534d57166 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -97,7 +97,7 @@ def extra_repr(self): return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" -class Zamba2HybridDynamicCache(Cache): +class Zamba2HybridDynamicCache: """ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache (which has a constant shape regardless of seq_len). diff --git a/tests/models/bamba/test_modeling_bamba.py b/tests/models/bamba/test_modeling_bamba.py index 245163d672c3..653a8254616b 100644 --- a/tests/models/bamba/test_modeling_bamba.py +++ b/tests/models/bamba/test_modeling_bamba.py @@ -297,6 +297,26 @@ class BambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi # This is because we are hitting edge cases with the causal_mask buffer model_split_percents = [0.5, 0.7, 0.8] + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, HybridMambaAttentionDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = self.model_tester_class(self) self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) diff --git a/tests/models/falcon_h1/test_modeling_falcon_h1.py b/tests/models/falcon_h1/test_modeling_falcon_h1.py index 37afc2cceba1..efcf00798de0 100644 --- a/tests/models/falcon_h1/test_modeling_falcon_h1.py +++ b/tests/models/falcon_h1/test_modeling_falcon_h1.py @@ -38,7 +38,7 @@ if is_torch_available(): import torch - from transformers import AutoTokenizer, Cache, FalconH1ForCausalLM, FalconH1Model + from transformers import AutoTokenizer, FalconH1ForCausalLM, FalconH1Model from transformers.models.falcon_h1.modeling_falcon_h1 import ( FalconHybridMambaAttentionDynamicCache, ) @@ -273,7 +273,7 @@ class FalconH1ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ) def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): - self.assertIsInstance(decoder_past_key_values, (tuple, Cache)) + self.assertIsInstance(decoder_past_key_values, FalconHybridMambaAttentionDynamicCache) # (batch, head, seq_length, head_features) expected_shape = ( @@ -283,31 +283,14 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value config.hidden_size // config.num_attention_heads, ) - if isinstance(decoder_past_key_values, Cache): - self.assertListEqual( - [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], - [expected_shape] * len(decoder_past_key_values.key_cache), - ) - self.assertListEqual( - [value_cache.shape for value_cache in decoder_past_key_values.value_cache], - [expected_shape] * len(decoder_past_key_values.value_cache), - ) - - # Legacy cache format checks. This branch should be removed when all models use `Cache` by default - else: - self.assertListEqual( - [isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values], - [True] * len(decoder_past_key_values), - ) - # check shape key, value - self.assertListEqual( - [layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values], - [expected_shape] * len(decoder_past_key_values), - ) - self.assertListEqual( - [layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values], - [expected_shape] * len(decoder_past_key_values), - ) + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) def setUp(self): self.model_tester = FalconH1ModelTester(self) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 98ccf21e59b3..c1627fc59f2f 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -342,6 +342,26 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, HybridMambaAttentionDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = JambaModelTester(self) self.config_tester = JambaConfigTester(self, config_class=JambaConfig, hidden_size=37) diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py index 7140373081bb..431417f4c18b 100644 --- a/tests/models/zamba/test_modeling_zamba.py +++ b/tests/models/zamba/test_modeling_zamba.py @@ -304,6 +304,26 @@ class ZambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, ZambaHybridDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = ZambaModelTester(self) self.config_tester = ConfigTester(self, config_class=ZambaConfig, hidden_size=37) diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py index 3f35a54acb66..cb742707d713 100644 --- a/tests/models/zamba2/test_modeling_zamba2.py +++ b/tests/models/zamba2/test_modeling_zamba2.py @@ -315,6 +315,26 @@ class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix test_headmasking = False test_pruning = False + def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config): + self.assertIsInstance(decoder_past_key_values, Zamba2HybridDynamicCache) + + # (batch, head, seq_length, head_features) + expected_shape = ( + batch_size, + config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads, + cache_length, + config.hidden_size // config.num_attention_heads, + ) + + self.assertListEqual( + [key_tensor.shape for key_tensor in decoder_past_key_values.key_cache], + [expected_shape] * len(decoder_past_key_values.key_cache), + ) + self.assertListEqual( + [value_cache.shape for value_cache in decoder_past_key_values.value_cache], + [expected_shape] * len(decoder_past_key_values.value_cache), + ) + def setUp(self): self.model_tester = Zamba2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37)