diff --git a/README.md b/README.md index 40b08583b1d3..a22a3c83f50d 100644 --- a/README.md +++ b/README.md @@ -521,12 +521,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch # Parameters: lr = 1e-3 max_grad_norm = 1.0 -num_total_steps = 1000 +num_training_steps = 1000 num_warmup_steps = 100 -warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 +warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1 ### Previously BertAdam optimizer was instantiated like this: -optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps) +optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_training_steps) ### and used like this: for batch in train_data: loss = model(batch) @@ -535,7 +535,7 @@ for batch in train_data: ### In Transformers, optimizer and schedules are splitted and instantiated like this: optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False -scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler +scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler ### and used like this: for batch in train_data: model.train() diff --git a/docs/source/main_classes/optimizer_schedules.rst b/docs/source/main_classes/optimizer_schedules.rst index ff0c9e6929c9..b30a2e0e2e16 100644 --- a/docs/source/main_classes/optimizer_schedules.rst +++ b/docs/source/main_classes/optimizer_schedules.rst @@ -18,19 +18,17 @@ Schedules Learning Rate Schedules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: transformers.ConstantLRSchedule - :members: +.. autofunction:: transformers.get_constant_schedule -.. autoclass:: transformers.WarmupConstantSchedule - :members: +.. autofunction:: transformers.get_constant_schedule_with_warmup .. image:: /imgs/warmup_constant_schedule.png :target: /imgs/warmup_constant_schedule.png :alt: -.. autoclass:: transformers.WarmupCosineSchedule +.. autofunction:: transformers.get_cosine_schedule_with_warmup :members: .. image:: /imgs/warmup_cosine_schedule.png @@ -38,8 +36,7 @@ Learning Rate Schedules :alt: -.. autoclass:: transformers.WarmupCosineWithHardRestartsSchedule - :members: +.. autofunction:: transformers.get_cosine_with_hard_restarts_schedule_with_warmup .. image:: /imgs/warmup_cosine_hard_restarts_schedule.png :target: /imgs/warmup_cosine_hard_restarts_schedule.png @@ -47,8 +44,7 @@ Learning Rate Schedules -.. autoclass:: transformers.WarmupLinearSchedule - :members: +.. autofunction:: transformers.get_linear_schedule_with_warmup .. image:: /imgs/warmup_linear_schedule.png :target: /imgs/warmup_linear_schedule.png diff --git a/docs/source/migration.md b/docs/source/migration.md index 553a79c82b03..d04b66d5e4ad 100644 --- a/docs/source/migration.md +++ b/docs/source/migration.md @@ -84,12 +84,12 @@ Here is a conversion examples from `BertAdam` with a linear warmup and decay sch # Parameters: lr = 1e-3 max_grad_norm = 1.0 -num_total_steps = 1000 +num_training_steps = 1000 num_warmup_steps = 100 -warmup_proportion = float(num_warmup_steps) / float(num_total_steps) # 0.1 +warmup_proportion = float(num_warmup_steps) / float(num_training_steps) # 0.1 ### Previously BertAdam optimizer was instantiated like this: -optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, t_total=num_total_steps) +optimizer = BertAdam(model.parameters(), lr=lr, schedule='warmup_linear', warmup=warmup_proportion, num_training_steps=num_training_steps) ### and used like this: for batch in train_data: loss = model(batch) @@ -98,7 +98,7 @@ for batch in train_data: ### In Transformers, optimizer and schedules are splitted and instantiated like this: optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False) # To reproduce BertAdam specific behavior set correct_bias=False -scheduler = WarmupLinearSchedule(optimizer, warmup_steps=num_warmup_steps, t_total=num_total_steps) # PyTorch scheduler +scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) # PyTorch scheduler ### and used like this: for batch in train_data: loss = model(batch) diff --git a/examples/contrib/run_openai_gpt.py b/examples/contrib/run_openai_gpt.py index 7eb1b0be7651..2d165a91e326 100644 --- a/examples/contrib/run_openai_gpt.py +++ b/examples/contrib/run_openai_gpt.py @@ -41,7 +41,7 @@ from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME, - WarmupLinearSchedule) + get_linear_schedule_with_warmup) ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" @@ -211,7 +211,7 @@ def tokenize_and_encode(obj): {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.do_train: nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None diff --git a/examples/contrib/run_swag.py b/examples/contrib/run_swag.py index 8494c5fad9a1..5de93db7fe88 100644 --- a/examples/contrib/run_swag.py +++ b/examples/contrib/run_swag.py @@ -42,7 +42,7 @@ from transformers import (WEIGHTS_NAME, BertConfig, BertForMultipleChoice, BertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup logger = logging.getLogger(__name__) @@ -322,7 +322,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/distillation/distiller.py b/examples/distillation/distiller.py index d51bdae77fef..0442072e84f4 100644 --- a/examples/distillation/distiller.py +++ b/examples/distillation/distiller.py @@ -35,7 +35,7 @@ except: from tensorboardX import SummaryWriter -from transformers import WarmupLinearSchedule +from transformers import get_linear_schedule_with_warmup from utils import logger from lm_seqs_dataset import LmSeqsDataset @@ -137,9 +137,9 @@ def __init__(self, betas=(0.9, 0.98)) warmup_steps = math.ceil(num_train_optimization_steps * params.warmup_prop) - self.scheduler = WarmupLinearSchedule(self.optimizer, - warmup_steps=warmup_steps, - t_total=num_train_optimization_steps) + self.scheduler = get_linear_schedule_with_warmup(self.optimizer, + num_warmup_steps=warmup_steps, + num_training_steps=num_train_optimization_steps) if self.fp16: try: diff --git a/examples/distillation/run_squad_w_distillation.py b/examples/distillation/run_squad_w_distillation.py index 7c662df0106a..70b65dc1b8fa 100644 --- a/examples/distillation/run_squad_w_distillation.py +++ b/examples/distillation/run_squad_w_distillation.py @@ -46,7 +46,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from ..utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -101,7 +101,7 @@ def train(args, train_dataset, model, tokenizer, teacher=None): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_glue.py b/examples/run_glue.py index 1558a812c3e3..27048ad565b6 100644 --- a/examples/run_glue.py +++ b/examples/run_glue.py @@ -49,7 +49,7 @@ DistilBertForSequenceClassification, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import glue_compute_metrics as compute_metrics from transformers import glue_output_modes as output_modes @@ -100,7 +100,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_lm_finetuning.py b/examples/run_lm_finetuning.py index 2044cfe9e87a..0085aee727a1 100644 --- a/examples/run_lm_finetuning.py +++ b/examples/run_lm_finetuning.py @@ -42,7 +42,7 @@ from tqdm import tqdm, trange -from transformers import (WEIGHTS_NAME, AdamW, WarmupLinearSchedule, +from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup, BertConfig, BertForMaskedLM, BertTokenizer, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, @@ -185,7 +185,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_multiple_choice.py b/examples/run_multiple_choice.py index 638bbe74f180..544014fb6629 100644 --- a/examples/run_multiple_choice.py +++ b/examples/run_multiple_choice.py @@ -43,7 +43,7 @@ XLNetTokenizer, RobertaConfig, RobertaForMultipleChoice, RobertaTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_multiple_choice import (convert_examples_to_features, processors) @@ -101,7 +101,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_ner.py b/examples/run_ner.py index b35d8298fe34..0077080aecaf 100644 --- a/examples/run_ner.py +++ b/examples/run_ner.py @@ -33,7 +33,7 @@ from tqdm import tqdm, trange from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer @@ -80,7 +80,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id): {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/examples/run_squad.py b/examples/run_squad.py index d9dc2abfdec8..b954a8b8b92c 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -45,7 +45,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -100,7 +100,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/templates/adding_a_new_example_script/run_xxx.py b/templates/adding_a_new_example_script/run_xxx.py index 489dcb19c7b9..77ce587a5489 100644 --- a/templates/adding_a_new_example_script/run_xxx.py +++ b/templates/adding_a_new_example_script/run_xxx.py @@ -43,7 +43,7 @@ XLNetTokenizer, DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer) -from transformers import AdamW, WarmupLinearSchedule +from transformers import AdamW, get_linear_schedule_with_warmup from utils_squad import (read_squad_examples, convert_examples_to_features, RawResult, write_predictions, @@ -98,7 +98,7 @@ def train(args, train_dataset, model, tokenizer): {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) - scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) if args.fp16: try: from apex import amp diff --git a/transformers/__init__.py b/transformers/__init__.py index 53f3c39dc72d..426f3bd3a248 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -97,8 +97,8 @@ from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model # Optimization - from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, - WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) + from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) # TensorFlow diff --git a/transformers/optimization.py b/transformers/optimization.py index a48b5fea54d4..99e6cc75e402 100644 --- a/transformers/optimization.py +++ b/transformers/optimization.py @@ -23,90 +23,66 @@ logger = logging.getLogger(__name__) -class ConstantLRSchedule(LambdaLR): - """ Constant learning rate schedule. + +def get_constant_schedule(optimizer, last_epoch=-1): + """ Create a schedule with a constant learning rate. """ - def __init__(self, optimizer, last_epoch=-1): - super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch) + return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) -class WarmupConstantSchedule(LambdaLR): - """ Linear warmup and then constant. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Keeps multiplicative variable equal to 1. after warmup_steps. +def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): + """ Create a schedule with a constant learning rate preceded by a warmup + period during which the learning rate increases linearly between 0 and 1. """ - def __init__(self, optimizer, warmup_steps, last_epoch=-1): - self.warmup_steps = warmup_steps - super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1.0, self.warmup_steps)) + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) return 1. + return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) -class WarmupLinearSchedule(LambdaLR): - """ Linear warmup and then linear decay. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Linearly decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps. - """ - def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1, self.warmup_steps)) - return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) - - -class WarmupCosineSchedule(LambdaLR): - """ Linear warmup and then cosine decay. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - Decreases the multiplicative variable from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve. - If `cycles` (default=0.5) is different from default, then the multiplicative variable follows cosine function after warmup. + +def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): + """ Create a schedule with a learning rate that decreases linearly after + linearly increasing during a warmup period. """ - def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - self.cycles = cycles - super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1.0, self.warmup_steps)) - # progress after warmup - progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) - return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) - - -class WarmupCosineWithHardRestartsSchedule(LambdaLR): - """ Linear warmup and then cosine cycles with hard restarts. - Multiplies the learning rate defined in the optimizer by a dynamic variable determined by the current step. - Linearly increases the multiplicative variable from 0. to 1. over `warmup_steps` training steps. - If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying - learning rate (with hard restarts). + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + + +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): + """ Create a schedule with a learning rate that decreases following the + values of the cosine function between 0 and `pi * cycles` after a warmup + period during which it increases linearly between 0 and 1. """ - def __init__(self, optimizer, warmup_steps, t_total, cycles=1., last_epoch=-1): - self.warmup_steps = warmup_steps - self.t_total = t_total - self.cycles = cycles - super(WarmupCosineWithHardRestartsSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) - - def lr_lambda(self, step): - if step < self.warmup_steps: - return float(step) / float(max(1, self.warmup_steps)) - # progress after warmup - progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) - if progress >= 1.0: - return 0.0 - return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0)))) + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) +def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): + """ Create a schedule with a learning rate that decreases following the + values of the cosine function with several hard restarts, after a warmup + period during which it increases linearly between 0 and 1. + """ + def lr_lambda(current_step): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) + if progress >= 1.: + return 0. + return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) + + return LambdaLR(optimizer, lr_lambda, last_epoch) + class AdamW(Optimizer): """ Implements Adam algorithm with weight decay fix. diff --git a/transformers/tests/optimization_test.py b/transformers/tests/optimization_test.py index 84dbaca52a9c..ab9afbfcf72f 100644 --- a/transformers/tests/optimization_test.py +++ b/transformers/tests/optimization_test.py @@ -25,8 +25,12 @@ if is_torch_available(): import torch - from transformers import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, - WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) + from transformers import (AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup) else: pytestmark = pytest.mark.skip("Require Torch") @@ -87,59 +91,60 @@ def assertListAlmostEqual(self, list1, list2, tol): self.assertAlmostEqual(a, b, delta=tol) def test_constant_scheduler(self): - scheduler = ConstantLRSchedule(self.optimizer) + scheduler = get_constant_schedule(self.optimizer) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [10.] * self.num_steps self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = ConstantLRSchedule(self.optimizer) + scheduler = get_constant_schedule(self.optimizer) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_constant_scheduler(self): - scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) + scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0] self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4) + scheduler = get_constant_schedule_with_warmup(self.optimizer, num_warmup_steps=4) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_linear_scheduler(self): - scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListEqual([l[0] for l in lrs], expected_learning_rates) - scheduler = WarmupLinearSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_linear_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_cosine_scheduler(self): - scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - scheduler = WarmupCosineSchedule(self.optimizer, warmup_steps=2, t_total=10) + scheduler = get_cosine_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) def test_warmup_cosine_hard_restart_scheduler(self): - scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) lrs = unwrap_schedule(scheduler, self.num_steps) expected_learning_rates = [5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46, 0.0] self.assertEqual(len(lrs[0]), 1) self.assertListAlmostEqual([l[0] for l in lrs], expected_learning_rates, tol=1e-2) - scheduler = WarmupCosineWithHardRestartsSchedule(self.optimizer, warmup_steps=2, cycles=2, t_total=10) + scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(self.optimizer, num_warmup_steps=2, num_cycles=2, num_training_steps=10) lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps) self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2]) + if __name__ == "__main__": unittest.main()