-
Notifications
You must be signed in to change notification settings - Fork 348
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
The current Mamba2's inference branch here gives incorrect results. There are two issues:
norm_before_gate: the training branch usesnorm_before_gate = False(here, fed into the fusedmamba_split_conv1d_scan_combinedkernel, which includes the gating and normalization). However, the inference branch usesFusedRMSNormSwishGate, which applies rmsnorm before gating. I don't have much time to add and testnorm_before_gatetoFusedRMSNormSwishGateat the moment so I just used the gated RMSNorm from the Mamba2 repo, and it seems to pass the test I pasted below.- Computation of
timestep: the following line should be removed sincemamba_chunk_scan_combinedexpects rawdtbefore addingdt_biasandsoftplus:
https://github.com/sustcsonglin/flash-linear-attention/blob/e7f57746d6a0cfbf9200228f24a898f4e904ad8d/fla/models/mamba2/modeling_mamba2.py#L366
Also we should passdt_bias=self.dt_biasanddt_softplus=Truetomamba_chunk_scan_combined. In principle the current implementation (i.e., sending transformeddt,dt_bias=Noneanddt_softplus=Falsetomamba_chunk_scan_combined) should yield the same result as the fix above, but it fails my test... Not sure why though.
Steps to reproduce the bug
Simple test script:
import torch
from fla.models.mamba2.configuration_mamba2 import Mamba2Config
from fla.models.mamba2.modeling_mamba2 import Mamba2Mixer
def test_mamba2_eval():
B, T, D = 4, 512, 768
dtype = torch.bfloat16
config = Mamba2Config(
num_heads=24,
head_dim=64,
hidden_size=768,
expand=2,
n_groups=1
)
torch.manual_seed(42)
with torch.amp.autocast(device_type="cuda", dtype=dtype):
with torch.no_grad():
mixer = Mamba2Mixer(config, layer_idx=0).to("cuda")
hidden_states = torch.rand(size=(B, T, D), dtype=dtype, device="cuda")
mixer.train()
out_train = mixer(hidden_states)
mixer.eval()
out_eval = mixer(hidden_states)
assert torch.allclose(out_train, out_eval, atol=1e-3)
if __name__ == "__main__":
test_mamba2_eval()Expected behavior
The allclose assertion should pass
Environment info
- torch: 2.4.0
- triton: 3.0.0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working