-
Notifications
You must be signed in to change notification settings - Fork 33.5k
Add LlamaForSequenceClassification #22209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
0cd1058
da8b1b9
2da92af
c214270
6737e38
97095bf
c128bfc
61ff387
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -24,19 +24,12 @@ | |||
| import torch | ||||
| import torch.utils.checkpoint | ||||
| from torch import nn | ||||
| from torch.nn import CrossEntropyLoss | ||||
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||||
|
|
||||
| from ...activations import ACT2FN | ||||
| from ...modeling_outputs import ( | ||||
| BaseModelOutputWithPast, | ||||
| CausalLMOutputWithPast, | ||||
| ) | ||||
| from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast | ||||
| from ...modeling_utils import PreTrainedModel | ||||
| from ...utils import ( | ||||
| add_start_docstrings, | ||||
| logging, | ||||
| replace_return_docstrings, | ||||
| ) | ||||
| from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings | ||||
| from .configuration_llama import LlamaConfig | ||||
|
|
||||
|
|
||||
|
|
@@ -831,3 +824,105 @@ def _reorder_cache(past_key_values, beam_idx): | |||
| for layer_past in past_key_values: | ||||
| reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) | ||||
| return reordered_past | ||||
|
|
||||
|
|
||||
| class LlamaForSequenceClassification(LlamaPreTrainedModel): | ||||
| _keys_to_ignore_on_load_missing = [r"lm_head.weight"] | ||||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure which other keys should be ignored for the LLaMa implementation - some guidance here is appreciated!
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that it should be similar to the ones in |
||||
|
|
||||
| def __init__(self, config): | ||||
| super().__init__(config) | ||||
| self.num_labels = config.num_labels | ||||
| self.transformer = LlamaModel(config) | ||||
| self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) | ||||
|
|
||||
| # Model parallel | ||||
|
lewtun marked this conversation as resolved.
Outdated
|
||||
| self.model_parallel = False | ||||
| self.device_map = None | ||||
|
lewtun marked this conversation as resolved.
Outdated
|
||||
|
|
||||
| # Initialize weights and apply final processing | ||||
| self.post_init() | ||||
|
|
||||
| @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) | ||||
| def forward( | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I am not mistaken this function is copied from
If that's the case could you add a copied form statement 🙏 ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, it's not actually copied from there because LLaMa doesn't accept |
||||
| self, | ||||
| input_ids: torch.LongTensor = None, | ||||
| attention_mask: Optional[torch.Tensor] = None, | ||||
| past_key_values: Optional[List[torch.FloatTensor]] = None, | ||||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||||
| labels: Optional[torch.LongTensor] = None, | ||||
| use_cache: Optional[bool] = None, | ||||
| output_attentions: Optional[bool] = None, | ||||
| output_hidden_states: Optional[bool] = None, | ||||
| return_dict: Optional[bool] = None, | ||||
| ) -> Union[Tuple, SequenceClassifierOutputWithPast]: | ||||
| r""" | ||||
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | ||||
| Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | ||||
| config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | ||||
| `config.num_labels > 1` a classification loss is computed (Cross-Entropy). | ||||
| """ | ||||
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||||
|
|
||||
| transformer_outputs = self.transformer( | ||||
| input_ids, | ||||
| past_key_values=past_key_values, | ||||
| attention_mask=attention_mask, | ||||
| inputs_embeds=inputs_embeds, | ||||
| use_cache=use_cache, | ||||
| output_attentions=output_attentions, | ||||
| output_hidden_states=output_hidden_states, | ||||
| return_dict=return_dict, | ||||
| ) | ||||
| hidden_states = transformer_outputs[0] | ||||
| logits = self.score(hidden_states) | ||||
|
|
||||
| if input_ids is not None: | ||||
| batch_size = input_ids.shape[0] | ||||
| else: | ||||
| batch_size = inputs_embeds.shape[0] | ||||
|
|
||||
| if self.config.pad_token_id is None and batch_size != 1: | ||||
| raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | ||||
| if self.config.pad_token_id is None: | ||||
| sequence_lengths = -1 | ||||
| else: | ||||
| if input_ids is not None: | ||||
| sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) | ||||
| else: | ||||
| sequence_lengths = -1 | ||||
|
|
||||
| pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] | ||||
|
|
||||
| loss = None | ||||
| if labels is not None: | ||||
| if self.config.problem_type is None: | ||||
| if self.num_labels == 1: | ||||
| self.config.problem_type = "regression" | ||||
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | ||||
| self.config.problem_type = "single_label_classification" | ||||
| else: | ||||
| self.config.problem_type = "multi_label_classification" | ||||
|
|
||||
| if self.config.problem_type == "regression": | ||||
| loss_fct = MSELoss() | ||||
| if self.num_labels == 1: | ||||
| loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) | ||||
| else: | ||||
| loss = loss_fct(pooled_logits, labels) | ||||
| elif self.config.problem_type == "single_label_classification": | ||||
| loss_fct = CrossEntropyLoss() | ||||
| loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) | ||||
| elif self.config.problem_type == "multi_label_classification": | ||||
| loss_fct = BCEWithLogitsLoss() | ||||
| loss = loss_fct(pooled_logits, labels) | ||||
| if not return_dict: | ||||
| output = (pooled_logits,) + transformer_outputs[1:] | ||||
| return ((loss,) + output) if loss is not None else output | ||||
|
|
||||
| return SequenceClassifierOutputWithPast( | ||||
| loss=loss, | ||||
| logits=pooled_logits, | ||||
| past_key_values=transformer_outputs.past_key_values, | ||||
| hidden_states=transformer_outputs.hidden_states, | ||||
| attentions=transformer_outputs.attentions, | ||||
| ) | ||||
Uh oh!
There was an error while loading. Please reload this page.