Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the deprecated tuning property and enums #16379

Merged
merged 4 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))

- Tuner removal
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))

## [unreleased] - 202Y-MM-DD

Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/callbacks/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ def __init__(
self._duration = duration.total_seconds() if duration is not None else None
self._interval = interval
self._verbose = verbose
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
self._offset = 0

def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
Expand Down Expand Up @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -
self._check_time_remaining(trainer)

def state_dict(self) -> Dict[str, Any]:
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
time_elapsed = state_dict.get("time_elapsed", {})
Expand Down
58 changes: 3 additions & 55 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from enum import Enum, EnumMeta
from typing import Any, List, Optional
from typing import Optional

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class _DeprecationManagingEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.
Adapted from: https://stackoverflow.com/a/62309159/208880
"""

def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj

def __getitem__(cls, name: str) -> Any:
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
member.deprecate()
return member

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
obj = super().__call__(*args, **kwargs)
if isinstance(obj, Enum):
obj.deprecate()
return obj


class TrainerStatus(LightningEnum):
Expand All @@ -58,7 +31,7 @@ def stopped(self) -> bool:
return self in (self.FINISHED, self.INTERRUPTED)


class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class TrainerFn(LightningEnum):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
Expand All @@ -69,21 +42,9 @@ class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)

@classmethod
def _without_tune(cls) -> List["TrainerFn"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
class RunningStage(LightningEnum):
"""Enum for the current running stage.
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
Expand All @@ -93,15 +54,13 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
- ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING``
- ``TrainerFn.TESTING`` - ``RunningStage.TESTING``
- ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING``
- ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}``
"""

TRAINING = "train"
SANITY_CHECKING = "sanity_check"
VALIDATING = "validate"
TESTING = "test"
PREDICTING = "predict"
TUNING = "tune"

@property
def evaluating(self) -> bool:
Expand All @@ -115,17 +74,6 @@ def dataloader_prefix(self) -> Optional[str]:
return "val"
return self.value

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0."
)

@classmethod
def _without_tune(cls) -> List["RunningStage"]:
fns = [fn for fn in cls if fn != "tune"]
return fns


@dataclass
class TrainerState:
Expand Down
16 changes: 1 addition & 15 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.seed import isolate_rng
from pytorch_lightning.utilities.types import (
_EVALUATE_OUTPUT,
Expand Down Expand Up @@ -1891,20 +1891,6 @@ def predicting(self, val: bool) -> None:
elif self.predicting:
self.state.stage = None

@property
def tuning(self) -> bool:
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")
return self.state.stage == RunningStage.TUNING

@tuning.setter
def tuning(self, val: bool) -> None:
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.")

if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
self.state.stage = None

@property
def validating(self) -> bool:
return self.state.stage == RunningStage.VALIDATING
Expand Down
26 changes: 0 additions & 26 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import DataLoader

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
Expand All @@ -30,7 +29,6 @@
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
Expand Down Expand Up @@ -259,30 +257,6 @@ def test_v1_10_deprecated_accelerator_setup_environment_method():
CPUAccelerator().setup_environment(torch.device("cpu"))


def test_tuning_enum():
with pytest.deprecated_call(
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
TrainerFn.TUNING

with pytest.deprecated_call(
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
RunningStage.TUNING


def test_tuning_trainer_property():
trainer = Trainer()

with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."):
trainer.tuning

with pytest.deprecated_call(
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."
):
trainer.tuning = True


def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def test_loops_restore(tmpdir):
trainer = Trainer(**trainer_args)
trainer.strategy.connect(model)

trainer_fns = [fn for fn in TrainerFn._without_tune()]

trainer_fns = list(TrainerFn)
for fn in trainer_fns:
trainer_fn = getattr(trainer, f"{fn}_loop")
trainer_fn.load_state_dict = mock.Mock()
Expand Down