Skip to content

Commit 2e8796a

Browse files
authored
slym/sharp (#7391)
1 parent 2baef81 commit 2e8796a

File tree

4 files changed

+21
-0
lines changed

4 files changed

+21
-0
lines changed

nemo/collections/nlp/models/language_modeling/megatron_base_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
132132
global_batch_size=cfg.get('global_batch_size'),
133133
rampup_batch_size=cfg.get('rampup_batch_size'),
134134
use_fp8=cfg.get('fp8', False),
135+
use_sharp=cfg.get('sharp', False),
135136
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
136137
seed=self.cfg.get('seed', 1234),
137138
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),

nemo/collections/nlp/modules/common/megatron/megatron_init.py

+2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def initialize_model_parallel_for_nemo(
6767
global_batch_size=None,
6868
rampup_batch_size=None,
6969
use_fp8=False,
70+
use_sharp=False,
7071
init_mpi_proc_group=False,
7172
seed=1234,
7273
apex_transformer_log_level=30,
@@ -84,6 +85,7 @@ def initialize_model_parallel_for_nemo(
8485
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
8586
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
8687
app_state.use_fp8 = use_fp8
88+
app_state.use_sharp = use_sharp
8789
app_state.init_mpi_proc_group = init_mpi_proc_group
8890
(
8991
app_state.tensor_model_parallel_rank,

nemo/collections/nlp/parts/nlp_overrides.py

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
169169
virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
170170
pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
171171
use_fp8=app_state.use_fp8,
172+
use_sharp=app_state.use_sharp,
172173
)
173174

174175
# assert that fake tp and pp rank match after model parallel init

nemo/utils/app_state.py

+17
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self):
5555
self._data_parallel_group = None
5656
self._megatron_checkpoint_version = None
5757
self._use_fp8 = False
58+
self._use_sharp = False
5859
self._init_mpi_proc_gruop = False
5960

6061
self._random_seed = None
@@ -364,6 +365,22 @@ def use_fp8(self, use_fp8):
364365
"""
365366
self._use_fp8 = use_fp8
366367

368+
@property
369+
def use_sharp(self):
370+
""" Property returns the use of sharp.
371+
Returns:
372+
Use of sharp.
373+
"""
374+
return self._use_sharp
375+
376+
@use_sharp.setter
377+
def use_sharp(self, use_sharp):
378+
""" Property sets the use of sharp.
379+
Args:
380+
use_sharp: Use of sharp.
381+
"""
382+
self._use_sharp = use_sharp
383+
367384
@property
368385
def init_mpi_proc_group(self):
369386
""" Property sets the initialization of mpi process group.

0 commit comments

Comments
 (0)