Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MambaForSequenceClassification #29552

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/en/model_doc/mamba.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,8 @@ trainer.train()

[[autodoc]] MambaForCausalLM
- forward

## MambaForSequenceClassification

[[autodoc]] MambaForSequenceClassification
- forward
2 changes: 1 addition & 1 deletion docs/source/en/tasks/sequence_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit
<!--This tip is automatically generated by `make fix-copies`, do not fill manually!-->


[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)
[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [Mamba](../model_doc/mamba), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso)



Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2610,6 +2610,7 @@
[
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaForSequenceClassification",
"MambaModel",
"MambaPreTrainedModel",
]
Expand Down Expand Up @@ -7266,6 +7267,7 @@
from .models.mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaForSequenceClassification,
MambaModel,
MambaPreTrainedModel,
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@
("llama", "LlamaForSequenceClassification"),
("longformer", "LongformerForSequenceClassification"),
("luke", "LukeForSequenceClassification"),
("mamba", "MambaForSequenceClassification"),
("markuplm", "MarkupLMForSequenceClassification"),
("mbart", "MBartForSequenceClassification"),
("mega", "MegaForSequenceClassification"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/mamba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
_import_structure["modeling_mamba"] = [
"MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST",
"MambaForCausalLM",
"MambaForSequenceClassification",
"MambaModel",
"MambaPreTrainedModel",
]
Expand All @@ -51,6 +52,7 @@
from .modeling_mamba import (
MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST,
MambaForCausalLM,
MambaForSequenceClassification,
MambaModel,
MambaPreTrainedModel,
)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/mamba/configuration_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class MambaConfig(PretrainedConfig):
Whether or not to rescale `out_proj` weights when initializing.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the cache should be used.
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
The dropout probabilities used in the sequence classification model [`MambaForSequenceClassification`].


Example:
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(
time_step_floor=1e-4,
rescale_prenorm_residual=False,
use_cache=True,
hidden_dropout_prob=0.1,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -153,5 +156,6 @@ def __init__(
self.rescale_prenorm_residual = rescale_prenorm_residual
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.hidden_dropout_prob = hidden_dropout_prob

super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs)
172 changes: 171 additions & 1 deletion src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -362,6 +362,10 @@ def _init_weights(self, module):
with torch.no_grad():
module.dt_proj.bias.copy_(inv_dt)
module.dt_proj.bias._no_reinit = True
# elif isinstance(module, MambaClassificationHead):
# for name, p in module.named_parameters():
# if name in ["dense.weight", "out_proj.weight"]:
# nn.init.normal_(p, std=self.config.initializer_range)

if isinstance(module, nn.Linear):
if module.bias is not None:
Expand Down Expand Up @@ -441,6 +445,32 @@ class MambaCausalLMOutput(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None


@dataclass
class MambaSequenceClassifierOutput(ModelOutput):
"""
Base class for outputs of sentence classification models.

Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`):
The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
avoid providing the old `input_ids`.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
cache_params: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None


MAMBA_START_DOCSTRING = r"""

This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
Expand Down Expand Up @@ -683,3 +713,143 @@ def forward(
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)


class MambaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""

def __init__(self, config):
super().__init__()
# self.activation = ACT2FN[config.hidden_act]
# self.dense = nn.Linear(config.hidden_size, config.hidden_size)
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

# module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
self.out_proj.weight.data.normal_(mean=0.0, std=config.initializer_range)

self.config = config

def forward(self, features, **kwargs):
# x = features[:, 0, :] # take <s> token (equiv. to [CLS])
# x = self.dropout(x)
# x = self.dense(x)
# x = self.activation(x)
# x = self.dropout(x)
x = features
x = self.out_proj(x)
return x


@add_start_docstrings(
"""Mamba Model backbone with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks.""",
MAMBA_START_DOCSTRING,
)
class MambaForSequenceClassification(MambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.backbone = MambaModel(config)
amyeroberts marked this conversation as resolved.
Show resolved Hide resolved
self.classifier = MambaClassificationHead(config)
# self.classifier = nn.Linear(config.hidden_size, config.num_labels)

for param in self.base_model.parameters():
param.requires_grad = False
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure whether we actually want to freeze the params for the base model here, but I found there's a test specific to sequence classification that expects all the unfrozen params to be initialized in the range [0.0, 1.0] and the initialized values for the mixer don't satisfy that assertion.

So... I froze them and made sure the classification head params were initialized to satisfiy the test. It makes intuitive sense to me to freeze them in the case of transfer learning for this task and I did confirm that running LoRA PEFT with target_modules=["x_proj", "embeddings", "in_proj", "out_proj"] and task_type=TaskType.SEQ_CLS does unfreeze the target modules so it appears to work fine, but not sure if we want to force them to be frozen by default.

Anyway, happy to adjust if there's a better practice to follow here.


# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MambaSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_params: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss.
Indices should be in `[0, ..., config.num_labels - 1]`.
If `config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

mamba_outputs = self.backbone(
input_ids,
cache_params=cache_params,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = mamba_outputs[0]
logits = self.classifier(sequence_output)

if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
print(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]

loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"

if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + mamba_outputs[1:]
return ((loss,) + output) if loss is not None else output

return MambaSequenceClassifierOutput(
loss=loss,
logits=pooled_logits,
cache_params=mamba_outputs.cache_params,
hidden_states=mamba_outputs.hidden_states,
)
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -5069,6 +5069,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MambaForSequenceClassification(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class MambaModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
Loading