Skip to content

Commit

Permalink
Remove the deprecated tuning property and enums (#16379)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jan 19, 2023
1 parent 1cc52ab commit 5d79508
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 101 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))

- Tuner removal
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))

### Fixed

Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def __init__(
self._duration = duration.total_seconds() if duration is not None else None
self._interval = interval
self._verbose = verbose
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._offset = 0

def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
Expand Down Expand Up @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -
self._check_time_remaining(trainer)

def state_dict(self) -> Dict[str, Any]:
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
time_elapsed = state_dict.get("time_elapsed", {})
Expand Down
58 changes: 3 additions & 55 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from typing import Any, List, Optional
from typing import Optional

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class _DeprecationManagingEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.
Adapted from: https://stackoverflow.com/a/62309159/208880
"""

def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj

def __getitem__(cls, name: str) -> Any:
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
member.deprecate()
return member

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
obj = super().__call__(*args, **kwargs)
if isinstance(obj, Enum):
obj.deprecate()
return obj


class TrainerStatus(LightningEnum):
Expand All @@ -58,7 +31,7 @@ def stopped(self) -> bool:
return self in (self.FINISHED, self.INTERRUPTED)


class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class TrainerFn(LightningEnum):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
Expand All @@ -69,21 +42,9 @@ class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)

@classmethod
def _without_tune(cls) -> List["TrainerFn"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class RunningStage(LightningEnum):
"""Enum for the current running stage.
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
Expand All @@ -93,15 +54,13 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""

TRAINING = "train"
SANITY_CHECKING = "sanity_check"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"

@property
def evaluating(self) -> bool:
Expand All @@ -115,17 +74,6 @@ def dataloader_prefix(self) -> Optional[str]:
return "val"
return self.value

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)

@classmethod
def _without_tune(cls) -> List["RunningStage"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


@dataclass
class TrainerState:
Expand Down
16 changes: 1 addition & 15 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning.utilities.types import (
_EVALUATE_OUTPUT,
Expand Down Expand Up @@ -1891,20 +1891,6 @@ def predicting(self, val: bool) -> None:
elif self.predicting:
self.state.stage = None

@property
def tuning(self) -> bool:
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
return self.state.stage == RunningStage.TUNING

@tuning.setter
def tuning(self, val: bool) -> None:
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")

if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
self.state.stage = None

@property
def validating(self) -> bool:
return self.state.stage == RunningStage.VALIDATING
Expand Down
26 changes: 0 additions & 26 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
Expand All @@ -30,7 +29,6 @@
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
Expand Down Expand Up @@ -271,30 +269,6 @@ def test_v1_10_deprecated_accelerator_setup_environment_method():
CPUAccelerator().setup_environment(torch.device("cpu"))


def test_tuning_enum():
with pytest.deprecated_call(
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
TrainerFn.TUNING

with pytest.deprecated_call(
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
RunningStage.TUNING


def test_tuning_trainer_property():
trainer = Trainer()

with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."):
trainer.tuning

with pytest.deprecated_call(
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
trainer.tuning = True


def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def test_loops_restore(tmpdir):
trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)

trainer_fns = [fn for fn in TrainerFn._without_tune()]

trainer_fns = list(TrainerFn)
for fn in trainer_fns:
trainer_fn = getattr(trainer, f"{fn}_loop")
trainer_fn.load_state_dict = mock.Mock()
Expand Down

0 comments on commit 5d79508

Please sign in to comment.