diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 7805d0fea9..9ff78243ef 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -25,8 +25,6 @@ import torch.distributed as dist from pyre_extensions import none_throws -from torchsnapshot.knobs import override_max_per_rank_io_concurrency -from torchsnapshot.snapshot import PendingSnapshot, Snapshot, SNAPSHOT_METADATA_FNAME from torchtnt.framework.callback import Callback from torchtnt.framework.state import EntryPoint, State @@ -46,6 +44,12 @@ try: import torchsnapshot + from torchsnapshot.knobs import override_max_per_rank_io_concurrency + from torchsnapshot.snapshot import ( + PendingSnapshot, + Snapshot, + SNAPSHOT_METADATA_FNAME, + ) _TStateful = torchsnapshot.Stateful _TORCHSNAPSHOT_AVAILABLE = True