Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,45 +531,44 @@ def torch_forward(self, input_states, cache_params: Optional[Mamba2Cache]=None,
# This is the analog of a causal mask
L = torch.exp(segment_sum(A))

# First, contraction of C and B to get G (attention-weights like)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
# Contraction of C and B to get G (attention-weights like)
# shape: (b, c, l, s, h, n)
G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, :, :]
G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)


# Step 2: Compute M, equivalent to applying attention mask to weights
# Compute M, equivalent to applying attention mask to weights
M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
M = M_intermediate.sum(dim=-1)

# Step 3: Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
# Compute Y_diag (apply to values)
Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(dim=3)

# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)

decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
# permute back B * decay states
states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
B_decay = B * decay_states.permute(0, -2, -1, 1)[..., None]
states = (B_decay[..., None, :] * hidden_states[..., None]).sum(dim=2)

# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if cache_params is not None and cache_params.seqlen_offset > 0:
previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
else:
previous_states = torch.zeros_like(states[:, :1])
states = torch.cat([previous_states, states], dim=1)
decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))

states_permuted = states.permute(0, 2, 1, 3, 4)
result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
new_states = result.permute(0, 2, 1, 3, 4)
decay_chunk = decay_chunk.transpose(1, 3)
new_states = (decay_chunk[..., None, None] * states[:, :, None, ...]).sum(dim=1)
states, ssm_state = new_states[:, :-1], new_states[:, -1]

# Compute state -> output conversion per chunk
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
# compute Yoff
C_times_states = (C[..., None, :] * states[:, :, None, ...])
state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)

# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
y = Y_diag + Y_off
# [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
Expand Down
26 changes: 26 additions & 0 deletions tests/models/mamba2/test_modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from transformers import AutoTokenizer, Mamba2Config, is_torch_available
from transformers.testing_utils import require_read_token, require_torch, require_torch_gpu, slow, torch_device
from transformers.utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -158,6 +159,27 @@ def prepare_config_and_inputs_for_common(self):
inputs_dict = {"input_ids": input_ids}
return config, inputs_dict

def create_and_check_mamba2_slow_vs_fast_forward(self, config, input_ids, *args, gradient_checkpointing=False):
model = Mamba2Model(config)
model.eval()

if not (is_mamba_2_ssm_available() and is_causal_conv1d_available()):
self.parent.skipTest(
"This test needs the Mamba2 fast path. Skipping as the necessary packages have not been found."
)
if torch_device != "cuda":
self.parent.skipTest("This test needs the Mamba2 fast path. Skipping as we need a cuda capable device.")

model.to(torch_device)
if gradient_checkpointing:
model.gradient_checkpointing_enable()

token_emb = model.embeddings(input_ids)
outputs_fast = model.layers[0].mixer.cuda_kernels_forward(token_emb)
outputs_slow = model.layers[0].mixer.torch_forward(token_emb)

self.parent.assertTrue(torch.allclose(outputs_fast, outputs_slow, atol=1e-3, rtol=1e-3))


@unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
Expand Down Expand Up @@ -195,6 +217,10 @@ def test_initialization(self):
# check if it's a ones like
self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))

def test_mamba2_slow_vs_fast_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_mamba2_slow_vs_fast_forward(*config_and_inputs)

@unittest.skip(reason="Mamba 2 weights are not tied")
def test_tied_weights_keys(self):
pass
Expand Down