Skip to content
7 changes: 7 additions & 0 deletions src/transformers/models/mamba2/configuration_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@ def __init__(
tie_word_embeddings=False,
**kwargs,
):
if (hidden_size * expand) != (num_heads * head_dim):
raise ValueError(
"Inconsistent configuration: hidden_size * expand "
f"({hidden_size * expand}) must equal num_heads * head_dim "
f"({num_heads * head_dim})."
)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.state_size = state_size
Expand Down
21 changes: 11 additions & 10 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,13 +462,19 @@ def cuda_kernels_forward(
return out

# fmt: off
def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
def torch_forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[Mamba2Cache]=None,
cache_position:Optional[torch.LongTensor]=None,
attention_mask: Optional[torch.Tensor]=None
):
batch_size, seq_len, _ = hidden_states.shape
dtype = hidden_states.dtype

# 1. Gated MLP's linear projection
input_states = apply_mask_to_padding_states(input_states, attention_mask)
projected_states = self.in_proj(input_states)
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
Expand Down Expand Up @@ -662,11 +668,6 @@ def forward(
):
if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
dtype = hidden_states.dtype
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
# tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)

return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)


Expand Down