Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,16 @@ On the test dataset the following results could be achieved:
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
```

### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)

Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):

| Model | F-Score Dev | F-Score Test
| --------------------------------- | ------- | --------
| `bert-large-cased` | 95.59 | 91.70
| `roberta-large` | 95.96 | 91.87
| `distilbert-base-uncased` | 94.34 | 90.32

## Abstractive summarization

Based on the script
Expand Down
16 changes: 10 additions & 6 deletions examples/run_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,18 @@
from transformers import AdamW, WarmupLinearSchedule
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer

logger = logging.getLogger(__name__)

ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)),
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
())

MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer)
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
}


Expand Down Expand Up @@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
batch = tuple(t.to(args.device) for t in batch)
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
# XLM and RoBERTa don"t use segment_ids
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids

outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)

Expand Down Expand Up @@ -206,9 +209,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
with torch.no_grad():
inputs = {"input_ids": batch[0],
"attention_mask": batch[1],
"token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None,
# XLM and RoBERTa don"t use segment_ids
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]

Expand Down Expand Up @@ -520,3 +523,4 @@ def main():

if __name__ == "__main__":
main()

1 change: 1 addition & 0 deletions transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DistilBertForTokenClassification,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model

Expand Down
73 changes: 73 additions & 0 deletions transformers/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from .modeling_utils import PreTrainedModel, prune_linear_layer
from .configuration_distilbert import DistilBertConfig
Expand Down Expand Up @@ -702,3 +703,75 @@ def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_em
outputs = (total_loss,) + outputs

return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)


@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
DISTILBERT_START_DOCSTRING,
DISTILBERT_INPUTS_DOCSTRING)
class DistilBertForTokenClassification(DistilBertPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.

Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.

Examples::

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, scores = outputs[:2]

"""
def __init__(self, config):
super(DistilBertForTokenClassification, self).__init__(config)
self.num_labels = config.num_labels

self.distilbert = DistilBertModel(config)
self.dropout = nn.Dropout(config.dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)

self.init_weights()

def forward(self, input_ids=None, attention_mask=None, head_mask=None,
inputs_embeds=None, labels=None):

outputs = self.distilbert(input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds)

sequence_output = outputs[0]

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
if labels is not None:
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

return outputs # (loss), scores, (hidden_states), (attentions)
20 changes: 20 additions & 0 deletions transformers/tests/modeling_distilbert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if is_torch_available():
from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM,
DistilBertForTokenClassification,
DistilBertForQuestionAnswering, DistilBertForSequenceClassification)
else:
pytestmark = pytest.mark.skip("Require Torch")
Expand Down Expand Up @@ -180,6 +181,21 @@ def create_and_check_distilbert_for_sequence_classification(self, config, input_
[self.batch_size, self.num_labels])
self.check_loss_output(result)

def create_and_check_distilbert_for_token_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
config.num_labels = self.num_labels
model = DistilBertForTokenClassification(config=config)
model.eval()

loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(
list(result["logits"].size()),
[self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)

def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
Expand Down Expand Up @@ -209,6 +225,10 @@ def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs)

def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)

# @pytest.mark.slow
# def test_model_from_pretrained(self):
# cache_dir = "/tmp/transformers_test/"
Expand Down