From 0cd1058731ca4b8a1aebf57740fbedf92109975e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 16 Mar 2023 16:08:05 +0000 Subject: [PATCH 1/7] Add LlamaForSequenceClassification --- docs/source/en/model_doc/llama.mdx | 5 + .../en/tasks/sequence_classification.mdx | 2 +- src/transformers/__init__.py | 12 +- src/transformers/models/auto/modeling_auto.py | 1 + src/transformers/models/llama/__init__.py | 7 +- .../models/llama/modeling_llama.py | 115 ++++++++++++++++-- src/transformers/utils/dummy_pt_objects.py | 7 ++ tests/models/llama/test_modeling_llama.py | 38 ++++-- 8 files changed, 152 insertions(+), 35 deletions(-) diff --git a/docs/source/en/model_doc/llama.mdx b/docs/source/en/model_doc/llama.mdx index 3f6ea3409bb8..ca0cebd7acba 100644 --- a/docs/source/en/model_doc/llama.mdx +++ b/docs/source/en/model_doc/llama.mdx @@ -64,3 +64,8 @@ This model was contributed by [zphang](https://huggingface.co/zphang) with contr [[autodoc]] LlamaForCausalLM - forward + +## LlamaForSequenceClassification + +[[autodoc]] LlamaForSequenceClassification + - forward \ No newline at end of file diff --git a/docs/source/en/tasks/sequence_classification.mdx b/docs/source/en/tasks/sequence_classification.mdx index fb80047dd1e9..20a3f4d14686 100644 --- a/docs/source/en/tasks/sequence_classification.mdx +++ b/docs/source/en/tasks/sequence_classification.mdx @@ -28,7 +28,7 @@ The task illustrated in this tutorial is supported by the following model archit -[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) +[ALBERT](../model_doc/albert), [BART](../model_doc/bart), [BERT](../model_doc/bert), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CANINE](../model_doc/canine), [ConvBERT](../model_doc/convbert), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [DeBERTa](../model_doc/deberta), [DeBERTa-v2](../model_doc/deberta-v2), [DistilBERT](../model_doc/distilbert), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [ErnieM](../model_doc/ernie_m), [ESM](../model_doc/esm), [FlauBERT](../model_doc/flaubert), [FNet](../model_doc/fnet), [Funnel Transformer](../model_doc/funnel), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPT Neo](../model_doc/gpt_neo), [GPT-J](../model_doc/gptj), [I-BERT](../model_doc/ibert), [LayoutLM](../model_doc/layoutlm), [LayoutLMv2](../model_doc/layoutlmv2), [LayoutLMv3](../model_doc/layoutlmv3), [LED](../model_doc/led), [LiLT](../model_doc/lilt), [LLaMA](../model_doc/llama), [Longformer](../model_doc/longformer), [LUKE](../model_doc/luke), [MarkupLM](../model_doc/markuplm), [mBART](../model_doc/mbart), [Megatron-BERT](../model_doc/megatron-bert), [MobileBERT](../model_doc/mobilebert), [MPNet](../model_doc/mpnet), [MVP](../model_doc/mvp), [Nezha](../model_doc/nezha), [Nyströmformer](../model_doc/nystromformer), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Perceiver](../model_doc/perceiver), [PLBart](../model_doc/plbart), [QDQBert](../model_doc/qdqbert), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [SqueezeBERT](../model_doc/squeezebert), [TAPAS](../model_doc/tapas), [Transformer-XL](../model_doc/transfo-xl), [XLM](../model_doc/xlm), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod), [YOSO](../model_doc/yoso) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index c64febc41da8..c71a60409bbc 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1801,11 +1801,7 @@ ] ) _import_structure["models.llama"].extend( - [ - "LlamaForCausalLM", - "LlamaModel", - "LlamaPreTrainedModel", - ] + ["LlamaForCausalLM", "LlamaForSequenceClassification", "LlamaModel", "LlamaPreTrainedModel"] ) _import_structure["models.longformer"].extend( [ @@ -5198,11 +5194,7 @@ LiltModel, LiltPreTrainedModel, ) - from .models.llama import ( - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, - ) + from .models.llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel from .models.longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LongformerForMaskedLM, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 6da804733a5f..eff11b45a53f 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -652,6 +652,7 @@ ("layoutlmv3", "LayoutLMv3ForSequenceClassification"), ("led", "LEDForSequenceClassification"), ("lilt", "LiltForSequenceClassification"), + ("llama", "LlamaForSequenceClassification"), ("longformer", "LongformerForSequenceClassification"), ("luke", "LukeForSequenceClassification"), ("markuplm", "MarkupLMForSequenceClassification"), diff --git a/src/transformers/models/llama/__init__.py b/src/transformers/models/llama/__init__.py index 2501c282f0ba..adef4306f287 100644 --- a/src/transformers/models/llama/__init__.py +++ b/src/transformers/models/llama/__init__.py @@ -43,6 +43,7 @@ "LlamaForCausalLM", "LlamaModel", "LlamaPreTrainedModel", + "LlamaForSequenceClassification", ] @@ -63,11 +64,7 @@ except OptionalDependencyNotAvailable: pass else: - from .modeling_llama import ( - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, - ) + from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel else: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 40a3f0f802a0..b7bc4ecea517 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -24,19 +24,12 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ...modeling_utils import PreTrainedModel -from ...utils import ( - add_start_docstrings, - logging, - replace_return_docstrings, -) +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from .configuration_llama import LlamaConfig @@ -831,3 +824,105 @@ def _reorder_cache(past_key_values, beam_idx): for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past + + +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 6fb3d2d9d0bb..a80af49e2784 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -3768,6 +3768,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class LlamaForSequenceClassification(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class LlamaModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 23d573454da4..afc03faab8fc 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -27,7 +27,7 @@ if is_torch_available(): import torch - from transformers import LlamaForCausalLM, LlamaModel + from transformers import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel class LlamaModelTester: @@ -255,14 +255,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class LlamaModelTest(ModelTesterMixin, unittest.TestCase): - all_model_classes = ( - ( - LlamaModel, - LlamaForCausalLM, - ) - if is_torch_available() - else () - ) + all_model_classes = (LlamaModel, LlamaForCausalLM, LlamaForSequenceClassification) if is_torch_available() else () all_generative_model_classes = (LlamaForCausalLM,) if is_torch_available() else () test_headmasking = False @@ -283,6 +276,33 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + def test_llama_sequence_classification_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + + def test_llama_sequence_classification_model_for_multi_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 3 + config.problem_type = "multi_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor( + [self.model_tester.batch_size, config.num_labels], self.model_tester.type_sequence_label_size + ).to(torch.float) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + @unittest.skip("LLaMA does not support head pruning.") def test_head_pruning(self): pass From da8b1b9c043dfb7f74871cfd5118c5c1bc088c49 Mon Sep 17 00:00:00 2001 From: lewtun Date: Thu, 16 Mar 2023 17:22:00 +0100 Subject: [PATCH 2/7] Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/llama/modeling_llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b7bc4ecea517..f40008537264 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -836,8 +836,6 @@ def __init__(self, config): self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Model parallel - self.model_parallel = False - self.device_map = None # Initialize weights and apply final processing self.post_init() From 2da92aff8604e9ca4e51739f8904ddc4d9a35465 Mon Sep 17 00:00:00 2001 From: lewtun Date: Thu, 16 Mar 2023 18:10:26 +0100 Subject: [PATCH 3/7] Update src/transformers/models/llama/modeling_llama.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- src/transformers/models/llama/modeling_llama.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f40008537264..4de0ea4436af 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -835,8 +835,6 @@ def __init__(self, config): self.transformer = LlamaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - # Model parallel - # Initialize weights and apply final processing self.post_init() From c214270dba7edb01e5ded977fb83a1832bf0e67c Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 16 Mar 2023 17:14:41 +0000 Subject: [PATCH 4/7] Add docstring --- src/transformers/models/llama/modeling_llama.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 4de0ea4436af..d9c0914e73e3 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -350,7 +350,7 @@ def forward( @add_start_docstrings( - "The bare OPT Model outputting raw hidden-states without any specific head on top.", + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", LLAMA_START_DOCSTRING, ) class LlamaPreTrainedModel(PreTrainedModel): @@ -826,6 +826,21 @@ def _reorder_cache(past_key_values, beam_idx): return reordered_past +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) class LlamaForSequenceClassification(LlamaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"] From 6737e380fc6a4cb73150da4fa821dd463f9a7204 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 16 Mar 2023 17:22:21 +0000 Subject: [PATCH 5/7] Add test --- tests/models/llama/test_modeling_llama.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index afc03faab8fc..f8873feb13a8 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -288,6 +288,19 @@ def test_llama_sequence_classification_model(self): result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) + def test_llama_sequence_classification_model_for_single_label(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.num_labels = 1 + config.problem_type = "single_label_classification" + input_ids = input_dict["input_ids"] + attention_mask = input_ids.ne(1).to(torch_device) + sequence_labels = ids_tensor([self.model_tester.batch_size], self.model_tester.type_sequence_label_size) + model = LlamaForSequenceClassification(config) + model.to(torch_device) + model.eval() + result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, 1)) + def test_llama_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 From c128bfc8ea94d6d1ed2b0ee23ad7846fc0314206 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 17 Mar 2023 10:19:05 +0000 Subject: [PATCH 6/7] Add input embedding getter and setter --- .../models/llama/modeling_llama.py | 10 ++++-- tests/models/llama/test_modeling_llama.py | 33 +++++++++++++++++-- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index d9c0914e73e3..5cdea1b787f6 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -847,12 +847,18 @@ class LlamaForSequenceClassification(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.transformer = LlamaModel(config) + self.model = LlamaModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -874,7 +880,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - transformer_outputs = self.transformer( + transformer_outputs = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index f8873feb13a8..0d684c00cb71 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -276,6 +276,35 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) + # def test_inputs_embeds(self): + # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # for model_class in (LlamaModel,): + # model = model_class(config) + # model.to(torch_device) + # model.eval() + + # inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + # if not self.is_encoder_decoder: + # input_ids = inputs["input_ids"] + # del inputs["input_ids"] + # else: + # encoder_input_ids = inputs["input_ids"] + # decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + # del inputs["input_ids"] + # inputs.pop("decoder_input_ids", None) + + # wte = model.get_input_embeddings() + # if not self.is_encoder_decoder: + # inputs["inputs_embeds"] = wte(input_ids) + # else: + # inputs["inputs_embeds"] = wte(encoder_input_ids) + # inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + # with torch.no_grad(): + # model(**inputs)[0] + def test_llama_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3 @@ -290,7 +319,7 @@ def test_llama_sequence_classification_model(self): def test_llama_sequence_classification_model_for_single_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.num_labels = 1 + config.num_labels = 3 config.problem_type = "single_label_classification" input_ids = input_dict["input_ids"] attention_mask = input_ids.ne(1).to(torch_device) @@ -299,7 +328,7 @@ def test_llama_sequence_classification_model_for_single_label(self): model.to(torch_device) model.eval() result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) - self.assertEqual(result.logits.shape, (self.model_tester.batch_size, 1)) + self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) def test_llama_sequence_classification_model_for_multi_label(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() From 61ff387b44dffc4772ed41db7b4e0f5d31953f74 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 17 Mar 2023 10:20:19 +0000 Subject: [PATCH 7/7] Remove dead code --- tests/models/llama/test_modeling_llama.py | 29 ----------------------- 1 file changed, 29 deletions(-) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 0d684c00cb71..dea92d5111fd 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -276,35 +276,6 @@ def test_model_various_embeddings(self): config_and_inputs[0].position_embedding_type = type self.model_tester.create_and_check_model(*config_and_inputs) - # def test_inputs_embeds(self): - # config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - # for model_class in (LlamaModel,): - # model = model_class(config) - # model.to(torch_device) - # model.eval() - - # inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) - - # if not self.is_encoder_decoder: - # input_ids = inputs["input_ids"] - # del inputs["input_ids"] - # else: - # encoder_input_ids = inputs["input_ids"] - # decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) - # del inputs["input_ids"] - # inputs.pop("decoder_input_ids", None) - - # wte = model.get_input_embeddings() - # if not self.is_encoder_decoder: - # inputs["inputs_embeds"] = wte(input_ids) - # else: - # inputs["inputs_embeds"] = wte(encoder_input_ids) - # inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) - - # with torch.no_grad(): - # model(**inputs)[0] - def test_llama_sequence_classification_model(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() config.num_labels = 3