Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Dec 8, 2024

What does this PR do?

Kind of a follow-up to #34901 as there are some issues in the current code:

Fixes #33567
Fixes #34817

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@molbap @ArthurZucker

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments for clarification

Comment on lines +106 to +108
# Only left padding is valid
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
attention_mask[0, :1] = 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a mask, maybe for some other tests as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, is it intended that it is only tuned out for the first element of the batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, that was pretty willy nilly from me; could definitely be changed just wanted to debug and see if stuff works

Copy link
Contributor Author

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more comments for the cache

@vasqu
Copy link
Contributor Author

vasqu commented Dec 8, 2024

Integration tests will probably need an update but I don't have a GPU for the 7B atm.

Edit: If you could update these integration tests, then gladly :D especially since I'm on vacay very soon

@molbap
Copy link
Contributor

molbap commented Dec 9, 2024

Hey @vasqu thanks! Taking a look in a min

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @vasqu thanks a bunch! left a couple questions/comments but looks good

(batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
device=hidden_states.device, dtype=dtype
# 2. Convolution sequence transformation
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently will break torch compile, FWIW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the triton kernels themself are not easy to compile atm either way but this should definitely be handled properly in the future. FYI, you would need to register fake ops for torch to make it properly work which would entail some separate mamba2 utils for the kernel - see https://github.com/facebookresearch/lingua/tree/main/apps/mamba/component

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah not actionable immediately but would be nice to have in a near future! thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely! Would love to see it :)

batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Gated MLP's linear projection
projected_states = self.in_proj(input_states.squeeze(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this - seems improvable yes, but the squeeze is a no-op unless seq_len == 1, so in caching situation indeed. So we're ending up with a [batch_size, H] tensor instead of a [batch_size, seq_len, H] tensor. Then, we're splitting this one on the last dimension, so it should be fine

Comment on lines +106 to +108
# Only left padding is valid
attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long)
attention_mask[0, :1] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, is it intended that it is only tuned out for the first element of the batch?

input_states = remove_padding_influence(input_states, attention_mask)
projected_states = self.in_proj(input_states)
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2
_, _, gate, hidden_states_B_C, dt = projected_states.split(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, now it's aligned with cuda kernel forward in naming. TBH the whole split is the same for cuda and torch so could be factored out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a refactor :D would leave this to a separate PR and focus on making things work first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah for sure!

@molbap molbap added State space models Issues or PRs related to state space models such as mamba, mamba2 run-slow labels Dec 9, 2024
@molbap
Copy link
Contributor

molbap commented Dec 9, 2024

also I added the slow label - feel free to launch a commit with message "[run-slow] mamba2" so we can trigger the slow CI! that way we make sure multi-gpu is indeed fixed

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

@molbap Yup, added an empty commit - will get to the comments/review a bit later 🫡

(I could expect some failures on the integration tests, not sure let's see)

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

Attempt 2 at multi gpu, at least a different error :p

@vasqu
Copy link
Contributor Author

vasqu commented Dec 9, 2024

Things that remain:

  • Mulit-GPU sigh
  • Supersede slow path fix with same tests included?
  • Compile compatibility (future)
  • Refactor some stuff (future)

Otherwise, ready to go @molbap

Edit: Hub seems to have some unrelated issues

@vasqu vasqu changed the title [Mamba2] Fix Cache and several other small issues [Mamba2] Fix caching, slow path, and multi-gpu Dec 10, 2024
@vasqu
Copy link
Contributor Author

vasqu commented Dec 19, 2024

Hey 👋 I don't need direct credit, I just think that the list given in the docstring is misleading:

The are a few differences between this and Mamba2Mixer:
- The variable use_precomputed_states is slightly different due to the HybridCache structure
- There's a few non-obvious bugs fixed with batching in the slow path that exist in main
- Some extra variables that our layer doesn't need have been removed
- We ported most of the refactors in https://github.com/huggingface/transformers/pull/35154, which is (as of Dec 18, 2024) unmerged

The changes are mainly because of the cache + dropping some attributes.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @vasqu 😉
Your last comment is completely aligned with our philosophy: if indeed bamba is now the same, we shall add bamba with modular, isolating the differences if there are any left! cc @molbap on this! Merry Christmas as well!

Comment on lines +316 to +322
d_mlp = (
projected_states.shape[-1]
- 2 * self.intermediate_size
- 2 * self.n_groups * self.ssm_state_size
- self.num_heads
) // 2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is less readable, but I mean mamba in general is hard to read 😄

@ArthurZucker ArthurZucker merged commit 5a2aedc into huggingface:main Dec 20, 2024
6 checks passed
@vasqu vasqu deleted the fix-mamba2-caching branch December 20, 2024 15:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-slow State space models Issues or PRs related to state space models such as mamba, mamba2

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Mamba2 torch_forward reduction dimension possibly incorrect? Mamba 2 Multi-GPU errors out on generation with parallel beam search

4 participants