From 99d2b37d24d005b43ed9745fd6fdb6a6c26911cf Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Thu, 7 Oct 2021 00:56:34 -0700 Subject: [PATCH 01/19] CL initial commit --- .../curriculum_learning/ds_config_cl.json | 37 +++++++ .../curriculum_learning/pretrain_gpt_cl.sh | 104 ++++++++++++++++++ megatron/arguments.py | 7 ++ megatron/checkpointing.py | 2 + megatron/learning_rates.py | 46 ++++++-- megatron/model/gpt_model.py | 16 ++- megatron/training.py | 44 +++++++- pretrain_gpt.py | 2 + 8 files changed, 243 insertions(+), 15 deletions(-) create mode 100644 examples/curriculum_learning/ds_config_cl.json create mode 100644 examples/curriculum_learning/pretrain_gpt_cl.sh diff --git a/examples/curriculum_learning/ds_config_cl.json b/examples/curriculum_learning/ds_config_cl.json new file mode 100644 index 000000000..05eb441b2 --- /dev/null +++ b/examples/curriculum_learning/ds_config_cl.json @@ -0,0 +1,37 @@ +{ + "train_batch_size": 512, + "gradient_accumulation_steps": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": 0 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015, + "max_grad_norm": 1.0, + "betas": [0.9, 0.95] + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "wall_clock_breakdown": false, + "zero_allow_untested_optimizer": false, + "curriculum_learning": { + "enabled": true, + "curriculum_type": "seqlen", + "min_difficulty": 8, + "max_difficulty": 1024, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 60000, + "difficulty_step": 8 + } + } +} diff --git a/examples/curriculum_learning/pretrain_gpt_cl.sh b/examples/curriculum_learning/pretrain_gpt_cl.sh new file mode 100644 index 000000000..c11067b63 --- /dev/null +++ b/examples/curriculum_learning/pretrain_gpt_cl.sh @@ -0,0 +1,104 @@ +#!/bin/bash +sudo pip install pybind11 + +# This is a dummy train script to show how to use curriculum +# learning, some parameters are not for actual GPT pretraining. + +############################################################ +# New configs for curriculum learning, see README.md +TRAIN_TOKENS=10000000000 +LR_DECAY_TOKENS=10000000000 +############################################################ + +TARGET_GLOBAL_BATCH_SIZE=512 +TRAIN_SAMPLES=146484375 +LR=1.0e-4 +MIN_LR=1.0e-5 +LR_DECAY_SAMPLES=126953125 +LR_WARMUP_SAMPLES=183105 + +LOG_INTERVAL=100 +EVAL_ITERS=10 +EVAL_INTERVAL=100 +SAVE_INTERVAL=1000 + +VOCAB_PATH=/data/Megatron-LM/data/gpt2-vocab.json +MERGE_PATH=/data/Megatron-LM/data/gpt2-merges.txt +DATA_PATH=/data/Megatron-LM/data/indexed_datasets/megatron + +MICRO_BATCH_SIZE=1 +MP_SIZE=1 +PP_SIZE=1 + +NUM_GPUS=128 +echo ${NUM_GPUS} +if [[ $PP_SIZE -gt 0 ]]; then + DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) +else + DP_SIZE=$(( ${NUM_GPUS} / ${MP_SIZE} )) +fi +GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${MICRO_BATCH_SIZE} * ${DP_SIZE}) )) + +NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MICRO_BATCH_SIZE}-cl" +current_time=$(date "+%Y.%m.%d-%H.%M.%S") +host="${HOSTNAME}" +TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}" +mkdir -p ${TENSORBOARD_DIR} +CHECKPOINT_PATH="checkpoints/${NAME}" + +megatron_options=" \ + --data-path ${DATA_PATH} \ + --vocab-file ${VOCAB_PATH} \ + --merge-file ${MERGE_PATH} \ + --data-impl mmap \ + --override-lr-scheduler \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --tensor-model-parallel-size ${MP_SIZE} \ + --init-method-std 0.014 \ + --lr-decay-samples ${LR_DECAY_SAMPLES} \ + --lr-decay-tokens ${LR_DECAY_TOKENS} \ + --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ + --micro-batch-size ${MICRO_BATCH_SIZE} \ + --global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \ + --num-layers 12 \ + --hidden-size 768 \ + --num-attention-heads 12 \ + --seq-length 1024 \ + --max-position-embeddings 1024 \ + --train-samples ${TRAIN_SAMPLES} \ + --train-tokens ${TRAIN_TOKENS} \ + --lr ${LR} \ + --min-lr ${MIN_LR} \ + --lr-decay-style cosine \ + --split 98,2,0 \ + --log-interval ${LOG_INTERVAL} \ + --eval-interval ${EVAL_INTERVAL} \ + --eval-iters ${EVAL_ITERS} \ + --save-interval ${SAVE_INTERVAL} \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --hysteresis 2 \ + --num-workers 0 \ + --checkpoint-activations \ + --fp16 \ + --load ${CHECKPOINT_PATH} \ + --save ${CHECKPOINT_PATH} \ + --tensorboard-queue-size 1 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-dir ${TENSORBOARD_DIR}" + +config_json="ds_config_cl.json" + +deepspeed_options=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --pipeline-model-parallel-size ${PP_SIZE} \ + --partition-activations" + +run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log" +echo ${run_cmd} +eval ${run_cmd} +set +x \ No newline at end of file diff --git a/megatron/arguments.py b/megatron/arguments.py index 787cb62e9..998bc6580 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -170,6 +170,7 @@ def parse_args(extra_args_provider=None, defaults={}, # Consumed tokens. args.consumed_train_samples = 0 args.consumed_valid_samples = 0 + args.consumed_train_tokens = 0 args.gigaflos_no_embeds = 0 # Iteration-based training. @@ -428,6 +429,9 @@ def _add_training_args(parser): help='Total number of samples to train over all ' 'training runs. Note that either train-iters or ' 'train-samples should be provided.') + group.add_argument('--train-tokens', type=int, default=None, + help='Total number of tokens to train over all ' + 'training runs.') group.add_argument('--log-interval', type=int, default=100, help='Report loss and timing interval.') group.add_argument('--exit-interval', type=int, default=None, @@ -495,6 +499,9 @@ def _add_learning_rate_args(parser): group.add_argument('--lr-decay-samples', type=int, default=None, help='number of samples to decay learning rate over,' ' If None defaults to `--train-samples`') + group.add_argument('--lr-decay-tokens', type=int, default=None, + help='number of tokens to decay learning rate over,' + ' If not None will override iter/sample-based decay') group.add_argument('--lr-warmup-fraction', type=float, default=None, help='fraction of lr-warmup-(iters/samples) to use ' 'for warmup (as a float)') diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index e72d2ede6..a24fd8947 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -127,6 +127,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration + state_dict['tokens'] = args.consumed_train_tokens # DeepSpeed saves the model/optimizer/scheduler if not args.deepspeed: @@ -339,6 +340,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True else: try: iteration = state_dict['iteration'] + args.consumed_train_tokens = state_dict['tokens'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index d200bdb17..2f8f3bbd4 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -17,7 +17,7 @@ import math -from megatron import print_rank_0 +from megatron import print_rank_0, get_args class AnnealingLR(object): """Anneals the learning rate.""" @@ -26,7 +26,7 @@ def __init__(self, optimizer, max_lr, min_lr, warmup_steps, decay_steps, decay_style, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False): - + args = get_args() # Class values. self.optimizer = optimizer @@ -41,6 +41,10 @@ def __init__(self, optimizer, max_lr, min_lr, assert self.decay_steps > 0 assert self.warmup_steps < self.decay_steps + self.decay_tokens = args.lr_decay_tokens + self.num_tokens = 0 + self.warmup_tokens = 0 + self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler @@ -61,6 +65,9 @@ def get_lr(self): # Use linear warmup for the initial part. if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: + if self.num_steps == self.warmup_steps and \ + self.decay_tokens is not None: + self.warmup_tokens = self.num_tokens return self.max_lr * float(self.num_steps) / \ float(self.warmup_steps) @@ -68,14 +75,21 @@ def get_lr(self): if self.decay_style == 'constant': return self.max_lr - # For any steps larger than `self.decay_steps`, use `self.min_lr`. - if self.num_steps > self.decay_steps: - return self.min_lr - - # If we are done with the warmup period, use the decay style. - num_steps_ = self.num_steps - self.warmup_steps - decay_steps_ = self.decay_steps - self.warmup_steps - decay_ratio = float(num_steps_) / float(decay_steps_) + if self.decay_tokens is None: + # For any steps larger than `self.decay_steps`, use `self.min_lr`. + if self.num_steps > self.decay_steps: + return self.min_lr + + # If we are done with the warmup period, use the decay style. + num_steps_ = self.num_steps - self.warmup_steps + decay_steps_ = self.decay_steps - self.warmup_steps + decay_ratio = float(num_steps_) / float(decay_steps_) + else: + if self.num_tokens > self.decay_tokens: + return self.min_lr + num_tokens_ = self.num_tokens - self.warmup_tokens + decay_tokens_ = self.decay_tokens - self.warmup_tokens + decay_ratio = float(num_tokens_) / float(decay_tokens_) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = self.max_lr - self.min_lr @@ -91,8 +105,12 @@ def get_lr(self): return self.min_lr + coeff * delta_lr - def step(self, increment): + def step(self, increment, token_num=None): """Set lr for all parameters groups.""" + if token_num is None: + args = get_args() + token_num = args.consumed_train_tokens + self.num_tokens = token_num self.num_steps += increment new_lr = self.get_lr() for group in self.optimizer.param_groups: @@ -104,6 +122,8 @@ def state_dict(self): 'max_lr': self.max_lr, 'warmup_steps': self.warmup_steps, 'num_steps': self.num_steps, + 'warmup_tokens': self.warmup_tokens, + 'num_tokens': self.num_tokens, 'decay_style': self.decay_style, 'decay_steps': self.decay_steps, 'min_lr': self.min_lr @@ -161,4 +181,6 @@ def load_state_dict(self, sd): num_steps = sd['num_iters'] else: num_steps = sd['num_steps'] - self.step(increment=num_steps) + self.warmup_tokens = sd['warmup_tokens'] + self.num_tokens = sd['num_tokens'] + self.step(num_steps, self.num_tokens) diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index f76cf1347..c8c50510a 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -102,8 +102,20 @@ def set_input_tensor(self, input_tensor): def forward(self, input_ids, position_ids, attention_mask, labels=None, tokentype_ids=None, layer_past=None, get_key_value=False, - forward_method_parallel_output=None): - + forward_method_parallel_output=None, curriculum_seqlen=None): + if curriculum_seqlen is not None: + args = get_args() + args.curriculum_seqlen = curriculum_seqlen + if curriculum_seqlen < input_ids.size()[1]: + # seqlen-based curriculum learning + # input_ids, position_ids, labels have size [batch size, seqlen] + input_ids = input_ids[:, :curriculum_seqlen].contiguous() + position_ids = position_ids[:, :curriculum_seqlen].contiguous() + labels = labels[:, :curriculum_seqlen].contiguous() + + # attention_mask has size [1, 1, seqlen, seqlen] + attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous() + lm_output = self.language_model( input_ids, position_ids, diff --git a/megatron/training.py b/megatron/training.py index b90dab018..a37a359ba 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -19,6 +19,7 @@ import math import sys import time +import json # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() @@ -113,6 +114,14 @@ def pretrain(train_valid_test_dataset_provider, args = get_args() timers = get_timers() + args.curriculum_learning = False + if args.deepspeed: + args.deepspeed_configuration = json.load( + open(args.deepspeed_config, 'r', encoding='utf-8')) + if "curriculum_learning" in args.deepspeed_configuration: + if "enabled" in args.deepspeed_configuration["curriculum_learning"]: + args.curriculum_learning = args.deepspeed_configuration["curriculum_learning"]["enabled"] + # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) @@ -557,10 +566,14 @@ def add_to_logging(name): is_last_rank(): writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples) writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration) + writer.add_scalar('steps-vs-tokens/y=steps,x=tokens', iteration, args.consumed_train_tokens) + writer.add_scalar('steps-vs-tokens/y=tokens,x=steps', args.consumed_train_tokens, iteration) if args.log_learning_rate_to_tensorboard: writer.add_scalar('learning-rate/learning-rate', learning_rate, iteration) writer.add_scalar('learning-rate/learning-rate vs samples', learning_rate, args.consumed_train_samples) + writer.add_scalar('learning-rate/learning-rate vs tokens', learning_rate, + args.consumed_train_tokens) if args.log_batch_size_to_tensorboard: writer.add_scalar('batch-size/batch-size', batch_size, iteration) writer.add_scalar('batch-size/batch-size vs samples', batch_size, @@ -569,24 +582,37 @@ def add_to_logging(name): writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration) writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key], args.consumed_train_samples) + writer.add_scalar(f"lm-loss-training/{key}" + ' vs tokens', loss_dict[key], + args.consumed_train_tokens) writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key], args.gigaflos_no_embeds) if args.log_loss_scale_to_tensorboard: writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration) writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale, args.consumed_train_samples) + writer.add_scalar('loss-scale/loss-scale vs tokens', loss_scale, + args.consumed_train_tokens) if grad_norm is not None: writer.add_scalar('grad-norm/grad-norm', grad_norm, iteration) writer.add_scalar('grad-norm/grad-norm vs samples', grad_norm, args.consumed_train_samples) + writer.add_scalar('grad-norm/grad-norm vs tokens', grad_norm, + args.consumed_train_tokens) if num_zeros_in_grad is not None: writer.add_scalar('num-zeros/num-zeros', num_zeros_in_grad, iteration) writer.add_scalar('num-zeros/num-zeros vs samples', num_zeros_in_grad, args.consumed_train_samples) + writer.add_scalar('num-zeros/num-zeros vs tokens', num_zeros_in_grad, + args.consumed_train_tokens) if params_norm is not None: writer.add_scalar('params-norm/params-norm', params_norm, iteration) writer.add_scalar('params-norm/params-norm vs samples', params_norm, args.consumed_train_samples) + writer.add_scalar('params-norm/params-norm vs tokens', params_norm, + args.consumed_train_tokens) + if args.curriculum_learning: + writer.add_scalar('curriculum_seqlen', args.curriculum_seqlen, + iteration) if args.log_timers_to_tensorboard: timers.write(timers_to_log, writer, iteration, normalizer=total_iterations) @@ -601,10 +627,14 @@ def add_to_logging(name): elapsed_time_per_iteration, iteration) writer.add_scalar('iteration-time/iteration-time vs samples', elapsed_time_per_iteration, args.consumed_train_samples) + writer.add_scalar('iteration-time/iteration-time vs tokens', + elapsed_time_per_iteration, args.consumed_train_tokens) log_string = ' iteration {:8d}/{:8d} |'.format( iteration, args.train_iters) log_string += ' consumed samples: {:12d} |'.format( args.consumed_train_samples) + log_string += ' consumed tokens: {:12d} |'.format( + args.consumed_train_tokens) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time_per_iteration * 1000.0) log_string += ' learning rate: {:.3E} |'.format(learning_rate) @@ -624,6 +654,8 @@ def add_to_logging(name): log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad) if params_norm is not None: log_string += ' params norm: {:.3f} |'.format(params_norm) + if args.curriculum_learning: + log_string += ' curriculum seqlen: {:5d} |'.format(args.curriculum_seqlen) log_string += ' number of skipped iterations: {:3d} |'.format( total_loss_dict[skipped_iters_key]) log_string += ' number of nan iterations: {:3d} |'.format( @@ -678,7 +710,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler, timers('interval-time').start() print_datetime('before the start of training step') report_memory_flag = True - while iteration < args.train_iters: + while iteration < args.train_iters and (args.train_tokens is None or \ + args.consumed_train_tokens < args.train_tokens): update_num_microbatches(args.consumed_train_samples) if args.deepspeed: # inform deepspeed of any batch size changes @@ -699,6 +732,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler, args.micro_batch_size * \ get_num_microbatches() args.consumed_train_samples += new_samples + if args.curriculum_learning: + args.consumed_train_tokens += new_samples * args.curriculum_seqlen + else: + args.consumed_train_tokens += new_samples * args.seq_length args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True)) # Logging. @@ -841,6 +878,9 @@ def evaluate_and_print_results(prefix, forward_step_func, writer.add_scalar(f'lm-loss-validation/{key} validation vs samples', total_loss_dict[key].item(), args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} validation vs tokens', + total_loss_dict[key].item(), + args.consumed_train_tokens) writer.add_scalar(f'lm-loss-validation/{key} validation vs gigaflos (without embeddings)', total_loss_dict[key].item(), args.gigaflos_no_embeds) @@ -849,6 +889,8 @@ def evaluate_and_print_results(prefix, forward_step_func, iteration) writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs samples', ppl, args.consumed_train_samples) + writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs tokens', + ppl, args.consumed_train_tokens) writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs gigaflos (without embeddings)', ppl, args.gigaflos_no_embeds) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 0137cad5e..c3efed8cb 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -174,6 +174,8 @@ def forward_step(data_iterator, model): output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + if args.curriculum_learning and args.curriculum_seqlen < args.seq_length: + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() return output_tensor, partial(loss_func, loss_mask) From 4c9c4a39a411313b738e642bd28237cfca54a17f Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Thu, 7 Oct 2021 21:36:02 -0700 Subject: [PATCH 02/19] CL+PP support --- megatron/training.py | 15 ++++++++++++--- pretrain_gpt.py | 13 +++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index a37a359ba..802a7b096 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -118,9 +118,17 @@ def pretrain(train_valid_test_dataset_provider, if args.deepspeed: args.deepspeed_configuration = json.load( open(args.deepspeed_config, 'r', encoding='utf-8')) - if "curriculum_learning" in args.deepspeed_configuration: - if "enabled" in args.deepspeed_configuration["curriculum_learning"]: - args.curriculum_learning = args.deepspeed_configuration["curriculum_learning"]["enabled"] + if "curriculum_learning" in args.deepspeed_configuration and \ + "enabled" in args.deepspeed_configuration["curriculum_learning"]: + args.curriculum_learning = args.deepspeed_configuration[ \ + "curriculum_learning"]["enabled"] + if args.curriculum_learning and \ + args.pipeline_model_parallel_size >= 1: + from deepspeed.runtime.data_pipeline.curriculum_scheduler \ + import CurriculumScheduler + args.curriculum_scheduler = CurriculumScheduler( \ + args.deepspeed_configuration["curriculum_learning"]) + # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() @@ -728,6 +736,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, optimizer, lr_scheduler) iteration += 1 + args.iteration = iteration new_samples = mpu.get_data_parallel_world_size() * \ args.micro_batch_size * \ get_num_microbatches() diff --git a/pretrain_gpt.py b/pretrain_gpt.py index c3efed8cb..e4f4e8f14 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -146,6 +146,19 @@ def get_batch_pipe(data): prefix_indices=None, loss_on_targets_only=args.loss_on_targets_only ) + if args.curriculum_learning: + args.curriculum_seqlen = args.curriculum_scheduler.update_difficulty( \ + args.iteration + 1) + if args.curriculum_seqlen < tokens.size()[1]: + # seqlen-based curriculum learning + # tokens, position_ids, labels, loss_mask have size [batch size, seqlen] + tokens = tokens[:, :args.curriculum_seqlen].contiguous() + position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() + labels = labels[:, :args.curriculum_seqlen].contiguous() + loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() + + # attention_mask has size [1, 1, seqlen, seqlen] + attention_mask = attention_mask[:, :, :args.curriculum_seqlen, :args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) From 82a319843ec041d0965273d6d5fafb8b68f0ce74 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Fri, 8 Oct 2021 02:32:15 -0700 Subject: [PATCH 03/19] update --- examples/curriculum_learning/README.md | 26 +++++++++++++++++++ .../curriculum_learning/pretrain_gpt_cl.sh | 2 +- pretrain_gpt.py | 3 ++- 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 examples/curriculum_learning/README.md diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md new file mode 100644 index 000000000..60aa1290f --- /dev/null +++ b/examples/curriculum_learning/README.md @@ -0,0 +1,26 @@ +This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer our [paper](https://arxiv.org/abs/2108.06084). + +# Disable batch size warmup (--rampup-batch-size) +In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So you shall remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. + +# Token-based training termination + +Because CL changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 2X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens (e.g., 300B for GPT-3 like training). + +# Token-based LR decay + +Again because CL changes token per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you are using `--lr-warmup-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the full seqlen (e.g. 2K for GPT-3) to it. If `--lr-decay-tokens` is given, it will override `--lr-warmup-samples` so you can keep both in script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup which is both unnecessary and harmful. + +# Token-based tensorboard + +Because of the above changes, we also add token-based tensorboard scalars. We also add scalar that plot the seqlen at each step. + +# Curriculum learning hyperparameters tuning strategy + +The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are three configs that you need to change, and two of them need some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. + +First, the `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. + +Second, the `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. + +Third, the `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gain. diff --git a/examples/curriculum_learning/pretrain_gpt_cl.sh b/examples/curriculum_learning/pretrain_gpt_cl.sh index c11067b63..45d48a9b7 100644 --- a/examples/curriculum_learning/pretrain_gpt_cl.sh +++ b/examples/curriculum_learning/pretrain_gpt_cl.sh @@ -63,7 +63,7 @@ megatron_options=" \ --global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \ --num-layers 12 \ --hidden-size 768 \ - --num-attention-heads 12 \ + --num-attention-heads 16 \ --seq-length 1024 \ --max-position-embeddings 1024 \ --train-samples ${TRAIN_SAMPLES} \ diff --git a/pretrain_gpt.py b/pretrain_gpt.py index e4f4e8f14..7c0bb13bd 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -72,7 +72,7 @@ def model_provider(pre_process=True, post_process=True): # must be bool or the training crashes expecting bool, but getting Half args.attn_mask = attention_mask.to(torch.bool) - + args.attn_mask_original = attention_mask.to(torch.bool) else: model = GPTModel( num_tokentypes=0, @@ -159,6 +159,7 @@ def get_batch_pipe(data): # attention_mask has size [1, 1, seqlen, seqlen] attention_mask = attention_mask[:, :, :args.curriculum_seqlen, :args.curriculum_seqlen].contiguous() + args.attn_mask = args.attn_mask_original[:, :, :args.curriculum_seqlen, :args.curriculum_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) From 21e91b97d69fbd02b4173c18c0370c1f9cb39f50 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 11:31:47 -0700 Subject: [PATCH 04/19] Apply suggestions from code review Co-authored-by: Stas Bekman --- examples/curriculum_learning/README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md index 60aa1290f..9859fe616 100644 --- a/examples/curriculum_learning/README.md +++ b/examples/curriculum_learning/README.md @@ -1,7 +1,7 @@ -This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer our [paper](https://arxiv.org/abs/2108.06084). +This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). # Disable batch size warmup (--rampup-batch-size) -In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So you shall remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. +In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. # Token-based training termination @@ -9,18 +9,18 @@ Because CL changes length of each sequence/sample during training, it is very ha # Token-based LR decay -Again because CL changes token per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you are using `--lr-warmup-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the full seqlen (e.g. 2K for GPT-3) to it. If `--lr-decay-tokens` is given, it will override `--lr-warmup-samples` so you can keep both in script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup which is both unnecessary and harmful. +Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-warmup-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). If `--lr-decay-tokens` is given, it will override `--lr-warmup-samples` so you can keep both in the script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. # Token-based tensorboard -Because of the above changes, we also add token-based tensorboard scalars. We also add scalar that plot the seqlen at each step. +Because of the above changes, we also add token-based tensorboard scalars. We also add scalars that plot the seqlen at each step. # Curriculum learning hyperparameters tuning strategy -The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are three configs that you need to change, and two of them need some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. +The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are three config entries that you need to change, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. First, the `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. -Second, the `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. +Second, the `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. -Third, the `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gain. +Third, the `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains. From 6010a3dd23ac50c300301dc5dd6d67f27152206b Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 11:34:41 -0700 Subject: [PATCH 05/19] apply code review comments --- examples/curriculum_learning/README.md | 2 +- .../curriculum_learning/pretrain_gpt_cl.sh | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md index 9859fe616..43e533cbd 100644 --- a/examples/curriculum_learning/README.md +++ b/examples/curriculum_learning/README.md @@ -9,7 +9,7 @@ Because CL changes length of each sequence/sample during training, it is very ha # Token-based LR decay -Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-warmup-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). If `--lr-decay-tokens` is given, it will override `--lr-warmup-samples` so you can keep both in the script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. +Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). If `--lr-decay-tokens` is given, it will override `--lr-decay-samples` so you can keep both in the script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. # Token-based tensorboard diff --git a/examples/curriculum_learning/pretrain_gpt_cl.sh b/examples/curriculum_learning/pretrain_gpt_cl.sh index 45d48a9b7..ce96ea1ad 100644 --- a/examples/curriculum_learning/pretrain_gpt_cl.sh +++ b/examples/curriculum_learning/pretrain_gpt_cl.sh @@ -1,21 +1,21 @@ #!/bin/bash -sudo pip install pybind11 # This is a dummy train script to show how to use curriculum # learning, some parameters are not for actual GPT pretraining. -############################################################ -# New configs for curriculum learning, see README.md -TRAIN_TOKENS=10000000000 -LR_DECAY_TOKENS=10000000000 -############################################################ - TARGET_GLOBAL_BATCH_SIZE=512 TRAIN_SAMPLES=146484375 LR=1.0e-4 MIN_LR=1.0e-5 LR_DECAY_SAMPLES=126953125 LR_WARMUP_SAMPLES=183105 +SEQLEN=1024 + +############################################################ +# New configs for curriculum learning, see README.md +TRAIN_TOKENS=10000000000 +LR_DECAY_TOKENS=$(($LR_DECAY_SAMPLES*$SEQLEN)) +############################################################ LOG_INTERVAL=100 EVAL_ITERS=10 @@ -64,8 +64,8 @@ megatron_options=" \ --num-layers 12 \ --hidden-size 768 \ --num-attention-heads 16 \ - --seq-length 1024 \ - --max-position-embeddings 1024 \ + --seq-length ${SEQLEN} \ + --max-position-embeddings ${SEQLEN} \ --train-samples ${TRAIN_SAMPLES} \ --train-tokens ${TRAIN_TOKENS} \ --lr ${LR} \ From 405c7a69b98666bf7e7453169e17a57cb7e9dbf2 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 11:44:36 -0700 Subject: [PATCH 06/19] make it easier to read large numbers --- .../curriculum_learning/pretrain_gpt_cl.sh | 12 +- tests/test_training.py | 183 ++++++++++++------ 2 files changed, 132 insertions(+), 63 deletions(-) diff --git a/examples/curriculum_learning/pretrain_gpt_cl.sh b/examples/curriculum_learning/pretrain_gpt_cl.sh index ce96ea1ad..aa84ecec3 100644 --- a/examples/curriculum_learning/pretrain_gpt_cl.sh +++ b/examples/curriculum_learning/pretrain_gpt_cl.sh @@ -4,16 +4,16 @@ # learning, some parameters are not for actual GPT pretraining. TARGET_GLOBAL_BATCH_SIZE=512 -TRAIN_SAMPLES=146484375 +TRAIN_SAMPLES=146_484_375 LR=1.0e-4 MIN_LR=1.0e-5 -LR_DECAY_SAMPLES=126953125 -LR_WARMUP_SAMPLES=183105 +LR_DECAY_SAMPLES=126_953_125 +LR_WARMUP_SAMPLES=183_105 SEQLEN=1024 ############################################################ # New configs for curriculum learning, see README.md -TRAIN_TOKENS=10000000000 +TRAIN_TOKENS=10_000_000_000 LR_DECAY_TOKENS=$(($LR_DECAY_SAMPLES*$SEQLEN)) ############################################################ @@ -43,7 +43,7 @@ NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MI current_time=$(date "+%Y.%m.%d-%H.%M.%S") host="${HOSTNAME}" TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}" -mkdir -p ${TENSORBOARD_DIR} +mkdir -p ${TENSORBOARD_DIR} CHECKPOINT_PATH="checkpoints/${NAME}" megatron_options=" \ @@ -101,4 +101,4 @@ deepspeed_options=" \ run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log" echo ${run_cmd} eval ${run_cmd} -set +x \ No newline at end of file +set +x diff --git a/tests/test_training.py b/tests/test_training.py index 6b3ea9189..753ca4357 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -19,6 +19,7 @@ import glob import unittest from pathlib import Path +from parameterized import parameterized from megatron.testing_utils import ( CaptureStdout, @@ -71,70 +72,138 @@ def setUp(self): if os.path.exists(meg_lock_file_path): os.unlink(meg_lock_file_path) - - def test_training_all(self): - # all in one test - src_dir = self.src_dir + def get_variation_config(self, variation, output_dir): data_dir = f"{self.data_dir}/gpt2" - output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) pp_size, tp_size, dp_size = get_3d_dimensions() num_gpus = pp_size * tp_size * dp_size n_samples = 200 # about 37 iterations exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume - args = f""" - --tensor-model-parallel-size {tp_size} - --pipeline-model-parallel-size {pp_size} - --distributed-backend nccl - - --num-layers 2 - --hidden-size 64 - --num-attention-heads 2 - --seq-length 128 - --max-position-embeddings 1024 - --micro-batch-size 1 - --rampup-batch-size 2 2 {n_samples} - --global-batch-size 16 - --train-samples {n_samples} - - --optimizer adam - --adam-beta1 0.9 - --adam-beta2 0.95 - --adam-eps 1e-8 - --lr 1e-4 - --lr-warmup-samples 5 - --clip-grad 1.0 - --weight-decay 1e-1 - --fp16 - - --log-interval 5 - --save-interval 10 - --eval-interval 10 - --eval-iters 5 - --checkpoint-activations - --glu-activation geglu - --exit-interval {exit_interval} - - --merge-file {data_dir}/gpt2-tiny-merges.txt - --vocab-file {data_dir}/gpt2-tiny-vocab.json - --save {output_dir}/checkpoints - --load {output_dir}/checkpoints - --data-path {data_dir}/meg-gpt2-openwebtext_text_document - --codecarbon-dir {output_dir}/codecarbon - --tensorboard-dir {output_dir}/tensorboard - --tensorboard-queue-size 5 - --log-timers-to-tensorboard - --log-batch-size-to-tensorboard - --log-validation-ppl-to-tensorboard - """.split() - - ds_args = f""" - --deepspeed - --deepspeed_config {self.test_file_dir_str}/ds_config.json - --zero-stage 1 - --deepspeed-activation-checkpointing - """.split() + + if variation == "base": + # XXX: refactor repeated elements + args = f""" + --tensor-model-parallel-size {tp_size} + --pipeline-model-parallel-size {pp_size} + --distributed-backend nccl + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --seq-length 128 + --max-position-embeddings 1024 + --micro-batch-size 1 + --rampup-batch-size 2 2 {n_samples} + --global-batch-size 16 + --train-samples {n_samples} + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-4 + --lr-warmup-samples 5 + --clip-grad 1.0 + --weight-decay 1e-1 + --fp16 + + --log-interval 5 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --glu-activation geglu + --exit-interval {exit_interval} + + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab-file {data_dir}/gpt2-tiny-vocab.json + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + """.split() + + ds_args = f""" + --deepspeed + --deepspeed_config {self.test_file_dir_str}/ds_config.json + --zero-stage 1 + --deepspeed-activation-checkpointing + """.split() + + elif variation == "cl": + args = f""" + --tensor-model-parallel-size {tp_size} + --pipeline-model-parallel-size {pp_size} + --distributed-backend nccl + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --seq-length 128 + --max-position-embeddings 1024 + --micro-batch-size 1 + --rampup-batch-size 2 2 {n_samples} + --global-batch-size 16 + --train-samples {n_samples} + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-4 + --lr-warmup-samples 5 + --clip-grad 1.0 + --weight-decay 1e-1 + --fp16 + + --log-interval 5 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --glu-activation geglu + --exit-interval {exit_interval} + + --merge-file {data_dir}/gpt2-tiny-merges.txt + --vocab-file {data_dir}/gpt2-tiny-vocab.json + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --codecarbon-dir {output_dir}/codecarbon + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + """.split() + + ds_args = f""" + --deepspeed + --deepspeed_config {self.test_file_dir_str}/ds_config_cl.json + --zero-stage 1 + --deepspeed-activation-checkpointing + """.split() + + + else: + raise ValueError(f"Don't know of variation {variation}") + + return args, ds_args, num_gpus + + + @parameterized.expand(["base", "cl"]) + def test_training_all(self, variation): + # all in one test + src_dir = self.src_dir + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + + args, ds_args, num_gpus = self.get_variation_config(variation, output_dir) script = [f"{src_dir}/pretrain_gpt.py"] launcher = get_launcher(num_gpus) From a90d30eba0cb3bedcc5e73a72c4c1070b79470c6 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 11:57:40 -0700 Subject: [PATCH 07/19] add a cl test --- tests/ds_config_cl.json | 29 +++++++++++++++++++++++++++++ tests/test_training.py | 19 +++++++++++++++---- 2 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 tests/ds_config_cl.json diff --git a/tests/ds_config_cl.json b/tests/ds_config_cl.json new file mode 100644 index 000000000..58e957d7b --- /dev/null +++ b/tests/ds_config_cl.json @@ -0,0 +1,29 @@ +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 16, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 1 + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "curriculum_learning": { + "enabled": true, + "curriculum_type": "seqlen", + "min_difficulty": 8, + "max_difficulty": 128, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 25, + "difficulty_step": 2 + } + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} diff --git a/tests/test_training.py b/tests/test_training.py index 753ca4357..000d57077 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -80,6 +80,8 @@ def get_variation_config(self, variation, output_dir): n_samples = 200 # about 37 iterations exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume + seq_len = 128 + if variation == "base": # XXX: refactor repeated elements @@ -91,7 +93,7 @@ def get_variation_config(self, variation, output_dir): --num-layers 2 --hidden-size 64 --num-attention-heads 2 - --seq-length 128 + --seq-length {seq_len} --max-position-embeddings 1024 --micro-batch-size 1 --rampup-batch-size 2 2 {n_samples} @@ -104,6 +106,8 @@ def get_variation_config(self, variation, output_dir): --adam-eps 1e-8 --lr 1e-4 --lr-warmup-samples 5 + --lr-decay-samples 5 + --lr-decay-tokens 5 --clip-grad 1.0 --weight-decay 1e-1 --fp16 @@ -137,6 +141,11 @@ def get_variation_config(self, variation, output_dir): """.split() elif variation == "cl": + # CurriculumLearning + + lr_decay_samples = 6 + lr_decay_tokens = lr_decay_samples * seq_len + args = f""" --tensor-model-parallel-size {tp_size} --pipeline-model-parallel-size {pp_size} @@ -145,12 +154,12 @@ def get_variation_config(self, variation, output_dir): --num-layers 2 --hidden-size 64 --num-attention-heads 2 - --seq-length 128 + --seq-length {seq_len} --max-position-embeddings 1024 --micro-batch-size 1 - --rampup-batch-size 2 2 {n_samples} --global-batch-size 16 - --train-samples {n_samples} + --train-samples {n_samples*2} + --train-tokens {n_samples} --optimizer adam --adam-beta1 0.9 @@ -158,6 +167,8 @@ def get_variation_config(self, variation, output_dir): --adam-eps 1e-8 --lr 1e-4 --lr-warmup-samples 5 + --lr-decay-samples {lr_decay_samples} + --lr-decay-tokens {lr_decay_tokens} --clip-grad 1.0 --weight-decay 1e-1 --fp16 From fb04d2bed1eab9d0c2439be32161a35076401424 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 12:01:36 -0700 Subject: [PATCH 08/19] apply review comments --- examples/curriculum_learning/README.md | 2 +- examples/curriculum_learning/pretrain_gpt_cl.sh | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md index 43e533cbd..adf9a831d 100644 --- a/examples/curriculum_learning/README.md +++ b/examples/curriculum_learning/README.md @@ -9,7 +9,7 @@ Because CL changes length of each sequence/sample during training, it is very ha # Token-based LR decay -Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). If `--lr-decay-tokens` is given, it will override `--lr-decay-samples` so you can keep both in the script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. +Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). The you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. # Token-based tensorboard diff --git a/examples/curriculum_learning/pretrain_gpt_cl.sh b/examples/curriculum_learning/pretrain_gpt_cl.sh index ce96ea1ad..727031366 100644 --- a/examples/curriculum_learning/pretrain_gpt_cl.sh +++ b/examples/curriculum_learning/pretrain_gpt_cl.sh @@ -56,7 +56,6 @@ megatron_options=" \ --adam-beta2 0.95 \ --tensor-model-parallel-size ${MP_SIZE} \ --init-method-std 0.014 \ - --lr-decay-samples ${LR_DECAY_SAMPLES} \ --lr-decay-tokens ${LR_DECAY_TOKENS} \ --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ --micro-batch-size ${MICRO_BATCH_SIZE} \ From 8e4a46601ebe027185ff3d9d840dbcc7a95081fd Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 12:03:24 -0700 Subject: [PATCH 09/19] Update examples/curriculum_learning/README.md Co-authored-by: Stas Bekman --- examples/curriculum_learning/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md index adf9a831d..3ff6033a5 100644 --- a/examples/curriculum_learning/README.md +++ b/examples/curriculum_learning/README.md @@ -9,7 +9,7 @@ Because CL changes length of each sequence/sample during training, it is very ha # Token-based LR decay -Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). The you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. +Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). Then you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. # Token-based tensorboard From d86a4f426323399a8ca05a68744253c69bb5c32d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 12:05:13 -0700 Subject: [PATCH 10/19] update --- tests/test_training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_training.py b/tests/test_training.py index 000d57077..d73ad164a 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -167,7 +167,6 @@ def get_variation_config(self, variation, output_dir): --adam-eps 1e-8 --lr 1e-4 --lr-warmup-samples 5 - --lr-decay-samples {lr_decay_samples} --lr-decay-tokens {lr_decay_tokens} --clip-grad 1.0 --weight-decay 1e-1 From e5a335dd3594dae0f6f112efd2e7e3fabdc05c8e Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 12:15:15 -0700 Subject: [PATCH 11/19] fix --- tests/test_training.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index d73ad164a..cfd4d6341 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -106,8 +106,7 @@ def get_variation_config(self, variation, output_dir): --adam-eps 1e-8 --lr 1e-4 --lr-warmup-samples 5 - --lr-decay-samples 5 - --lr-decay-tokens 5 + --lr-decay-samples 6 --clip-grad 1.0 --weight-decay 1e-1 --fp16 From 0c4073b62054df23dadda89195ee1f23c6df1eef Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 12:17:00 -0700 Subject: [PATCH 12/19] new requirement --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index b8217cf72..6dc1385fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ datasets nltk numpy +parameterized pybind11 regex six From d25fa9e08d3f4f02a79b3e430860f43feb3d46c0 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 13:55:55 -0700 Subject: [PATCH 13/19] Update megatron/learning_rates.py Co-authored-by: Stas Bekman --- megatron/learning_rates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 2f8f3bbd4..74004b66b 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -76,6 +76,8 @@ def get_lr(self): return self.max_lr if self.decay_tokens is None: + # step-based decay + # For any steps larger than `self.decay_steps`, use `self.min_lr`. if self.num_steps > self.decay_steps: return self.min_lr From 7cd53dc0be48871bd5bef22ffd14a524e54670e0 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 13:56:05 -0700 Subject: [PATCH 14/19] Update megatron/learning_rates.py Co-authored-by: Stas Bekman --- megatron/learning_rates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 74004b66b..5435b60b4 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -87,6 +87,8 @@ def get_lr(self): decay_steps_ = self.decay_steps - self.warmup_steps decay_ratio = float(num_steps_) / float(decay_steps_) else: + # token-based decay + if self.num_tokens > self.decay_tokens: return self.min_lr num_tokens_ = self.num_tokens - self.warmup_tokens From 5a492b34597878df065a451db2e8d0a6460c0ae4 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 15:41:11 -0700 Subject: [PATCH 15/19] fix samples and tokens - thank you Conglong --- tests/test_training.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_training.py b/tests/test_training.py index cfd4d6341..4f1edd786 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -78,7 +78,7 @@ def get_variation_config(self, variation, output_dir): pp_size, tp_size, dp_size = get_3d_dimensions() num_gpus = pp_size * tp_size * dp_size - n_samples = 200 # about 37 iterations + n_samples = 300 # about 56 iterations exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume seq_len = 128 @@ -145,6 +145,14 @@ def get_variation_config(self, variation, output_dir): lr_decay_samples = 6 lr_decay_tokens = lr_decay_samples * seq_len + train_tokens = n_samples * seq_len + + # XXX: if changing seq_len from 128, must adjust ds config to: + # curriculum_learning.max_difficulty: $SEQLEN + + # XXX: probably we should write the ds config on the fly to keep everything in sync, + # rather than using the pre-saved config + args = f""" --tensor-model-parallel-size {tp_size} --pipeline-model-parallel-size {pp_size} @@ -158,7 +166,7 @@ def get_variation_config(self, variation, output_dir): --micro-batch-size 1 --global-batch-size 16 --train-samples {n_samples*2} - --train-tokens {n_samples} + --train-tokens {train_tokens} --optimizer adam --adam-beta1 0.9 From 8ca1db7f48c1a42746661e0a5e35c5ec9aa2d198 Mon Sep 17 00:00:00 2001 From: Conglong Li Date: Sat, 9 Oct 2021 15:55:58 -0700 Subject: [PATCH 16/19] fix truncation --- pretrain_gpt.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 7c0bb13bd..443910192 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -156,10 +156,11 @@ def get_batch_pipe(data): position_ids = position_ids[:, :args.curriculum_seqlen].contiguous() labels = labels[:, :args.curriculum_seqlen].contiguous() loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous() - + actual_seqlen = tokens.size()[1] + if actual_seqlen != args.attn_mask.size()[2]: # attention_mask has size [1, 1, seqlen, seqlen] - attention_mask = attention_mask[:, :, :args.curriculum_seqlen, :args.curriculum_seqlen].contiguous() - args.attn_mask = args.attn_mask_original[:, :, :args.curriculum_seqlen, :args.curriculum_seqlen].contiguous() + attention_mask = attention_mask[:, :, :actual_seqlen, :actual_seqlen].contiguous() + args.attn_mask = args.attn_mask_original[:, :, :actual_seqlen, :actual_seqlen].contiguous() return (tokens, position_ids, attention_mask), (labels, loss_mask) From d7301a1bac9aed0e04497272d4ef8667ef6d6d0c Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 16:33:49 -0700 Subject: [PATCH 17/19] switch to deepspeed@master --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6dc1385fb..5e9bd3c16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,6 @@ six tensorboard torch transformers -DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git@big-science +DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git # edit to a higher SHA or future release if needed codecarbon @ git+https://github.com/mlco2/codecarbon.git@e6c3863 From dbf8abdd80f9cd6b8f6d295c877f013e392b65d6 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 16:53:50 -0700 Subject: [PATCH 18/19] extend the doc --- examples/curriculum_learning/README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/curriculum_learning/README.md b/examples/curriculum_learning/README.md index 3ff6033a5..2bf96139a 100644 --- a/examples/curriculum_learning/README.md +++ b/examples/curriculum_learning/README.md @@ -17,10 +17,12 @@ Because of the above changes, we also add token-based tensorboard scalars. We al # Curriculum learning hyperparameters tuning strategy -The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are three config entries that you need to change, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. +The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are a few config entries that you may need to adjust to your circumstances, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. -First, the `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. +1. `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. -Second, the `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. +2. `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. -Third, the `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains. +3. `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains. + +4. `difficulty_step` is the change in seq length per CL step. A smaller value is preferable since it gives more smooth CL and better stability. Like `min_difficulty` it too needs to be multiple of 8 for Tensor core acceleration, thus 8 is a good default. From b7fd67edd6eb610c28e6af2b882f7837a14104d0 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sat, 9 Oct 2021 19:43:29 -0700 Subject: [PATCH 19/19] Trigger CI