diff --git a/.circleci/deploy.sh b/.circleci/deploy.sh index 2bff0102ae07..98151963d815 100755 --- a/.circleci/deploy.sh +++ b/.circleci/deploy.sh @@ -5,8 +5,12 @@ function deploy_doc(){ git checkout $1 if [ ! -z "$2" ] then - echo "Pushing version" $2 - make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html $doc:$dir/$2 + if [ -d "$dir/$2" ]; then + echo "Directory" $2 "already exists" + else + echo "Pushing version" $2 + make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html $doc:$dir/$2 + fi else echo "Pushing master" make clean && make html && scp -r -oStrictHostKeyChecking=no _build/html/* $doc:$dir diff --git a/README.md b/README.md index 17dfea6374f2..40d6acb76e5f 100644 --- a/README.md +++ b/README.md @@ -520,12 +520,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch # Parameters: lr = 1e-3 max_grad_norm = 1.0 -num_total_steps = 1000 +num_training_steps = 1000 num_warmup_steps = 100 -warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 +warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1 ### Previously BertAdam optimizer was instantiated like this: -optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps) +optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_training_steps) ### and used like this: for batch in train_data: loss = model(batch) @@ -534,7 +534,7 @@ for batch in train_data: ### In Transformers, optimizer and schedules are splitted and instantiated like this: optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False -scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler +scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler ### and used like this: for batch in train_data: model.train() diff --git a/docs/source/main_classes/optimizer_schedules.rst b/docs/source/main_classes/optimizer_schedules.rst index ff0c9e6929c9..b30a2e0e2e16 100644 --- a/docs/source/main_classes/optimizer_schedules.rst +++ b/docs/source/main_classes/optimizer_schedules.rst @@ -18,19 +18,17 @@ Schedules Learning Rate Schedules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: transformers.ConstantLRSchedule - :members: +.. autofunction:: transformers.get_constant_schedule -.. autoclass:: transformers.WarmupConstantSchedule - :members: +.. autofunction:: transformers.get_constant_schedule_with_warmup .. image:: /imgs/warmup_constant_schedule.png :target: /imgs/warmup_constant_schedule.png :alt: -.. autoclass:: transformers.WarmupCosineSchedule +.. autofunction:: transformers.get_cosine_schedule_with_warmup :members: .. image:: /imgs/warmup_cosine_schedule.png @@ -38,8 +36,7 @@ Learning Rate Schedules :alt: -.. autoclass:: transformers.WarmupCosineWithHardRestartsSchedule - :members: +.. autofunction:: transformers.get_cosine_with_hard_restarts_schedule_with_warmup .. image:: /imgs/warmup_cosine_hard_restarts_schedule.png :target: /imgs/warmup_cosine_hard_restarts_schedule.png @@ -47,8 +44,7 @@ Learning Rate Schedules -.. autoclass:: transformers.WarmupLinearSchedule - :members: +.. autofunction:: transformers.get_linear_schedule_with_warmup .. image:: /imgs/warmup_linear_schedule.png :target: /imgs/warmup_linear_schedule.png diff --git a/docs/source/migration.md b/docs/source/migration.md index 553a79c82b03..d04b66d5e4ad 100644 --- a/docs/source/migration.md +++ b/docs/source/migration.md @@ -84,12 +84,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch # Parameters: lr = 1e-3 max_grad_norm = 1.0 -num_total_steps = 1000 +num_training_steps = 1000 num_warmup_steps = 100 -warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 +warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1 ### Previously BertAdam optimizer was instantiated like this: -optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps) +optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, num_training_steps=num_training_steps) ### and used like this: for batch in train_data: loss = model(batch) @@ -98,7 +98,7 @@ for batch in train_data: ### In Transformers, optimizer and schedules are splitted and instantiated like this: optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False -scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler +scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler ### and used like this: for batch in train_data: loss = model(batch) diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index edb47e7f1cc6..b0a578fd804b 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -155,5 +155,9 @@ Here is the full list of the currently provided pretrained models together with | CTRL | ``ctrl`` | | 48-layer, 1280-hidden, 16-heads, 1.6B parameters | | | | | Salesforce's Large-sized CTRL English model | +-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| CamemBERT | ``camembert-base`` | | 12-layer, 768-hidden, 12-heads, 110M parameters | +| | | | CamemBERT using the BERT-base architecture | +| | | (see `details `__) | ++-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ -.. `__ \ No newline at end of file +.. `__ diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index ccba75e7c074..530aff8eb011 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -188,3 +188,35 @@ assert predicted_text == 'Who was Jim Henson? Jim Henson was a man' ``` Examples for each model class of each model architecture (Bert, GPT, GPT-2, Transformer-XL, XLNet and XLM) can be found in the [documentation](#documentation). + +#### Using the past + +GPT-2 as well as some other models (GPT, XLNet, Transfo-XL, CTRL) make use of a `past` or `mems` attribute which can be used to prevent re-computing the key/value pairs when using sequential decoding. It is useful when generating sequences as a big part of the attention mechanism benefits from previous computations. + +Here is a fully-working example using the `past` with `GPT2LMHeadModel` and argmax decoding (which should only be used as an example, as argmax decoding introduces a lot of repetition): + +```python +from transformers import GPT2LMHeadModel, GPT2Tokenizer +import torch + +tokenizer = GPT2Tokenizer.from_pretrained("gpt2") +model = GPT2LMHeadModel.from_pretrained('gpt2') + +generated = tokenizer.encode("The Manhattan bridge") +context = torch.tensor([generated]) +past = None + +for i in range(100): + print(i) + output, past = model(context, past=past) + token = torch.argmax(output[0, :]) + + generated += [token.tolist()] + context = token.unsqueeze(0) + +sequence = tokenizer.decode(generated) + +print(sequence) +``` + +The model only requires a single token as input as all the previous tokens' key/value pairs are contained in the `past`. \ No newline at end of file diff --git a/examples/README.md b/examples/README.md index 2b66b92f1ada..abb4cb6e5a65 100644 --- a/examples/README.md +++ b/examples/README.md @@ -554,6 +554,16 @@ On the test dataset the following results could be achieved: 10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085 ``` +### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) + +Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run): + +| Model | F-Score Dev | F-Score Test +| --------------------------------- | ------- | -------- +| `bert-large-cased` | 95.59 | 91.70 +| `roberta-large` | 95.96 | 91.87 +| `distilbert-base-uncased` | 94.34 | 90.32 + ## Abstractive summarization Based on the script diff --git a/examples/contrib/run_camembert.py b/examples/contrib/run_camembert.py new file mode 100644 index 000000000000..28144d516709 --- /dev/null +++ b/examples/contrib/run_camembert.py @@ -0,0 +1,48 @@ +from pathlib import Path +import tarfile +import urllib.request + +import torch + +from transformers.tokenization_camembert import CamembertTokenizer +from transformers.modeling_camembert import CamembertForMaskedLM + + +def fill_mask(masked_input, model, tokenizer, topk=5): + # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py + assert masked_input.count('') == 1 + input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 + logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple + masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() + logits = logits[0, masked_index, :] + prob = logits.softmax(dim=0) + values, indices = prob.topk(k=topk, dim=0) + topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) + for i in range(len(indices))]) + masked_token = tokenizer.mask_token + topk_filled_outputs = [] + for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): + predicted_token = predicted_token_bpe.replace('\u2581', ' ') + if " {0}".format(masked_token) in masked_input: + topk_filled_outputs.append(( + masked_input.replace( + ' {0}'.format(masked_token), predicted_token + ), + values[index].item(), + predicted_token, + )) + else: + topk_filled_outputs.append(( + masked_input.replace(masked_token, predicted_token), + values[index].item(), + predicted_token, + )) + return topk_filled_outputs + + +tokenizer = CamembertTokenizer.from_pretrained('camembert-base') +model = CamembertForMaskedLM.from_pretrained('camembert-base') +model.eval() + +masked_input = "Le camembert est :)" +print(fill_mask(masked_input, model, tokenizer, topk=3)) diff --git a/examples/contrib/run_openai_gpt.py b/examples/contrib/run_openai_gpt.py index 7eb1b0be7651..2d165a91e326 100644 --- a/examples/contrib/run_openai_gpt.py +++ b/examples/contrib/run_openai_gpt.py @@ -41,7 +41,7 @@ from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME, - WarmupLinearSchedule) + get_linear_schedule_with_warmup) ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" @@ -211,7 +211,7 @@ def tokenize_and_encode(obj): {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.do_train: nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None diff --git a/examples/contrib/run_swag.py b/examples/contrib/run_swag.py index 8494c5fad9a1..5de93db7fe88 100644 --- a/examples/contrib/run_swag.py +++ b/examples/contrib/run_swag.py @@ -42,7 +42,7 @@ from transformers import (WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup logger = logging.getLogger(__name__) @@ -322,7 +322,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/distillation/distiller.py b/examples/distillation/distiller.py index d51bdae77fef..0442072e84f4 100644 --- a/examples/distillation/distiller.py +++ b/examples/distillation/distiller.py @@ -35,7 +35,7 @@ except: from tensorboardX import SummaryWriter -from transformers import WarmupLinearSchedule +from transformers import get_linear_schedule_with_warmup from utils import logger from lm_seqs_dataset import LmSeqsDataset @@ -137,9 +137,9 @@ def __init__(self, betas=(0.9, 0.98)) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) - self.scheduler = WarmupLinearSchedule(self.optimizer, - warmup_steps=warmup_steps, - t_total=num_train_optimization_steps) + self.scheduler = get_linear_schedule_with_warmup(self.optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=num_train_optimization_steps) if self.fp16: try: diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 7c662df0106a..70b65dc1b8fa 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -46,7 +46,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from ..utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -101,7 +101,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_glue.py b/examples/run_glue.py index 1558a812c3e3..19316cb0ec41 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -49,7 +49,7 @@ DistilBertForSequenceClassification, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import glue_compute_metrics as compute_metrics from transformers import glue_output_modes as output_modes @@ -100,7 +100,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp @@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu eval + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 2044cfe9e87a..52a1b75a6542 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -42,7 +42,7 @@ from tqdm import tqdm, trange -from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, +from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, BertConfig, BertForMaskedLM, BertTokenizer, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, @@ -63,10 +63,10 @@ class TextDataset(Dataset): - def __init__(self, tokenizer, file_path='train', block_size=512): + def __init__(self, tokenizer, args, file_path='train', block_size=512): assert os.path.isfile(file_path) directory, filename = os.path.split(file_path) - cached_features_file = os.path.join(directory, 'cached_lm_' + str(block_size) + '_' + filename) + cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename) if os.path.exists(cached_features_file): logger.info("Loading features from cached file %s", cached_features_file) @@ -99,7 +99,7 @@ def __getitem__(self, item): def load_and_cache_examples(args, tokenizer, evaluate=False): - dataset = TextDataset(tokenizer, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) + dataset = TextDataset(tokenizer, args, file_path=args.eval_data_file if evaluate else args.train_data_file, block_size=args.block_size) return dataset @@ -185,7 +185,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp @@ -300,6 +300,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_multiple_choice.py b/examples/run_multiple_choice.py index 638bbe74f180..30c333292921 100644 --- a/examples/run_multiple_choice.py +++ b/examples/run_multiple_choice.py @@ -43,7 +43,7 @@ XLNetTokenizer, RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_multiple_choice import (convert_examples_to_features, processors) @@ -101,7 +101,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp @@ -229,6 +229,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False): eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) diff --git a/examples/run_ner.py b/examples/run_ner.py index b35d8298fe34..4359e587aea1 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -33,19 +33,21 @@ from tqdm import tqdm, trange from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer +from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer logger = logging.getLogger(__name__) ALL_MODELS = sum( - (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig)), + (tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)), ()) MODEL_CLASSES = { "bert": (BertConfig, BertForTokenClassification, BertTokenizer), - "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer) + "roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer), + "distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer) } @@ -80,7 +82,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp @@ -121,9 +123,10 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): batch = tuple(t.to(args.device) for t in batch) inputs = {"input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None, - # XLM and RoBERTa don"t use segment_ids "labels": batch[3]} + if args.model_type != "distilbert": + inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids + outputs = model(**inputs) loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc) @@ -191,6 +194,10 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation %s *****", prefix) logger.info(" Num examples = %d", len(eval_dataset)) @@ -206,9 +213,9 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" with torch.no_grad(): inputs = {"input_ids": batch[0], "attention_mask": batch[1], - "token_type_ids": batch[2] if args.model_type in ["bert", "xlnet"] else None, - # XLM and RoBERTa don"t use segment_ids "labels": batch[3]} + if args.model_type != "distilbert": + inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids outputs = model(**inputs) tmp_eval_loss, logits = outputs[:2] @@ -520,3 +527,4 @@ def main(): if __name__ == "__main__": main() + diff --git a/examples/run_squad.py b/examples/run_squad.py index d9dc2abfdec8..d7fdc32ae7e5 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -45,7 +45,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -100,7 +100,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp @@ -217,6 +217,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + # Eval! logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(dataset)) diff --git a/examples/run_summarization_finetuning.py b/examples/run_summarization_finetuning.py index 448505c7273c..f5604c266935 100644 --- a/examples/run_summarization_finetuning.py +++ b/examples/run_summarization_finetuning.py @@ -275,6 +275,10 @@ def evaluate(args, model, tokenizer, prefix=""): eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size ) + # multi-gpu evaluate + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + logger.info("***** Running evaluation {} *****".format(prefix)) logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Batch size = %d", args.eval_batch_size) diff --git a/templates/adding_a_new_example_script/run_xxx.py b/templates/adding_a_new_example_script/run_xxx.py index 489dcb19c7b9..77ce587a5489 100644 --- a/templates/adding_a_new_example_script/run_xxx.py +++ b/templates/adding_a_new_example_script/run_xxx.py @@ -43,7 +43,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -98,7 +98,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/transformers/__init__.py b/transformers/__init__.py index d922f52a1d46..cdf0669b391f 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -42,6 +42,7 @@ from .tokenization_xlm import XLMTokenizer from .tokenization_roberta import RobertaTokenizer from .tokenization_distilbert import DistilBertTokenizer +from .tokenization_camembert import CamembertTokenizer # Configurations from .configuration_utils import PretrainedConfig @@ -56,6 +57,7 @@ from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP # Modeling if is_torch_available(): @@ -94,12 +96,16 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, + DistilBertForTokenClassification, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + from .modeling_camembert import (CamembertForMaskedLM, CamembertModel, + CamembertForSequenceClassification, CamembertForMultipleChoice, + CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP) from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model # Optimization - from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, - WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) + from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) # TensorFlow diff --git a/transformers/configuration_auto.py b/transformers/configuration_auto.py index edd21a670ce4..2906136139c5 100644 --- a/transformers/configuration_auto.py +++ b/transformers/configuration_auto.py @@ -27,6 +27,7 @@ from .configuration_roberta import RobertaConfig from .configuration_distilbert import DistilBertConfig from .configuration_ctrl import CTRLConfig +from .configuration_camembert import CamembertConfig logger = logging.getLogger(__name__) @@ -50,6 +51,7 @@ class method. - contains `xlnet`: XLNetConfig (XLNet model) - contains `xlm`: XLMConfig (XLM model) - contains `roberta`: RobertaConfig (RoBERTa model) + - contains `camembert`: CamembertConfig (CamemBERT model) - contains `ctrl` : CTRLConfig (CTRL model) This class cannot be instantiated using `__init__()` (throw an error). """ @@ -72,6 +74,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - contains `xlnet`: XLNetConfig (XLNet model) - contains `xlm`: XLMConfig (XLM model) - contains `roberta`: RobertaConfig (RoBERTa model) + - contains `camembert`: CamembertConfig (CamemBERT model) - contains `ctrl` : CTRLConfig (CTRL model) Params: pretrained_model_name_or_path: either: @@ -116,6 +119,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): """ if 'distilbert' in pretrained_model_name_or_path: return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'camembert' in pretrained_model_name_or_path: + return CamembertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) elif 'bert' in pretrained_model_name_or_path: @@ -134,4 +139,4 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta', 'camembert', 'ctrl'".format(pretrained_model_name_or_path)) diff --git a/transformers/configuration_camembert.py b/transformers/configuration_camembert.py new file mode 100644 index 000000000000..3ff64454e503 --- /dev/null +++ b/transformers/configuration_camembert.py @@ -0,0 +1,33 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" CamemBERT configuration """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging + +from .configuration_roberta import RobertaConfig + +logger = logging.getLogger(__name__) + +CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-config.json", +} + + +class CamembertConfig(RobertaConfig): + pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP diff --git a/transformers/modeling_camembert.py b/transformers/modeling_camembert.py new file mode 100644 index 000000000000..05538926e2af --- /dev/null +++ b/transformers/modeling_camembert.py @@ -0,0 +1,257 @@ +# coding=utf-8 +# Copyright 2019 Inria, Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch CamemBERT model. """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging + +from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForMultipleChoice +from .configuration_camembert import CamembertConfig +from .file_utils import add_start_docstrings + +logger = logging.getLogger(__name__) + +CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { + 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-pytorch_model.bin", +} + + +CAMEMBERT_START_DOCSTRING = r""" The CamemBERT model was proposed in + `CamemBERT: a Tasty French Language Model`_ + by Louis Martin, Benjamin Muller, Pedro Javier Ortiz Suárez, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah, and Benoît Sagot. It is based on Facebook's RoBERTa model released in 2019. + + It is a model trained on 138GB of French text. + + This implementation is the same as RoBERTa. + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`CamemBERT: a Tasty French Language Model`: + https://arxiv.org/abs/1911.03894 + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~transformers.CamembertConfig`): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +CAMEMBERT_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + To match pre-training, CamemBERT input sequence should be formatted with and tokens as follows: + + (a) For sequence pairs: + + ``tokens: Is this Jacksonville ? No it is not . `` + + (b) For single sequences: + + ``tokens: the dog is hairy . `` + + Fully encoded sequences or sequence pairs can be obtained using the CamembertTokenizer.encode function with + the ``add_special_tokens`` parameter set to ``True``. + + CamemBERT is a model with absolute position embeddings so it's usually advised to pad the inputs on + the right rather than the left. + + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **token_type_ids**: (`optional` need to be trained) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Optional segment token indices to indicate first and second portions of the inputs. + This embedding matrice is not trained (not pretrained during CamemBERT pretraining), you will have to train it + during finetuning. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + (see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details). + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1[``. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: + Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. +""" + +@add_start_docstrings("The bare CamemBERT Model transformer outputting raw hidden-states without any specific head on top.", + CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING) +class CamembertModel(RobertaModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + eo match pre-training, CamemBERT input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1`` + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + ``token_type_ids: 0 0 0 0 0 0 0`` + + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = CamembertTokenizer.from_pretrained('camembert-base') + model = CamembertModel.from_pretrained('camembert-base') + input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + config_class = CamembertConfig + pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP + + +@add_start_docstrings("""CamemBERT Model with a `language modeling` head on top. """, + CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING) +class CamembertForMaskedLM(RobertaForMaskedLM): + r""" + **masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the masked language modeling loss. + Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) + Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels + in ``[0, ..., config.vocab_size]`` + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Masked language modeling loss. + **prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)`` + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = CamembertTokenizer.from_pretrained('camembert-base') + model = CamembertForMaskedLM.from_pretrained('camembert-base') + input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, masked_lm_labels=input_ids) + loss, prediction_scores = outputs[:2] + + """ + config_class = CamembertConfig + pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP + + +@add_start_docstrings("""CamemBERT Model transformer with a sequence classification/regression head on top (a linear layer + on top of the pooled output) e.g. for GLUE tasks. """, + CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING) +class CamembertForSequenceClassification(RobertaForSequenceClassification): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = CamembertTokenizer.from_pretrained('camembert-base') + model = CamembertForSequenceClassification.from_pretrained('camembert-base') + input_ids = torch.tensor(tokenizer.encode("J'aime le camembert !")).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, logits = outputs[:2] + + """ + config_class = CamembertConfig + pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP + + +@add_start_docstrings("""CamemBERT Model with a multiple choice classification head on top (a linear layer on top of + the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, + CAMEMBERT_START_DOCSTRING, CAMEMBERT_INPUTS_DOCSTRING) +class CamembertForMultipleChoice(RobertaForMultipleChoice): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension + of the input tensors. (see `input_ids` above). + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = CamembertTokenizer.from_pretrained('camembert-base') + model = CamembertForMultipleChoice.from_pretrained('camembert-base') + choices = ["J'aime le camembert !", "Je deteste le camembert !"] + input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, classification_scores = outputs[:2] + + """ + config_class = CamembertConfig + pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP diff --git a/transformers/modeling_distilbert.py b/transformers/modeling_distilbert.py index 00106627a8ab..d30f493c695b 100644 --- a/transformers/modeling_distilbert.py +++ b/transformers/modeling_distilbert.py @@ -30,6 +30,7 @@ import torch import torch.nn as nn +from torch.nn import CrossEntropyLoss from .modeling_utils import PreTrainedModel, prune_linear_layer from .configuration_distilbert import DistilBertConfig @@ -702,3 +703,75 @@ def forward(self, input_ids=None, attention_mask=None, head_mask=None, inputs_em outputs = (total_loss,) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of + the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, + DISTILBERT_START_DOCSTRING, + DISTILBERT_INPUTS_DOCSTRING) +class DistilBertForTokenClassification(DistilBertPreTrainedModel): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Labels for computing the token classification loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss. + **scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)`` + Classification scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') + model = DistilBertForTokenClassification.from_pretrained('distilbert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, scores = outputs[:2] + + """ + def __init__(self, config): + super(DistilBertForTokenClassification, self).__init__(config) + self.num_labels = config.num_labels + + self.distilbert = DistilBertModel(config) + self.dropout = nn.Dropout(config.dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def forward(self, input_ids=None, attention_mask=None, head_mask=None, + inputs_embeds=None, labels=None): + + outputs = self.distilbert(input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), scores, (hidden_states), (attentions) diff --git a/transformers/optimization.py b/transformers/optimization.py index a48b5fea54d4..99e6cc75e402 100644 --- a/transformers/optimization.py +++ b/transformers/optimization.py @@ -23,90 +23,66 @@ logger = logging.getLogger(__name__) -class ConstantLRSchedule(LambdaLR): - """ Constant learning rate schedule. + +def get_constant_schedule(optimizer, last_epoch=-1): + """ Create a schedule with a constant learning rate. """ - def __init__(self, optimizer, last_epoch=-1): - super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) -class WarmupConstantSchedule(LambdaLR): - """ Linear warmup and then constant. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Keeps multiplicative variable equal to 1. after warmup_steps. +def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): + """ Create a schedule with a constant learning rate preceded by a warmup + period during which the learning rate increases linearly between 0 and 1. """ - def __init__(self, optimizer, warmup_steps, last_epoch=-1): - self.warmup_steps = warmup_steps - super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1.0, self.warmup_steps)) + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) return 1. + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) -class WarmupLinearSchedule(LambdaLR): - """ Linear warmup and then linear decay. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Linearly decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps. - """ - def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1, self.warmup_steps)) - return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) - - -class WarmupCosineSchedule(LambdaLR): - """ Linear warmup and then cosine decay. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. - If `cycles` (default=0.5) is different from default, then the multiplicative variable follows cosine function after warmup. + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. """ - def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - self.cycles = cycles - super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1.0, self.warmup_steps)) - # progress after warmup - progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) - return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) - - -class WarmupCosineWithHardRestartsSchedule(LambdaLR): - """ Linear warmup and then cosine cycles with hard restarts. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying - learning rate (with hard restarts). + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): + """ Create a schedule with a learning rate that decreases following the + values of the cosine function between 0 and `pi * cycles` after a warmup + period during which it increases linearly between 0 and 1. """ - def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - self.cycles = cycles - super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1, self.warmup_steps)) - # progress after warmup - progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) - if progress >= 1.0: - return 0.0 - return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): + """ Create a schedule with a learning rate that decreases following the + values of the cosine function with several hard restarts, after a warmup + period during which it increases linearly between 0 and 1. + """ + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.: + return 0. + return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix. diff --git a/transformers/tests/modeling_distilbert_test.py b/transformers/tests/modeling_distilbert_test.py index 937d03396d59..8099c035865d 100644 --- a/transformers/tests/modeling_distilbert_test.py +++ b/transformers/tests/modeling_distilbert_test.py @@ -23,6 +23,7 @@ if is_torch_available(): from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, + DistilBertForTokenClassification, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) else: pytestmark = pytest.mark.skip("Require Torch") @@ -180,6 +181,21 @@ def create_and_check_distilbert_for_sequence_classification(self, config, input_ [self.batch_size, self.num_labels]) self.check_loss_output(result) + def create_and_check_distilbert_for_token_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + config.num_labels = self.num_labels + model = DistilBertForTokenClassification(config=config) + model.eval() + + loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual( + list(result["logits"].size()), + [self.batch_size, self.seq_length, self.num_labels]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs @@ -209,6 +225,10 @@ def test_for_sequence_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs) + # @pytest.mark.slow # def test_model_from_pretrained(self): # cache_dir = "/tmp/transformers_test/" diff --git a/transformers/tests/optimization_test.py b/transformers/tests/optimization_test.py index 84dbaca52a9c..ab9afbfcf72f 100644 --- a/transformers/tests/optimization_test.py +++ b/transformers/tests/optimization_test.py @@ -25,8 +25,12 @@ if is_torch_available(): import torch - from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, - WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) + from transformers import (AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup) else: pytestmark = pytest.mark.skip("Require Torch") @@ -87,59 +91,60 @@ def assertListAlmostEqual(self, list1, list2, tol): self.assertAlmostEqual(a, b, delta=tol) def test_constant_scheduler(self): - scheduler = ConstantLRSchedule(self.optimizer) + scheduler = get_constant_schedule(self.optimizer) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [10.] * self.num_steps self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = ConstantLRSchedule(self.optimizer) + scheduler = get_constant_schedule(self.optimizer) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_constant_scheduler(self): - scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) + scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) + scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_linear_scheduler(self): - scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_cosine_scheduler(self): - scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_cosine_hard_restart_scheduler(self): - scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) + if __name__ == "__main__": unittest.main() diff --git a/transformers/tests/tokenization_tests_commons.py b/transformers/tests/tokenization_tests_commons.py index a921696b77d2..fdaf8cc137c7 100644 --- a/transformers/tests/tokenization_tests_commons.py +++ b/transformers/tests/tokenization_tests_commons.py @@ -160,6 +160,26 @@ def test_add_tokens_tokenizer(self): self.assertEqual(tokens[0], tokenizer.eos_token_id) self.assertEqual(tokens[-2], tokenizer.pad_token_id) + def test_add_special_tokens(self): + tokenizer = self.get_tokenizer() + input_text, output_text = self.get_input_output_texts() + + special_token = "[SPECIAL TOKEN]" + + tokenizer.add_special_tokens({"cls_token": special_token}) + encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False) + assert len(encoded_special_token) == 1 + + text = " ".join([input_text, special_token, output_text]) + encoded = tokenizer.encode(text, add_special_tokens=False) + + input_encoded = tokenizer.encode(input_text, add_special_tokens=False) + output_encoded = tokenizer.encode(output_text, add_special_tokens=False) + special_token_id = tokenizer.encode(special_token, add_special_tokens=False) + assert encoded == input_encoded + special_token_id + output_encoded + + decoded = tokenizer.decode(encoded, skip_special_tokens=True) + assert special_token not in decoded def test_required_methods_tokenizer(self): tokenizer = self.get_tokenizer() diff --git a/transformers/tokenization_auto.py b/transformers/tokenization_auto.py index ec056de17fa0..4510159905bc 100644 --- a/transformers/tokenization_auto.py +++ b/transformers/tokenization_auto.py @@ -27,6 +27,7 @@ from .tokenization_xlm import XLMTokenizer from .tokenization_roberta import RobertaTokenizer from .tokenization_distilbert import DistilBertTokenizer +from .tokenization_camembert import CamembertTokenizer logger = logging.getLogger(__name__) @@ -41,6 +42,7 @@ class method. The tokenizer class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `distilbert`: DistilBertTokenizer (DistilBert model) - contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `bert`: BertTokenizer (Bert model) @@ -64,8 +66,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): The tokenizer class to instantiate is selected as the first pattern matching in the `pretrained_model_name_or_path` string (in the following order): + - contains `camembert`: CamembertTokenizer (CamemBERT model) - contains `distilbert`: DistilBertTokenizer (DistilBert model) - - contains `roberta`: RobertaTokenizer (XLM model) + - contains `roberta`: RobertaTokenizer (RoBERTa model) - contains `bert`: BertTokenizer (Bert model) - contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model) - contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model) @@ -103,6 +106,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ if 'distilbert' in pretrained_model_name_or_path: return DistilBertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) + elif 'camembert' in pretrained_model_name_or_path: + return CamembertTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'roberta' in pretrained_model_name_or_path: return RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) elif 'bert' in pretrained_model_name_or_path: @@ -121,4 +126,4 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): return CTRLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path)) + "'xlm', 'roberta', 'camembert', 'ctrl'".format(pretrained_model_name_or_path)) diff --git a/transformers/tokenization_camembert.py b/transformers/tokenization_camembert.py new file mode 100644 index 000000000000..41d3d74cffaa --- /dev/null +++ b/transformers/tokenization_camembert.py @@ -0,0 +1,137 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" Tokenization classes for Camembert model.""" +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sentencepiece as spm +from transformers.tokenization_utils import PreTrainedTokenizer + + +VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'} + +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'camembert-base': "https://s3.amazonaws.com/models.huggingface.co/bert/camembert-base-sentencepiece.bpe.model", + } +} + +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + 'camembert-base': None, +} + +class CamembertTokenizer(PreTrainedTokenizer): + """ + Adapted from RobertaTokenizer and XLNetTokenizer + SentencePiece based tokenizer. Peculiarities: + + - requires `SentencePiece `_ + """ + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__(self, vocab_file, bos_token="", eos_token="", sep_token="", + cls_token="", unk_token="", pad_token='', mask_token='', + additional_special_tokens=['NOTUSED', 'NOTUSED'], **kwargs): + super(CamembertTokenizer, self).__init__(max_len=512, bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, + sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, + mask_token=mask_token, additional_special_tokens=additional_special_tokens, + **kwargs) + self.max_len_single_sentence = self.max_len - 2 # take into account special tokens + self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens + self.sp_model = spm.SentencePieceProcessor() + self.sp_model.Load(str(vocab_file)) + # HACK: These tokens were added by fairseq but don't seem to be actually used when duplicated in the actual + # sentencepiece vocabulary (this is the case for and + self.fairseq_tokens_to_ids = {'NOTUSED': 0, '': 1, 'NOTUSED': 2, '': 3} + self.fairseq_offset = len(self.fairseq_tokens_to_ids) + self.fairseq_tokens_to_ids[''] = len(self.sp_model) + len(self.fairseq_tokens_to_ids) + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A RoBERTa sequence has the following format: + single sequence: X + pair of sequences: A B + """ + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError("You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model.") + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A RoBERTa sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + + if token_ids_1 is None, only returns the first portion of the mask (0's). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep) * [0] + len(token_ids_1 + sep) * [1] + + @property + def vocab_size(self): + return self.fairseq_offset + len(self.sp_model) + + def _tokenize(self, text): + return self.sp_model.EncodeAsPieces(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str/unicode) in an id using the vocab. """ + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + return self.fairseq_offset + self.sp_model.PieceToId(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) diff --git a/transformers/tokenization_utils.py b/transformers/tokenization_utils.py index cd14cc4582db..4fa26a26f8dc 100644 --- a/transformers/tokenization_utils.py +++ b/transformers/tokenization_utils.py @@ -21,6 +21,7 @@ import json import six import copy +import itertools from io import open from .file_utils import cached_path, is_tf_available, is_torch_available @@ -641,9 +642,9 @@ def split_on_tokens(tok_list, text): tokenized_text += [sub_text] text_list = tokenized_text - return sum((self._tokenize(token, **kwargs) if token not \ + return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \ in self.added_tokens_encoder and token not in self.all_special_tokens \ - else [token] for token in tokenized_text), []) + else [token] for token in tokenized_text))) added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens tokenized_text = split_on_tokens(added_tokens, text) @@ -671,10 +672,6 @@ def convert_tokens_to_ids(self, tokens): ids = [] for token in tokens: ids.append(self._convert_token_to_id_with_added_voc(token)) - if len(ids) > self.max_len: - logger.warning("Token indices sequence length is longer than the specified maximum sequence length " - "for this model ({} > {}). Running this sequence through the model will result in " - "indexing errors".format(len(ids), self.max_len)) return ids def _convert_token_to_id_with_added_voc(self, token): @@ -877,6 +874,11 @@ def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tok encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length] encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length] + if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len: + logger.warning("Token indices sequence length is longer than the specified maximum sequence length " + "for this model ({} > {}). Running this sequence through the model will result in " + "indexing errors".format(len(ids), self.max_len)) + return encoded_inputs def truncate_sequences(self, ids, pair_ids=None, num_tokens_to_remove=0, truncation_strategy='longest_first', stride=0): @@ -1055,7 +1057,7 @@ def all_special_ids(self): class attributes (cls_token, unk_token...). """ all_toks = self.all_special_tokens - all_ids = list(self._convert_token_to_id(t) for t in all_toks) + all_ids = self.convert_tokens_to_ids(all_toks) return all_ids @staticmethod