-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement MambaForSequenceClassification #31155
base: main
Are you sure you want to change the base?
Conversation
Referring to #29552, "there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion." This results in a test failure even though the classifier head is initialized properly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for opening a PR! 🤗
Could you rebase on main and make sure the CIs are green! 🤗 |
Of course! It should be good to merge now. |
@ArthurZucker Pending review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗
""", | ||
MAMBA_START_DOCSTRING, | ||
) | ||
class MambaForSequenceClassification(MambaPreTrainedModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use copied from here? : # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mamba, LLAMA->MAMBA, self.transformer->self.model, transformer_outputs->model_outputs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if I understand your comment. The forward method of Mamba and Llama for sequence classification seem different. Could you please elaborate! 🤗
@ArthurZucker Now that the #32080 is merged, can we do a final review for this one too? Also, I would like to add Mamba2ForSequenceClassification to this PR as well so we have both Mamba models with classification capabilities. Then I would be able to release the EHRMamba model on HuggingFace. |
Hey, any update on this? |
labels: Optional[torch.LongTensor] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
use_cache: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think there is a
**kwargs,
missing in the forward function line 842-843
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I suppose it won't be necessary (as opposed to MambaForCausalLM
) but having it is good. I will add it in a commit now.
@Adibvafa have you tried running this? Whenever the model gets to an evaluation step I get the error below. ERROR:
|
This comment was marked as outdated.
This comment was marked as outdated.
Hey @Jellymoon, the Mamba model, works as expected during the training loop. However, it fails during the evaluation loop. So, I found that it is necessary to set cc: @Adibvafa model = MambaForSequenceClassification.from_pretrained(
model_path,
num_labels=len(id2label),
id2label=id2label,
label2id=label2id,
use_cache=False # This needs to be passed when using eval and training Mamba for sequence classification otherwise it will raise an error
) |
I noticed that the training speed (fine-tuning) is very slow compared to the other HF transformer models. Can something be improved here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ran locally. Loaded, Finetuned and did batch inference. Works as expected.
I will take a look. Thank you for bringing this up! @Jellymoon @mohith7548 |
Do you have mamba-ssm installed? Is it slow in the classification or in Mamba in general? |
@Adibvafa, I have |
Amazing! |
@Adibvafa, a bug in Mamba? or |
I successfully ran the Mamba model with the new changes you made to the code. Any chance that this will also support the Mamba2 model? |
I think the low-precision bug possibly refers to #32691. The huge amount of memory in the slow path is to be expected though and is one of the reasons why the kernel exists (i.e. to avoid materializing certain tensors etc). Nothing you can really do about this tbh. |
test_pruning = False | ||
test_head_masking = False # Mamba does not have attention heads | ||
test_model_parallel = False | ||
test_mismatched_shapes = False # MambaMixer follows a different initialization |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems a bit weird to me 🤔 Disabling the test_mismatched_shapes
flag shouldn't be needed imo.
Could you add get_input_embeddings
and set_input_embeddings
methods for the ForSeqClassification class and see if it fixes those tests?
What does this PR do?
Adds the
MambaForSequenceClassification
model based onMambaModel
backbone.We recently published EHRMamba, a state-of-the-art foundation model for Electronic Health Records. This model is built on the same architecture and we will release the trained weights using the MambaForSequenceClassification class.
https://vectorinstitute.github.io/EHRMamba
Fixes #30431
Before submitting
Pull Request section?
to it if that's the case.
Mamba Models - Missing MambaForSequenceClassification #30431
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
As discussed in #30431, @ArthurZucker could you take a look? 😊
Notes
This implementation closely follows the GPT2ForSequenceClassification method, with the exception of pooling the last hidden states before passing them to the classifier to improve efficiency.