From 000316c8c3c8a10a0bbbdd54cac5981719e53db2 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Feb 2024 22:23:18 +0000 Subject: [PATCH 1/6] MoE parameter passing Signed-off-by: Alexandros Koumparoulis --- .../language_modeling/megatron_base_model.py | 7 +++- .../modules/common/megatron/megatron_init.py | 27 ++++++++++++++ nemo/collections/nlp/parts/nlp_overrides.py | 1 + nemo/utils/app_state.py | 36 +++++++++++++++++++ 4 files changed, 70 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 27669ea2f31c..0c4e87e43b34 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -160,7 +160,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): # Overrides used when converting checkpoints if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true": app_state = AppState() - init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size + init_world_size = ( + app_state.tensor_model_parallel_size + * app_state.pipeline_model_parallel_size + * app_state.expert_model_parallel_size + ) init_global_rank = app_state.global_rank init_local_rank = app_state.local_rank else: @@ -185,6 +189,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): global_rank=init_global_rank, local_rank=init_local_rank, tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), + expert_model_parallel_size=cfg.get('expert_model_parallel_size', 1), pipeline_model_parallel_size=cfg.get('pipeline_model_parallel_size', 1), virtual_pipeline_model_parallel_size=vp_size, pipeline_model_parallel_split_rank=cfg.get('pipeline_model_parallel_split_rank', 0), diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 013838e7688e..33e59acd5739 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -33,6 +33,8 @@ from megatron.core import tensor_parallel from megatron.core.parallel_state import ( get_pipeline_model_parallel_rank, + set_expert_model_parallel_rank, + set_expert_model_parallel_world_size, set_pipeline_model_parallel_rank, set_pipeline_model_parallel_split_rank, set_pipeline_model_parallel_world_size, @@ -60,6 +62,7 @@ def initialize_model_parallel_for_nemo( global_rank, local_rank, tensor_model_parallel_size=1, + expert_model_parallel_size=1, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, pipeline_model_parallel_split_rank=None, @@ -81,6 +84,7 @@ def initialize_model_parallel_for_nemo( app_state.global_rank = global_rank app_state.world_size = world_size app_state.local_rank = local_rank + app_state.expert_model_parallel_size = expert_model_parallel_size app_state.tensor_model_parallel_size = tensor_model_parallel_size app_state.pipeline_model_parallel_size = pipeline_model_parallel_size app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size @@ -90,6 +94,7 @@ def initialize_model_parallel_for_nemo( ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, + app_state.expert_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, @@ -102,12 +107,16 @@ def initialize_model_parallel_for_nemo( virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=pipeline_model_parallel_split_rank, context_parallel_size_=context_parallel_size, + expert_model_parallel_size_=expert_model_parallel_size, ) # update apex.transformer globals set_tensor_model_parallel_world_size(app_state.tensor_model_parallel_size) set_tensor_model_parallel_rank(app_state.tensor_model_parallel_rank) + set_expert_model_parallel_world_size(app_state.expert_model_parallel_size) + set_expert_model_parallel_rank(app_state.expert_model_parallel_rank) + set_pipeline_model_parallel_rank(app_state.pipeline_model_parallel_rank) if HAVE_INTERLEAVED: set_virtual_pipeline_model_parallel_world_size(app_state.virtual_pipeline_model_parallel_size) @@ -179,6 +188,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size_, pipeline_model_parallel_split_rank_=None, virtual_pipeline_model_parallel_size_=None, + expert_model_parallel_size_=None, context_parallel_size_=1, ): """ @@ -302,6 +312,22 @@ def fake_initialize_model_parallel( logging.info(f'All tensor model parallel group ranks: {all_tensor_model_parallel_group_ranks}') logging.info(f'Rank {rank} has tensor model parallel rank: {tensor_model_parallel_rank}') + # EP rank + expert_model_parallel_rank = 0 + if expert_model_parallel_size_ is not None and expert_model_parallel_size_ > 1: + tensor_and_data_group_size: int = tensor_model_parallel_size * data_parallel_size + num_tensor_and_data_groups: int = world_size // tensor_and_data_group_size + tensor_and_expert_group_size: int = tensor_model_parallel_size * expert_model_parallel_size_ + num_expert_groups: int = data_parallel_size // expert_model_parallel_size_ + for i in range(num_tensor_and_data_groups): + for j in range(num_expert_groups): + start_rank = i * tensor_and_data_group_size + j * tensor_and_expert_group_size + end_rank = i * tensor_and_data_group_size + (j + 1) * tensor_and_expert_group_size + ranks = range(start_rank, end_rank) + if rank in ranks: + expert_model_parallel_rank = list(ranks).index(rank) + + # Build the pipeline model-parallel groups and embedding groups # (first and last rank in each pipeline model-parallel group). all_pipeline_model_parallel_group_ranks = [] @@ -340,6 +366,7 @@ def fake_initialize_model_parallel( return ( tensor_model_parallel_rank, pipeline_model_parallel_rank, + expert_model_parallel_rank, model_parallel_size, data_parallel_size, pipeline_model_parallel_split_rank_, diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index d2655d3ea60e..50386afc1893 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -129,6 +129,7 @@ def init_model_parallel(sharp: bool, nccl_communicator_config_path: str = None) context_parallel_size=app_state.context_parallel_size, nccl_communicator_config_path=nccl_communicator_config_path, use_sharp=sharp, + expert_model_parallel_size=app_state.expert_model_parallel_size, ) # assert that fake tp and pp rank match after model parallel init diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index eb6b6d91ba5e..1d28d2940afb 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -39,6 +39,7 @@ def __init__(self): self._local_rank = None self._global_rank = None self._tensor_model_parallel_rank = None + self._expert_model_parallel_rank = None self._pipeline_model_parallel_rank = None self._data_parallel_rank = None @@ -46,6 +47,7 @@ def __init__(self): self._model_parallel_size = None self._tensor_model_parallel_size = None self._tensor_model_parallel_group = None + self._expert_model_parallel_size = None self._pipeline_model_parallel_size = None self._virtual_pipeline_model_parallel_size = None self._pipeline_model_parallel_group = None @@ -141,6 +143,40 @@ def tensor_model_parallel_size(self, size): """ self._tensor_model_parallel_size = size + @property + def expert_model_parallel_rank(self): + """ Property returns the expert model parallel rank. + Returns: + Tensor model parallel rank. + """ + return self._expert_model_parallel_rank + + @expert_model_parallel_rank.setter + def expert_model_parallel_rank(self, rank): + """ Property sets the expert model parallel rank. + Args: + rank (int): Tensor model parallel rank. + """ + self._expert_model_parallel_rank = rank + + + @property + def expert_model_parallel_size(self): + """ Property returns the number of GPUs in each expert parallel group. + Returns: + Number of GPUs in each expert parallel group. + """ + return self._expert_model_parallel_size + + @expert_model_parallel_size.setter + def expert_model_parallel_size(self, size): + """ Property sets the number of GPUs in each expert parallel group. + Args: + size (int): Number of GPUs in each expert parallel group. + """ + self._expert_model_parallel_size = size + + @property def pipeline_model_parallel_size(self): """ Property returns the number of GPUs in each model parallel group. From cd2fd16b67e18c2912eb0fc198ca6a1da57649ee Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Fri, 16 Feb 2024 22:28:46 +0000 Subject: [PATCH 2/6] Pass EP/MoE params in consumer scripts. Signed-off-by: Alexandros Koumparoulis --- .../nlp/language_modeling/megatron_gpt_eval.py | 17 ++++++++++++++--- .../tuning/megatron_gpt_sft.py | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index e31c80dedee6..54b23a6664dc 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -199,7 +199,7 @@ def main(cfg) -> None: assert ( cfg.trainer.devices * cfg.trainer.num_nodes - == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * max(1, cfg.expert_model_parallel_size) ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" if cfg.gpt_model_file: @@ -224,6 +224,8 @@ def main(cfg) -> None: # with dist checkpointing we can use the model parallel config specified by the user pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + pretrained_cfg.expert_model_parallel_size = cfg.expert_model_parallel_size + pretrained_cfg.micro_batch_size = 1 if trainer.precision == "16": pretrained_cfg.megatron_amp_O2 = False elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False): @@ -237,13 +239,21 @@ def main(cfg) -> None: ) elif cfg.checkpoint_dir: app_state = AppState() - if cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1: - app_state.model_parallel_size = cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size + if ( + cfg.tensor_model_parallel_size > 1 + or cfg.pipeline_model_parallel_size > 1 + or cfg.expert_model_parallel_size > 1 + ): + app_state.model_parallel_size = ( + cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * cfg.expert_model_parallel_size + ) app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + app_state.expert_model_parallel_size = cfg.expert_model_parallel_size ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, + app_state.expert_model_parallel_rank, app_state.model_parallel_size, app_state.data_parallel_size, app_state.pipeline_model_parallel_split_rank, @@ -254,6 +264,7 @@ def main(cfg) -> None: tensor_model_parallel_size_=cfg.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + expert_model_parallel_size_=cfg.expert_model_parallel_size, ) checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) # checkpoint_path is a dir in case of distributed checkpointing diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index 295685aacb97..44d0737ad44e 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -73,6 +73,7 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): gpt_cfg.ffn_dropout = cfg.model.ffn_dropout gpt_cfg.use_flash_attention = cfg.model.get('use_flash_attention', False) gpt_cfg.tensor_model_parallel_size = cfg.model.get('tensor_model_parallel_size', 1) + gpt_cfg.expert_model_parallel_size = cfg.model.get('expert_model_parallel_size', 1) gpt_cfg.pipeline_model_parallel_size = cfg.model.get('pipeline_model_parallel_size', 1) gpt_cfg.pipeline_model_parallel_split_rank = cfg.model.get('pipeline_model_parallel_split_rank', 0) From f26b639c6af016eb0c66e69486b3f5b127ac450c Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 21 Feb 2024 14:25:23 -0800 Subject: [PATCH 3/6] PR fixes Signed-off-by: Alexandros Koumparoulis --- Jenkinsfile | 2 +- examples/nlp/language_modeling/megatron_gpt_eval.py | 12 ++++++------ .../nlp/modules/common/megatron/megatron_init.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1f8f8662c72d..bf71a044568c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -91,7 +91,7 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 240a8ef7a21df201e47b5b2ae33cc5f4c5486849 && \ + git checkout 5f9c870f9f24b482509699d206a9dbb00958f6fc && \ pip install .' } } diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 54b23a6664dc..9d012a66e906 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -199,7 +199,7 @@ def main(cfg) -> None: assert ( cfg.trainer.devices * cfg.trainer.num_nodes - == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * max(1, cfg.expert_model_parallel_size) + == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * max(1, cfg.get('expert_model_parallel_size', 1)) ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" if cfg.gpt_model_file: @@ -224,7 +224,7 @@ def main(cfg) -> None: # with dist checkpointing we can use the model parallel config specified by the user pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size - pretrained_cfg.expert_model_parallel_size = cfg.expert_model_parallel_size + pretrained_cfg.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) pretrained_cfg.micro_batch_size = 1 if trainer.precision == "16": pretrained_cfg.megatron_amp_O2 = False @@ -242,14 +242,14 @@ def main(cfg) -> None: if ( cfg.tensor_model_parallel_size > 1 or cfg.pipeline_model_parallel_size > 1 - or cfg.expert_model_parallel_size > 1 + or cfg.get('expert_model_parallel_size', 1) > 1 ): app_state.model_parallel_size = ( - cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * cfg.expert_model_parallel_size + cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * cfg.get('expert_model_parallel_size', 1) ) app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size - app_state.expert_model_parallel_size = cfg.expert_model_parallel_size + app_state.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) ( app_state.tensor_model_parallel_rank, app_state.pipeline_model_parallel_rank, @@ -264,7 +264,7 @@ def main(cfg) -> None: tensor_model_parallel_size_=cfg.tensor_model_parallel_size, pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, - expert_model_parallel_size_=cfg.expert_model_parallel_size, + expert_model_parallel_size_=cfg.get('expert_model_parallel_size', 1), ) checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) # checkpoint_path is a dir in case of distributed checkpointing diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 33e59acd5739..132d9e1325d1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -188,7 +188,7 @@ def fake_initialize_model_parallel( pipeline_model_parallel_size_, pipeline_model_parallel_split_rank_=None, virtual_pipeline_model_parallel_size_=None, - expert_model_parallel_size_=None, + expert_model_parallel_size_=1, context_parallel_size_=1, ): """ From e4abac163fdf970866212a4dae974f8e8cbd73c6 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Wed, 21 Feb 2024 23:22:41 +0000 Subject: [PATCH 4/6] Use latest commit of mcore-0.5 Signed-off-by: Alexandros Koumparoulis --- Jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index bf71a044568c..e929c1d8ffd7 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -91,7 +91,7 @@ pipeline { steps { sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 5f9c870f9f24b482509699d206a9dbb00958f6fc && \ + git checkout 98da3792f53c80ac9e865eab49a6fa5ccc293d22 && \ pip install .' } } From c9576710fbcae694816a507369fa054f1e475891 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Thu, 22 Feb 2024 12:39:33 +0000 Subject: [PATCH 5/6] CI fix Signed-off-by: Alexandros Koumparoulis --- .../nlp/models/language_modeling/megatron_base_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 0c4e87e43b34..8aa1ecf26240 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -163,7 +163,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): init_world_size = ( app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size - * app_state.expert_model_parallel_size + * (app_state.expert_model_parallel_size or 1) ) init_global_rank = app_state.global_rank init_local_rank = app_state.local_rank From 8d946f5067eb5a8e8aec177917339dd06316a16c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 14:57:00 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/nlp/language_modeling/megatron_gpt_eval.py | 8 ++++++-- .../nlp/modules/common/megatron/megatron_init.py | 1 - nemo/utils/app_state.py | 2 -- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index 9d012a66e906..96cd75b546c1 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -199,7 +199,9 @@ def main(cfg) -> None: assert ( cfg.trainer.devices * cfg.trainer.num_nodes - == cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * max(1, cfg.get('expert_model_parallel_size', 1)) + == cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * max(1, cfg.get('expert_model_parallel_size', 1)) ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" if cfg.gpt_model_file: @@ -245,7 +247,9 @@ def main(cfg) -> None: or cfg.get('expert_model_parallel_size', 1) > 1 ): app_state.model_parallel_size = ( - cfg.tensor_model_parallel_size * cfg.pipeline_model_parallel_size * cfg.get('expert_model_parallel_size', 1) + cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * cfg.get('expert_model_parallel_size', 1) ) app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 132d9e1325d1..5f402707fb59 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -327,7 +327,6 @@ def fake_initialize_model_parallel( if rank in ranks: expert_model_parallel_rank = list(ranks).index(rank) - # Build the pipeline model-parallel groups and embedding groups # (first and last rank in each pipeline model-parallel group). all_pipeline_model_parallel_group_ranks = [] diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 1d28d2940afb..8ba9880219ec 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -159,7 +159,6 @@ def expert_model_parallel_rank(self, rank): """ self._expert_model_parallel_rank = rank - @property def expert_model_parallel_size(self): """ Property returns the number of GPUs in each expert parallel group. @@ -176,7 +175,6 @@ def expert_model_parallel_size(self, size): """ self._expert_model_parallel_size = size - @property def pipeline_model_parallel_size(self): """ Property returns the number of GPUs in each model parallel group.