Skip to content

Commit e40e720

Browse files
JimmyZhang12jiemingz
authored andcommitted
disable overlap_param_gather_with_optimizer_step (NVIDIA#11102)
* disable overlap_param_gather_with_optimizer_step Signed-off-by: Jimmy Zhang <[email protected]> * fix comment Signed-off-by: Jieming Zhang <[email protected]> * Apply isort and black reformatting Signed-off-by: JimmyZhang12 <[email protected]> * fix typo again Signed-off-by: Jieming Zhang <[email protected]> * Apply isort and black reformatting Signed-off-by: JimmyZhang12 <[email protected]> --------- Signed-off-by: Jimmy Zhang <[email protected]> Signed-off-by: Jieming Zhang <[email protected]> Signed-off-by: JimmyZhang12 <[email protected]> Co-authored-by: Jimmy Zhang <[email protected]> Co-authored-by: JimmyZhang12 <[email protected]>
1 parent 4b7707c commit e40e720

File tree

8 files changed

+11
-8
lines changed

8 files changed

+11
-8
lines changed

nemo/collections/llm/recipes/gpt3_175b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
229229
tp_comm_overlap_cfg=userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048,
230230
defer_embedding_wgrad_compute=True,
231231
wgrad_deferral_limit=50,
232-
overlap_param_gather_with_optimizer_step=True,
232+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
233233
align_param_gather=True,
234234
)
235235
)

nemo/collections/llm/recipes/llama31_405b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
231231
tp_comm_overlap_cfg=userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192,
232232
defer_embedding_wgrad_compute=True,
233233
wgrad_deferral_limit=50,
234-
overlap_param_gather_with_optimizer_step=True,
234+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
235235
align_param_gather=True,
236236
)
237237
)

nemo/collections/llm/recipes/llama3_70b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
232232
tp_comm_overlap_cfg=userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192,
233233
defer_embedding_wgrad_compute=True,
234234
wgrad_deferral_limit=22,
235-
overlap_param_gather_with_optimizer_step=True,
235+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing.
236236
align_param_gather=True,
237237
)
238238
)

nemo/collections/llm/recipes/mixtral_8x22b.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,9 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
226226
MegatronTokenDropCallback,
227227
),
228228
run.Config(
229-
MegatronCommOverlapCallback, overlap_param_gather_with_optimizer_step=True, align_param_gather=True
229+
MegatronCommOverlapCallback,
230+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
231+
align_param_gather=True,
230232
),
231233
]
232234
)

nemo/collections/llm/recipes/mixtral_8x7b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
222222
run.Config(MegatronTokenDropCallback),
223223
run.Config(
224224
MegatronCommOverlapCallback,
225-
overlap_param_gather_with_optimizer_step=True,
225+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing.
226226
align_param_gather=True,
227227
),
228228
]

nemo/collections/llm/recipes/nemotron4_22b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
209209
tp_comm_overlap=True,
210210
defer_embedding_wgrad_compute=True,
211211
wgrad_deferral_limit=22,
212-
overlap_param_gather_with_optimizer_step=True,
212+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
213213
align_param_gather=True,
214214
)
215215
)

nemo/collections/llm/recipes/nemotron4_340b.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial:
212212
tp_comm_overlap=True,
213213
defer_embedding_wgrad_compute=True,
214214
wgrad_deferral_limit=22,
215-
overlap_param_gather_with_optimizer_step=True,
215+
overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing
216216
align_param_gather=True,
217217
)
218218
)

nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ def _get_optimizer_overlap_cfgs(self, parallelism_cfg: ParallelismConfig) -> _Co
181181
comm_overlap_cfg.overlap_grad_reduce = True
182182
comm_overlap_cfg.overlap_param_gather = True
183183
if parallelism_cfg.pipeline_model_parallel_size > 1 and vp_size > 1:
184-
comm_overlap_cfg.overlap_param_gather_with_optimizer_step = True
184+
# Currently disabled due to an issue with checkpointing
185+
# comm_overlap_cfg.overlap_param_gather_with_optimizer_step = True
185186
comm_overlap_cfg.align_param_gather = True
186187

187188
comm_overlap_cfg = self._override_user_cfgs(comm_overlap_cfg)

0 commit comments

Comments
 (0)