-
Notifications
You must be signed in to change notification settings - Fork 31.9k
[Mamba2] Fix caching, slow path, and multi-gpu
#35154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some comments for clarification
| # Only left padding is valid | ||
| attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) | ||
| attention_mask[0, :1] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a mask, maybe for some other tests as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, is it intended that it is only tuned out for the first element of the batch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh, that was pretty willy nilly from me; could definitely be changed just wanted to debug and see if stuff works
…generate (gives total ids + mask at each step)
vasqu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments for the cache
|
Integration tests will probably need an update but I don't have a GPU for the 7B atm. Edit: If you could update these integration tests, then gladly :D especially since I'm on vacay very soon |
|
Hey @vasqu thanks! Taking a look in a min |
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @vasqu thanks a bunch! left a couple questions/comments but looks good
| (batch_size, self.num_heads, self.head_dim, self.ssm_state_size), | ||
| device=hidden_states.device, dtype=dtype | ||
| # 2. Convolution sequence transformation | ||
| if cache_params is not None and cache_position is not None and cache_position[0] > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently will break torch compile, FWIW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the triton kernels themself are not easy to compile atm either way but this should definitely be handled properly in the future. FYI, you would need to register fake ops for torch to make it properly work which would entail some separate mamba2 utils for the kernel - see https://github.com/facebookresearch/lingua/tree/main/apps/mamba/component
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah not actionable immediately but would be nice to have in a near future! thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely! Would love to see it :)
| batch_size, seq_len, _ = input_states.shape | ||
| dtype = input_states.dtype | ||
| # Gated MLP's linear projection | ||
| projected_states = self.in_proj(input_states.squeeze(1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this - seems improvable yes, but the squeeze is a no-op unless seq_len == 1, so in caching situation indeed. So we're ending up with a [batch_size, H] tensor instead of a [batch_size, seq_len, H] tensor. Then, we're splitting this one on the last dimension, so it should be fine
| # Only left padding is valid | ||
| attention_mask = torch.ones(size=(self.batch_size, self.seq_length), device=input_ids.device, dtype=torch.long) | ||
| attention_mask[0, :1] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, is it intended that it is only tuned out for the first element of the batch?
| input_states = remove_padding_influence(input_states, attention_mask) | ||
| projected_states = self.in_proj(input_states) | ||
| d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size-self.num_heads) // 2 | ||
| _, _, gate, hidden_states_B_C, dt = projected_states.split( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, now it's aligned with cuda kernel forward in naming. TBH the whole split is the same for cuda and torch so could be factored out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a refactor :D would leave this to a separate PR and focus on making things work first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah for sure!
|
also I added the slow label - feel free to launch a commit with message "[run-slow] mamba2" so we can trigger the slow CI! that way we make sure multi-gpu is indeed fixed |
|
@molbap Yup, added an empty commit - will get to the comments/review a bit later 🫡 (I could expect some failures on the integration tests, not sure let's see) |
|
Attempt 2 at multi gpu, at least a different error :p |
|
Things that remain:
Otherwise, ready to go @molbap Edit: Hub seems to have some unrelated issues |
Mamba2] Fix Cache and several other small issuesMamba2] Fix caching, slow path, and multi-gpu
|
Hey 👋 I don't need direct credit, I just think that the list given in the docstring is misleading: transformers/src/transformers/models/bamba/modular_bamba.py Lines 214 to 218 in 667ed56
The changes are mainly because of the cache + dropping some attributes. |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| d_mlp = ( | ||
| projected_states.shape[-1] | ||
| - 2 * self.intermediate_size | ||
| - 2 * self.n_groups * self.ssm_state_size | ||
| - self.num_heads | ||
| ) // 2 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is less readable, but I mean mamba in general is hard to read 😄
What does this PR do?
Kind of a follow-up to #34901 as there are some issues in the current code:
Mamba2] Fix slow path #34901)Fixes #33567
Fixes #34817
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@molbap @ArthurZucker