From a199d4e5c76498b67d707822df2d7d5237259544 Mon Sep 17 00:00:00 2001 From: Cyrile Date: Thu, 17 Apr 2025 18:15:56 +0200 Subject: [PATCH 1/5] Add config validation and style tweaks --- .../models/mamba2/configuration_mamba2.py | 8 +++++++ .../models/mamba2/modeling_mamba2.py | 21 ++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index ae6ea5cfaced..215f8abd4c27 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -140,6 +140,13 @@ def __init__( tie_word_embeddings=False, **kwargs, ): + if (hidden_size * expand) != (num_heads * head_dim): + raise AttributeError( + "Inconsistent configuration: hidden_size * expand " + f"({hidden_size * expand}) must equal num_heads * head_dim " + f"({num_heads * head_dim})." + ) + self.vocab_size = vocab_size self.hidden_size = hidden_size self.state_size = state_size @@ -171,6 +178,7 @@ def __init__( self.time_step_limit = time_step_limit self.tie_word_embeddings = tie_word_embeddings + super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index a1ca8d095c25..c4d16546abfe 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -462,13 +462,19 @@ def cuda_kernels_forward( return out # fmt: off - def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None): - batch_size, seq_len, _ = input_states.shape - dtype = input_states.dtype + def torch_forward( + self, + hidden_states: torch.Tensor, + cache_params: Optional[Mamba2Cache]=None, + cache_position:Optional[torch.LongTensor]=None, + attention_mask: Optional[torch.Tensor]=None + ): + batch_size, seq_len, _ = hidden_states.shape + dtype = hidden_states.dtype # 1. Gated MLP's linear projection - input_states = apply_mask_to_padding_states(input_states, attention_mask) - projected_states = self.in_proj(input_states) + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + projected_states = self.in_proj(hidden_states) d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 _, _, gate, hidden_states_B_C, dt = projected_states.split( [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1 @@ -662,11 +668,6 @@ def forward( ): if is_fast_path_available and "cuda" in self.in_proj.weight.device.type: return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask) - dtype = hidden_states.dtype - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: - # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66 - hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) - return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) From fada993b287cd92b02478d9a468b1a126503dd58 Mon Sep 17 00:00:00 2001 From: Cyrile Date: Wed, 23 Apr 2025 11:50:38 +0200 Subject: [PATCH 2/5] Fix style issues --- src/transformers/models/mamba2/configuration_mamba2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index 215f8abd4c27..53054c00503b 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -141,7 +141,7 @@ def __init__( **kwargs, ): if (hidden_size * expand) != (num_heads * head_dim): - raise AttributeError( + raise ValueError( "Inconsistent configuration: hidden_size * expand " f"({hidden_size * expand}) must equal num_heads * head_dim " f"({num_heads * head_dim})." From 3cb7dda73460d1b295ed09c9b56c87e9c1004ec6 Mon Sep 17 00:00:00 2001 From: Cyrile Date: Wed, 23 Apr 2025 12:03:26 +0200 Subject: [PATCH 3/5] Fix style issues --- src/transformers/models/mamba2/configuration_mamba2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mamba2/configuration_mamba2.py b/src/transformers/models/mamba2/configuration_mamba2.py index 53054c00503b..3b1b1177c0a5 100644 --- a/src/transformers/models/mamba2/configuration_mamba2.py +++ b/src/transformers/models/mamba2/configuration_mamba2.py @@ -178,7 +178,6 @@ def __init__( self.time_step_limit = time_step_limit self.tie_word_embeddings = tie_word_embeddings - super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, From d3e5f8cd00c8392e8bd164e813ab535ba94f1f3b Mon Sep 17 00:00:00 2001 From: Cyrile Date: Tue, 13 May 2025 15:49:09 +0200 Subject: [PATCH 4/5] style --- .../models/mamba2/modeling_mamba2.py | 12 ++------ tests/models/mamba2/test_modeling_mamba2.py | 29 ++++++++++++++++++- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index c4d16546abfe..a2a9c75a4048 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -21,7 +21,6 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...generation import GenerationMixin @@ -1087,7 +1086,7 @@ def forward( use_cache: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, - **kwargs, # for now we need this for generation + **kwargs, # for now we need this for generation and loss_function ) -> Union[Tuple, Mamba2CausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1113,14 +1112,7 @@ def forward( loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + mamba2_outputs[1:] diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 31565bf23d89..f637fbb61172 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -37,6 +37,33 @@ from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer +class Mamba2ConfigTester(ConfigTester): + def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int): + _input_dict = self.inputs_dict.copy() + _input_dict["hidden_size"] = hidden_size + _input_dict["num_heads"] = num_heads + _input_dict["expand"] = expand + _input_dict["head_dim"] = head_dim + return self.config_class(**_input_dict) + + def test_consistence(self): + self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2) + self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2) + self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2) + with self.parent.assertRaises(ValueError): + self._create_config(hidden_size=2, num_heads=2, expand=4, head_dim=4) + with self.parent.assertRaises(ValueError): + self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2) + + def test_mamba2_offset_properties(self): + self.test_attn_offsets() + self.test_expert_offsets() + + def run_common_tests(self): + self.test_mamba2_offset_properties() + return super().run_common_tests() + + class Mamba2ModelTester: def __init__( self, @@ -226,7 +253,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix def setUp(self): self.model_tester = Mamba2ModelTester(self) - self.config_tester = ConfigTester( + self.config_tester = Mamba2ConfigTester( self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] ) From 462f8cf43422d49cb9d73528ceac9cdd3d67a62b Mon Sep 17 00:00:00 2001 From: Cyrile Date: Wed, 14 May 2025 12:01:30 +0200 Subject: [PATCH 5/5] Small fixes for copy/paste errors --- tests/models/mamba2/test_modeling_mamba2.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index f637fbb61172..8bf2d760a292 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -46,21 +46,17 @@ def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim _input_dict["head_dim"] = head_dim return self.config_class(**_input_dict) - def test_consistence(self): + def test_hidden_size_compatibility(self): self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2) self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2) self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2) with self.parent.assertRaises(ValueError): - self._create_config(hidden_size=2, num_heads=2, expand=4, head_dim=4) + self._create_config(hidden_size=2, num_heads=4, expand=2, head_dim=4) with self.parent.assertRaises(ValueError): self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2) - def test_mamba2_offset_properties(self): - self.test_attn_offsets() - self.test_expert_offsets() - def run_common_tests(self): - self.test_mamba2_offset_properties() + self.test_hidden_size_compatibility() return super().run_common_tests()