diff --git a/tests/framework/callbacks/test_base_checkpointer.py b/tests/framework/callbacks/test_base_checkpointer.py index fe97de1d99..105d6052ea 100644 --- a/tests/framework/callbacks/test_base_checkpointer.py +++ b/tests/framework/callbacks/test_base_checkpointer.py @@ -31,14 +31,16 @@ from torchtnt.framework.callbacks.base_checkpointer import ( BaseCheckpointer as BaseCheckpointer, ) -from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + RestoreOptions, +) from torchtnt.framework.callbacks.lambda_callback import Lambda from torchtnt.framework.fit import fit from torchtnt.framework.state import State from torchtnt.framework.train import train from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData -from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.distributed import get_global_rank, spawn_multi_process from torchtnt.utils.env import init_from_env from torchtnt.utils.test_utils import skip_if_not_distributed diff --git a/torchtnt/framework/callbacks/base_checkpointer.py b/torchtnt/framework/callbacks/base_checkpointer.py index 3cf4cd1746..ded257c412 100644 --- a/torchtnt/framework/callbacks/base_checkpointer.py +++ b/torchtnt/framework/callbacks/base_checkpointer.py @@ -16,7 +16,10 @@ import torch.distributed as dist from pyre_extensions import none_throws from torchtnt.framework.callback import Callback -from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + RestoreOptions, +) from torchtnt.framework.state import EntryPoint, State from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit from torchtnt.framework.utils import get_timing_context @@ -25,7 +28,6 @@ _metadata_exists, _sort_by_metric_value, _sort_by_recency, - BestCheckpointConfig, get_best_checkpoint_path, get_checkpoint_dirpaths, get_latest_checkpoint_path, diff --git a/torchtnt/framework/callbacks/checkpointer_types.py b/torchtnt/framework/callbacks/checkpointer_types.py index d7ab2693cf..5ccdf7862d 100644 --- a/torchtnt/framework/callbacks/checkpointer_types.py +++ b/torchtnt/framework/callbacks/checkpointer_types.py @@ -7,7 +7,7 @@ # pyre-strict from dataclasses import dataclass -from typing import Optional +from typing import Literal, Optional # TODO: eventually support overriding all knobs @@ -39,3 +39,17 @@ class RestoreOptions: restore_eval_progress: bool = True restore_optimizers: bool = True restore_lr_schedulers: bool = True + + +@dataclass +class BestCheckpointConfig: + """ + Config for saving the best checkpoints. + + Args: + monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit. + mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric. + """ + + monitored_metric: str + mode: Literal["min", "max"] = "min" diff --git a/torchtnt/framework/callbacks/dcp_saver.py b/torchtnt/framework/callbacks/dcp_saver.py index 3b9f274d92..66cd699f50 100644 --- a/torchtnt/framework/callbacks/dcp_saver.py +++ b/torchtnt/framework/callbacks/dcp_saver.py @@ -22,7 +22,11 @@ ) from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer -from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + KnobOptions, + RestoreOptions, +) from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, @@ -32,7 +36,6 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import MultiStateful, Stateful diff --git a/torchtnt/framework/callbacks/torchsnapshot_saver.py b/torchtnt/framework/callbacks/torchsnapshot_saver.py index ab2a0515d6..b5138cd681 100644 --- a/torchtnt/framework/callbacks/torchsnapshot_saver.py +++ b/torchtnt/framework/callbacks/torchsnapshot_saver.py @@ -22,7 +22,11 @@ ) from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer -from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions +from torchtnt.framework.callbacks.checkpointer_types import ( + BestCheckpointConfig, + KnobOptions, + RestoreOptions, +) from torchtnt.framework.state import State from torchtnt.framework.unit import ( AppStateMixin, @@ -32,7 +36,6 @@ TTrainUnit, ) from torchtnt.framework.utils import get_timing_context -from torchtnt.utils.checkpoint import BestCheckpointConfig from torchtnt.utils.optimizer import init_optim_state from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn from torchtnt.utils.stateful import Stateful diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index eac6b113e4..07ed6cfc52 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -7,7 +7,6 @@ # pyre-strict from .checkpoint import ( - BestCheckpointConfig, CheckpointPath, get_best_checkpoint_path, get_checkpoint_dirpaths, @@ -92,7 +91,6 @@ "get_best_checkpoint_path", "get_checkpoint_dirpaths", "get_latest_checkpoint_path", - "BestCheckpointConfig", "copy_data_to_device", "CPUStats", "get_device_from_env", diff --git a/torchtnt/utils/checkpoint.py b/torchtnt/utils/checkpoint.py index 5438afc830..f8e1817ef5 100644 --- a/torchtnt/utils/checkpoint.py +++ b/torchtnt/utils/checkpoint.py @@ -31,20 +31,6 @@ class MetricData: value: float -@dataclass -class BestCheckpointConfig: - """ - Config for saving the best checkpoints. - - Args: - monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit. - mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric. - """ - - monitored_metric: str - mode: Literal["min", "max"] = "min" - - @total_ordering class CheckpointPath: """