Skip to content

Commit

Permalink
Mark Trainer.terminate_on_nan protected and deprecate public proper…
Browse files Browse the repository at this point in the history
…ty (Lightning-AI#9849)

Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
2 people authored and rohitgr7 committed Oct 18, 2021
1 parent 7a1e967 commit 2dd6b97
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))


- Deprecated `Trainer.terminate_on_nan` public attribute access ([#9849](https://github.com/PyTorchLightning/pytorch-lightning/pull/9849))


- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`


Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def backward_fn(loss: Tensor) -> None:
self._backward(loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer.terminate_on_nan:
if self.trainer._terminate_on_nan:
detect_nan_parameters(self.trainer.lightning_module)

return backward_fn
Expand Down Expand Up @@ -460,7 +460,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

if self.trainer.terminate_on_nan:
if self.trainer._terminate_on_nan:
check_finite_loss(result.closure_loss)

if self.trainer.move_metrics_to_cpu:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def on_trainer_init(
f"`track_grad_norm` should be an int, a float or 'inf' (infinity norm). Got {track_grad_norm}."
)

self.trainer.terminate_on_nan = terminate_on_nan
self.trainer._terminate_on_nan = terminate_on_nan
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
self.trainer.track_grad_norm = float(track_grad_norm)
27 changes: 20 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1990,13 +1990,6 @@ def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop, PredictionLoop
if self.predicting:
return self.predict_loop

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop

@property
def _ckpt_path(self) -> Optional[str]:
if self.state.fn == TrainerFn.VALIDATING:
Expand Down Expand Up @@ -2046,3 +2039,23 @@ def __getstate__(self):

def __setstate__(self, state):
self.__dict__ = state

@property
def train_loop(self) -> FitLoop:
rank_zero_deprecation(
"`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6."
)
return self.fit_loop

@property
def terminate_on_nan(self) -> bool:
rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.")
return self._terminate_on_nan

@terminate_on_nan.setter
def terminate_on_nan(self, val: bool) -> None:
rank_zero_deprecation(
f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7."
f" Please set `Trainer(detect_anomaly={val})` instead."
)
self._terminate_on_nan = val # : 212
7 changes: 7 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def test_v1_7_0_trainer_terminate_on_nan(tmpdir, terminate_on_nan):
assert trainer.terminate_on_nan is terminate_on_nan
assert trainer._detect_anomaly is False

trainer = Trainer()
with pytest.deprecated_call(match=r"`Trainer.terminate_on_nan` is deprecated in v1.5"):
_ = trainer.terminate_on_nan

with pytest.deprecated_call(match=r"Setting `Trainer.terminate_on_nan = True` is deprecated in v1.5"):
trainer.terminate_on_nan = True


def test_v1_7_0_deprecated_on_task_dataloader(tmpdir):
class CustomBoringModel(BoringModel):
Expand Down

0 comments on commit 2dd6b97

Please sign in to comment.