Skip to content

Commit

Permalink
Fixing mcore bert for TP, PP and SP (NVIDIA#8336) (NVIDIA#8443)
Browse files Browse the repository at this point in the history
* Fixing mcore bert for TP, PP and SP

* Fixing mcore bert for TP, PP and SP

* Fixing mcore version

* Fixing mcore version

* Update Jenkinsfile



* Update Jenkinsfile



* Update Jenkinsfile



---------

Signed-off-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: Shanmugam Ramasamy <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
4 people authored Feb 20, 2024
1 parent 345525e commit 200a2d8
Showing 1 changed file with 96 additions and 37 deletions.
133 changes: 96 additions & 37 deletions nemo/collections/nlp/models/language_modeling/megatron_bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional
import queue
from typing import Any, Dict, Iterator, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -343,8 +344,8 @@ def training_step(self, dataloader_iter, batch_idx):

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=False,
seq_length=seq_length,
Expand Down Expand Up @@ -405,6 +406,65 @@ def training_step(self, dataloader_iter, batch_idx):

return loss_mean[0]

def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]:
""" Convert data iterator into form expected by Megatron
With interleaved pipeline parallelism, Megatron expects a
list of one data iterator per model chunk. Each model
chunk independently gets data from its data iterator, so
we need to interact with the data iterator multiple times
for each microbatch step. Instead of incorporating this
logic into the data loader, we cache the iterator's output
to the first model chunk and reuse it in the other model
chunks.
"""

if not isinstance(self.model, list) or len(self.model) == 1:
return data_iterator # TODO @tmoon: Remove
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList([data_iterator])

class CachingIterator:
"""Iterator wrapper that caches values"""

class Proxy:
"""Returns values from caching iterator wrapper
Assumed to never advance past the caching iterator.
"""

def __init__(self):
self.cache = queue.Queue()

def __iter__(self):
return self

def __next__(self):
return self.cache.get_nowait()

def __init__(self, iterator: Iterator):
self.iterator = iterator
self.proxies = []

def make_proxy(self):
self.proxies.append(CachingIterator.Proxy())
return self.proxies[-1]

def __iter__(self):
return self

def __next__(self):
val = next(self.iterator)
for proxy in self.proxies:
proxy.cache.put(val)
return val

# Make list of iterator wrappers
iters = [CachingIterator(data_iterator)]
while len(iters) < len(self.model):
iters.append(iters[0].make_proxy())
return iters # TODO @tmoon: Remove
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def allreduce_first_last_embeddings(self):

# Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407
Expand All @@ -416,17 +476,16 @@ def allreduce_first_last_embeddings(self):
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
module_list = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[-1] # only the last virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
module = module_list[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
module = module_list[-1]

share_embeddings = (
module.share_embeddings_and_output_weights if self.mcore_bert else module.share_token_embeddings
)
if share_embeddings:
word_embeddings_weight = (
module.shared_embedding_or_output_weight() if self.mcore_bert else module.word_embeddings_weight()
)
Expand All @@ -453,8 +512,8 @@ def validation_step(self, dataloader_iter, batch_idx):

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=True,
seq_length=seq_length,
Expand Down Expand Up @@ -727,23 +786,17 @@ def setup(self, stage=None):

# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if isinstance(self.model, list):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_bert
else module.sync_initial_word_embeddings
)
sync_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
sync_embeddings = (
self.model.initialize_last_stage_with_word_embeddings
if self.mcore_bert
else self.model.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_bert', False):
self.setup_transformer_engine_tp_groups()
Expand Down Expand Up @@ -917,22 +970,28 @@ def configure_optimizers(self):
# Disable overlapped grad sync for embedding grad when
# pipeline parallelism is enabled
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
module = modules[0] # only the first virtual rank has the embeddings
if self.cfg.get('share_embeddings_and_output_weights', True):
param = (
module.shared_embedding_or_output_weight()
if self.mcore_bert
else module.word_embeddings_weight()
)
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[-1] # only the last virtual rank has the embeddings
if len(modules) > 1:
module = modules[-1] # only the last virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
module = modules[0]
if self.cfg.get('share_embeddings_and_output_weights', True):
param = (
module.shared_embedding_or_output_weight()
if self.mcore_bert
else module.word_embeddings_weight()
)
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True

Expand Down

0 comments on commit 200a2d8

Please sign in to comment.