Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed checkpointing with mcore GPT #7116

Merged
merged 121 commits into from
Aug 28, 2023
Merged

Distributed checkpointing with mcore GPT #7116

merged 121 commits into from
Aug 28, 2023

Conversation

ericharper
Copy link
Collaborator

@ericharper ericharper commented Jul 27, 2023

This PR needs mcore dist ckpt for GPT PR to be pushed before merging.

What does this PR do ?

Adds distributed checkpointing when using mcore gpt.

Distributed checkpointing enables training runs to restart automatically with different model parallel configs.
The checkpoint is saved to disk according to the sharded_state_dict:

Below is a sample of what the checkpoint looks like on disk.

common.pt                                                     model.decoder.layers.self_attention.linear_qkv.weight                           optimizer.state.exp_avg.model.embedding.word_embeddings.weight                     optimizer.state.fp32_from_fp16.model.decoder.final_layernorm.bias
metadata.json                                                 model.embedding.position_embeddings.weight                                      optimizer.state.exp_avg.model.output_layer.weight                                  optimizer.state.fp32_from_fp16.model.decoder.final_layernorm.weight
model.decoder.final_layernorm.bias                            model.embedding.word_embeddings.weight                                          optimizer.state.exp_avg_sq.model.decoder.final_layernorm.bias                      optimizer.state.fp32_from_fp16.model.decoder.layers.input_layernorm.bias
model.decoder.final_layernorm.weight                          model.output_layer.weight                                                       optimizer.state.exp_avg_sq.model.decoder.final_layernorm.weight                    optimizer.state.fp32_from_fp16.model.decoder.layers.input_layernorm.weight
model.decoder.layers.input_layernorm.bias                     optimizer.state.exp_avg.model.decoder.final_layernorm.bias                      optimizer.state.exp_avg_sq.model.decoder.layers.input_layernorm.bias               optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc1.bias
model.decoder.layers.input_layernorm.weight                   optimizer.state.exp_avg.model.decoder.final_layernorm.weight                    optimizer.state.exp_avg_sq.model.decoder.layers.input_layernorm.weight             optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc1.weight
model.decoder.layers.mlp.linear_fc1.bias                      optimizer.state.exp_avg.model.decoder.layers.input_layernorm.bias               optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc1.bias                optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc2.bias
model.decoder.layers.mlp.linear_fc1._extra_state              optimizer.state.exp_avg.model.decoder.layers.input_layernorm.weight             optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc1.weight              optimizer.state.fp32_from_fp16.model.decoder.layers.mlp.linear_fc2.weight
model.decoder.layers.mlp.linear_fc1.weight                    optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc1.bias                optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc2.bias                optimizer.state.fp32_from_fp16.model.decoder.layers.post_self_attn_layernorm.bias
model.decoder.layers.mlp.linear_fc2.bias                      optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc1.weight              optimizer.state.exp_avg_sq.model.decoder.layers.mlp.linear_fc2.weight              optimizer.state.fp32_from_fp16.model.decoder.layers.post_self_attn_layernorm.weight
model.decoder.layers.mlp.linear_fc2._extra_state              optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc2.bias                optimizer.state.exp_avg_sq.model.decoder.layers.post_self_attn_layernorm.bias      optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_proj.bias
model.decoder.layers.mlp.linear_fc2.weight                    optimizer.state.exp_avg.model.decoder.layers.mlp.linear_fc2.weight              optimizer.state.exp_avg_sq.model.decoder.layers.post_self_attn_layernorm.weight    optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_proj.weight
model.decoder.layers.post_self_attn_layernorm.bias            optimizer.state.exp_avg.model.decoder.layers.post_self_attn_layernorm.bias      optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_proj.bias    optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_qkv.bias
model.decoder.layers.post_self_attn_layernorm.weight          optimizer.state.exp_avg.model.decoder.layers.post_self_attn_layernorm.weight    optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_proj.weight  optimizer.state.fp32_from_fp16.model.decoder.layers.self_attention.linear_qkv.weight
model.decoder.layers.self_attention.linear_proj.bias          optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_proj.bias    optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_qkv.bias     optimizer.state.fp32_from_fp16.model.embedding.position_embeddings.weight
model.decoder.layers.self_attention.linear_proj._extra_state  optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_proj.weight  optimizer.state.exp_avg_sq.model.decoder.layers.self_attention.linear_qkv.weight   optimizer.state.fp32_from_fp16.model.embedding.word_embeddings.weight
model.decoder.layers.self_attention.linear_proj.weight        optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_qkv.bias     optimizer.state.exp_avg_sq.model.embedding.position_embeddings.weight              optimizer.state.fp32_from_fp16.model.output_layer.weight
model.decoder.layers.self_attention.linear_qkv.bias           optimizer.state.exp_avg.model.decoder.layers.self_attention.linear_qkv.weight   optimizer.state.exp_avg_sq.model.embedding.word_embeddings.weight
model.decoder.layers.self_attention.linear_qkv._extra_state   optimizer.state.exp_avg.model.embedding.position_embeddings.weight              optimizer.state.exp_avg_sq.model.output_layer.weight

Then inside a module directory we have the sharded tensor:

ls model.decoder.layers.mlp.linear_fc1.weight/
0.0.0  1.0.0  10.0.0  11.0.0  12.0.0  13.0.0  14.0.0  15.0.0  2.0.0  3.0.0  4.0.0  5.0.0  6.0.0  7.0.0  8.0.0  9.0.0

To implement distributed checkpointing for a model, the sharded_state_dict has to be defined.
This is done in megatron core so that in NeMo, if the module is from mcore, we only have to call module.sharded_state_dict().

Collection: NLP

Usage

Usage is automatic when using mcore:

model.mcore_gpt=True

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

ericharper and others added 30 commits June 7, 2023 12:00
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
Signed-off-by: ericharper <[email protected]>
@github-actions github-actions bot removed the CI label Aug 23, 2023
@github-actions github-actions bot added the CI label Aug 24, 2023
@ericharper ericharper marked this pull request as ready for review August 25, 2023 00:49
mikolajblaz and others added 3 commits August 25, 2023 17:56
* Integrate  new DistOpt state dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change optimizer fp32_param key

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: eharper <[email protected]>
Jenkinsfile Show resolved Hide resolved
Comment on lines +28 to +32
from megatron.core.dist_checkpointing.optimizer import (
get_param_id_to_sharded_param_map,
make_sharded_optimizer_tensor,
optim_state_to_sharding_state,
)

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'make_sharded_optimizer_tensor' is not used.
Copy link
Collaborator

@aklife97 aklife97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@ericharper ericharper merged commit d6357fd into main Aug 28, 2023
15 checks passed
@ericharper ericharper deleted the mcore_gpt_dist_ckpt branch August 28, 2023 21:52
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* start adding gpt from megatron core path

Signed-off-by: ericharper <[email protected]>

* set model parallel config

Signed-off-by: ericharper <[email protected]>

* use model parallel config object

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update args

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* set vp size to none if it is 1

Signed-off-by: ericharper <[email protected]>

* set vp size to none if it is 1

Signed-off-by: ericharper <[email protected]>

* add TransformerConfig

Signed-off-by: ericharper <[email protected]>

* start updating to TransformerConfig

Signed-off-by: ericharper <[email protected]>

* add todo

Signed-off-by: ericharper <[email protected]>

* revert to model parallel config

Signed-off-by: ericharper <[email protected]>

* add hidden_size to model_parallel_config

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove imports

Signed-off-by: ericharper <[email protected]>

* revert

Signed-off-by: ericharper <[email protected]>

* remove import

Signed-off-by: ericharper <[email protected]>

* small clean up

Signed-off-by: ericharper <[email protected]>

* update hidden size in peft base model, add mcore commit to jenkins

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update module args

Signed-off-by: ericharper <[email protected]>

* add config obj to flash attention tests

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove args

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove sequence parallel arg

Signed-off-by: ericharper <[email protected]>

* update args

Signed-off-by: ericharper <[email protected]>

* add config to self

Signed-off-by: ericharper <[email protected]>

* update args

Signed-off-by: ericharper <[email protected]>

* update args

Signed-off-by: ericharper <[email protected]>

* update args

Signed-off-by: ericharper <[email protected]>

* add config to test

Signed-off-by: ericharper <[email protected]>

* get hidden_size from config

Signed-off-by: ericharper <[email protected]>

* add try except

Signed-off-by: ericharper <[email protected]>

* use default

Signed-off-by: ericharper <[email protected]>

* update config with hidden size

Signed-off-by: ericharper <[email protected]>

* remove arg

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* comment out jenkins test

Signed-off-by: ericharper <[email protected]>

* revert import

Signed-off-by: ericharper <[email protected]>

* build transformer config

Signed-off-by: ericharper <[email protected]>

* add model to provider func

Signed-off-by: ericharper <[email protected]>

* update forward and float16 wrapper

Signed-off-by: ericharper <[email protected]>

* instantiate model parallel config after init model parallel

Signed-off-by: ericharper <[email protected]>

* set virtual rank

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add GQA config to megatron gpt model (NVIDIA#7096)

* Add GQA config in gpt config file

Signed-off-by: jasonwan <[email protected]>

* Verify mcore is enabled when using GQA

Signed-off-by: jasonwan <[email protected]>

---------

Signed-off-by: jasonwan <[email protected]>

* revert

Signed-off-by: ericharper <[email protected]>

* update strategy and exp_manager

Signed-off-by: ericharper <[email protected]>

* update model checkpoint

Signed-off-by: ericharper <[email protected]>

* update megatron gpt model

Signed-off-by: ericharper <[email protected]>

* correct var

Signed-off-by: ericharper <[email protected]>

* check for mcore gpt and use gpt model list

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove model prefix

Signed-off-by: ericharper <[email protected]>

* setup te tp groups

Signed-off-by: ericharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert

Signed-off-by: eharper <[email protected]>

* revert

Signed-off-by: eharper <[email protected]>

* add default

Signed-off-by: eharper <[email protected]>

* add default

Signed-off-by: eharper <[email protected]>

* revert

Signed-off-by: eharper <[email protected]>

* update sharded state dict for interleaved

Signed-off-by: eharper <[email protected]>

* update load for interleaved

Signed-off-by: eharper <[email protected]>

* check sharded state dict is nonempty

Signed-off-by: eharper <[email protected]>

* remove import

Signed-off-by: eharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert comment

Signed-off-by: eharper <[email protected]>

* inject before checking legacy ckpt

Signed-off-by: eharper <[email protected]>

* revert

Signed-off-by: eharper <[email protected]>

* pop arg for now

Signed-off-by: eharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert jenkins change

Signed-off-by: eharper <[email protected]>

* remove device state_dict

Signed-off-by: eharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reduce batch size for max steps

Signed-off-by: eharper <[email protected]>

* update megatron core commit

Signed-off-by: eharper <[email protected]>

* Integrate dist ckpt with new DistOpt state dict v2 (NVIDIA#7281)

* Integrate  new DistOpt state dict

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Change optimizer fp32_param key

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>

* update apex commit

Signed-off-by: eharper <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: ericharper <[email protected]>
Signed-off-by: jasonwan <[email protected]>
Signed-off-by: eharper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jason Wang <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CI core Changes to NeMo Core NLP
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants