Skip to content

Conversation

@koayon
Copy link
Contributor

@koayon koayon commented Mar 12, 2024

What does this PR do?

  1. Previously the cache_params variable was inconsistently typed as MambaCache, torch.Tensor or list[torch.Tensor]. This PR updates this to MambaCache everywhere which is inline with the attributes that are being accessed in the logic.
  2. Previously if you ran the forward method with use_cache=True, the cache_params in the output would still be None as this argument wasn't being passed through to the MambaModel. This PR fixes this as below:
  3. Also updates the docstrings in line with this change

Allowed the use_cache information to be passed through so that you can do:

import torch as t

from transformers import AutoTokenizer

if __name__ == "__main__":
    model = MambaForCausalLM(MambaConfig())
    tokeniser = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
    input_ids: t.Tensor = tokeniser("Hey how are you doing?", return_tensors="pt")[  # type: ignore
        "input_ids"
    ]

    out: MambaCausalLMOutput = model(input_ids=input_ids, use_cache=True)
    assert out.cache_params is not None
    print(out.cache_params.ssm_states)

And get back the ssm_states, which was not previously possible

Before submitting

Who can review?

@ArthurZucker
@gante

@amyeroberts
Copy link
Contributor

Thanks for adding this @koayon

Pinging @gante for first review of the cache logic, as @ArthurZucker is off this week

@koayon koayon changed the title Update Mamba types and allow use_cache to be passed through Update Mamba types and pass through use_cache attr to MambaModel Mar 12, 2024
Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the PR 🤗

I'd like a final check from @ArthurZucker, though -- there are some terminology updates in the docstrings, and I'm not very familiar with Mamba :)

@koayon
Copy link
Contributor Author

koayon commented Mar 19, 2024

Hey @ArthurZucker! Hope you had a great holiday 🙌
Is it possible for you to take a little look at the docstring here?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the typing, there are a lot of unrelated changes, and I am not sure I understand how you got the cache_params to be None but adding kwargs to the call should not really be the solution !

use_cache=True returns the cache for me:

Uploading image.png…

unless you are training.

Comment on lines 53 to 60
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
(
selective_state_update,
selective_scan_fn,
causal_conv1d_fn,
causal_conv1d_update,
mamba_inner_fn,
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unrelated and is styling, should be reverted!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I've updated the styling 👌

Comment on lines -277 to -293
class MambaCache:
def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.seqlen_offset = 0
self.dtype = dtype
intermediate_size = config.intermediate_size
ssm_state_size = config.state_size
conv_kernel_size = config.conv_kernel

self.conv_states = {
i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}
self.ssm_states = {
i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype)
for i in range(config.num_hidden_layers)
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if moved, let's just keep the styling of this one please

Comment on lines 164 to 210
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, unrelated change

)
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(
hidden_states, conv_weights, self.conv1d.bias, activation=self.activation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

snae gere

else:
if cache_params is not None:
conv_states = nn.functional.pad(
hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, unrelated change

Comment on lines 94 to 134
self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False)
self.x_proj = nn.Linear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here


if cache_params is None and use_cache:
cache_params = MambaCache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated change let's revert

return model_kwargs

def prepare_inputs_for_generation(
self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this required? it should not. The cache params are passed right above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's the use_cache argument that needs to be passed in for this to work as expected - we could restrict to just passing that through?

Copy link
Contributor Author

@koayon koayon Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have amended this to only pass through the use_cache argument

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect

@koayon
Copy link
Contributor Author

koayon commented Mar 20, 2024

Hey @ArthurZucker, thanks for your review! 🙌
I've reverted the styling changes, thanks 🙏

In terms of the image that you were sending, it's unfortunately not showing up for me. But without the change to pass in use_cache I don't see the cache_params being returned. If there's a difference for you, I've just thought that it might be how it's running on CUDA vs MPS/CPU?

I append the following to the file:

import torch as t

from transformers import AutoTokenizer

if __name__ == "__main__":
    model = MambaForCausalLM(MambaConfig())
    tokeniser = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
    input_ids: t.Tensor = tokeniser("Hey how are you doing?", return_tensors="pt")["input_ids"]  # type: ignore

    out: MambaCausalLMOutput = model(input_ids=input_ids, use_cache=True)
    assert out.cache_params is not None
    print(out.cache_params.ssm_states)

If the use_cache argument isn't passed through to the backbone (either with kwargs or separately as in the newer version), there is no cache_params returned and I get the error:

python src/transformers/models/mamba/modeling_mamba.py      
...
Traceback (most recent call last):
 File "/[PATH_TO_TRANSFORMERS]/transformers/src/transformers/models/mamba/modeling_mamba.py", line 688, in <module>
   assert out.cache_params is not None
AssertionError

whereas with the use_cache argument being passed through I get a tensor returned:

{0: tensor([[[-5.5237e-04,  9.6599e-04,  6.6771e-04,  ..., -5.3982e-04,
          -4.6061e-04, -7.0508e-04],
         [-3.7170e-05, -2.2089e-04, -1.0218e-04,  ...,  7.1232e-05, 
         ...

It does seem like this would be required for the expected behaviour. Please let me know if you have any questions! 😄

@ArthurZucker
Copy link
Collaborator

Alright, when using from_pretrained, the cache is used and passed subsequently, but not when using the initialization

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good to go!

cache_params: Optional[MambaCache] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add use_cache as an arg here

Copy link
Contributor Author

@koayon koayon Mar 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Added! 👍

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry forgot about these!

@koayon
Copy link
Contributor Author

koayon commented Mar 20, 2024

@ArthurZucker great suggestion, didn't realise that was an attribute of the Config 👌
Think it's ready to merge now 🚀

@ArthurZucker
Copy link
Collaborator

The failing test seems new, but it's because when training the use_cache should be disabled by the model

@ArthurZucker
Copy link
Collaborator

I'll have a look

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating!

@ArthurZucker ArthurZucker merged commit 76b3b20 into huggingface:main Mar 20, 2024
ArthurZucker added a commit that referenced this pull request Mar 20, 2024
)

* Update docstring for RMSNorm

* Update cache_params object to correct MambaCache type

* Update docstrings and type info

* Pass through use_cache

* ruff

* Reformat with 119 char limit per line (thanks Arthur)

* Pass through use_cache specifically to the backbone rather than all keyword arguments

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/mamba/modeling_mamba.py

Co-authored-by: Arthur <[email protected]>

* Update tab

* Update src/transformers/models/mamba/modeling_mamba.py

* Update src/transformers/models/mamba/modeling_mamba.py

Co-authored-by: Arthur <[email protected]>

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants