Skip to content

Commit

Permalink
update checkpoint docs to use unit instead of auto unit
Browse files Browse the repository at this point in the history
Reviewed By: ananthsub

Differential Revision: D48245705

fbshipit-source-id: e498780d3bdc49da5be97a764197f87741999047
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Aug 11, 2023
1 parent ca175c5 commit 24ca3b3
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions docs/source/checkpointing.rst
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
Checkpointing
================================

TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://github.com/pytorch/torchsnapshot>`_ under the hood.
TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot <https://pytorch.org/torchsnapshot/main/>`_ under the hood.

.. code-block:: python
module = nn.Linear(input_dim, 1)
unit = MyAutoUnit(module=module)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_train_steps=100,
Expand All @@ -32,7 +32,8 @@ The state dict type to be used for checkpointing FSDP modules can be specified i
# sets state dict type of FSDP module
state_dict_type=STATE_DICT_TYPE.SHARDED_STATE_DICT
)
unit = MyAutoUnit(module=module, strategy=fsdp_strategy)
module = prepare_fsdp(module, strategy=fsdp_strategy)
unit = MyUnit(module=module)
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
Expand All @@ -49,9 +50,9 @@ Or you can manually set this using `FSDP.set_state_dict_type <https://pytorch.or
.. code-block:: python
module = nn.Linear(input_dim, 1)
fsdp_strategy = FSDPStrategy()
unit = MyAutoUnit(module=module, strategy=fsdp_strategy)
FSDP.set_state_dict_type(unit.module, StateDictType.SHARDED_STATE_DICT)
module = FSDP(module, ....)
FSDP.set_state_dict_type(module, StateDictType.SHARDED_STATE_DICT)
unit = MyUnit(module=module, ...)
tss = TorchSnapshotSaver(
dirpath=your_dirpath_here,
save_every_n_epochs=2,
Expand Down

0 comments on commit 24ca3b3

Please sign in to comment.