Skip to content

Commit

Permalink
Add knob for load_directly_on_device (NVIDIA#9125)
Browse files Browse the repository at this point in the history
* Add knob for load_directly_on_device

Signed-off-by: Mikołaj Błaż <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Mikołaj Błaż <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pablo Garay <[email protected]>
  • Loading branch information
4 people authored May 10, 2024
1 parent 8d5648e commit 22c5851
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ model:
fsdp_grad_reduce_dtype: 32 # Gradient reduction data type.
fsdp_sharded_checkpoint: False # Store and load FSDP shared checkpoint.

# Distributed checkpoint format
# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _plugins(self) -> list:
self.cfg.model.get('mcore_gpt', False) or self.cfg.model.get('mcore_bert', False)
)
if use_dist_ckpt:
plugins.append(DistributedCheckpointIO(self.cfg.model.get('dist_ckpt_format', 'zarr')))
plugins.append(DistributedCheckpointIO.from_config(self.cfg.model))

return plugins

Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def use_distributed_checkpointing(self):
logging.warning(
'Distributed checkpoints requires DistributedCheckpointIO plugin to be used. Setting up a default now.'
)
self.checkpoint_io = DistributedCheckpointIO(self.lightning_module.cfg.get('dist_ckpt_format', 'zarr'))
self.checkpoint_io = DistributedCheckpointIO.from_config(self.lightning_module.cfg)
if not has_sharded_state_dict and has_dist_ckpt_io:
logging.warning(
'DistributedCheckpointIO configured but should not be used. Reverting back to TorchCheckpointIO'
Expand Down Expand Up @@ -1151,7 +1151,7 @@ def dummy():
tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
checkpoint_io = DistributedCheckpointIO(conf.get('dist_ckpt_format', 'zarr'))
checkpoint_io = DistributedCheckpointIO.from_config(conf)
checkpoint = checkpoint_io.load_checkpoint(tmp_model_weights_dir, sharded_state_dict=checkpoint)
instance.on_load_checkpoint(checkpoint)
if hasattr(instance, 'setup_transformer_engine_tp_groups'):
Expand Down
15 changes: 13 additions & 2 deletions nemo/utils/callbacks/dist_ckpt_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,25 @@ class DistributedCheckpointIO(CheckpointIO):
Args:
save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving.
load_directly_on_device (bool, optional): if True, loads the weights directly
on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
always loads on device). Defaults to True.
"""

def __init__(self, save_ckpt_format: str):
def __init__(self, save_ckpt_format: str, load_directly_on_device: bool = True):
super().__init__()
self.save_ckpt_format = save_ckpt_format
self.load_directly_on_device = load_directly_on_device

self.save_sharded_strategy = self.determine_dist_ckpt_save_strategy()

@classmethod
def from_config(cls, model_cfg):
return cls(
save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
)

def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
""" Saves a distributed checkpoint. Creates the checkpoint root directory if doesn't exist.
Expand Down Expand Up @@ -59,7 +70,7 @@ def load_checkpoint(
if map_location is not None:
raise ValueError('DistributedCheckpointIO doesnt handle map_location argument')

if self.save_ckpt_format == 'zarr':
if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
else:
sharded_strategy = None
Expand Down

0 comments on commit 22c5851

Please sign in to comment.