Skip to content

Commit

Permalink
Move BestCheckpointConfig to utils/checkpoint.py
Browse files Browse the repository at this point in the history
Differential Revision: D56455020
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Apr 26, 2024
1 parent d0a6b88 commit 2e96a32
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 33 deletions.
6 changes: 2 additions & 4 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@
from torchtnt.framework.callbacks.base_checkpointer import (
BaseCheckpointer as BaseCheckpointer,
)
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.lambda_callback import Lambda
from torchtnt.framework.fit import fit
from torchtnt.framework.state import State

from torchtnt.framework.train import train
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_distributed
Expand Down
6 changes: 2 additions & 4 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.framework.callback import Callback
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
RestoreOptions,
)
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.framework.utils import get_timing_context
Expand All @@ -28,6 +25,7 @@
_metadata_exists,
_sort_by_metric_value,
_sort_by_recency,
BestCheckpointConfig,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
Expand Down
16 changes: 1 addition & 15 deletions torchtnt/framework/callbacks/checkpointer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from dataclasses import dataclass
from typing import Literal, Optional
from typing import Optional


# TODO: eventually support overriding all knobs
Expand Down Expand Up @@ -39,17 +39,3 @@ class RestoreOptions:
restore_eval_progress: bool = True
restore_optimizers: bool = True
restore_lr_schedulers: bool = True


@dataclass
class BestCheckpointConfig:
"""
Config for saving the best checkpoints.
Args:
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
"""

monitored_metric: str
mode: Literal["min", "max"] = "min"
7 changes: 2 additions & 5 deletions torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand All @@ -37,6 +33,7 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import MultiStateful, Stateful
Expand Down
7 changes: 2 additions & 5 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@
)

from torchtnt.framework.callbacks.base_checkpointer import BaseCheckpointer
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
KnobOptions,
RestoreOptions,
)
from torchtnt.framework.callbacks.checkpointer_types import KnobOptions, RestoreOptions
from torchtnt.framework.state import State
from torchtnt.framework.unit import (
AppStateMixin,
Expand All @@ -36,6 +32,7 @@
TTrainUnit,
)
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.checkpoint import BestCheckpointConfig
from torchtnt.utils.optimizer import init_optim_state
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
from torchtnt.utils.stateful import Stateful
Expand Down
2 changes: 2 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

from .checkpoint import (
BestCheckpointConfig,
CheckpointPath,
get_best_checkpoint_path,
get_checkpoint_dirpaths,
Expand Down Expand Up @@ -160,4 +161,5 @@
"get_best_checkpoint_path",
"get_checkpoint_dirpaths",
"get_latest_checkpoint_path",
"BestCheckpointConfig",
]
14 changes: 14 additions & 0 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ class MetricData:
value: float


@dataclass
class BestCheckpointConfig:
"""
Config for saving the best checkpoints.
Args:
monitored_metric: Metric to monitor for saving best checkpoints. Must be an numerical or tensor attribute on the unit.
mode: One of `min` or `max`. The save file is overwritten based the max or min of the monitored metric.
"""

monitored_metric: str
mode: Literal["min", "max"] = "min"


@total_ordering
class CheckpointPath:
"""
Expand Down

0 comments on commit 2e96a32

Please sign in to comment.