Skip to content

Commit c2daa91

Browse files
Implement async distributed checkpoint save (#9028)
* Prevent duplicated checkpoints Signed-off-by: Mikołaj Błaż <[email protected]> * Introduce DistributedCheckpointIO Signed-off-by: Mikołaj Błaż <[email protected]> * Fix DistCkptIO usage Signed-off-by: Mikołaj Błaż <[email protected]> * Use NeMo logger Signed-off-by: Mikołaj Błaż <[email protected]> * [DCIO] Fix save_to dist ckpt path Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add versioning to save_to Signed-off-by: Mikołaj Błaż <[email protected]> * Add versioning logic to all .nemo files Signed-off-by: Mikołaj Błaż <[email protected]> * Add versioning test Signed-off-by: Mikołaj Błaż <[email protected]> * Add dist-ckpt test Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż <[email protected]> * Rename existing ckpts instead of using different name Signed-off-by: Mikołaj Błaż <[email protected]> * Add comment Signed-off-by: Mikołaj Błaż <[email protected]> * Use dist ckpt flag in all methods Signed-off-by: Mikołaj Błaż <[email protected]> * Improve error msg Signed-off-by: Mikołaj Błaż <[email protected]> * Add dist ckpt unit tests Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load_checkpoint Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż <[email protected]> * Fix auto-issues Signed-off-by: Mikołaj Błaż <[email protected]> * Fix ckpt_dir var Signed-off-by: Mikołaj Błaż <[email protected]> * Restore skipping behavior The fix from prevent-duplicated-checkpoints is required to skip the checkpoints Signed-off-by: Mikołaj Błaż <[email protected]> * Fix steps on single-GPU machine Signed-off-by: Mikołaj Błaż <[email protected]> * Run dist-ckpt test on GPU Signed-off-by: Mikołaj Błaż <[email protected]> * Add docs Signed-off-by: Mikołaj Błaż <[email protected]> * Apply black Signed-off-by: Mikołaj Błaż <[email protected]> * Prevent saving last for non-equal val intervals Signed-off-by: Mikołaj Błaż <[email protected]> * Move checkpoint on rank 0 Signed-off-by: Mikołaj Błaż <[email protected]> * Fix num steps in tests Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Mikołaj Błaż <[email protected]> * Add async ckpt implementation Signed-off-by: Mikołaj Błaż <[email protected]> * Abstract AsyncFinalizableCheckpointIO away Signed-off-by: Mikołaj Błaż <[email protected]> * Change async_save flag location Signed-off-by: Mikołaj Błaż <[email protected]> * Add debug info Signed-off-by: Mikołaj Błaż <[email protected]> * Apply formatting Signed-off-by: Mikołaj Błaż <[email protected]> * Handle multiple async saves Signed-off-by: Mikołaj Błaż <[email protected]> * Apply formatting Signed-off-by: Mikołaj Błaż <[email protected]> * Move finalization calls to a callback Signed-off-by: Mikołaj Błaż <[email protected]> * Avoid deadlock in teardown Signed-off-by: Mikołaj Błaż <[email protected]> * Adjust to MCore implementation Signed-off-by: Mikołaj Błaż <[email protected]> * Add notes and copyrights Signed-off-by: Mikołaj Błaż <[email protected]> * Apply formatting Signed-off-by: Mikołaj Błaż <[email protected]> * Fix async_request attribute Signed-off-by: Mikołaj Błaż <[email protected]> * Add MCore import guards Signed-off-by: Mikołaj Błaż <[email protected]> * Add async test Signed-off-by: Mikołaj Błaż <[email protected]> * Fix finalize_fn arg Signed-off-by: Mikołaj Błaż <[email protected]> * Add docs Signed-off-by: Mikołaj Błaż <[email protected]> * Remove checkpoints from accurate steps Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix MCore class usage Signed-off-by: Mikołaj Błaż <[email protected]> * Update docs Signed-off-by: Mikołaj Błaż <[email protected]> * Fix logger usage Signed-off-by: Mikołaj Błaż <[email protected]> * Fix rebase Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix code scan issues Signed-off-by: Mikołaj Błaż <[email protected]> * Remove unsused import Signed-off-by: Mikołaj Błaż <[email protected]> * Use dist-ckpt for Bert Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix load checkpoint return val Signed-off-by: Mikołaj Błaż <[email protected]> * Use dist-ckpt based on sharded_state_dict Signed-off-by: Mikołaj Błaż <[email protected]> * Add async logging Signed-off-by: Mikołaj Błaż <[email protected]> * Remove deprecated argument Signed-off-by: Mikołaj Błaż <[email protected]> * Use correct checkpoint_io Signed-off-by: Mikołaj Błaż <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bad merge Signed-off-by: Mikołaj Błaż <[email protected]> * Improve debug msg Signed-off-by: Mikołaj Błaż <[email protected]> * Run async test on GPU Signed-off-by: Mikołaj Błaż <[email protected]> * Fix async ckpt unit test Signed-off-by: Mikołaj Błaż <[email protected]> * Apply isort and black reformatting Signed-off-by: mikolajblaz <[email protected]> * Clarify async logs Signed-off-by: Mikołaj Błaż <[email protected]> * Add schema print Signed-off-by: Mikołaj Błaż <[email protected]> --------- Signed-off-by: Mikołaj Błaż <[email protected]> Signed-off-by: mikolajblaz <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5df8e11 commit c2daa91

File tree

9 files changed

+806
-109
lines changed

9 files changed

+806
-109
lines changed

examples/nlp/language_modeling/conf/megatron_gpt_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ exp_manager:
5252
save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits
5353
filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}'
5454
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
55+
async_save: False # Set to True to enable async checkpoint save. Currently works only with distributed checkpoints
5556

5657
model:
5758
# use GPTModel from megatron.core

nemo/collections/nlp/parts/megatron_trainer_builder.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
# limitations under the License.
1414

1515
import sys
16-
from typing import Union
16+
from typing import Optional, Union
1717

18+
from lightning_fabric.utilities.exceptions import MisconfigurationException
1819
from omegaconf import DictConfig
1920
from pytorch_lightning import Trainer
2021
from pytorch_lightning.callbacks import ModelSummary
@@ -31,7 +32,11 @@
3132
PipelineMixedPrecisionPlugin,
3233
)
3334
from nemo.utils import logging
34-
from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO
35+
from nemo.utils.callbacks.dist_ckpt_io import (
36+
AsyncFinalizableCheckpointIO,
37+
AsyncFinalizerCallback,
38+
DistributedCheckpointIO,
39+
)
3540

3641

3742
class MegatronTrainerBuilder:
@@ -51,7 +56,10 @@ def _training_strategy(self) -> Union[NLPDDPStrategy, NLPFSDPStrategy]:
5156
_IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive)
5257
if _IS_INTERACTIVE and self.cfg.trainer.devices == 1:
5358
logging.info("Detected interactive environment, using NLPDDPStrategyNotebook")
54-
return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,)
59+
return NLPDDPStrategyNotebook(
60+
no_ddp_communication_hook=True,
61+
find_unused_parameters=False,
62+
)
5563

5664
if self.cfg.model.get('fsdp', False):
5765
assert (
@@ -89,7 +97,7 @@ def _grad_scaler(self) -> GradScaler:
8997
Returns a scaler for precision plugins.
9098
"""
9199
return GradScaler(
92-
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
100+
init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
93101
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
94102
hysteresis=self.cfg.model.get('hysteresis', 2),
95103
)
@@ -137,19 +145,41 @@ def _plugins(self) -> list:
137145
use_dist_ckpt = not self.cfg.model.get('fsdp', False) and (
138146
self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False)
139147
)
148+
async_save = self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False)
140149
if use_dist_ckpt:
141-
plugins.append(DistributedCheckpointIO.from_config(self.cfg.model))
150+
checkpoint_io = DistributedCheckpointIO.from_config(self.cfg.model, async_save)
151+
if async_save:
152+
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)
153+
plugins.append(checkpoint_io)
154+
elif async_save:
155+
raise MisconfigurationException(
156+
'exp_manager.checkpoint_callback_params.async_save=True without'
157+
'distributed checkpoints is currently not supported'
158+
)
142159

143160
return plugins
144161

162+
def _callbacks(self, callbacks: Optional[list]) -> list:
163+
"""
164+
Returns:
165+
callbacks: list of callbacks passed to Trainer.callbacks.
166+
"""
167+
if callbacks is None:
168+
callbacks = []
169+
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
170+
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
171+
callbacks.append(CustomProgressBar())
172+
173+
if self.cfg.exp_manager.checkpoint_callback_params.get('async_save', False):
174+
callbacks.append(AsyncFinalizerCallback())
175+
return callbacks
176+
145177
def create_trainer(self, callbacks=None) -> Trainer:
146178
# cfg.trainer.precision becomes None in Trainer if precision_plugins exist since both precision plugins and precision
147179
precision = self.cfg.trainer.precision
148180
strategy = self._training_strategy()
149181
plugins = self._plugins()
150-
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
151-
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
152-
callbacks = [CustomProgressBar()]
182+
callbacks = self._callbacks(callbacks)
153183
trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
154184
# Restore the precision value after Trainer is built.
155185
self.cfg.trainer.precision = precision
@@ -161,21 +191,23 @@ class MegatronBertTrainerBuilder(MegatronTrainerBuilder):
161191

162192
def _grad_scaler(self) -> GradScaler:
163193
return GradScaler(
164-
init_scale=self.cfg.model.get('native_amp_init_scale', 2 ** 32),
194+
init_scale=self.cfg.model.get('native_amp_init_scale', 2**32),
165195
growth_interval=self.cfg.model.get('native_amp_growth_interval', 1000),
166196
)
167197

168198

169199
class MegatronT5TrainerBuilder(MegatronTrainerBuilder):
170200
"""Builder for T5 model Trainer with overrides."""
171201

172-
def create_trainer(self) -> Trainer:
202+
def _callbacks(self, callbacks: Optional[list]) -> list:
203+
callbacks = super()._callbacks(callbacks)
204+
callbacks.append(ModelSummary(max_depth=3))
205+
return callbacks
206+
207+
def create_trainer(self, callbacks=None) -> Trainer:
173208
strategy = self._training_strategy()
174209
plugins = self._plugins()
175-
callbacks = [ModelSummary(max_depth=3)]
176-
# enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks
177-
if 'enable_progress_bar' not in self.cfg.trainer or self.cfg.trainer.enable_progress_bar:
178-
callbacks.append(CustomProgressBar())
210+
callbacks = self._callbacks(callbacks)
179211
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
180212

181213

@@ -207,7 +239,7 @@ class MegatronLMPPTrainerBuilder(MegatronTrainerBuilder):
207239

208240
def _grad_scaler(self) -> GradScaler:
209241
return GradScaler(
210-
init_scale=self.cfg.model.get("native_amp_init_scale", 2 ** 32),
242+
init_scale=self.cfg.model.get("native_amp_init_scale", 2**32),
211243
growth_interval=self.cfg.model.get("native_amp_growth_interval", 1000),
212244
hysteresis=self.cfg.model.get("hysteresis", 2),
213245
enabled=False if self.cfg.model.pipeline_model_parallel_size > 1 else True,

0 commit comments

Comments
 (0)