Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Megatron GPT model finetuning #6210

Merged
merged 87 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
b63dcee
copy from sft_from_gpt
soares-f Dec 19, 2022
d05d632
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2022
0cb5907
Changed tokenization and example
soares-f Dec 29, 2022
e57114c
Merge branch 'GPT_SFT' of https://github.com/soares-f/NeMo into GPT_SFT
soares-f Dec 29, 2022
0785902
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2022
8f11a14
maybe remove (got from upstream)
soares-f Jan 2, 2023
b2dd38d
merge and commit
MaximumEntropy Jan 12, 2023
e8f1924
Eval metrics while finetuning
MaximumEntropy Jan 12, 2023
b49d37b
Add missing args
MaximumEntropy Jan 13, 2023
2de9931
Add arg
MaximumEntropy Jan 13, 2023
7636372
Fix
MaximumEntropy Jan 13, 2023
6b30660
Fix
MaximumEntropy Jan 13, 2023
2e9ab6c
Wrap in try except
MaximumEntropy Jan 13, 2023
7f5eba1
Try fix
MaximumEntropy Jan 13, 2023
4387574
Fix
MaximumEntropy Jan 13, 2023
8bdeff4
Add separate validation and test batch sizes
MaximumEntropy Jan 13, 2023
983f6e3
Fix
MaximumEntropy Jan 13, 2023
78ab97f
Fix
MaximumEntropy Jan 13, 2023
6e19953
Fix
MaximumEntropy Jan 13, 2023
63c81fe
Add assert
MaximumEntropy Jan 13, 2023
63d6489
Fix
MaximumEntropy Jan 13, 2023
ed45634
Fix checkpoint name
MaximumEntropy Jan 14, 2023
19c1a1c
Explict sampling args
MaximumEntropy Jan 15, 2023
7fa203f
Update t0 script
MaximumEntropy Jan 18, 2023
1258436
Add niv2 script
MaximumEntropy Jan 18, 2023
3651097
Change workers
MaximumEntropy Jan 18, 2023
406f773
Merge branch 'main' of github.com:NVIDIA/NeMo into sandeepsub/gpt_sft
MaximumEntropy Jan 19, 2023
102c2a3
Fix labels
MaximumEntropy Jan 19, 2023
54b9a77
Merge branch 'main' of github.com:NVIDIA/NeMo into sandeepsub/gpt_sft
MaximumEntropy Jan 19, 2023
50f7160
Ignore download
MaximumEntropy Jan 20, 2023
6ba0d1e
Minor fixes
MaximumEntropy Jan 21, 2023
36ac0b1
Add dist opt support
MaximumEntropy Jan 21, 2023
c1395e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2023
f2512f3
Merge branch 'main' into sandeepsub/gpt_sft
MaximumEntropy Jan 29, 2023
2ce6986
Minor
MaximumEntropy Jan 31, 2023
a0188eb
Merge and fix
MaximumEntropy Jan 31, 2023
90f5065
Merge branch 'main' of github.com:NVIDIA/NeMo into sandeepsub/gpt_sft
MaximumEntropy Feb 1, 2023
c8087e9
Allow skipping validation
MaximumEntropy Feb 22, 2023
d39cd0a
Fix tokenization and padding to max batch
MaximumEntropy Feb 22, 2023
5e38f29
Adds several configurable flags for Megatron GPT models (#5991)
MaximumEntropy Feb 18, 2023
30db4fa
Fast glu activations (#6058)
MaximumEntropy Feb 23, 2023
1bf4e77
Explicitly check for united embeddings when logging params (#6085)
MaximumEntropy Feb 26, 2023
d87dbea
Option for model extracted dir
MaximumEntropy Mar 5, 2023
b000071
Fix
MaximumEntropy Mar 5, 2023
0c15b0c
Fix
MaximumEntropy Mar 5, 2023
edfc740
Add index mapping dir
MaximumEntropy Mar 5, 2023
13a354c
Assistant prompt
MaximumEntropy Mar 7, 2023
696106e
Fix
MaximumEntropy Mar 7, 2023
7d08a62
Remove ipdb
MaximumEntropy Mar 7, 2023
ed5984c
Fix
MaximumEntropy Mar 7, 2023
971a683
Override dropout
MaximumEntropy Mar 7, 2023
2903d00
Fix and merge
MaximumEntropy Mar 9, 2023
e135870
Change sampler
MaximumEntropy Mar 10, 2023
d0252ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2023
55e1892
Roll back again
MaximumEntropy Mar 15, 2023
31079e0
Revert TTS
MaximumEntropy Mar 15, 2023
373e2f9
Reset TTS
MaximumEntropy Mar 15, 2023
d43a5b1
Revert further
MaximumEntropy Mar 15, 2023
e18a670
Merge branch 'main' of github.com:NVIDIA/NeMo into sandeepsub/gpt_sft…
MaximumEntropy Mar 15, 2023
2b5d09f
Revert more to main
MaximumEntropy Mar 16, 2023
15f61e0
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Mar 16, 2023
0bfda75
Fix Test DS
MaximumEntropy Mar 20, 2023
49033a6
Address PR comments
MaximumEntropy Mar 21, 2023
d0a3393
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 21, 2023
04938ad
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Mar 31, 2023
efa57f5
Add the option to provide a prompt template via fstrings
MaximumEntropy Apr 3, 2023
7224f67
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2023
9a971d7
Add CI test
MaximumEntropy Apr 4, 2023
63a187f
Merge branch 'sandeepsub/gpt_sft_stable_rebase_main' of github.com:NV…
MaximumEntropy Apr 4, 2023
3fdf1d4
fix ci test
MaximumEntropy Apr 4, 2023
86efeee
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 4, 2023
2f3efd2
Fix CI test
MaximumEntropy Apr 4, 2023
78513c6
Merge branch 'sandeepsub/gpt_sft_stable_rebase_main' of github.com:NV…
MaximumEntropy Apr 4, 2023
a06f6cc
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 4, 2023
dea00db
Minor
MaximumEntropy Apr 4, 2023
1d96011
Merge branch 'sandeepsub/gpt_sft_stable_rebase_main' of github.com:NV…
MaximumEntropy Apr 4, 2023
d6d9837
Fix CI
MaximumEntropy Apr 5, 2023
7749ede
Fix CI
MaximumEntropy Apr 5, 2023
7fd0c85
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 5, 2023
ad69891
Fix
MaximumEntropy Apr 5, 2023
5df8955
Merge branch 'sandeepsub/gpt_sft_stable_rebase_main' of github.com:NV…
MaximumEntropy Apr 5, 2023
791402f
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 5, 2023
6c003b0
Fix CI
MaximumEntropy Apr 6, 2023
d99e276
Fix workers issue
MaximumEntropy Apr 6, 2023
9951a19
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 6, 2023
6443e69
Fix workers
MaximumEntropy Apr 6, 2023
7062845
Merge branch 'main' into sandeepsub/gpt_sft_stable_rebase_main
MaximumEntropy Apr 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3298,6 +3298,74 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings"
}
}
stage('L2: Megatron GPT Finetuning PP=2') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
+trainer.limit_val_batches=2 \
trainer.max_steps=3 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \
model.pipeline_model_parallel_size=2 \
model.tensor_model_parallel_size=1 \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/PP2/gpt_pp2_tp1.nemo \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.data.train_ds.micro_batch_size=1 \
model.data.train_ds.global_batch_size=4 \
model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.train_ds.concat_sampling_probabilities=[0.3,0.7] \
model.data.train_ds.num_workers=0 \
model.data.test_ds.micro_batch_size=1 \
model.data.test_ds.global_batch_size=4 \
model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.test_ds.names=[quarel,trec] \
model.data.validation_ds.micro_batch_size=1 \
model.data.validation_ds.global_batch_size=4 \
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.validation_ds.names=[quarel,trec]"
sh "python examples/nlp/language_modeling/tuning/megatron_gpt_sft.py \
trainer.devices=2 \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=2 \
+trainer.limit_val_batches=2 \
trainer.max_steps=3 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \
model.pipeline_model_parallel_size=2 \
model.tensor_model_parallel_size=1 \
model.restore_from_path=/home/TestData/nlp/megatron_gpt/PP2/gpt_pp2_tp1.nemo \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.data.train_ds.micro_batch_size=1 \
model.data.train_ds.global_batch_size=4 \
model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.train_ds.concat_sampling_probabilities=[0.3,0.7] \
model.data.train_ds.num_workers=0 \
model.data.test_ds.micro_batch_size=1 \
model.data.test_ds.global_batch_size=4 \
model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.test_ds.names=[quarel,trec] \
model.data.validation_ds.micro_batch_size=1 \
model.data.validation_ds.global_batch_size=4 \
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl,/home/TestData/nlp/megatron_sft/trec.jsonl] \
model.data.validation_ds.names=[quarel,trec]"
sh "rm -rf examples/nlp/language_modeling/gpt_sft_results"
}
}
stage('L2: Megatron GPT Eval') {
when {
anyOf {
Expand Down
164 changes: 164 additions & 0 deletions examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
name: megatron_gpt_sft

trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: 9999
max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: 200 # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
gradient_clip_val: 1.0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: validation_${model.data.validation_ds.metric.name}
save_top_k: 2
mode: max
save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below,
filename: 'megatron_gpt_sft--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
model_parallel_size: ${model.tensor_model_parallel_size}
save_best_model: True

model:
seed: 1234
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
global_batch_size: 128
micro_batch_size: 4
restore_from_path: ??? # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
answer_only_loss: False # not used right now
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0

data:
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
# Example of how each dataset is formatted
# {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
file_names: ??? # Path to a list of JSONL files corresponding to the source data.
MaximumEntropy marked this conversation as resolved.
Show resolved Hide resolved
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 4
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
# Example of how to specify concat_sampling_probabilities
# concat_sampling_probabilities:
# - 0.5
# - 0.25
# - 0.25
concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random'
MaximumEntropy marked this conversation as resolved.
Show resolved Hide resolved
context_key: 'input'
label_key: 'output'
add_eos: True
add_sep: False
add_bos: False
separate_prompt_and_response_with_newline: False
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: null # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"

validation_ds:
file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 4
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
context_key: 'input'
label_key: 'output'
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
write_predictions_to_file: False
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"

metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

test_ds:
file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 4
pin_memory: True
max_seq_length: 2048
min_seq_length: 1
drop_last: True
context_key: 'input'
label_key: 'output'
add_eos: ${model.data.train_ds.add_eos}
add_sep: ${model.data.train_ds.add_sep}
add_bos: ${model.data.train_ds.add_bos}
separate_prompt_and_response_with_newline: ${model.data.train_ds.separate_prompt_and_response_with_newline}
write_predictions_to_file: False
output_file_path_prefix: null # Prefix of the file to write predictions to.
truncation_field: "context" # Options: ['context', 'answer']
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"

metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

optim:
name: fused_adam # Supports distributed optimizer for memory savings. To enable, set to 'distributed_fused_adam'. Needs Apex to be built with specific args to work.
lr: 3e-5
weight_decay: 0.01
betas:
- 0.9
- 0.98
Loading