diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f7f0616765e27..5bfc63c017769 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192)) +- Tuner removal + * Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379)) + * Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379)) ### Fixed diff --git a/src/pytorch_lightning/callbacks/timer.py b/src/pytorch_lightning/callbacks/timer.py index 75763ae3ac868..f230bcb6fdd22 100644 --- a/src/pytorch_lightning/callbacks/timer.py +++ b/src/pytorch_lightning/callbacks/timer.py @@ -95,8 +95,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) def state_dict(self) -> Dict[str, Any]: - return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}} + return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage}} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index c7fa12715f119..288024046b944 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -12,37 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field -from enum import Enum, EnumMeta -from typing import Any, List, Optional +from typing import Optional from pytorch_lightning.utilities import LightningEnum from pytorch_lightning.utilities.enums import _FaultTolerantMode -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation - - -class _DeprecationManagingEnumMeta(EnumMeta): - """Enum that calls `deprecate()` whenever a member is accessed. - - Adapted from: https://stackoverflow.com/a/62309159/208880 - """ - - def __getattribute__(cls, name: str) -> Any: - obj = super().__getattribute__(name) - # ignore __dunder__ names -- prevents potential recursion errors - if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum): - obj.deprecate() - return obj - - def __getitem__(cls, name: str) -> Any: - member: _DeprecationManagingEnumMeta = super().__getitem__(name) - member.deprecate() - return member - - def __call__(cls, *args: Any, **kwargs: Any) -> Any: - obj = super().__call__(*args, **kwargs) - if isinstance(obj, Enum): - obj.deprecate() - return obj class TrainerStatus(LightningEnum): @@ -58,7 +31,7 @@ def stopped(self) -> bool: return self in (self.FINISHED, self.INTERRUPTED) -class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta): +class TrainerFn(LightningEnum): """ Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer` such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and @@ -69,21 +42,9 @@ class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta): VALIDATING = "validate" TESTING = "test" PREDICTING = "predict" - TUNING = "tune" - - def deprecate(self) -> None: - if self == self.TUNING: - rank_zero_deprecation( - f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0." - ) - - @classmethod - def _without_tune(cls) -> List["TrainerFn"]: - fns = [fn for fn in cls if fn != "tune"] - return fns -class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): +class RunningStage(LightningEnum): """Enum for the current running stage. This stage complements :class:`TrainerFn` by specifying the current running stage for each function. @@ -93,7 +54,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): - ``TrainerFn.VALIDATING`` - ``RunningStage.VALIDATING`` - ``TrainerFn.TESTING`` - ``RunningStage.TESTING`` - ``TrainerFn.PREDICTING`` - ``RunningStage.PREDICTING`` - - ``TrainerFn.TUNING`` - ``RunningStage.{TUNING,SANITY_CHECKING,TRAINING,VALIDATING}`` """ TRAINING = "train" @@ -101,7 +61,6 @@ class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): VALIDATING = "validate" TESTING = "test" PREDICTING = "predict" - TUNING = "tune" @property def evaluating(self) -> bool: @@ -115,17 +74,6 @@ def dataloader_prefix(self) -> Optional[str]: return "val" return self.value - def deprecate(self) -> None: - if self == self.TUNING: - rank_zero_deprecation( - f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v2.0.0." - ) - - @classmethod - def _without_tune(cls) -> List["RunningStage"]: - fns = [fn for fn in cls if fn != "tune"] - return fns - @dataclass class TrainerState: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 38e735df2a5b9..1f06fbd11ecaf 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -96,7 +96,7 @@ from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.seed import isolate_rng from pytorch_lightning.utilities.types import ( _EVALUATE_OUTPUT, @@ -1891,20 +1891,6 @@ def predicting(self, val: bool) -> None: elif self.predicting: self.state.stage = None - @property - def tuning(self) -> bool: - rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.") - return self.state.stage == RunningStage.TUNING - - @tuning.setter - def tuning(self, val: bool) -> None: - rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0.") - - if val: - self.state.stage = RunningStage.TUNING - elif self.tuning: - self.state.stage = None - @property def validating(self) -> bool: return self.state.stage == RunningStage.VALIDATING diff --git a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py index b830d3c6be551..a70a55ff65d45 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_2-0.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_2-0.py @@ -20,7 +20,6 @@ from lightning_utilities.test.warning import no_warning_call from torch.utils.data import DataLoader -from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.core.mixins.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset @@ -30,7 +29,6 @@ from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.utils import on_colab_kaggle -from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.apply_func import ( apply_to_collection, apply_to_collections, @@ -271,30 +269,6 @@ def test_v1_10_deprecated_accelerator_setup_environment_method(): CPUAccelerator().setup_environment(torch.device("cpu")) -def test_tuning_enum(): - with pytest.deprecated_call( - match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0." - ): - TrainerFn.TUNING - - with pytest.deprecated_call( - match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v2.0.0." - ): - RunningStage.TUNING - - -def test_tuning_trainer_property(): - trainer = Trainer() - - with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0."): - trainer.tuning - - with pytest.deprecated_call( - match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v2.0.0." - ): - trainer.tuning = True - - def test_v1_8_1_deprecated_rank_zero_only(): from pytorch_lightning.utilities.distributed import rank_zero_only diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 0846b7a8a7d82..a5d23bce26719 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -150,8 +150,7 @@ def test_loops_restore(tmpdir): trainer = Trainer(**trainer_args) trainer.strategy.connect(model) - trainer_fns = [fn for fn in TrainerFn._without_tune()] - + trainer_fns = list(TrainerFn) for fn in trainer_fns: trainer_fn = getattr(trainer, f"{fn}_loop") trainer_fn.load_state_dict = mock.Mock()