Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Nov 24, 2024

What does this PR do?

  • Fixes a wrong contraction in the mamba2 slow path discovered by @HanGuo97
  • Simplifying one step (4 vs 1 permutations)
  • Cleaning up some comments to follow the og ssd minimal more closely again

Verified it locally, see the test over here

Fixes #34817

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @molbap

@molbap
Copy link
Contributor

molbap commented Nov 25, 2024

Thanks @vasqu for the investigation and fix - and @HanGuo97 for reporting!
Indeed, looks like this is a better reduction - pen&paper helped me understand it. One thing, could you move the test you did locally to an actual test, with a dependency check to protect the mamba_chunk_scan_combined import? that way, either with a test image with this dep installed or locally, one can verify that this is always true

@molbap
Copy link
Contributor

molbap commented Nov 25, 2024

Also can you please push this empty commit to trigger the slow tests workflow?
git commit --allow-empty -m "[run-slow] mamba2"

@vasqu
Copy link
Contributor Author

vasqu commented Nov 25, 2024

@molbap I think there might be a missing cache initialization for the conv in the cuda forward, i.e. see

hidden_states_B_C = causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)[:, :seq_len]
where some copy mises. Not sure when I'll get to it, but I would leave that to a separate PR.

Added a test now. It takes 3-4s on my machine due to the warming of the triton kernel - should I make it a slow test?

@molbap
Copy link
Contributor

molbap commented Nov 25, 2024

@vasqu 3-4 s is ok, no need to flag it as slow I think!
and for sure, ping me on the separate PR when you have time 🫡
Slow tests have one failure that seems to be independent from this PR

             # Discretize x into dB
            # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
            hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
            dBx = dB * hidden_states[..., None]
    
            # State calculation
            cache_params.ssm_states[self.layer_idx].copy_(
>               cache_params.ssm_states[self.layer_idx] * dA + dBx
            )
E           RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

it's related to multi-GPU runs, it's separate from this issue, I think we're ok to go here! cc @ArthurZucker

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM as discussed above :) to be reviewed by core maintainer

@vasqu
Copy link
Contributor Author

vasqu commented Nov 25, 2024

I found an issue for the failure, so yea seems like a separate issue ^^ #33567

@vasqu
Copy link
Contributor Author

vasqu commented Dec 8, 2024

gentle ping @ArthurZucker

@vasqu
Copy link
Contributor Author

vasqu commented Dec 10, 2024

Closing in favor of #35154

@vasqu vasqu closed this Dec 10, 2024
@vasqu vasqu deleted the fix-mamba2-slow-path branch December 24, 2024 17:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mamba2 torch_forward reduction dimension possibly incorrect?

2 participants