diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index c312b9b94351..3083139d7c9d 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -531,45 +531,44 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None, # This is the analog of a causal mask L = torch.exp(segment_sum(A)) - # First, contraction of C and B to get G (attention-weights like) - G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n) + # Contraction of C and B to get G (attention-weights like) + # shape: (b, c, l, s, h, n) + G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :] G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h) - - # Step 2: Compute M, equivalent to applying attention mask to weights + # Compute M, equivalent to applying attention mask to weights M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None] M = M_intermediate.sum(dim=-1) - # Step 3: Compute Y_diag (apply to values) - Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3) + # Compute Y_diag (apply to values) + Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3) + # 2. Compute the state for each intra-chunk # (right term of low-rank factorization of off-diagonal blocks; B terms) - decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) - B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None] - # permute back B * decay states - states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3) + B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None] + states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) if cache_params is not None and cache_params.seqlen_offset > 0: previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...] else: previous_states = torch.zeros_like(states[:, :1]) states = torch.cat([previous_states, states], dim=1) decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0)))) - - states_permuted = states.permute(0, 2, 1, 3, 4) - result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2) - new_states = result.permute(0, 2, 1, 3, 4) + decay_chunk = decay_chunk.transpose(1, 3) + new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1) states, ssm_state = new_states[:, :-1], new_states[:, -1] - # Compute state -> output conversion per chunk + # 4. Compute state -> output conversion per chunk # (left term of low-rank factorization of off-diagonal blocks; C terms) state_decay_out = torch.exp(A_cumsum) - # compute Yoff C_times_states = (C[..., None, :] * states[:, :, None, ...]) state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1) Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None]) - # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) y = Y_diag + Y_off # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim] y = y.reshape(batch_size, -1, self.num_heads, self.head_dim) diff --git a/tests/models/mamba2/test_modeling_mamba2.py b/tests/models/mamba2/test_modeling_mamba2.py index 9b3a9563b58d..e52f277eeba5 100644 --- a/tests/models/mamba2/test_modeling_mamba2.py +++ b/tests/models/mamba2/test_modeling_mamba2.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, Mamba2Config, is_torch_available from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device +from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -158,6 +159,27 @@ def prepare_config_and_inputs_for_common(self): inputs_dict = {"input_ids": input_ids} return config, inputs_dict + def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False): + model = Mamba2Model(config) + model.eval() + + if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()): + self.parent.skipTest( + "This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found." + ) + if torch_device != "cuda": + self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.") + + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + token_emb = model.embeddings(input_ids) + outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb) + outputs_slow = model.layers[0].mixer.torch_forward(token_emb) + + self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3)) + @unittest.skipIf( not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" @@ -195,6 +217,10 @@ def test_initialization(self): # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + def test_mamba2_slow_vs_fast_forward(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs) + @unittest.skip(reason="Mamba 2 weights are not tied") def test_tied_weights_keys(self): pass