From fd79aa7f4ac967a38b2577244ae9176eaca9b3c7 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Fri, 11 Aug 2023 10:50:28 -0700 Subject: [PATCH] update checkpoint docs to use unit instead of auto unit Reviewed By: ananthsub Differential Revision: D48245705 fbshipit-source-id: 891b65beb19d7e45d32dbd2933f06254d466ab6a --- docs/source/checkpointing.rst | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/docs/source/checkpointing.rst b/docs/source/checkpointing.rst index 6a7dc746d0..9d22f2d044 100644 --- a/docs/source/checkpointing.rst +++ b/docs/source/checkpointing.rst @@ -1,12 +1,12 @@ Checkpointing ================================ -TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot `_ under the hood. +TorchTNT offers checkpointing via the :class:`~torchtnt.framework.callbacks.TorchSnapshotSaver` which uses `TorchSnapshot `_ under the hood. .. code-block:: python module = nn.Linear(input_dim, 1) - unit = MyAutoUnit(module=module) + unit = MyUnit(module=module) tss = TorchSnapshotSaver( dirpath=your_dirpath_here, save_every_n_train_steps=100, @@ -32,7 +32,8 @@ The state dict type to be used for checkpointing FSDP modules can be specified i # sets state dict type of FSDP module state_dict_type=STATE_DICT_TYPE.SHARDED_STATE_DICT ) - unit = MyAutoUnit(module=module, strategy=fsdp_strategy) + module = prepare_fsdp(module, strategy=fsdp_strategy) + unit = MyUnit(module=module) tss = TorchSnapshotSaver( dirpath=your_dirpath_here, save_every_n_epochs=2, @@ -49,9 +50,9 @@ Or you can manually set this using `FSDP.set_state_dict_type