Skip to content
7 changes: 7 additions & 0 deletions src/transformers/models/mamba2/configuration_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def __init__(
tie_word_embeddings=False,
**kwargs,
):
if (hidden_size * expand) != (num_heads * head_dim):
raise ValueError(
"Inconsistent configuration: hidden_size * expand "
f"({hidden_size * expand}) must equal num_heads * head_dim "
f"({num_heads * head_dim})."
)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
Expand Down
33 changes: 13 additions & 20 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
Expand Down Expand Up @@ -457,13 +456,19 @@ def cuda_kernels_forward(
return out

# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
def torch_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache]=None,
cache_position:Optional[torch.LongTensor]=None,
attention_mask: Optional[torch.Tensor]=None
):
batch_size, seq_len, _ = hidden_states.shape
dtype = hidden_states.dtype

# 1. Gated MLP's linear projection
input_states = apply_mask_to_padding_states(input_states, attention_mask)
projected_states = self.in_proj(input_states)
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
Expand Down Expand Up @@ -657,11 +662,6 @@ def forward(
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)


Expand Down Expand Up @@ -1018,7 +1018,7 @@ def forward(
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs, # for now we need this for generation
**kwargs, # for now we need this for generation and loss_function
) -> Union[Tuple, Mamba2CausalLMOutput]:
r"""
cache_params (`Mamba2Cache`, *optional*):
Expand Down Expand Up @@ -1052,14 +1052,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + mamba2_outputs[1:]
Expand Down
29 changes: 28 additions & 1 deletion tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,33 @@
from transformers.models.mamba2.modeling_mamba2 import Mamba2Cache, Mamba2Mixer


class Mamba2ConfigTester(ConfigTester):
def _create_config(self, hidden_size: int, num_heads: int, expand: int, head_dim: int):
_input_dict = self.inputs_dict.copy()
_input_dict["hidden_size"] = hidden_size
_input_dict["num_heads"] = num_heads
_input_dict["expand"] = expand
_input_dict["head_dim"] = head_dim
return self.config_class(**_input_dict)

def test_consistence(self):
self._create_config(hidden_size=2, num_heads=2, expand=2, head_dim=2)
self._create_config(hidden_size=4, num_heads=4, expand=2, head_dim=2)
self._create_config(hidden_size=2, num_heads=4, expand=4, head_dim=2)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=2, num_heads=2, expand=4, head_dim=4)
with self.parent.assertRaises(ValueError):
self._create_config(hidden_size=4, num_heads=2, expand=4, head_dim=2)

def test_mamba2_offset_properties(self):
self.test_attn_offsets()
self.test_expert_offsets()

def run_common_tests(self):
self.test_mamba2_offset_properties()
return super().run_common_tests()


class Mamba2ModelTester:
def __init__(
self,
Expand Down Expand Up @@ -233,7 +260,7 @@ class Mamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix

def setUp(self):
self.model_tester = Mamba2ModelTester(self)
self.config_tester = ConfigTester(
self.config_tester = Mamba2ConfigTester(
self, config_class=Mamba2Config, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"]
)

Expand Down