-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Update Mamba types and pass through use_cache attr to MambaModel #29605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for adding this @koayon Pinging @gante for first review of the cache logic, as @ArthurZucker is off this week |
gante
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the PR 🤗
I'd like a final check from @ArthurZucker, though -- there are some terminology updates in the docstrings, and I'm not very familiar with Mamba :)
|
Hey @ArthurZucker! Hope you had a great holiday 🙌 |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| is_fast_path_available = all( | ||
| (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) | ||
| ( | ||
| selective_state_update, | ||
| selective_scan_fn, | ||
| causal_conv1d_fn, | ||
| causal_conv1d_update, | ||
| mamba_inner_fn, | ||
| ) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is unrelated and is styling, should be reverted!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I've updated the styling 👌
| class MambaCache: | ||
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): | ||
| self.seqlen_offset = 0 | ||
| self.dtype = dtype | ||
| intermediate_size = config.intermediate_size | ||
| ssm_state_size = config.state_size | ||
| conv_kernel_size = config.conv_kernel | ||
|
|
||
| self.conv_states = { | ||
| i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) | ||
| for i in range(config.num_hidden_layers) | ||
| } | ||
| self.ssm_states = { | ||
| i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) | ||
| for i in range(config.num_hidden_layers) | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if moved, let's just keep the styling of this one please
| ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 | ||
| ssm_parameters, | ||
| [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], | ||
| dim=-1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, unrelated change
| ) | ||
| cache_params.conv_states[self.layer_idx].copy_(conv_states) | ||
| hidden_states = causal_conv1d_fn( | ||
| hidden_states, conv_weights, self.conv1d.bias, activation=self.activation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
snae gere
| else: | ||
| if cache_params is not None: | ||
| conv_states = nn.functional.pad( | ||
| hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, unrelated change
| self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) | ||
| self.x_proj = nn.Linear( | ||
| self.intermediate_size, | ||
| self.time_step_rank + self.ssm_state_size * 2, | ||
| bias=False, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
|
|
||
| if cache_params is None and use_cache: | ||
| cache_params = MambaCache( | ||
| self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated change let's revert
| return model_kwargs | ||
|
|
||
| def prepare_inputs_for_generation( | ||
| self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
| inputs_embeds=inputs_embeds, | ||
| output_hidden_states=output_hidden_states, | ||
| return_dict=return_dict, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this required? it should not. The cache params are passed right above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe it's the use_cache argument that needs to be passed in for this to work as expected - we could restrict to just passing that through?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have amended this to only pass through the use_cache argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perfect
|
Hey @ArthurZucker, thanks for your review! 🙌 In terms of the image that you were sending, it's unfortunately not showing up for me. But without the change to pass in use_cache I don't see the cache_params being returned. If there's a difference for you, I've just thought that it might be how it's running on CUDA vs MPS/CPU? I append the following to the file: import torch as t
from transformers import AutoTokenizer
if __name__ == "__main__":
model = MambaForCausalLM(MambaConfig())
tokeniser = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
input_ids: t.Tensor = tokeniser("Hey how are you doing?", return_tensors="pt")["input_ids"] # type: ignore
out: MambaCausalLMOutput = model(input_ids=input_ids, use_cache=True)
assert out.cache_params is not None
print(out.cache_params.ssm_states)If the use_cache argument isn't passed through to the backbone (either with kwargs or separately as in the newer version), there is no cache_params returned and I get the error: python src/transformers/models/mamba/modeling_mamba.py
...
Traceback (most recent call last):
File "/[PATH_TO_TRANSFORMERS]/transformers/src/transformers/models/mamba/modeling_mamba.py", line 688, in <module>
assert out.cache_params is not None
AssertionErrorwhereas with the use_cache argument being passed through I get a tensor returned: {0: tensor([[[-5.5237e-04, 9.6599e-04, 6.6771e-04, ..., -5.3982e-04,
-4.6061e-04, -7.0508e-04],
[-3.7170e-05, -2.2089e-04, -1.0218e-04, ..., 7.1232e-05,
...It does seem like this would be required for the expected behaviour. Please let me know if you have any questions! 😄 |
|
Alright, when using from_pretrained, the cache is used and passed subsequently, but not when using the initialization |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost good to go!
| cache_params: Optional[MambaCache] = None, | ||
| labels: Optional[torch.LongTensor] = None, | ||
| output_hidden_states: Optional[bool] = None, | ||
| return_dict: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add use_cache as an arg here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ArthurZucker Added! 👍
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry forgot about these!
Co-authored-by: Arthur <[email protected]>
Co-authored-by: Arthur <[email protected]>
|
@ArthurZucker great suggestion, didn't realise that was an attribute of the Config 👌 |
|
The failing test seems new, but it's because when training the use_cache should be disabled by the model |
|
I'll have a look |
Co-authored-by: Arthur <[email protected]>
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating!
) * Update docstring for RMSNorm * Update cache_params object to correct MambaCache type * Update docstrings and type info * Pass through use_cache * ruff * Reformat with 119 char limit per line (thanks Arthur) * Pass through use_cache specifically to the backbone rather than all keyword arguments * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <[email protected]> * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <[email protected]> * Update tab * Update src/transformers/models/mamba/modeling_mamba.py * Update src/transformers/models/mamba/modeling_mamba.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
MambaCache,torch.Tensororlist[torch.Tensor]. This PR updates this toMambaCacheeverywhere which is inline with the attributes that are being accessed in the logic.MambaModel. This PR fixes this as below:Allowed the use_cache information to be passed through so that you can do:
And get back the ssm_states, which was not previously possible
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker
@gante