Skip to content

Commit

Permalink
Virtual pipeline parallel support for MegatronGPTSFTModel (NVIDIA#7964)
Browse files Browse the repository at this point in the history
* Virtual pipeline parallel support for MegatronGPTSFTModel

Signed-off-by: Valerie Sarge <[email protected]>

* Deduplicate word embedding init code in MegatronGPTModel and MegatronGPTSFTModel into one method

Signed-off-by: Valerie Sarge <[email protected]>

* Correct TP group init call in MegatronGPTSFTModel to check for both TE and MCore, as in MegatronGPTModel

Signed-off-by: Valerie Sarge <[email protected]>

* Correct accidental double pipeline parallel size check

Signed-off-by: Valerie Sarge <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Correct get_gpt_module_list -> get_model_module_list from SFT model

Signed-off-by: Valerie Sarge <[email protected]>

---------

Signed-off-by: Valerie Sarge <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 17, 2024
1 parent 89c5411 commit 34753c7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 32 deletions.
29 changes: 16 additions & 13 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,19 +1204,7 @@ def setup(self, stage=None):
self.setup_test_data(self.cfg.data)

if stage == 'fit':
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.cfg.get('share_embeddings_and_output_weights', True):
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_gpt
else module.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
self.initialize_last_rank_embeddings()

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
Expand Down Expand Up @@ -1446,6 +1434,21 @@ def mgpt_wrapper(self):
def list_export_subnets(self):
return ['mgpt_wrapper']

def initialize_last_rank_embeddings(self):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.cfg.get('share_embeddings_and_output_weights', True):
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_gpt
else module.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def _reset_activation_checkpointing_args(self):
""" Disables activation checkpointing completely and saves the values so that
_restore_activation_checkpointing_args can restore them later. This function must always be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if hasattr(self.cfg.data.test_ds, "metric"):
self.test_metric_label_key = self.cfg.data.test_ds.metric.get('label_key', 'labels')

if self.cfg.get('megatron_amp_O2', False):
base_module = self.model.module
else:
base_module = self.model
if self.use_peft and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is not supported when using PEFT')

# Set the profile start and end steps in the unit of global batach
if hasattr(self, '_nsys_profile_enabled'):
Expand Down Expand Up @@ -197,17 +195,9 @@ def setup(self, stage=None):
raise NotImplementedError('Lightning 2.0 does not support multiple dataloaders with dataloader_iter')

# when using pipeline model parallel the final stage need to initialize word embeddings
if not self.cfg.get('mcore_gpt', False):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if isinstance(self.model, list):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
module.sync_initial_word_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
self.model.sync_initial_word_embeddings()
self.initialize_last_rank_embeddings()

if self.cfg.get('transformer_engine', False):
if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
self.setup_complete = True

Expand Down Expand Up @@ -358,16 +348,17 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

self.model.config.no_sync_func = no_sync_func
self.model.config.grad_sync_func = grad_sync_func
self.model.config.param_sync_func = param_sync_func
for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=data_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(data_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=seq_length,
Expand Down

0 comments on commit 34753c7

Please sign in to comment.