Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Virtual pipeline parallel support for MegatronGPTSFTModel #7964

Merged
merged 13 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,19 +1204,7 @@ def setup(self, stage=None):
self.setup_test_data(self.cfg.data)

if stage == 'fit':
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.cfg.get('share_embeddings_and_output_weights', True):
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_gpt
else module.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
self.initialize_last_rank_embeddings()

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
Expand Down Expand Up @@ -1446,6 +1434,21 @@ def mgpt_wrapper(self):
def list_export_subnets(self):
return ['mgpt_wrapper']

def initialize_last_rank_embeddings(self):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if self.cfg.get('share_embeddings_and_output_weights', True):
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_gpt
else module.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

def _reset_activation_checkpointing_args(self):
""" Disables activation checkpointing completely and saves the values so that
_restore_activation_checkpointing_args can restore them later. This function must always be
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
if hasattr(self.cfg.data.test_ds, "metric"):
self.test_metric_label_key = self.cfg.data.test_ds.metric.get('label_key', 'labels')

if self.cfg.get('megatron_amp_O2', False):
base_module = self.model.module
else:
base_module = self.model
if self.use_peft and self.cfg.get('virtual_pipeline_model_parallel_size', None):
raise ValueError('Virtual pipeline model parallel is not supported when using PEFT')

# Set the profile start and end steps in the unit of global batach
if hasattr(self, '_nsys_profile_enabled'):
Expand Down Expand Up @@ -197,17 +195,9 @@ def setup(self, stage=None):
raise NotImplementedError('Lightning 2.0 does not support multiple dataloaders with dataloader_iter')

# when using pipeline model parallel the final stage need to initialize word embeddings
if not self.cfg.get('mcore_gpt', False):
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if isinstance(self.model, list):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
module.sync_initial_word_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
self.model.sync_initial_word_embeddings()
self.initialize_last_rank_embeddings()

if self.cfg.get('transformer_engine', False):
if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_gpt', False):
self.setup_transformer_engine_tp_groups()
self.setup_complete = True

Expand Down Expand Up @@ -358,16 +348,17 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
grad_sync_func = self.reduce_overlap_gradients
param_sync_func = self.sync_overlap_parameters

self.model.config.no_sync_func = no_sync_func
self.model.config.grad_sync_func = grad_sync_func
self.model.config.param_sync_func = param_sync_func
for module in self.get_model_module_list():
module.config.no_sync_func = no_sync_func
module.config.grad_sync_func = grad_sync_func
module.config.param_sync_func = param_sync_func

fwd_bwd_function = get_forward_backward_func()

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=data_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(data_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=forward_only,
seq_length=seq_length,
Expand Down
Loading