-
Notifications
You must be signed in to change notification settings - Fork 278
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #495 Added OSS checkpointing docs Reviewed By: daniellepintz Differential Revision: D46036738 fbshipit-source-id: 6d40a854f09c597c6a7503229f48423df532f060
- Loading branch information
1 parent
5150591
commit a9ad674
Showing
2 changed files
with
65 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
Checkpointing | ||
================================ | ||
|
||
TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://github.com/pytorch/torchsnapshot>`_ under the hood. | ||
|
||
.. code-block:: python | ||
module = nn.Linear(input_dim, 1) | ||
unit = MyAutoUnit(module=module) | ||
tss = TorchSnapshotSaver( | ||
dirpath=your_dirpath_here, | ||
save_every_n_train_steps=100, | ||
save_every_n_epochs=2, | ||
) | ||
# loads latest checkpoint, if it exists | ||
if latest_checkpoint_dir: | ||
tss.restore_from_latest(your_dirpath_here, unit, train_dataloader=dataloader) | ||
train( | ||
unit, | ||
dataloader, | ||
callbacks=[tss] | ||
) | ||
There is built-in support for saving and loading distributed models (DDP, FSDP). | ||
|
||
The state dict type to be used for checkpointing FSDP modules can be specified in the :class:`~torchtnt.utils.prepare_module.FSDPStrategy`'s state_dict_type argument like so: | ||
|
||
.. code-block:: python | ||
module = nn.Linear(input_dim, 1) | ||
fsdp_strategy = FSDPStrategy( | ||
# sets state dict type of FSDP module | ||
state_dict_type=STATE_DICT_TYPE.SHARDED_STATE_DICT | ||
) | ||
unit = MyAutoUnit(module=module, strategy=fsdp_strategy) | ||
tss = TorchSnapshotSaver( | ||
dirpath=your_dirpath_here, | ||
save_every_n_epochs=2, | ||
) | ||
train( | ||
unit, | ||
dataloader, | ||
# checkpointer callback will use state dict type specified in FSDPStrategy | ||
callbacks=[tss] | ||
) | ||
Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.set_state_dict_type>`_. | ||
|
||
.. code-block:: python | ||
module = nn.Linear(input_dim, 1) | ||
fsdp_strategy = FSDPStrategy() | ||
unit = MyAutoUnit(module=module, strategy=fsdp_strategy) | ||
FSDP.set_state_dict_type(unit.module, StateDictType.SHARDED_STATE_DICT) | ||
tss = TorchSnapshotSaver( | ||
dirpath=your_dirpath_here, | ||
save_every_n_epochs=2, | ||
) | ||
train( | ||
unit, | ||
dataloader, | ||
callbacks=[tss] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters