Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MambaForSequenceClassification #29552

Closed
wants to merge 5 commits into from
Closed

Add MambaForSequenceClassification #29552

wants to merge 5 commits into from

Conversation

mjschock
Copy link

@mjschock mjschock commented Mar 9, 2024

What does this PR do?

Adds MambaForSequenceClassification for sequence classification with the MambaModel backbone.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker thanks for your work bringing in Mamba! I'm wondering if there's any objection to adding a MambaForSequenceClassification model? I followed the template example as best as I could and happy to continue with adding the new test and finishing it up.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mjschook, thanks for opening this PR and adding this!

AFAICT, these changes seem reasonable to me. @ArthurZucker is off for a week - so let's wait for him to come back to confirm if there's any reason for not adding this to Mamba.

A few things that will need to be added:

  • Tests for the model i.e. equivalent to create_and_check_mamba_model and the model should be added to all_model_classes
  • The model needs to be documented in mamba.md
  • All the tests in the CI should be passing

x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = ACT2FN[self.config.hidden_act](x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The activation layer should be set in the init and then called here i.e.

    def __init__(...):
        ...
        self.activation = ACT2FN[config.hidden_act]

    def forward(...):
        ...
        x = self.activation(x)

@mjschock mjschock changed the title [WIP] Add MambaForSequenceClassification Add MambaForSequenceClassification Mar 15, 2024
self.classifier = MambaClassificationHead(config)

for param in self.base_model.parameters():
param.requires_grad = False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether we actually want to freeze the params for the base model here, but I found there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion.

So... I froze them and made sure the classification head params were initialized to satisfiy the test. It makes intuitive sense to me to freeze them in the case of transfer learning for this task and I did confirm that running LoRA PEFT with target_modules=["x_proj", "embeddings", "in_proj", "out_proj"] and task_type=TaskType.SEQ_CLS does unfreeze the target modules so it appears to work fine, but not sure if we want to force them to be frozen by default.

Anyway, happy to adjust if there's a better practice to follow here.

@ArthurZucker ArthurZucker added the Feature request Request for a new feature label Mar 19, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! We usually do the following before merging such PRs:

  1. Create a feature request issue to add this new class
  2. Wait until the community picks this up
  3. Wait until there are actually pretrained checkpoints released by the community or the authors.

As is, this does not really help anyone as it can be easily implemented by anyone that wants to train a model no?

@mjschock
Copy link
Author

Thanks for the feedback @ArthurZucker - I'll close it since I went another direction, using prompt tuning instead. I'll keep the process you laid out in mind for the future. =)

@mjschock mjschock closed this Mar 27, 2024
@scottfleming
Copy link

Would be helpful to have this class. Looking forward to #30431

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants