diff --git a/tests/framework/callbacks/test_torchsnapshot_saver.py b/tests/framework/callbacks/test_torchsnapshot_saver.py index 5dfa14b93b..96354e253d 100644 --- a/tests/framework/callbacks/test_torchsnapshot_saver.py +++ b/tests/framework/callbacks/test_torchsnapshot_saver.py @@ -91,6 +91,32 @@ def test_save_every_n_train_epochs(self) -> None: os.path.exists(expected_path) and os.path.isdir(expected_path) ) + def test_save_on_train_end(self) -> None: + input_dim = 2 + dataset_len = 10 + batch_size = 2 + max_epochs = 3 + expected_steps_per_epoch = math.ceil(dataset_len / batch_size) + save_every_n_train_epochs = 2 + + my_unit = DummyTrainUnit(input_dim=input_dim) + dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size) + state = init_train_state(dataloader=dataloader, max_epochs=max_epochs) + with tempfile.TemporaryDirectory() as temp_dir: + expected_path = os.path.join( + temp_dir, + f"epoch_{max_epochs}_step_{expected_steps_per_epoch * (max_epochs)}", + ) + snapshot = TorchSnapshotSaver( + temp_dir, + save_every_n_epochs=save_every_n_train_epochs, + replicated=["**"], + ) + train(state, my_unit, callbacks=[snapshot]) + self.assertTrue( + os.path.exists(expected_path) and os.path.isdir(expected_path) + ) + def test_save_restore(self) -> None: input_dim = 2 dataset_len = 10 diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index 2c12e6ebf0..e7c7fe9e58 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -148,6 +148,16 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: self._async_snapshot(snapshot_path, app_state, wait=True) def on_train_end(self, state: State, unit: TTrainUnit) -> None: + app_state = _get_app_state(state, unit, self._replicated, intra_epoch=False) + + train_state = none_throws(state.train_state) + epoch = train_state.progress.num_epochs_completed + global_step = train_state.progress.num_steps_completed + + # save snapshot to predetermined path + # TODO: discuss whether this path should be customized + snapshot_path = _get_snapshot_save_path(self._dirpath, epoch, global_step) + self._async_snapshot(snapshot_path, app_state, wait=False) self._wait() def on_exception(