Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed checkpointing with mcore GPT #7116

Merged
merged 121 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
29d2e74
start adding gpt from megatron core path
ericharper Jun 7, 2023
f7671dd
set model parallel config
ericharper Jun 9, 2023
910ce35
pull main
ericharper Jun 19, 2023
da87792
use model parallel config object
ericharper Jun 19, 2023
59a008b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2023
7d7a4c3
update args
ericharper Jun 22, 2023
2f1bced
pull
ericharper Jun 22, 2023
066a0df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 22, 2023
a114cb8
set vp size to none if it is 1
ericharper Jun 23, 2023
7992ddd
set vp size to none if it is 1
ericharper Jun 23, 2023
c693224
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jun 23, 2023
f8960e2
resolve conflict
ericharper Jun 23, 2023
5b3e877
Merge branch 'main' into mcore_gpt_path
ericharper Jun 23, 2023
71c9d04
add TransformerConfig
ericharper Jun 28, 2023
2e03e5e
Merge branch 'main' of github.com:NVIDIA/NeMo into mcore_gpt_path
ericharper Jun 28, 2023
b1211df
start updating to TransformerConfig
ericharper Jun 28, 2023
281d115
add todo
ericharper Jun 29, 2023
4c2768d
revert to model parallel config
ericharper Jun 30, 2023
a65a2ca
add hidden_size to model_parallel_config
ericharper Jun 30, 2023
73e76d2
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jun 30, 2023
09ab5b0
resolve conflicts
ericharper Jun 30, 2023
4a557ec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
efc708e
remove imports
ericharper Jun 30, 2023
9c71633
revert
ericharper Jun 30, 2023
3e47d88
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jun 30, 2023
de4c6f6
remove import
ericharper Jun 30, 2023
4b281e4
small clean up
ericharper Jun 30, 2023
5309292
update hidden size in peft base model, add mcore commit to jenkins
ericharper Jun 30, 2023
0129f82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 30, 2023
d9c4a35
update module args
ericharper Jul 1, 2023
873e7df
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 1, 2023
f52cfa3
Merge branch 'main' into mcore_gpt_path
ericharper Jul 1, 2023
e0c8684
add config obj to flash attention tests
ericharper Jul 5, 2023
9cd1aea
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 5, 2023
792f08c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
c999d90
remove args
ericharper Jul 5, 2023
b6d2f78
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 5, 2023
d38e812
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2023
63bc4f8
remove sequence parallel arg
ericharper Jul 5, 2023
387bb30
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 5, 2023
01894de
Merge branch 'main' into mcore_gpt_path
ericharper Jul 5, 2023
b44c08a
update args
ericharper Jul 7, 2023
19f14ae
add config to self
ericharper Jul 7, 2023
525860e
update args
ericharper Jul 7, 2023
25418c0
update args
ericharper Jul 7, 2023
a28cb60
Merge branch 'main' into mcore_gpt_path
ericharper Jul 7, 2023
2c04f9c
update args
ericharper Jul 7, 2023
8a4fb28
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 7, 2023
8465932
add config to test
ericharper Jul 10, 2023
574fbbe
Merge branch 'main' into mcore_gpt_path
ericharper Jul 10, 2023
f7c2130
get hidden_size from config
ericharper Jul 10, 2023
06b2cbc
add try except
ericharper Jul 10, 2023
0e303e6
use default
ericharper Jul 10, 2023
162addd
Merge branch 'main' into mcore_gpt_path
ericharper Jul 10, 2023
128d44c
update config with hidden size
ericharper Jul 11, 2023
d6aee26
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 11, 2023
7d106d6
Merge branch 'main' into mcore_gpt_path
ericharper Jul 11, 2023
54b2b3b
remove arg
ericharper Jul 11, 2023
93c9f3e
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 11, 2023
a655dbd
resolve conflict
ericharper Jul 11, 2023
9b751bc
resolve conflict
ericharper Jul 12, 2023
688f205
Merge branch 'main' into mcore_gpt_path
ericharper Jul 12, 2023
f6d3a9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2023
158359b
comment out jenkins test
ericharper Jul 13, 2023
60728f9
Merge branch 'mcore_gpt_path' of github.com:NVIDIA/NeMo into mcore_gp…
ericharper Jul 13, 2023
60b0309
revert import
ericharper Jul 13, 2023
71694a7
Merge branch 'main' into mcore_gpt_path
ericharper Jul 13, 2023
3d974be
Merge branch 'main' into mcore_gpt_path
ericharper Jul 25, 2023
073bc9c
build transformer config
ericharper Jul 20, 2023
a535e4e
add model to provider func
ericharper Jul 20, 2023
c883e9d
update forward and float16 wrapper
ericharper Jul 21, 2023
120eda3
instantiate model parallel config after init model parallel
ericharper Jul 21, 2023
6fc47b7
set virtual rank
ericharper Jul 25, 2023
2b6cbe7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 21, 2023
2320d50
Add GQA config to megatron gpt model (#7096)
blahBlahhhJ Jul 21, 2023
ab454e6
revert
ericharper Jul 25, 2023
2268ecd
update strategy and exp_manager
ericharper Jul 26, 2023
6ff9bfb
update model checkpoint
ericharper Jul 26, 2023
1a835c3
update megatron gpt model
ericharper Jul 26, 2023
9e2d2a1
correct var
ericharper Jul 27, 2023
b2e4848
check for mcore gpt and use gpt model list
ericharper Jul 27, 2023
4ff7c11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 27, 2023
d0ef515
remove model prefix
ericharper Jul 27, 2023
f4733c5
Merge branch 'mcore_gpt_dist_ckpt' of github.com:NVIDIA/NeMo into mco…
ericharper Jul 27, 2023
5331a32
setup te tp groups
ericharper Jul 28, 2023
7125c4e
pull main
ericharper Aug 17, 2023
fb350db
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2023
2ed7ecd
revert
ericharper Aug 17, 2023
e6b0a9a
revert
ericharper Aug 17, 2023
58ba28a
revert
ericharper Aug 17, 2023
6375eb2
add default
ericharper Aug 17, 2023
deb8a00
add default
ericharper Aug 17, 2023
c42e88b
revert
ericharper Aug 17, 2023
54dd027
update sharded state dict for interleaved
ericharper Aug 17, 2023
c4ad3bf
update load for interleaved
ericharper Aug 17, 2023
66c4a69
check sharded state dict is nonempty
ericharper Aug 21, 2023
7895739
remove import
ericharper Aug 21, 2023
cc3f461
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 21, 2023
44048ee
revert comment
ericharper Aug 21, 2023
71641a1
Merge branch 'mcore_gpt_dist_ckpt' of github.com:NVIDIA/NeMo into mco…
ericharper Aug 21, 2023
cc33389
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 21, 2023
b5a38bf
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 21, 2023
0338d46
inject before checking legacy ckpt
ericharper Aug 21, 2023
8fa76f2
revert
ericharper Aug 21, 2023
41584e5
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 21, 2023
af5b6eb
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 22, 2023
7fa2acc
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 22, 2023
7538816
pop arg for now
ericharper Aug 22, 2023
4353d68
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2023
1aa58bc
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 23, 2023
15bc476
revert jenkins change
ericharper Aug 23, 2023
3664f08
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 24, 2023
6ee93ec
remove device state_dict
ericharper Aug 24, 2023
a4e9a2f
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 24, 2023
aaf6214
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 24, 2023
c939257
reduce batch size for max steps
ericharper Aug 24, 2023
4487519
update megatron core commit
ericharper Aug 24, 2023
706c834
Integrate dist ckpt with new DistOpt state dict v2 (#7281)
mikolajblaz Aug 25, 2023
1771316
Merge branch 'main' into mcore_gpt_dist_ckpt
ericharper Aug 25, 2023
f4a5e01
update apex commit
ericharper Aug 25, 2023
497e934
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,17 @@ RUN apt-get update && \
WORKDIR /workspace/

WORKDIR /tmp/
# TODO: Remove once this Apex commit (5/12/23) is included in PyTorch
# container

# DP independent checkpoint format for distributed adam
RUN git clone https://github.com/NVIDIA/apex.git && \
cd apex && \
git checkout 8b7a1ff183741dd8f9b87e7bafd04cfde99cea28 && \
git checkout 7995de18677295c5edeeab082179edbfdb6ee16a && \
pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

# install megatron core, this can be removed once 0.3 pip package is released
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 && \
git checkout 84f64880b3651c4f7cf90da337ee4e7d9968acab && \
pip install -e .

# uninstall stuff from base container
Expand Down
6 changes: 3 additions & 3 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pipeline {
// commit has api fix for TE
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 && \
git checkout 84f64880b3651c4f7cf90da337ee4e7d9968acab && \
pip install -e .'
}
}
Expand Down Expand Up @@ -3745,11 +3745,11 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.data.train_ds.concat_sampling_probabilities=[0.3,0.7] \
model.data.train_ds.num_workers=0 \
model.data.test_ds.micro_batch_size=1 \
model.data.test_ds.global_batch_size=4 \
model.data.test_ds.global_batch_size=1 \
aklife97 marked this conversation as resolved.
Show resolved Hide resolved
model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.test_ds.names=[quarel] \
model.data.validation_ds.micro_batch_size=1 \
model.data.validation_ds.global_batch_size=4 \
model.data.validation_ds.global_batch_size=1 \
model.data.validation_ds.num_workers=0 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \
model.data.validation_ds.names=[quarel]"
Expand Down
6 changes: 4 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ To install Apex, run

git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout 57057e2fcf1c084c0fcc818f55c0ff6ea1b24ae2
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./
git checkout 7995de18677295c5edeeab082179edbfdb6ee16a
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_layer_norm" --global-option="--distributed_adam" --global-option="--deprecated_fused_adam" ./

It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies.

Expand All @@ -267,6 +267,8 @@ packaging is also needed:

pip install packaging

With the latest versions of Apex, the `pyproject.toml` file in Apex may need to be deleted in order to install locally.


Transformer Engine
~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)
# mp.set_start_method("spawn", force=True)
aklife97 marked this conversation as resolved.
Show resolved Hide resolved


@hydra_runner(config_path="conf", config_name="megatron_gpt_config")
Expand Down
75 changes: 64 additions & 11 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
}
if not self.mcore_gpt:
forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers
else:
# TODO: @eharper can we add this to mcore?
forward_args.pop('loss_mask')
output_tensor = model(**forward_args)

def loss_func(output_tensor):
Expand Down Expand Up @@ -1243,7 +1246,6 @@ def setup_transformer_engine_tp_groups(self):
""" This should be called after model parallel groups have been initialized
and only needs to be called when using Transformer Engine.
"""

for module in self.get_gpt_module_list():
"""Set TP group
Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398
Expand All @@ -1260,21 +1262,72 @@ def on_save_checkpoint(self, checkpoint) -> None:
"""LightningModule hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-save-checkpoint
"""
if isinstance(self.model, list):
for i in range(len(self.model)):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
checkpoint[f'model{i}'] = self.model[i].module.state_dict_for_save_checkpoint()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

# mcore uses distributed checkpointing
if self.mcore_gpt:
checkpoint['sharded_state_dict'] = self.sharded_state_dict()

# legacy checkpointing for interleaved
else:
if isinstance(self.model, list):
for i in range(len(self.model)):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
checkpoint[f'model{i}'] = self.model[i].module.state_dict_for_save_checkpoint()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def on_load_checkpoint(self, checkpoint) -> None:
"""LightningModule hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-load-checkpoint
"""
if isinstance(self.model, list):
for i in range(len(self.model)):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

# mcore uses distributed checkpointing
if self.mcore_gpt:
for index, module in enumerate(self.get_gpt_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}']
else:
checkpoint_state_dict = checkpoint['state_dict']
# checkpoint_state_dict has "model." but module does not so we need to remove it when loading
checkpoint_state_dict = {
key.replace('model.', ''): checkpoint_state_dict.pop(key)
for key in list(checkpoint_state_dict.keys())
}
module.load_state_dict(checkpoint_state_dict, strict=True)

# legacy checkpointing for interleaved
else:
if isinstance(self.model, list):
for i in range(len(self.model)):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
self.model[i].module.load_state_dict(checkpoint[f'model{i}'], strict=True)
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:
"""
Creates the sharded state dict which is used by dist_checkpoint to save the sharded tensors to disk.
When given the sharded_stated_dict, dist_checkpoint.load will load the tensors corresponding to
Dismissed Show dismissed Hide dismissed
self.state_dict().
The sharded tensor mapping is defined in the GPTModel class from mcore.
"""

if self.mcore_gpt:
module_prefix = f'{prefix}model.'
sharded_state_dict = {}
for index, module in enumerate(self.get_gpt_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
# virtual pipline rank must be set so that GPTModel returns the correct sharded state dict
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix)
sharded_state_dict[f'model_{index}'] = module_sharded_state_dict
else:
module_sharded_state_dict = module.sharded_state_dict(prefix=module_prefix)
sharded_state_dict.update(module_sharded_state_dict)

# reset vp rank
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

return sharded_state_dict

def parameters(self):
if isinstance(self.model, list):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def init_model(self, cfg: DictConfig, trainer: Trainer):
frozen_model_cfg.activations_checkpoint_method = self.cfg.get("activations_checkpoint_method", None)

if self.trainer.precision in ['bf16', 'bf16-mixed']:
# set hidden size in the model parallel config for pipeline parallel schedules
self.autocast_dtype = torch.bfloat16
elif self.trainer.precision in [32, '32', '32-true']:
self.autocast_dtype = torch.float
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,10 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

self.model.config.no_sync_func = no_sync_func
self.model.config.grad_sync_func = grad_sync_func
self.model.config.param_sync_func = param_sync_func

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
Expand Down
Loading
Loading