Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraPreTrainedModel,
ElectraForMultipleChoice,
ElectraForSequenceClassification,
ElectraForQuestionAnswering,
ElectraModel,
Expand Down
30 changes: 30 additions & 0 deletions src/transformers/configuration_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,27 @@ class ElectraConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
summary_type (:obj:`string`, optional, defaults to "first"):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Is one of the following options:
- 'last' => take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj (:obj:`boolean`, optional, defaults to :obj:`True`):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Add a projection after the vector extraction
summary_activation (:obj:`string` or :obj:`None`, optional, defaults to :obj:`None`):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
'gelu' => add a gelu activation to the output, Other => no activation.
summary_last_dropout (:obj:`float`, optional, defaults to 0.0):
Argument used when doing sequence summary. Used in for the multiple choice head in
:class:`~transformers.ElectraForMultipleChoice`.
Add a dropout after the projection and activation

Example::

Expand Down Expand Up @@ -107,6 +128,10 @@ def __init__(
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
summary_type="first",
summary_use_proj=True,
summary_activation="gelu",
summary_last_dropout=0.1,
pad_token_id=0,
**kwargs
):
Expand All @@ -125,3 +150,8 @@ def __init__(
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps

self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_last_dropout = summary_last_dropout
2 changes: 2 additions & 0 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
)
from .modeling_electra import (
ElectraForMaskedLM,
ElectraForMultipleChoice,
ElectraForPreTraining,
ElectraForQuestionAnswering,
ElectraForSequenceClassification,
Expand Down Expand Up @@ -315,6 +316,7 @@
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
[
(CamembertConfig, CamembertForMultipleChoice),
(ElectraConfig, ElectraForMultipleChoice),
(XLMRobertaConfig, XLMRobertaForMultipleChoice),
(LongformerConfig, LongformerForMultipleChoice),
(RobertaConfig, RobertaForMultipleChoice),
Expand Down
113 changes: 113 additions & 0 deletions src/transformers/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .configuration_electra import ElectraConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTrainedModel
from .modeling_utils import SequenceSummary


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -860,3 +861,115 @@ def forward(
outputs = (total_loss,) + outputs

return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)


@add_start_docstrings(
"""ELECTRA Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
ELECTRA_INPUTS_DOCSTRING,
)
class ElectraForMultipleChoice(ElectraPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.electra = ElectraModel(config)
self.summary = SequenceSummary(config)
self.classifier = nn.Linear(config.hidden_size, 1)

self.init_weights()

@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices-1]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)

Returns:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.ElectraConfig`) and inputs:
loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when :obj:`labels` is provided):
Classification loss.
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).

Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.

Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.

Examples::

from transformers import ElectraTokenizer, ElectraForMultipleChoice
import torch

tokenizer = ElectraTokenizer.from_pretrained('google/electra-base-discriminator')
model = ElectraForMultipleChoice.from_pretrained('google/electra-base-discriminator')

prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
labels = torch.tensor(0) # choice0 is correct (according to Wikipedia ;))

encoding = tokenizer.batch_encode_plus([[prompt, choice0], [prompt, choice1]], return_tensors='pt', pad_to_max_length=True)
outputs = model(**{k: v.unsqueeze(0) for k,v in encoding.items()}, labels=labels) # batch size is 1

# the linear classifier still needs to be trained
loss, logits = outputs[:2]
"""
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]

input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
if inputs_embeds is not None
else None
)

discriminator_hidden_states = self.electra(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
)

sequence_output = discriminator_hidden_states[0]

pooled_output = self.summary(sequence_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, num_choices)

outputs = (reshaped_logits,) + discriminator_hidden_states[
1:
] # add hidden states and attention if they are here

if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
outputs = (loss,) + outputs

return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
36 changes: 36 additions & 0 deletions tests/test_modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ElectraForMaskedLM,
ElectraForTokenClassification,
ElectraForPreTraining,
ElectraForMultipleChoice,
ElectraForSequenceClassification,
ElectraForQuestionAnswering,
)
Expand Down Expand Up @@ -266,6 +267,37 @@ def create_and_check_electra_for_question_answering(
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)

def create_and_check_electra_for_multiple_choice(
self,
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
fake_token_labels,
):
config.num_choices = self.num_choices
model = ElectraForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
Expand Down Expand Up @@ -329,6 +361,10 @@ def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_question_answering(*config_and_inputs)

def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_electra_for_multiple_choice(*config_and_inputs)

@slow
def test_model_from_pretrained(self):
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down