Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement MambaForSequenceClassification #31155

Open
wants to merge 66 commits into
base: main
Choose a base branch
from

Conversation

Adibvafa
Copy link
Contributor

@Adibvafa Adibvafa commented May 31, 2024

What does this PR do?

Adds the MambaForSequenceClassification model based on MambaModel backbone.

We recently published EHRMamba, a state-of-the-art foundation model for Electronic Health Records. This model is built on the same architecture and we will release the trained weights using the MambaForSequenceClassification class.
https://vectorinstitute.github.io/EHRMamba

Fixes #30431

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

As discussed in #30431, @ArthurZucker could you take a look? 😊

Notes

This implementation closely follows the GPT2ForSequenceClassification method, with the exception of pooling the last hidden states before passing them to the classifier to improve efficiency.

@Adibvafa Adibvafa changed the title Implemented MambaForSequenceClassification - Issue #30431 Implemented MambaForSequenceClassification May 31, 2024
@Adibvafa
Copy link
Contributor Author

Referring to #29552, "there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion."

This results in a test failure even though the classifier head is initialized properly.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for opening a PR! 🤗

@ArthurZucker
Copy link
Collaborator

Could you rebase on main and make sure the CIs are green! 🤗

@Adibvafa
Copy link
Contributor Author

Adibvafa commented Jun 7, 2024

Could you rebase on main and make sure the CIs are green! 🤗

Of course! It should be good to merge now.
There is a failed test for "MobileViTV2ModelTest" or similar which are unrelated to Mamba.

@Adibvafa Adibvafa changed the title Implemented MambaForSequenceClassification Implement MambaForSequenceClassification Jun 11, 2024
@Adibvafa
Copy link
Contributor Author

Adibvafa commented Aug 2, 2024

@ArthurZucker Pending review!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗

""",
MAMBA_START_DOCSTRING,
)
class MambaForSequenceClassification(MambaPreTrainedModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use copied from here? : # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mamba, LLAMA->MAMBA, self.transformer->self.model, transformer_outputs->model_outputs

Copy link
Contributor Author

@Adibvafa Adibvafa Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I understand your comment. The forward method of Mamba and Llama for sequence classification seem different. Could you please elaborate! 🤗

@Adibvafa
Copy link
Contributor Author

Adibvafa commented Aug 9, 2024

@ArthurZucker Now that the #32080 is merged, can we do a final review for this one too? Also, I would like to add Mamba2ForSequenceClassification to this PR as well so we have both Mamba models with classification capabilities. Then I would be able to release the EHRMamba model on HuggingFace.

@mohith7548
Copy link

Hey, any update on this?

labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
use_cache: Optional[bool] = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think there is a
**kwargs,
missing in the forward function line 842-843

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I suppose it won't be necessary (as opposed to MambaForCausalLM) but having it is good. I will add it in a commit now.

@Jellymoon
Copy link

@Adibvafa have you tried running this? Whenever the model gets to an evaluation step I get the error below.
The code I tried was the huggingface sequence classification tutorial (link to tutorial) but i used a gpu, replaced "distilbert/distilbert-base-uncased" with "state-spaces/mamba-130m-hf" and i replaced my local modeling_mamba.py with yours so it does load the model.

ERROR:

Traceback (most recent call last):
  File ".../tutorial_script.py", line 71, in <module>
  File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3754, in predict
    output = eval_loop(
             ^^^^^^^^^^
  File ".../lib/python3.11/site-packages/transformers/trainer.py", line 3887, in evaluation_loop
    logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/accelerator.py", line 2508, in pad_across_processes
    return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 411, in wrapper
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 678, in pad_across_processes
    return recursively_apply(
           ^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 107, in recursively_apply
    return honor_type(
           ^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 81, in honor_type
    return type(obj)(generator)
           ^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 110, in <genexpr>
    recursively_apply(
  File ".../lib/python3.11/site-packages/accelerate/utils/operations.py", line 128, in recursively_apply
    raise TypeError(
TypeError: Unsupported types (<class 'transformers.cache_utils.MambaCache'>) passed to `_pad_across_processes`. 
Only nested list/tuple/dicts of objects that are valid for `is_torch_tensor` should be passed.

@mohith7548

This comment was marked as outdated.

@mohith7548
Copy link

Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set use_cache=False when loading the model so that evaluation does not fail.

cc: @Adibvafa

model = MambaForSequenceClassification.from_pretrained(
    model_path, 
    num_labels=len(id2label), 
    id2label=id2label, 
    label2id=label2id,
    use_cache=False  # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
)

@mohith7548
Copy link

I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?

Copy link

@mohith7548 mohith7548 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran locally. Loaded, Finetuned and did batch inference. Works as expected.

@Adibvafa
Copy link
Contributor Author

Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set use_cache=False when loading the model so that evaluation does not fail.

cc: @Adibvafa

model = MambaForSequenceClassification.from_pretrained(
    model_path, 
    num_labels=len(id2label), 
    id2label=id2label, 
    label2id=label2id,
    use_cache=False  # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
)

I will take a look. Thank you for bringing this up! @Jellymoon @mohith7548

@Adibvafa
Copy link
Contributor Author

I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here?

Do you have mamba-ssm installed? Is it slow in the classification or in Mamba in general?

@mohith7548
Copy link

@Adibvafa, I have mamba-ssm installed. However, I realized that it also need causal-conv1d>=1.4.0 package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installed causal-conv1d>=1.4.0 finetuning works as expected.

@Adibvafa
Copy link
Contributor Author

@Adibvafa, I have mamba-ssm installed. However, I realized that it also need causal-conv1d>=1.4.0 package train faster. Otherwise it was showing some warning related to conv1d that it's gonna use slow/sequential version. Now that I installed causal-conv1d>=1.4.0 finetuning works as expected.

Amazing!
There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.

@mohith7548
Copy link

@Adibvafa, a bug in Mamba? or transformers? Can you eloborate? Please share the link of the issue.

@mohith7548
Copy link

I successfully ran the Mamba model with the new changes you made to the code. Any chance that this will also support the Mamba2 model?

@vasqu
Copy link
Contributor

vasqu commented Sep 5, 2024

There is currently a bug with the slow training path that either breaks in low precision training or uses a huge amount of memory at once. I suggest opening the issue for the memory surge. I have opened the issue and currently working on the low precision training error.

I think the low-precision bug possibly refers to #32691. The huge amount of memory in the slow path is to be expected though and is one of the reasons why the kernel exists (i.e. to avoid materializing certain tensors etc). Nothing you can really do about this tbh.
cc @mohith7548

test_pruning = False
test_head_masking = False # Mamba does not have attention heads
test_model_parallel = False
test_mismatched_shapes = False # MambaMixer follows a different initialization
Copy link
Contributor

@vasqu vasqu Sep 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems a bit weird to me 🤔 Disabling the test_mismatched_shapes flag shouldn't be needed imo.

Could you add get_input_embeddings and set_input_embeddings methods for the ForSeqClassification class and see if it fixes those tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Mamba Models - Missing MambaForSequenceClassification
6 participants