From 2ed63ce02e65ee23b837212d1a17b3274708ec8a Mon Sep 17 00:00:00 2001 From: Michael Schock Date: Fri, 8 Mar 2024 17:42:11 -0800 Subject: [PATCH 1/5] Add MambaForSequenceClassification --- src/transformers/__init__.py | 2 + src/transformers/models/mamba/__init__.py | 2 + .../models/mamba/modeling_mamba.py | 134 +++++++++++++++++- src/transformers/utils/dummy_pt_objects.py | 7 + 4 files changed, 144 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index cd5852924ee099..7de9d5c1178930 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -2610,6 +2610,7 @@ [ "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", "MambaForCausalLM", + "MambaForSequenceClassification", "MambaModel", "MambaPreTrainedModel", ] @@ -7266,6 +7267,7 @@ from .models.mamba import ( MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, MambaForCausalLM, + MambaForSequenceClassification, MambaModel, MambaPreTrainedModel, ) diff --git a/src/transformers/models/mamba/__init__.py b/src/transformers/models/mamba/__init__.py index 7a1c142e05d51e..c7ab7e63db3a7e 100644 --- a/src/transformers/models/mamba/__init__.py +++ b/src/transformers/models/mamba/__init__.py @@ -34,6 +34,7 @@ _import_structure["modeling_mamba"] = [ "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", "MambaForCausalLM", + "MambaForSequenceClassification", "MambaModel", "MambaPreTrainedModel", ] @@ -51,6 +52,7 @@ from .modeling_mamba import ( MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, MambaForCausalLM, + MambaForSequenceClassification, MambaModel, MambaPreTrainedModel, ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a3325b3af87c95..47d289b36e1c25 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -21,7 +21,7 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel @@ -441,6 +441,32 @@ class MambaCausalLMOutput(ModelOutput): hidden_states: Optional[Tuple[torch.FloatTensor]] = None +@dataclass +class MambaSequenceClassifierOutput(ModelOutput): + """ + Base class for outputs of sentence classification models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + cache_params: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None + + MAMBA_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -683,3 +709,109 @@ def forward( cache_params=mamba_outputs.cache_params, hidden_states=mamba_outputs.hidden_states, ) + + +class MambaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + self.config = config + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = ACT2FN[self.config.hidden_act](x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """Mamba Model backbone with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks.""", + MAMBA_START_DOCSTRING, +) +class MambaForSequenceClassification(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.backbone = MambaModel(config) + self.classifier = MambaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MambaSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_params: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., config.num_labels - 1]`. + If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.backbone( + input_ids, + cache_params=cache_params, + inputs_embeds=inputs_embeds, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = mamba_outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MambaSequenceClassifierOutput( + loss=loss, + logits=logits, + cache_params=mamba_outputs.cache_params, + hidden_states=mamba_outputs.hidden_states, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5b698e0afe50dd..99843b0c08e661 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -5069,6 +5069,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class MambaForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MambaModel(metaclass=DummyObject): _backends = ["torch"] From 46dec9c0bf3459fe22ee5c70ebfd30d74e3ec240 Mon Sep 17 00:00:00 2001 From: Michael Schock Date: Mon, 11 Mar 2024 19:27:26 -0700 Subject: [PATCH 2/5] Update docs and tests for MambaForSequenceClassification --- docs/source/en/model_doc/mamba.md | 5 ++++ .../en/tasks/sequence_classification.md | 2 +- src/transformers/models/auto/modeling_auto.py | 1 + .../models/mamba/modeling_mamba.py | 3 ++- tests/models/mamba/test_modeling_mamba.py | 25 +++++++++++++++++-- 5 files changed, 32 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 94eb2e2c2d528d..960595e784f8a0 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -102,3 +102,8 @@ trainer.train() [[autodoc]] MambaForCausalLM - forward + +## MambaForSequenceClassification + +[[autodoc]] MambaForSequenceClassification + - forward diff --git a/docs/source/en/tasks/sequence_classification.md b/docs/source/en/tasks/sequence_classification.md index 544d24a0bad6d5..0caca21ad24c45 100644 --- a/docs/source/en/tasks/sequence_classification.md +++ b/docs/source/en/tasks/sequence_classification.md @@ -33,7 +33,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [CodeLlama](../model_doc/code_llama), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [Falcon](../model_doc/falcon), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [Gemma](../model_doc/gemma), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [Mamba](../model_doc/mamba), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MPT](../model_doc/mpt), [MRA](../model_doc/mra), [MT5](../model_doc/mt5), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [StableLm](../model_doc/stablelm), [Starcoder2](../model_doc/starcoder2), [T5](../model_doc/t5), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [UMT5](../model_doc/umt5), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index dcdf54b7d90183..5da5b53c81581a 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -848,6 +848,7 @@ ("llama", "LlamaForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), + ("mamba", "MambaForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), ("mbart", "MBartForSequenceClassification"), ("mega", "MegaForSequenceClassification"), diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 47d289b36e1c25..a733714d937ad5 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -716,6 +716,7 @@ class MambaClassificationHead(nn.Module): def __init__(self, config): super().__init__() + self.activation = ACT2FN[config.hidden_act] self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) @@ -726,7 +727,7 @@ def forward(self, features, **kwargs): x = features[:, 0, :] # take token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) - x = ACT2FN[self.config.hidden_act](x) + x = self.activation(x) x = self.dropout(x) x = self.out_proj(x) return x diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 8bd121933b8052..013ecc37ed7ead 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -35,6 +35,7 @@ from transformers import ( MambaForCausalLM, + MambaForSequenceClassification, MambaModel, ) from transformers.models.mamba.modeling_mamba import MambaCache @@ -135,6 +136,7 @@ def get_config( pad_token_id=self.pad_token_id, gradient_checkpointing=gradient_checkpointing, tie_word_embeddings=self.tie_word_embeddings, + hidden_dropout_prob=self.hidden_dropout_prob, ) def get_pipeline_config(self): @@ -179,6 +181,15 @@ def create_and_check_causl_lm(self, config, input_ids, *args): self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + def create_and_check_for_sequence_classification(self, config, input_ids, *args): + config.num_labels = self.num_labels + model = MambaForSequenceClassification(config) + model.to(torch_device) + model.eval() + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + result = model(input_ids, labels=sequence_labels) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) + def create_and_check_state_equivalency(self, config, input_ids, *args): model = MambaModel(config=config) model.to(torch_device) @@ -226,7 +237,7 @@ def prepare_config_and_inputs_for_common(self): ) @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): - all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () + all_model_classes = (MambaModel, MambaForCausalLM, MambaForSequenceClassification) if is_torch_available() else () fx_compatible = False # FIXME let's try to support this @ArthurZucker test_torchscript = False # FIXME let's try to support this @ArthurZucker test_missing_keys = False @@ -235,7 +246,13 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_head_masking = False # Mamba does not have attention heads test_model_parallel = False pipeline_model_mapping = ( - {"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} + { + "feature-extraction": MambaModel, + "text-generation": MambaForCausalLM, + "text-classification": MambaForSequenceClassification, + } + if is_torch_available() + else {} ) def setUp(self): @@ -306,6 +323,10 @@ def test_mamba_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_causl_lm(*config_and_inputs) + def test_model_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + def test_state_equivalency(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_state_equivalency(*config_and_inputs) From 87d55c4428fc7bda660ff91bd892a339be2d20b7 Mon Sep 17 00:00:00 2001 From: Michael Schock Date: Wed, 13 Mar 2024 18:07:42 -0700 Subject: [PATCH 3/5] Add hidden_dropout_prob to MambaConfig and pull sequence_labels from prepare_config_and_inputs --- .../models/mamba/configuration_mamba.py | 4 ++++ tests/models/mamba/test_modeling_mamba.py | 16 ++++++++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index ec5e615c0bfa70..60650525d7d77d 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -83,6 +83,8 @@ class MambaConfig(PretrainedConfig): Whether or not to rescale `out_proj` weights when initializing. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probabilities used in the sequence classification model [`MambaForSequenceClassification`]. Example: @@ -127,6 +129,7 @@ def __init__( time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, + hidden_dropout_prob=0.1, **kwargs, ): self.vocab_size = vocab_size @@ -153,5 +156,6 @@ def __init__( self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache + self.hidden_dropout_prob = hidden_dropout_prob super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 013ecc37ed7ead..3cc93fc8ff027f 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -148,6 +148,7 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, + _, sequence_labels, token_labels, choice_labels, @@ -181,12 +182,11 @@ def create_and_check_causl_lm(self, config, input_ids, *args): self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_for_sequence_classification(self, config, input_ids, *args): + def create_and_check_for_sequence_classification(self, config, input_ids, sequence_labels, *args): config.num_labels = self.num_labels model = MambaForSequenceClassification(config) model.to(torch_device) model.eval() - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) result = model(input_ids, labels=sequence_labels) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) @@ -324,8 +324,16 @@ def test_mamba_lm_head_model(self): self.model_tester.create_and_check_causl_lm(*config_and_inputs) def test_model_for_sequence_classification(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs() - self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + ( + config, + input_ids, + _, + sequence_labels, + _, + _, + ) = self.model_tester.prepare_config_and_inputs() + + self.model_tester.create_and_check_for_sequence_classification(config, input_ids, sequence_labels) def test_state_equivalency(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() From 94f7a1e93e23588d042d5a5661ac3e28af79ccf9 Mon Sep 17 00:00:00 2001 From: Michael Schock Date: Thu, 14 Mar 2024 15:28:18 -0700 Subject: [PATCH 4/5] Freeze base model params for MambaForSequenceClassification and init classifier head linear layer weights --- src/transformers/models/mamba/modeling_mamba.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a733714d937ad5..d9bffe04fceba2 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -362,6 +362,10 @@ def _init_weights(self, module): with torch.no_grad(): module.dt_proj.bias.copy_(inv_dt) module.dt_proj.bias._no_reinit = True + elif isinstance(module, MambaClassificationHead): + for name, p in module.named_parameters(): + if name in ["dense.weight", "out_proj.weight"]: + nn.init.normal_(p, std=self.config.initializer_range) if isinstance(module, nn.Linear): if module.bias is not None: @@ -745,6 +749,9 @@ def __init__(self, config): self.backbone = MambaModel(config) self.classifier = MambaClassificationHead(config) + for param in self.base_model.parameters(): + param.requires_grad = False + # Initialize weights and apply final processing self.post_init() From b567c2bbfdc7499903f4611c99aa5dc713c2bed4 Mon Sep 17 00:00:00 2001 From: Michael Schock Date: Sun, 17 Mar 2024 23:44:52 -0700 Subject: [PATCH 5/5] Attempt at simplifying and conforming to GPT2/GPTNeoX style --- .../models/mamba/modeling_mamba.py | 66 ++++++++++++++----- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index d9bffe04fceba2..864d184dfa9671 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -362,10 +362,10 @@ def _init_weights(self, module): with torch.no_grad(): module.dt_proj.bias.copy_(inv_dt) module.dt_proj.bias._no_reinit = True - elif isinstance(module, MambaClassificationHead): - for name, p in module.named_parameters(): - if name in ["dense.weight", "out_proj.weight"]: - nn.init.normal_(p, std=self.config.initializer_range) + # elif isinstance(module, MambaClassificationHead): + # for name, p in module.named_parameters(): + # if name in ["dense.weight", "out_proj.weight"]: + # nn.init.normal_(p, std=self.config.initializer_range) if isinstance(module, nn.Linear): if module.bias is not None: @@ -720,19 +720,23 @@ class MambaClassificationHead(nn.Module): def __init__(self, config): super().__init__() - self.activation = ACT2FN[config.hidden_act] - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) + # self.activation = ACT2FN[config.hidden_act] + # self.dense = nn.Linear(config.hidden_size, config.hidden_size) + # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + # module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + self.out_proj.weight.data.normal_(mean=0.0, std=config.initializer_range) + self.config = config def forward(self, features, **kwargs): - x = features[:, 0, :] # take token (equiv. to [CLS]) - x = self.dropout(x) - x = self.dense(x) - x = self.activation(x) - x = self.dropout(x) + # x = features[:, 0, :] # take token (equiv. to [CLS]) + # x = self.dropout(x) + # x = self.dense(x) + # x = self.activation(x) + # x = self.dropout(x) + x = features x = self.out_proj(x) return x @@ -748,6 +752,7 @@ def __init__(self, config): self.num_labels = config.num_labels self.backbone = MambaModel(config) self.classifier = MambaClassificationHead(config) + # self.classifier = nn.Linear(config.hidden_size, config.num_labels) for param in self.base_model.parameters(): param.requires_grad = False @@ -791,6 +796,31 @@ def forward( sequence_output = mamba_outputs[0] logits = self.classifier(sequence_output) + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + print( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + loss = None if labels is not None: if self.config.problem_type is None: @@ -804,22 +834,22 @@ def forward( if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: - loss = loss_fct(logits, labels) + loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) + loss = loss_fct(pooled_logits, labels) if not return_dict: - output = (logits,) + mamba_outputs[1:] + output = (pooled_logits,) + mamba_outputs[1:] return ((loss,) + output) if loss is not None else output return MambaSequenceClassifierOutput( loss=loss, - logits=logits, + logits=pooled_logits, cache_params=mamba_outputs.cache_params, hidden_states=mamba_outputs.hidden_states, )