diff --git a/CHANGELOG.md b/CHANGELOG.md index 811f711dcc95e3..3a8fffe59a8159 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -334,6 +334,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901)) +- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460)) + + - Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540)) @@ -411,6 +414,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924)) +- Deprecated setting `Trainer(max_steps=None)`. To turn off the limit, set `Trainer(max_steps=-1)` (default) ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460)) + + - Deprecated access to the `AcceleratorConnector.is_slurm_managing_tasks` attribute and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101)) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 1fe70d9d4e77cd..c4782db473cfe0 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE -from pytorch_lightning.loops.utilities import _get_active_optimizers, _update_dataloader_iter +from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -28,7 +28,7 @@ from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import WarningCache +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -41,13 +41,20 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): max_steps: The maximum number of steps (batches) to process """ - def __init__(self, min_steps: int, max_steps: int): + def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None: super().__init__() - self.min_steps: int = min_steps - - if max_steps and max_steps < -1: - raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.") - self.max_steps: int = max_steps + if max_steps is None: + rank_zero_deprecation( + "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." + " Use `max_steps = -1` instead." + ) + max_steps = -1 + elif max_steps < -1: + raise MisconfigurationException( + f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}." + ) + self.min_steps = min_steps + self.max_steps = max_steps self.global_step: int = 0 self.batch_progress = BatchProgress() @@ -79,7 +86,7 @@ def batch_idx(self) -> int: @property def _is_training_done(self) -> bool: - max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps + max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps) return max_steps_reached or self._num_ready_batches_reached() @property diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index b7004f9436a0f9..024ff36a6f9a4b 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -16,9 +16,11 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop +from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) @@ -29,15 +31,19 @@ class FitLoop(Loop): Args: min_epochs: The minimum number of epochs - max_epochs: The maximum number of epochs + max_epochs: The maximum number of epochs, can be set -1 to turn this limit off """ - def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None): + def __init__( + self, + min_epochs: Optional[int] = 1, + max_epochs: int = 1000, + ) -> None: super().__init__() - # Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done - if max_epochs and max_epochs < -1: + if max_epochs < -1: + # Allow max_epochs to be zero, since this will be handled by fit_loop.done raise MisconfigurationException( - f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}." + f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}." ) self.max_epochs = max_epochs @@ -102,8 +108,16 @@ def max_steps(self) -> int: def max_steps(self, value: int) -> None: """Sets the maximum number of steps (forwards to epoch_loop)""" # TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided - if value and value < -1: - raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.") + if value is None: + rank_zero_deprecation( + "Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7." + " Use `max_steps = -1` instead." + ) + value = -1 + elif value < -1: + raise MisconfigurationException( + f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}." + ) self.epoch_loop.max_steps = value @property @@ -141,8 +155,8 @@ def done(self) -> bool: is reached. """ # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop - stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps - stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs + stop_steps = _is_max_limit_reached(self.global_step, self.max_steps) + stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs) should_stop = False if self.trainer.should_stop: @@ -249,16 +263,3 @@ def teardown(self) -> None: def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() - - @staticmethod - def _is_max_limit_enabled(max_value: Optional[int]) -> bool: - """Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps - is enabled. - - Args: - max_value: the value to check - - Returns: - whether the limit for this value should be enabled - """ - return max_value not in (None, -1) diff --git a/pytorch_lightning/loops/utilities.py b/pytorch_lightning/loops/utilities.py index 84b8893b5c43ee..017945fc37749a 100644 --- a/pytorch_lightning/loops/utilities.py +++ b/pytorch_lightning/loops/utilities.py @@ -168,3 +168,16 @@ def _get_active_optimizers( # find optimizer index by looking for the first {item > current_place} in the cumsum list opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right") return [(opt_idx, optimizers[opt_idx])] + + +def _is_max_limit_reached(current: int, maximum: int = -1) -> bool: + """Check if the limit has been reached (if enabled). + + Args: + current: the current value + maximum: the maximum value (or -1 to disable limit) + + Returns: + bool: whether the limit has been reached + """ + return maximum != -1 and current >= maximum diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index fca2456e3e2625..921c2e0a7e160b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.loops.fit_loop import FitLoop +from pytorch_lightning.loops.utilities import _is_max_limit_reached from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem @@ -227,7 +227,7 @@ def restore_loops(self) -> None: # crash if max_epochs is lower then the current epoch from the checkpoint if ( - FitLoop._is_max_limit_enabled(self.trainer.max_epochs) + self.trainer.max_epochs != -1 and self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs ): @@ -358,7 +358,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: # dump epoch/global_step/pytorch-lightning_version current_epoch = self.trainer.current_epoch global_step = self.trainer.global_step - has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step + has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps) global_step += 1 if not has_reached_max_steps: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c4aaf630a29e38..52a6f04f545fc1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -145,7 +145,7 @@ def __init__( accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, - max_steps: Optional[int] = None, + max_steps: int = -1, min_steps: Optional[int] = None, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, limit_train_batches: Union[int, float] = 1.0, @@ -327,9 +327,9 @@ def __init__( min_epochs: Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``. - max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None`` - and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set - ``max_steps`` to ``-1``. + max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` + and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set + ``max_epochs`` to ``-1``. min_steps: Force training for at least these number of steps. Disabled by default (None). @@ -460,10 +460,11 @@ def __init__( self.signal_connector = SignalConnector(self) self.tuner = Tuner(self) - # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1). fit_loop = FitLoop( min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs), - max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs), + max_epochs=( + max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1) + ), ) training_epoch_loop = TrainingEpochLoop(min_steps, max_steps) training_batch_loop = TrainingBatchLoop() @@ -1332,7 +1333,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_ if not ckpt_path: raise MisconfigurationException( - f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please' + f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please" f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`" ) return ckpt_path @@ -1937,7 +1938,7 @@ def current_epoch(self) -> int: return self.fit_loop.current_epoch @property - def max_epochs(self) -> Optional[int]: + def max_epochs(self) -> int: return self.fit_loop.max_epochs @property @@ -1945,7 +1946,7 @@ def min_epochs(self) -> Optional[int]: return self.fit_loop.min_epochs @property - def max_steps(self) -> Optional[int]: + def max_steps(self) -> int: return self.fit_loop.max_steps @property diff --git a/tests/callbacks/test_timer.py b/tests/callbacks/test_timer.py index a307a72bdde3b5..a1a8af06429822 100644 --- a/tests/callbacks/test_timer.py +++ b/tests/callbacks/test_timer.py @@ -49,7 +49,7 @@ def on_fit_start(self): timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0] assert timer._duration == 1 assert trainer.max_epochs == -1 - assert trainer.max_steps is None + assert trainer.max_steps == -1 @pytest.mark.parametrize( diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 4b19e6b22b4225..0fd8a4e2935a6d 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -390,6 +390,15 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir): _ = XLAStatsMonitor() +def test_v1_7_0_deprecated_max_steps_none(tmpdir): + with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"): + _ = Trainer(max_steps=None) + + trainer = Trainer() + with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"): + trainer.fit_loop.max_steps = None + + def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir): with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"): trainer = Trainer(resume_from_checkpoint="a") diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index 836c80a49821ef..973d781953b84a 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -21,10 +21,12 @@ def test_passing_no_env_variables(): """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.logger is not None - assert trainer.max_steps is None + assert trainer.max_steps == -1 + assert trainer.max_epochs == 1000 trainer = Trainer(False, max_steps=42) assert trainer.logger is None assert trainer.max_steps == 42 + assert trainer.max_epochs == -1 @mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 9043f5eca8df09..b2d88becb1ec70 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -383,7 +383,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch): optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) max_epochs = 1 if complete_epoch else None - max_steps = None if complete_epoch else 1 + max_steps = -1 if complete_epoch else 1 trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps) model.configure_optimizers = lambda: { diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index b2a108da2a7792..a45bf105722cf2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -498,17 +498,17 @@ def test_trainer_max_steps_and_epochs(tmpdir): @pytest.mark.parametrize( - "max_epochs,max_steps,incorrect_variable,incorrect_value", + "max_epochs,max_steps,incorrect_variable", [ - (-100, None, "max_epochs", -100), - (1, -2, "max_steps", -2), + (-100, -1, "max_epochs"), + (1, -2, "max_steps"), ], ) -def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value): +def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable): """Don't allow max_epochs or max_steps to be less than -1 or a float.""" with pytest.raises( MisconfigurationException, - match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}", + match=f"`{incorrect_variable}` must be a non-negative integer or -1", ): Trainer(max_epochs=max_epochs, max_steps=max_steps) @@ -516,13 +516,12 @@ def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrec @pytest.mark.parametrize( "max_epochs,max_steps,is_done,correct_trainer_epochs", [ - (None, None, False, 1000), - (-1, None, False, -1), - (None, -1, False, None), + (None, -1, False, 1000), + (-1, -1, False, -1), (5, -1, False, 5), (-1, 10, False, -1), - (None, 0, True, None), - (0, None, True, 0), + (None, 0, True, -1), + (0, -1, True, 0), (-1, 0, True, -1), (0, -1, True, 0), ], diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index ace232665be124..3218464772fcac 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -127,7 +127,6 @@ def _raise(): # These parameters are marked as Optional[...] in Trainer.__init__, with None as default. # They should not be changed by the argparse interface. "min_steps": None, - "max_steps": None, "accelerator": None, "weights_save_path": None, "profiler": None, diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index ff772630306f20..7a861504547778 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -134,7 +134,6 @@ def _raise(): # with None as default. They should not be changed by the argparse # interface. min_steps=None, - max_steps=None, accelerator=None, weights_save_path=None, profiler=None,