Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tp pp conversion #6218

Merged
merged 44 commits into from
Mar 25, 2023
Merged

Conversation

titu1994
Copy link
Collaborator

@titu1994 titu1994 commented Mar 16, 2023

What does this PR do ?

Adds support for changing pipeline parallel version post construction for GPT

Collection: [Core, NLP]

Changelog

  • Add new script for pp conversion to avoid breaking old script (should probably be deprecated eventually when new script fully supports all functionality)
  • Add some modifications to allow partially loading model parallel models into memory to extract parameters
  • Add PP TP conversion support for both Megatron GPT and Megatron T5 (only when shared embeddings are used between encoder and decoder)
  • Add support for forcing CPU construction of T5 and GPT models
  • Add support for loading model based on AppState parameters instead of Trainer parameters.

Usage

Usage:

# Megatron GPT
python megatron_change_num_partitions.py \
    --model_file=PATH_TO_SRC_FILE \
    --target_file=PATH_TO_TGT_FILE \
    --tensor_model_parallel_size=1 \
    --target_tensor_model_parallel_size=1 \
    --pipeline_model_parallel_size=1 \
    --target_pipeline_model_parallel_size=1 \
    --precision=bf16

# Megatron T5
python megatron_change_num_partitions.py \
    --model_file=PATH_TO_SRC_FILE \
    --target_file=PATH_TO_TGT_FILE \
    --model_class="nemo.collections.nlp.models.language_modeling.megatron_t5_model.MegatronT5Model" \
    --tensor_model_parallel_size=1 \
    --target_tensor_model_parallel_size=1 \
    --pipeline_model_parallel_size=1 \
    --target_pipeline_model_parallel_size=1 \
    --target_pipeline_model_parallel_split_rank=0 \
    --precision=bf16

# NOTE: When converting large models, always ensure that you pre-extract the nemo model and then only perform conversion

$ mkdir "unpacked_nemo_file"
$ tar -xvf "<path to nemo file>" -C "<absolute path to pwd>/unpacked_nemo_file/"

python megatron_change_num_partitions.py \
    ...
    --model_extracted_dir="<Absolute path to pwd>/unpacked_nemo_file/"

# NOTE: Conversion of other model types. 
# Default model type is MegatronGPTModel, if you want another model you need to pass classpath of the model
# For example - MegatronT5Model - 

python megatron_change_num_partitions.py \
    ...
    --model_class="nemo.collections.nlp.models.language_modeling.megatron_t5_model.MegatronT5Model"

# Additional arguments:

--num_gpu_per_node: Number of GPUs per node. Default is 8.
--megatron_legacy: Whether the model is a legacy Megatron model or not. Default is False. May be unsuported for 
    Pipeline Parallelism change.
--tokenizer_model_path: Path to tokenizer model. Default is None. When not None, overrides the tokenizer model path
    in the model config.
--tokenizer_vocab_file: Path to tokenizer vocab file. Default is None. When not None, overrides the tokenizer vocab
    file in the model config.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

@github-actions github-actions bot added the NLP label Mar 16, 2023
@titu1994 titu1994 marked this pull request as draft March 16, 2023 04:16
@titu1994 titu1994 marked this pull request as ready for review March 17, 2023 07:07
Copy link
Collaborator

@yidong72 yidong72 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pretty cool PR. just have one minor comment.

ericharper
ericharper previously approved these changes Mar 24, 2023
Copy link
Collaborator

@ericharper ericharper left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Collaborator

@aklife97 aklife97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! just a couple of minor comments, everything else looks great!

@aklife97
Copy link
Collaborator

also, can we add a PP change CI as well? would be helpful to keep testing that since the PR brings in global overrides that may cause issues if someone changes it

Signed-off-by: smajumdar <[email protected]>
@titu1994
Copy link
Collaborator Author

titu1994 commented Mar 24, 2023

Good point about jenkins test - updated old one from tp reduce and increase to jointly increase pp by even or odd number,
Though this tests only GPT. We need nightly tests that test the whole matrix of TP (inc x dec) x PP (inc x dec) x {GPT, T5} - but that is super expensive wrt time and storage space on this CI. Will need to look into how to set that up.

aklife97
aklife97 previously approved these changes Mar 24, 2023
Copy link
Collaborator

@aklife97 aklife97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: smajumdar <[email protected]>
@aklife97 aklife97 merged commit aaa0cca into NVIDIA:main Mar 25, 2023
@wdykas wdykas mentioned this pull request Mar 29, 2023
6 tasks
@titu1994 titu1994 deleted the support_tp_pp_conversion branch March 31, 2023 22:05
hsiehjackson pushed a commit to hsiehjackson/NeMo that referenced this pull request Jun 2, 2023
* Add required flags to partially laod model

Signed-off-by: smajumdar <[email protected]>

* Add cleaned up script for tp pp change

Signed-off-by: smajumdar <[email protected]>

* Add cleaned up script for tp pp change

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support to change parameter dtypes during conversion

Signed-off-by: smajumdar <[email protected]>

* Add Debug Prints flag

Signed-off-by: smajumdar <[email protected]>

* Improve error logs

Signed-off-by: smajumdar <[email protected]>

* Fix issues with TP > 1 for Megatron T5

Signed-off-by: smajumdar <[email protected]>

* Finalize splitting of T5 models

Signed-off-by: smajumdar <[email protected]>

* Update docstrings

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Finalize pp tp change for T5 models

Signed-off-by: smajumdar <[email protected]>

* Fix CodeQL issue

Signed-off-by: smajumdar <[email protected]>

* Fix dtype cast of num_gpu_per_node

Signed-off-by: smajumdar <[email protected]>

* Update config

Signed-off-by: smajumdar <[email protected]>

* Remove block for config checks

Signed-off-by: smajumdar <[email protected]>

* Reduce shared embedding check for older configs

Signed-off-by: smajumdar <[email protected]>

* Add support for extracted directory path

Signed-off-by: smajumdar <[email protected]>

* Force CPU init for TP 1 PP 1 temp model

Signed-off-by: smajumdar <[email protected]>

* Patch T5 models to init fully on CPU

Signed-off-by: smajumdar <[email protected]>

* Update docstring

Signed-off-by: smajumdar <[email protected]>

* Update docstring

Signed-off-by: smajumdar <[email protected]>

* Update prints to logging

Signed-off-by: smajumdar <[email protected]>

* Patch apex code

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Patch typo

Signed-off-by: smajumdar <[email protected]>

* Fix import test of ModelType

Signed-off-by: smajumdar <[email protected]>

* Add docstring comment for nlp override

Signed-off-by: smajumdar <[email protected]>

* Merge new file with old file

Signed-off-by: smajumdar <[email protected]>

* Update script call signature

Signed-off-by: smajumdar <[email protected]>

* Remove comments

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update jenkins test

Signed-off-by: smajumdar <[email protected]>

* Fix formatting

Signed-off-by: smajumdar <[email protected]>

* Add open_dict hooks

Signed-off-by: smajumdar <[email protected]>

* Fix unit test

Signed-off-by: smajumdar <[email protected]>

* Fix unit test

Signed-off-by: smajumdar <[email protected]>

* Retry in another directory

Signed-off-by: smajumdar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert second test cause of shutil.rename error on CI

Signed-off-by: smajumdar <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <[email protected]>
Signed-off-by: hsiehjackson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants