Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing mcore bert for TP, PP and SP #8336

Merged
merged 23 commits into from
Feb 16, 2024
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
c897f39
Fixing mcore bert for TP, PP and SP
Feb 5, 2024
670f0c9
Merge branch 'r1.23.0' into bert_mcore_fix
ericharper Feb 5, 2024
38fd5a9
Fixing mcore bert for TP, PP and SP
Feb 5, 2024
3c2a76f
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 6, 2024
88b76c1
Fixing mcore version
Feb 6, 2024
3a8eefd
Fixing mcore version
Feb 6, 2024
3176d8f
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 6, 2024
fad047e
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 6, 2024
74c084e
Update Jenkinsfile
shanmugamr1992 Feb 6, 2024
d3b5bdf
Update Jenkinsfile
shanmugamr1992 Feb 7, 2024
f337cca
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 7, 2024
2510afd
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 8, 2024
683065e
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 9, 2024
47ba4fa
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 10, 2024
f642ac3
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 11, 2024
361c1fd
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 12, 2024
819ec70
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 12, 2024
b2e402b
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 13, 2024
20eedb8
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 14, 2024
864388b
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 15, 2024
6e1226a
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 15, 2024
08c718d
Update Jenkinsfile
shanmugamr1992 Feb 16, 2024
5232247
Merge branch 'r1.23.0' into bert_mcore_fix
shanmugamr1992 Feb 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 96 additions & 37 deletions nemo/collections/nlp/models/language_modeling/megatron_bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.

import itertools
from typing import Any, Dict, List, Optional
import queue
from typing import Any, Dict, Iterator, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -340,8 +341,8 @@ def training_step(self, dataloader_iter, batch_idx):

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=False,
seq_length=seq_length,
Expand Down Expand Up @@ -402,6 +403,65 @@ def training_step(self, dataloader_iter, batch_idx):

return loss_mean[0]

def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]:
""" Convert data iterator into form expected by Megatron
With interleaved pipeline parallelism, Megatron expects a
list of one data iterator per model chunk. Each model
chunk independently gets data from its data iterator, so
we need to interact with the data iterator multiple times
for each microbatch step. Instead of incorporating this
logic into the data loader, we cache the iterator's output
to the first model chunk and reuse it in the other model
chunks.
"""

if not isinstance(self.model, list) or len(self.model) == 1:
return data_iterator # TODO @tmoon: Remove
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList([data_iterator])

class CachingIterator:
"""Iterator wrapper that caches values"""

class Proxy:
"""Returns values from caching iterator wrapper
Assumed to never advance past the caching iterator.
"""

def __init__(self):
self.cache = queue.Queue()

def __iter__(self):
return self

def __next__(self):
return self.cache.get_nowait()

def __init__(self, iterator: Iterator):
self.iterator = iterator
self.proxies = []

def make_proxy(self):
self.proxies.append(CachingIterator.Proxy())
return self.proxies[-1]

def __iter__(self):
return self

def __next__(self):
val = next(self.iterator)
for proxy in self.proxies:
proxy.cache.put(val)
return val

# Make list of iterator wrappers
iters = [CachingIterator(data_iterator)]
while len(iters) < len(self.model):
iters.append(iters[0].make_proxy())
return iters # TODO @tmoon: Remove
# TODO @tmoon: Use once available in Megatron-LM
# return DataIteratorList(iters)

def allreduce_first_last_embeddings(self):

# Modified from megatron-lm: https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/training.py#L407
Expand All @@ -413,17 +473,16 @@ def allreduce_first_last_embeddings(self):
parallel_state.is_pipeline_first_stage(ignore_virtual=True)
or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
module_list = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[-1] # only the last virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
module = module_list[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
module = module_list[-1]

share_embeddings = (
module.share_embeddings_and_output_weights if self.mcore_bert else module.share_token_embeddings
)
if share_embeddings:
word_embeddings_weight = (
module.shared_embedding_or_output_weight() if self.mcore_bert else module.word_embeddings_weight()
)
Expand All @@ -450,8 +509,8 @@ def validation_step(self, dataloader_iter, batch_idx):

losses_reduced_per_micro_batch = fwd_bwd_function(
forward_step_func=self.get_forward_output_and_loss_func(),
data_iterator=dataloader_iter,
model=[self.model],
data_iterator=self._make_data_iterator_list(dataloader_iter),
model=self.model,
num_microbatches=get_num_microbatches(),
forward_only=True,
seq_length=seq_length,
Expand Down Expand Up @@ -724,23 +783,17 @@ def setup(self, stage=None):

# when using pipeline model parallel the final stage need to initialize word embeddings
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
if isinstance(self.model, list):
for i, module in enumerate(self.model):
parallel_state.set_virtual_pipeline_model_parallel_rank(i)
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(index)
sync_embeddings = (
module.initialize_last_stage_with_word_embeddings
if self.mcore_bert
else module.sync_initial_word_embeddings
)
sync_embeddings()
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
else:
sync_embeddings = (
self.model.initialize_last_stage_with_word_embeddings
if self.mcore_bert
else self.model.sync_initial_word_embeddings
)
sync_embeddings()
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
parallel_state.set_virtual_pipeline_model_parallel_rank(0)

if self.cfg.get('transformer_engine', False) or self.cfg.get('mcore_bert', False):
self.setup_transformer_engine_tp_groups()
Expand Down Expand Up @@ -914,22 +967,28 @@ def configure_optimizers(self):
# Disable overlapped grad sync for embedding grad when
# pipeline parallelism is enabled
if parallel_state.get_pipeline_model_parallel_world_size() > 1:
modules = self.get_model_module_list()
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
module = modules[0] # only the first virtual rank has the embeddings
if self.cfg.get('share_embeddings_and_output_weights', True):
param = (
module.shared_embedding_or_output_weight()
if self.mcore_bert
else module.word_embeddings_weight()
)
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True
if parallel_state.is_pipeline_last_stage(ignore_virtual=True):
if isinstance(self.model, list):
module = self.model[-1] # only the last virtual rank has the embeddings
if len(modules) > 1:
module = modules[-1] # only the last virtual rank has the embeddings
else:
module = self.model
if module.share_token_embeddings:
param = module.word_embeddings_weight()
module = modules[0]
if self.cfg.get('share_embeddings_and_output_weights', True):
param = (
module.shared_embedding_or_output_weight()
if self.mcore_bert
else module.word_embeddings_weight()
)
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True

Expand Down
Loading