Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1c7253c
modeling: add DistilBertForTokenClassification implementation
stefan-it Nov 11, 2019
1806eab
module: add DistilBertForTokenClassification import
stefan-it Nov 11, 2019
2b07b9e
examples: add DistilBert support for NER fine-tuning
stefan-it Nov 11, 2019
94e5525
tests: add test case for DistilBertForTokenClassification implementation
stefan-it Nov 11, 2019
2e31176
fix multi-gpu eval
ronakice Nov 12, 2019
74d0bcb
Fix special tokens addition in decoder
LysandreJik Nov 12, 2019
7627dde
sum() is the leanest method to flatten a string list, so it's been re…
iedmrc Nov 14, 2019
022525b
replace LambdaLR scheduler wrappers by function
rlouf Nov 12, 2019
e18f786
Quickstart example showcasing past
LysandreJik Nov 14, 2019
a67e747
Reorganized max_len warning
LysandreJik Nov 14, 2019
d792989
Specify checkpoint in saved file for run_lm_finetuning.py
LysandreJik Nov 14, 2019
2276bf6
update the examples, docs and template
rlouf Nov 14, 2019
8f8d697
[CI][DOC] Don't rebuild if folder exists.
LysandreJik Nov 14, 2019
be7f2aa
[CI][DOC] Don't rebuild if folder exists - Correct directory.
LysandreJik Nov 14, 2019
0be9ae7
Merge pull request #1833 from huggingface/max-length-warning
thomwolf Nov 14, 2019
df99f8c
Merge pull request #1832 from huggingface/memory-leak-schedulers
thomwolf Nov 14, 2019
1a237d7
Merge pull request #1831 from iedmrc/gpt2-tokenization-sum-func-repla…
thomwolf Nov 14, 2019
5b322a3
Merge pull request #1811 from huggingface/special-tokens
thomwolf Nov 14, 2019
9629e2c
Merge pull request #1804 from ronakice/master
thomwolf Nov 14, 2019
05db5bc
added small comparison between BERT, RoBERTa and DistilBERT
thomwolf Nov 14, 2019
74ce8de
Merge pull request #1792 from stefan-it/distilbert-for-token-classifi…
thomwolf Nov 14, 2019
14b3aa3
Add tokenization_camembert.py
louismartin Nov 9, 2019
6e72fd0
Add demo_camembert.py
louismartin Nov 9, 2019
e44b939
Add configuration_camembert.py and modeling_camembert.py
louismartin Nov 13, 2019
fb6c70a
Update tokenization_camembert.py with urls
louismartin Nov 13, 2019
f12e4d8
Move demo_camembert.py to examples/contrib
louismartin Nov 13, 2019
3e20c2e
Update demo_camembert.py with new classes
louismartin Nov 13, 2019
694d4fc
Add CamemBERT classes to __init__.py
louismartin Nov 13, 2019
035fea5
Add CamemBERT to auto files and docs
louismartin Nov 13, 2019
26858f2
[camembert] Upload to s3 + rename script
julien-c Nov 16, 2019
f9abf73
[camembert] realign w/ recent changes
julien-c Nov 16, 2019
0477b30
[camembert] tokenizer: use additional_special_tokens
julien-c Nov 16, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .circleci/deploy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ function deploy_doc(){
git checkout $1
if [ ! -z "$2" ]
then
echo "Pushing version" $2
make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html $doc:$dir/$2
if [ -d "$dir/$2" ]; then
echo "Directory" $2 "already exists"
else
echo "Pushing version" $2
make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html $doc:$dir/$2
fi
else
echo "Pushing master"
make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html/* $doc:$dir
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -520,12 +520,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
# Parameters:
lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000
num_training_steps = 1000
num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1

### Previously BertAdam optimizer was instantiated like this:
optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps)
optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_training_steps)
### and used like this:
for batch in train_data:
loss = model(batch)
Expand All @@ -534,7 +534,7 @@ for batch in train_data:

### In Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler
### and used like this:
for batch in train_data:
model.train()
Expand Down
14 changes: 5 additions & 9 deletions docs/source/main_classes/optimizer_schedules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,33 @@ Schedules
Learning Rate Schedules
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: transformers.ConstantLRSchedule
:members:
.. autofunction:: transformers.get_constant_schedule


.. autoclass:: transformers.WarmupConstantSchedule
:members:
.. autofunction:: transformers.get_constant_schedule_with_warmup

.. image:: /imgs/warmup_constant_schedule.png
:target: /imgs/warmup_constant_schedule.png
:alt:


.. autoclass:: transformers.WarmupCosineSchedule
.. autofunction:: transformers.get_cosine_schedule_with_warmup
:members:

.. image:: /imgs/warmup_cosine_schedule.png
:target: /imgs/warmup_cosine_schedule.png
:alt:


.. autoclass:: transformers.WarmupCosineWithHardRestartsSchedule
:members:
.. autofunction:: transformers.get_cosine_with_hard_restarts_schedule_with_warmup

.. image:: /imgs/warmup_cosine_hard_restarts_schedule.png
:target: /imgs/warmup_cosine_hard_restarts_schedule.png
:alt:



.. autoclass:: transformers.WarmupLinearSchedule
:members:
.. autofunction:: transformers.get_linear_schedule_with_warmup

.. image:: /imgs/warmup_linear_schedule.png
:target: /imgs/warmup_linear_schedule.png
Expand Down
8 changes: 4 additions & 4 deletions docs/source/migration.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch
# Parameters:
lr = 1e-3
max_grad_norm = 1.0
num_total_steps = 1000
num_training_steps = 1000
num_warmup_steps = 100
warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1
warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1

### Previously BertAdam optimizer was instantiated like this:
optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps)
optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, num_training_steps=num_training_steps)
### and used like this:
for batch in train_data:
loss = model(batch)
Expand All @@ -98,7 +98,7 @@ for batch in train_data:

### In Transformers, optimizer and schedules are splitted and instantiated like this:
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler
### and used like this:
for batch in train_data:
loss = model(batch)
Expand Down
6 changes: 5 additions & 1 deletion docs/source/pretrained_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,9 @@ Here is the full list of the currently provided pretrained models together with
| CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters |
| | | | Salesforce's Large-sized CTRL English model |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| CamemBERT | ``camembert-base`` | | 12-layer, 768-hidden, 12-heads, 110M parameters |
| | | | CamemBERT using the BERT-base architecture |
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/camembert>`__) |
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+

.. <https://huggingface.co/transformers/examples.html>`__
.. <https://huggingface.co/transformers/examples.html>`__
32 changes: 32 additions & 0 deletions docs/source/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,35 @@ assert predicted_text == 'Who was Jim Henson? Jim Henson was a man'
```

Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation).

#### Using the past

GPT-2 as well as some other models (GPT, XLNet, Transfo-XL, CTRL) make use of a `past` or `mems` attribute which can be used to prevent re-computing the key/value pairs when using sequential decoding. It is useful when generating sequences as a big part of the attention mechanism benefits from previous computations.

Here is a fully-working example using the `past` with `GPT2LMHeadModel` and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition):

```python
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained('gpt2')

generated = tokenizer.encode("The Manhattan bridge")
context = torch.tensor([generated])
past = None

for i in range(100):
print(i)
output, past = model(context, past=past)
token = torch.argmax(output[0, :])

generated += [token.tolist()]
context = token.unsqueeze(0)

sequence = tokenizer.decode(generated)

print(sequence)
```

The model only requires a single token as input as all the previous tokens' key/value pairs are contained in the `past`.
10 changes: 10 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,16 @@ On the test dataset the following results could be achieved:
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
```

### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)

Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):

| Model | F-Score Dev | F-Score Test
| --------------------------------- | ------- | --------
| `bert-large-cased` | 95.59 | 91.70
| `roberta-large` | 95.96 | 91.87
| `distilbert-base-uncased` | 94.34 | 90.32

## Abstractive summarization

Based on the script
Expand Down
48 changes: 48 additions & 0 deletions examples/contrib/run_camembert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pathlib import Path
import tarfile
import urllib.request

import torch

from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_camembert import CamembertForMaskedLM


def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item())
for i in range(len(indices))])
masked_token = tokenizer.mask_token
topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')):
predicted_token = predicted_token_bpe.replace('\u2581', ' ')
if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append((
masked_input.replace(
' {0}'.format(masked_token), predicted_token
),
values[index].item(),
predicted_token,
))
else:
topk_filled_outputs.append((
masked_input.replace(masked_token, predicted_token),
values[index].item(),
predicted_token,
))
return topk_filled_outputs


tokenizer = CamembertTokenizer.from_pretrained('camembert-base')
model = CamembertForMaskedLM.from_pretrained('camembert-base')
model.eval()

masked_input = "Le camembert est <mask> :)"
print(fill_mask(masked_input, model, tokenizer, topk=3))
4 changes: 2 additions & 2 deletions examples/contrib/run_openai_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME,
WarmupLinearSchedule)
get_linear_schedule_with_warmup)

ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"

Expand Down Expand Up @@ -211,7 +211,7 @@ def tokenize_and_encode(obj):
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)

if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
Expand Down
4 changes: 2 additions & 2 deletions examples/contrib/run_swag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from transformers import (WEIGHTS_NAME, BertConfig,
BertForMultipleChoice, BertTokenizer)

from transformers import AdamW, WarmupLinearSchedule
from transformers import AdamW, get_linear_schedule_with_warmup

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -322,7 +322,7 @@ def train(args, train_dataset, model, tokenizer):
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
Expand Down
8 changes: 4 additions & 4 deletions examples/distillation/distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
except:
from tensorboardX import SummaryWriter

from transformers import WarmupLinearSchedule
from transformers import get_linear_schedule_with_warmup

from utils import logger
from lm_seqs_dataset import LmSeqsDataset
Expand Down Expand Up @@ -137,9 +137,9 @@ def __init__(self,
betas=(0.9, 0.98))

warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop)
self.scheduler = WarmupLinearSchedule(self.optimizer,
warmup_steps=warmup_steps,
t_total=num_train_optimization_steps)
self.scheduler = get_linear_schedule_with_warmup(self.optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=num_train_optimization_steps)

if self.fp16:
try:
Expand Down
4 changes: 2 additions & 2 deletions examples/distillation/run_squad_w_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)

from transformers import AdamW, WarmupLinearSchedule
from transformers import AdamW, get_linear_schedule_with_warmup

from ..utils_squad import (read_squad_examples, convert_examples_to_features,
RawResult, write_predictions,
Expand Down Expand Up @@ -101,7 +101,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None):
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
Expand Down
8 changes: 6 additions & 2 deletions examples/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
DistilBertForSequenceClassification,
DistilBertTokenizer)

from transformers import AdamW, WarmupLinearSchedule
from transformers import AdamW, get_linear_schedule_with_warmup

from transformers import glue_compute_metrics as compute_metrics
from transformers import glue_output_modes as output_modes
Expand Down Expand Up @@ -100,7 +100,7 @@ def train(args, train_dataset, model, tokenizer):
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
Expand Down Expand Up @@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

# multi-gpu eval
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)

# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
Expand Down
14 changes: 9 additions & 5 deletions examples/run_lm_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from tqdm import tqdm, trange

from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule,
from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
BertConfig, BertForMaskedLM, BertTokenizer,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
Expand All @@ -63,10 +63,10 @@


class TextDataset(Dataset):
def __init__(self, tokenizer, file_path='train', block_size=512):
def __init__(self, tokenizer, args, file_path='train', block_size=512):
assert os.path.isfile(file_path)
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename)
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)

if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
Expand Down Expand Up @@ -99,7 +99,7 @@ def __getitem__(self, item):


def load_and_cache_examples(args, tokenizer, evaluate=False):
dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size)
return dataset


Expand Down Expand Up @@ -185,7 +185,7 @@ def train(args, train_dataset, model, tokenizer):
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
Expand Down Expand Up @@ -300,6 +300,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)

# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
Expand Down
Loading