-
Notifications
You must be signed in to change notification settings - Fork 33.4k
Add LlamaForSequenceClassification #22209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0cd1058
da8b1b9
2da92af
c214270
6737e38
97095bf
c128bfc
61ff387
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -24,19 +24,12 @@ | |||
| import torch | ||||
| import torch.utils.checkpoint | ||||
| from torch import nn | ||||
| from torch.nn import CrossEntropyLoss | ||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||
|
|
||||
| from ...activations import ACT2FN | ||||
| from ...modeling_outputs import ( | ||||
| BaseModelOutputWithPast, | ||||
| CausalLMOutputWithPast, | ||||
| ) | ||||
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | ||||
| from ...modeling_utils import PreTrainedModel | ||||
| from ...utils import ( | ||||
| add_start_docstrings, | ||||
| logging, | ||||
| replace_return_docstrings, | ||||
| ) | ||||
| from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | ||||
| from .configuration_llama import LlamaConfig | ||||
|
|
||||
|
|
||||
|
|
@@ -357,7 +350,7 @@ def forward( | |||
|
|
||||
|
|
||||
| @add_start_docstrings( | ||||
| "The bare OPT Model outputting raw hidden-states without any specific head on top.", | ||||
| "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", | ||||
| LLAMA_START_DOCSTRING, | ||||
| ) | ||||
| class LlamaPreTrainedModel(PreTrainedModel): | ||||
|
|
@@ -831,3 +824,122 @@ def _reorder_cache(past_key_values, beam_idx): | |||
| for layer_past in past_key_values: | ||||
| reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) | ||||
| return reordered_past | ||||
|
|
||||
|
|
||||
| @add_start_docstrings( | ||||
| """ | ||||
| The LLaMa Model transformer with a sequence classification head on top (linear layer). | ||||
|
|
||||
| [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models | ||||
| (e.g. GPT-2) do. | ||||
|
|
||||
| Since it does classification on the last token, it requires to know the position of the last token. If a | ||||
| `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If | ||||
| no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the | ||||
| padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in | ||||
| each row of the batch). | ||||
| """, | ||||
| LLAMA_START_DOCSTRING, | ||||
| ) | ||||
| class LlamaForSequenceClassification(LlamaPreTrainedModel): | ||||
| _keys_to_ignore_on_load_missing = [r"lm_head.weight"] | ||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure which other keys should be ignored for the LLaMa implementation - some guidance here is appreciated!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it should be similar to the ones in |
||||
|
|
||||
| def __init__(self, config): | ||||
| super().__init__(config) | ||||
| self.num_labels = config.num_labels | ||||
| self.model = LlamaModel(config) | ||||
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||||
|
|
||||
| # Initialize weights and apply final processing | ||||
| self.post_init() | ||||
|
|
||||
| def get_input_embeddings(self): | ||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI @younesbelkada @amyeroberts these getter/setters were missing in my original implementation, but now all tests pass locally for me and this should be good to go IMO :) |
||||
| return self.model.embed_tokens | ||||
|
|
||||
| def set_input_embeddings(self, value): | ||||
| self.model.embed_tokens = value | ||||
|
|
||||
| @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||||
| def forward( | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I am not mistaken this function is copied from
If that's the case could you add a copied form statement 🙏 ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, it's not actually copied from there because LLaMa doesn't accept |
||||
| self, | ||||
| input_ids: torch.LongTensor = None, | ||||
| attention_mask: Optional[torch.Tensor] = None, | ||||
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
| labels: Optional[torch.LongTensor] = None, | ||||
| use_cache: Optional[bool] = None, | ||||
| output_attentions: Optional[bool] = None, | ||||
| output_hidden_states: Optional[bool] = None, | ||||
| return_dict: Optional[bool] = None, | ||||
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: | ||||
| r""" | ||||
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||||
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | ||||
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||||
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||||
| """ | ||||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||
|
|
||||
| transformer_outputs = self.model( | ||||
| input_ids, | ||||
| past_key_values=past_key_values, | ||||
| attention_mask=attention_mask, | ||||
| inputs_embeds=inputs_embeds, | ||||
| use_cache=use_cache, | ||||
| output_attentions=output_attentions, | ||||
| output_hidden_states=output_hidden_states, | ||||
| return_dict=return_dict, | ||||
| ) | ||||
| hidden_states = transformer_outputs[0] | ||||
| logits = self.score(hidden_states) | ||||
|
|
||||
| if input_ids is not None: | ||||
| batch_size = input_ids.shape[0] | ||||
| else: | ||||
| batch_size = inputs_embeds.shape[0] | ||||
|
|
||||
| if self.config.pad_token_id is None and batch_size != 1: | ||||
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||||
| if self.config.pad_token_id is None: | ||||
| sequence_lengths = -1 | ||||
| else: | ||||
| if input_ids is not None: | ||||
| sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) | ||||
| else: | ||||
| sequence_lengths = -1 | ||||
|
|
||||
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||||
|
|
||||
| loss = None | ||||
| if labels is not None: | ||||
| if self.config.problem_type is None: | ||||
| if self.num_labels == 1: | ||||
| self.config.problem_type = "regression" | ||||
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||||
| self.config.problem_type = "single_label_classification" | ||||
| else: | ||||
| self.config.problem_type = "multi_label_classification" | ||||
|
|
||||
| if self.config.problem_type == "regression": | ||||
| loss_fct = MSELoss() | ||||
| if self.num_labels == 1: | ||||
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||||
| else: | ||||
| loss = loss_fct(pooled_logits, labels) | ||||
| elif self.config.problem_type == "single_label_classification": | ||||
| loss_fct = CrossEntropyLoss() | ||||
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||||
| elif self.config.problem_type == "multi_label_classification": | ||||
| loss_fct = BCEWithLogitsLoss() | ||||
| loss = loss_fct(pooled_logits, labels) | ||||
| if not return_dict: | ||||
| output = (pooled_logits,) + transformer_outputs[1:] | ||||
| return ((loss,) + output) if loss is not None else output | ||||
|
|
||||
| return SequenceClassifierOutputWithPast( | ||||
| loss=loss, | ||||
| logits=pooled_logits, | ||||
| past_key_values=transformer_outputs.past_key_values, | ||||
| hidden_states=transformer_outputs.hidden_states, | ||||
| attentions=transformer_outputs.attentions, | ||||
| ) | ||||
Uh oh!
There was an error while loading. Please reload this page.