From 2dd6b97e63d48b207982c8359b6981b0a2fdc99d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 12 Oct 2021 13:23:22 -0700 Subject: [PATCH] Mark `Trainer.terminate_on_nan` protected and deprecate public property (#9849) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 3 +++ .../loops/optimization/optimizer_loop.py | 4 +-- .../connectors/training_trick_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 27 ++++++++++++++----- tests/deprecated_api/test_remove_1-7.py | 7 +++++ 5 files changed, 33 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c7be481f20eb2..e8011d175d6a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175)) +- Deprecated `Trainer.terminate_on_nan` public attribute access ([#9849](https://github.com/PyTorchLightning/pytorch-lightning/pull/9849)) + + - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index 590160c645afc..f0ab8b915b29f 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -344,7 +344,7 @@ def backward_fn(loss: Tensor) -> None: self._backward(loss, optimizer, opt_idx) # check if model weights are nan - if self.trainer.terminate_on_nan: + if self.trainer._terminate_on_nan: detect_nan_parameters(self.trainer.lightning_module) return backward_fn @@ -460,7 +460,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches) - if self.trainer.terminate_on_nan: + if self.trainer._terminate_on_nan: check_finite_loss(result.closure_loss) if self.trainer.move_metrics_to_cpu: diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 5165056d95391..ffa11ef1985a8 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -52,7 +52,7 @@ def on_trainer_init( f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}." ) - self.trainer.terminate_on_nan = terminate_on_nan + self.trainer._terminate_on_nan = terminate_on_nan self.trainer.gradient_clip_val = gradient_clip_val self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower()) self.trainer.track_grad_norm = float(track_grad_norm) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f2fb6bff4db18..3c50cfc504621 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1990,13 +1990,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop if self.predicting: return self.predict_loop - @property - def train_loop(self) -> FitLoop: - rank_zero_deprecation( - "`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." - ) - return self.fit_loop - @property def _ckpt_path(self) -> Optional[str]: if self.state.fn == TrainerFn.VALIDATING: @@ -2046,3 +2039,23 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__ = state + + @property + def train_loop(self) -> FitLoop: + rank_zero_deprecation( + "`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6." + ) + return self.fit_loop + + @property + def terminate_on_nan(self) -> bool: + rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.") + return self._terminate_on_nan + + @terminate_on_nan.setter + def terminate_on_nan(self, val: bool) -> None: + rank_zero_deprecation( + f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7." + f" Please set `Trainer(detect_anomaly={val})` instead." + ) + self._terminate_on_nan = val # : 212 diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 51502ef0f195f..af6781f3b609e 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -131,6 +131,13 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan): assert trainer.terminate_on_nan is terminate_on_nan assert trainer._detect_anomaly is False + trainer = Trainer() + with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"): + _ = trainer.terminate_on_nan + + with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"): + trainer.terminate_on_nan = True + def test_v1_7_0_deprecated_on_task_dataloader(tmpdir): class CustomBoringModel(BoringModel):