Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of
'torch_dist' or 'zarr'. Defaults to 'torch_dist'.
ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead.
Defaults to True.
Defaults to False.
ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save
with PyTorch distributed format. Defaults to None.
ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves.
Expand Down Expand Up @@ -292,7 +292,7 @@ def __init__(
use_te_rng_tracker: bool = False,
use_sharp: bool = False,
save_ckpt_format: str = "torch_dist",
ckpt_async_save: bool = True,
ckpt_async_save: bool = False,
ckpt_torch_dist_multiproc: int = None, ## TODO(ashors): put elsewhere?
ckpt_assume_constant_structure: bool = False,
ckpt_parallel_save: bool = True,
Expand Down
4 changes: 3 additions & 1 deletion nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def create_checkpoint_io(wrapping_ckpt_io=None, **kwargs):

if wrapping_ckpt_io:
checkpoint_io = wrapping_ckpt_io(checkpoint_io)
if kwargs.get("async_save", False):

async_save = kwargs.get("async_save", False)
if async_save:
checkpoint_io = AsyncFinalizableCheckpointIO(checkpoint_io)

return checkpoint_io
Expand Down
6 changes: 5 additions & 1 deletion nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ class AsyncFinalizableCheckpointIO(_WrappingCheckpointIO):
AsyncCheckpointIO does). Allows to perform a (synchronous) finalization
function after all ranks finish checkpoint saving.

This wrapper always creates the AsyncCallsQueue with persistent workers. This is known
to increase memory usage, and sometimes leads to out of memory errors.

NOTE: for correctness, this plugin must be used together with the
AsyncFinalizerCallback callback which performs the finalization checks.

Args:
checkpoint_io (CheckpointIO): wrapped checkpoint_io object. Must be
of type AsyncCompatibleCheckpointIO.
persistent_workers (bool): whether to use persistent workers for checkpoint writing. Defaults to False.
Requires the underlying checkpoint_io.save_checkpoint to return save_fn, save_args, finalize_fn.
"""

Expand All @@ -108,7 +112,7 @@ def __init__(self, checkpoint_io: AsyncCompatibleCheckpointIO) -> None:
raise ValueError(f'Incompatible wrapped checkpoint_io type: {type(checkpoint_io)}')

super().__init__(checkpoint_io)
self.async_calls_queue = AsyncCallsQueue()
self.async_calls_queue = AsyncCallsQueue(persistent=True)

def save_checkpoint(
self,
Expand Down
Loading