-
Notifications
You must be signed in to change notification settings - Fork 228
Curriculum learning support #132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
99d2b37
CL initial commit
conglongli 4c9c4a3
CL+PP support
conglongli 82a3198
update
conglongli 21e91b9
Apply suggestions from code review
conglongli 6010a3d
apply code review comments
conglongli 405c7a6
make it easier to read large numbers
stas00 a90d30e
add a cl test
stas00 fb04d2b
apply review comments
conglongli 8e4a466
Update examples/curriculum_learning/README.md
conglongli 3ed7075
Merge branch 'main' of https://github.com/conglongli/Megatron-DeepSpe…
stas00 d86a4f4
update
stas00 e5a335d
fix
stas00 0c4073b
new requirement
stas00 d25fa9e
Update megatron/learning_rates.py
conglongli 7cd53dc
Update megatron/learning_rates.py
conglongli 5a492b3
fix samples and tokens - thank you Conglong
stas00 8ca1db7
fix truncation
conglongli d7301a1
switch to deepspeed@master
stas00 dbf8abd
extend the doc
stas00 b7fd67e
Trigger CI
stas00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084). | ||
|
|
||
| # Disable batch size warmup (--rampup-batch-size) | ||
| In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now. | ||
|
|
||
| # Token-based training termination | ||
|
|
||
| Because CL changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 2X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens (e.g., 300B for GPT-3 like training). | ||
|
|
||
| # Token-based LR decay | ||
|
|
||
| Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). Then you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful. | ||
|
|
||
| # Token-based tensorboard | ||
|
|
||
| Because of the above changes, we also add token-based tensorboard scalars. We also add scalars that plot the seqlen at each step. | ||
|
|
||
| # Curriculum learning hyperparameters tuning strategy | ||
|
|
||
| The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are a few config entries that you may need to adjust to your circumstances, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description. | ||
|
|
||
| 1. `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this. | ||
|
|
||
| 2. `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8. | ||
|
|
||
| 3. `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains. | ||
|
|
||
| 4. `difficulty_step` is the change in seq length per CL step. A smaller value is preferable since it gives more smooth CL and better stability. Like `min_difficulty` it too needs to be multiple of 8 for Tensor core acceleration, thus 8 is a good default. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| { | ||
| "train_batch_size": 512, | ||
| "gradient_accumulation_steps": 1, | ||
| "steps_per_print": 1, | ||
| "zero_optimization": { | ||
| "stage": 0 | ||
| }, | ||
| "optimizer": { | ||
| "type": "Adam", | ||
| "params": { | ||
| "lr": 0.00015, | ||
| "max_grad_norm": 1.0, | ||
| "betas": [0.9, 0.95] | ||
| } | ||
| }, | ||
| "gradient_clipping": 1.0, | ||
| "fp16": { | ||
| "enabled": true, | ||
| "loss_scale": 0, | ||
| "loss_scale_window": 1000, | ||
| "hysteresis": 2, | ||
| "min_loss_scale": 1 | ||
| }, | ||
| "wall_clock_breakdown": false, | ||
| "zero_allow_untested_optimizer": false, | ||
| "curriculum_learning": { | ||
| "enabled": true, | ||
| "curriculum_type": "seqlen", | ||
| "min_difficulty": 8, | ||
| "max_difficulty": 1024, | ||
| "schedule_type": "fixed_linear", | ||
| "schedule_config": { | ||
| "total_curriculum_step": 60000, | ||
| "difficulty_step": 8 | ||
| } | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| #!/bin/bash | ||
|
|
||
| # This is a dummy train script to show how to use curriculum | ||
| # learning, some parameters are not for actual GPT pretraining. | ||
|
|
||
| TARGET_GLOBAL_BATCH_SIZE=512 | ||
| TRAIN_SAMPLES=146_484_375 | ||
| LR=1.0e-4 | ||
| MIN_LR=1.0e-5 | ||
| LR_DECAY_SAMPLES=126_953_125 | ||
| LR_WARMUP_SAMPLES=183_105 | ||
| SEQLEN=1024 | ||
|
|
||
| ############################################################ | ||
| # New configs for curriculum learning, see README.md | ||
| TRAIN_TOKENS=10_000_000_000 | ||
| LR_DECAY_TOKENS=$(($LR_DECAY_SAMPLES*$SEQLEN)) | ||
| ############################################################ | ||
|
|
||
| LOG_INTERVAL=100 | ||
| EVAL_ITERS=10 | ||
| EVAL_INTERVAL=100 | ||
| SAVE_INTERVAL=1000 | ||
|
|
||
| VOCAB_PATH=/data/Megatron-LM/data/gpt2-vocab.json | ||
| MERGE_PATH=/data/Megatron-LM/data/gpt2-merges.txt | ||
| DATA_PATH=/data/Megatron-LM/data/indexed_datasets/megatron | ||
|
|
||
| MICRO_BATCH_SIZE=1 | ||
| MP_SIZE=1 | ||
| PP_SIZE=1 | ||
|
|
||
| NUM_GPUS=128 | ||
| echo ${NUM_GPUS} | ||
| if [[ $PP_SIZE -gt 0 ]]; then | ||
| DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) )) | ||
| else | ||
| DP_SIZE=$(( ${NUM_GPUS} / ${MP_SIZE} )) | ||
| fi | ||
| GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${MICRO_BATCH_SIZE} * ${DP_SIZE}) )) | ||
|
|
||
| NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MICRO_BATCH_SIZE}-cl" | ||
| current_time=$(date "+%Y.%m.%d-%H.%M.%S") | ||
| host="${HOSTNAME}" | ||
| TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}" | ||
| mkdir -p ${TENSORBOARD_DIR} | ||
| CHECKPOINT_PATH="checkpoints/${NAME}" | ||
|
|
||
| megatron_options=" \ | ||
| --data-path ${DATA_PATH} \ | ||
| --vocab-file ${VOCAB_PATH} \ | ||
| --merge-file ${MERGE_PATH} \ | ||
| --data-impl mmap \ | ||
| --override-lr-scheduler \ | ||
| --adam-beta1 0.9 \ | ||
| --adam-beta2 0.95 \ | ||
| --tensor-model-parallel-size ${MP_SIZE} \ | ||
| --init-method-std 0.014 \ | ||
| --lr-decay-tokens ${LR_DECAY_TOKENS} \ | ||
| --lr-warmup-samples ${LR_WARMUP_SAMPLES} \ | ||
| --micro-batch-size ${MICRO_BATCH_SIZE} \ | ||
| --global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \ | ||
| --num-layers 12 \ | ||
| --hidden-size 768 \ | ||
| --num-attention-heads 16 \ | ||
| --seq-length ${SEQLEN} \ | ||
| --max-position-embeddings ${SEQLEN} \ | ||
| --train-samples ${TRAIN_SAMPLES} \ | ||
| --train-tokens ${TRAIN_TOKENS} \ | ||
| --lr ${LR} \ | ||
| --min-lr ${MIN_LR} \ | ||
| --lr-decay-style cosine \ | ||
| --split 98,2,0 \ | ||
| --log-interval ${LOG_INTERVAL} \ | ||
| --eval-interval ${EVAL_INTERVAL} \ | ||
| --eval-iters ${EVAL_ITERS} \ | ||
| --save-interval ${SAVE_INTERVAL} \ | ||
| --weight-decay 0.1 \ | ||
| --clip-grad 1.0 \ | ||
| --hysteresis 2 \ | ||
| --num-workers 0 \ | ||
| --checkpoint-activations \ | ||
| --fp16 \ | ||
| --load ${CHECKPOINT_PATH} \ | ||
| --save ${CHECKPOINT_PATH} \ | ||
| --tensorboard-queue-size 1 \ | ||
| --log-timers-to-tensorboard \ | ||
| --log-batch-size-to-tensorboard \ | ||
| --log-validation-ppl-to-tensorboard \ | ||
| --tensorboard-dir ${TENSORBOARD_DIR}" | ||
|
|
||
| config_json="ds_config_cl.json" | ||
|
|
||
| deepspeed_options=" \ | ||
| --deepspeed \ | ||
| --deepspeed_config ${config_json} \ | ||
| --pipeline-model-parallel-size ${PP_SIZE} \ | ||
| --partition-activations" | ||
|
|
||
| run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log" | ||
| echo ${run_cmd} | ||
| eval ${run_cmd} | ||
| set +x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.