Skip to content

Commit

Permalink
Pipeline paralleism in Bert (NVIDIA#5293)
Browse files Browse the repository at this point in the history
* Global batch size support for validation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Global batch size support for bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bert batch support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bert batch size support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* O2 support for bert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_pretraining.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Update megatron_bert_model.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_config.yaml

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_model.py

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* Bug fix

* Bug fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Bug fix

* Bug fix

* Bug fix

* Update megatron_bert_config.yaml

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* PPBert

* PPBert

* PPBert

* PPBert

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update megatron_bert_config.yaml

Signed-off-by: Shanmugam Ramasamy <[email protected]>

* bug fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* bug fix

* bug fix

* bug fix

* bug fix

Signed-off-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
3 people authored and Hainan Xu committed Nov 29, 2022
1 parent 23241c7 commit 4251963
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 50 deletions.
70 changes: 70 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -2897,6 +2897,76 @@ pipeline {
// }
// }
// }
stage('L2: Megatron Bert Pretraining and Resume Training with Pipeline Paralleism') {
when {
anyOf {
branch 'main'
changeRequest target: 'main'
}
}
failFast true
steps {
sh "python examples/nlp/language_modeling/megatron_bert_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=10 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \
model.pipeline_model_parallel_size=2 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \
model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings"
sh "python examples/nlp/language_modeling/megatron_bert_pretraining.py \
trainer.devices=2 \
trainer.accelerator=gpu \
trainer.log_every_n_steps=1 \
trainer.val_check_interval=10 \
trainer.limit_val_batches=2 \
trainer.accumulate_grad_batches=1 \
trainer.max_steps=20 \
trainer.precision=16 \
trainer.gradient_clip_val=1.0 \
exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \
exp_manager.resume_if_exists=True \
model.pipeline_model_parallel_size=2 \
model.optim.name=fused_adam \
model.optim.lr=2e-4 \
model.optim.sched.warmup_steps=2 \
model.optim.sched.constant_steps=2 \
model.optim.sched.min_lr=8e-5 \
model.max_position_embeddings=128 \
model.encoder_seq_length=128 \
model.data.seq_length=128 \
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \
model.num_layers=8 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \
model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings"
sh "rm -rf examples/nlp/language_modeling/bert_pretrain_results"
sh "rm -rf examples/nlp/language_modeling/bert_index_mappings"
}
}
stage('L2: Megatron Bert Pretraining and Resume Training') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,13 @@ def set_input_tensor(self, input_tensor):
def forward(self, bert_model_input, attention_mask, token_type_ids=None, lm_labels=None):

extended_attention_mask = bert_extended_attention_mask(attention_mask)
input_ids = bert_model_input
position_ids = build_position_ids(input_ids)

if parallel_state.is_pipeline_first_stage():
input_ids = bert_model_input
position_ids = build_position_ids(input_ids)
else:
position_ids = None
input_ids = None

lm_output = self.language_model(
input_ids, position_ids, extended_attention_mask, token_type_ids=token_type_ids
Expand Down
Loading

0 comments on commit 4251963

Please sign in to comment.