diff --git a/docs/source/en/model_doc/llama.mdx b/docs/source/en/model_doc/llama.mdx index 28e39b11ce60..0057bab98d3a 100644 --- a/docs/source/en/model_doc/llama.mdx +++ b/docs/source/en/model_doc/llama.mdx @@ -64,3 +64,8 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr [[autodoc]] LlamaForCausalLM - forward + +## LlamaForSequenceClassification + +[[autodoc]] LlamaForSequenceClassification + - forward \ No newline at end of file diff --git a/docs/source/en/tasks/sequence_classification.mdx b/docs/source/en/tasks/sequence_classification.mdx index fb80047dd1e9..20a3f4d14686 100644 --- a/docs/source/en/tasks/sequence_classification.mdx +++ b/docs/source/en/tasks/sequence_classification.mdx @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c64febc41da8..c71a60409bbc 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1801,11 +1801,7 @@ ] ) _import_structure["models.llama"].extend( - [ - "LlamaForCausalLM", - "LlamaModel", - "LlamaPreTrainedModel", - ] + ["LlamaForCausalLM", "LlamaForSequenceClassification", "LlamaModel", "LlamaPreTrainedModel"] ) _import_structure["models.longformer"].extend( [ @@ -5198,11 +5194,7 @@ LiltModel, LiltPreTrainedModel, ) - from .models.llama import ( - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, - ) + from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel from .models.longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LongformerForMaskedLM, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6da804733a5f..eff11b45a53f 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -652,6 +652,7 @@ ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), + ("llama", "LlamaForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 2501c282f0ba..adef4306f287 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -43,6 +43,7 @@ "LlamaForCausalLM", "LlamaModel", "LlamaPreTrainedModel", + "LlamaForSequenceClassification", ] @@ -63,11 +64,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_llama import ( - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, - ) + from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel else: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 40a3f0f802a0..5cdea1b787f6 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -24,19 +24,12 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_llama import LlamaConfig @@ -357,7 +350,7 @@ def forward( @add_start_docstrings( - "The bare OPT Model outputting raw hidden-states without any specific head on top.", + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): @@ -831,3 +824,122 @@ def _reorder_cache(past_key_values, beam_idx): for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 6fb3d2d9d0bb..a80af49e2784 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3768,6 +3768,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LlamaForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlamaModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 23d573454da4..dea92d5111fd 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -27,7 +27,7 @@ if is_torch_available(): import torch - from transformers import LlamaForCausalLM, LlamaModel + from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel class LlamaModelTester: @@ -255,14 +255,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class LlamaModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - LlamaModel, - LlamaForCausalLM, - ) - if is_torch_available() - else () - ) + all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else () test_headmasking = False @@ -283,6 +276,46 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + def test_llama_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_llama_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_llama_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @unittest.skip("LLaMA does not support head pruning.") def test_head_pruning(self): pass