Skip to content

[Bug]: Mamba2 incorrect inference time behavior #63

@zhixuan-lin

Description

@zhixuan-lin

Describe the bug

The current Mamba2's inference branch here gives incorrect results. There are two issues:

  • norm_before_gate: the training branch uses norm_before_gate = False (here, fed into the fused mamba_split_conv1d_scan_combined kernel, which includes the gating and normalization). However, the inference branch uses FusedRMSNormSwishGate, which applies rmsnorm before gating. I don't have much time to add and test norm_before_gate to FusedRMSNormSwishGate at the moment so I just used the gated RMSNorm from the Mamba2 repo, and it seems to pass the test I pasted below.
  • Computation of timestep: the following line should be removed since mamba_chunk_scan_combined expects rawdt before adding dt_bias and softplus:
    https://github.com/sustcsonglin/flash-linear-attention/blob/e7f57746d6a0cfbf9200228f24a898f4e904ad8d/fla/models/mamba2/modeling_mamba2.py#L366
    Also we should pass dt_bias=self.dt_bias and dt_softplus=True to mamba_chunk_scan_combined. In principle the current implementation (i.e., sending transformed dt, dt_bias=None and dt_softplus=False to mamba_chunk_scan_combined) should yield the same result as the fix above, but it fails my test... Not sure why though.

Steps to reproduce the bug

Simple test script:

import torch
from fla.models.mamba2.configuration_mamba2 import Mamba2Config
from fla.models.mamba2.modeling_mamba2 import Mamba2Mixer


def test_mamba2_eval():
    B, T, D = 4, 512, 768
    dtype = torch.bfloat16
    config = Mamba2Config(
        num_heads=24,
        head_dim=64,
        hidden_size=768,
        expand=2,
        n_groups=1
    )

    torch.manual_seed(42)
    with torch.amp.autocast(device_type="cuda", dtype=dtype):
        with torch.no_grad():
            mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
            hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")

            mixer.train()
            out_train = mixer(hidden_states)

            mixer.eval()
            out_eval = mixer(hidden_states)

            assert torch.allclose(out_train, out_eval, atol=1e-3)

if __name__ == "__main__":
    test_mamba2_eval()

Expected behavior

The allclose assertion should pass

Environment info

  1. torch: 2.4.0
  2. triton: 3.0.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions