From d9b5962d66ed5dbfc61c772bd6cb94ba512b7cf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 19 Mar 2020 14:24:45 +0100 Subject: [PATCH] nan detection and intervention (#1097) * check for nan values * test nan detection on loss * sys.exit * whitespace * detect nan and inf values in loss and params * update * added documentation * moved detect nan to training loop, remove flag for print * blank line * test * rename * deprecate print_nan_grads * deprecated print_nan_grads * remove unused imports * update changelog * fix line too long * correct deprecated version Co-Authored-By: Jirka Borovec * raise exception instead of sysexit Co-Authored-By: Jirka Borovec * raise exception instead of sysexit Co-Authored-By: Jirka Borovec * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec * fix test Co-authored-by: Jirka Borovec --- CHANGELOG.md | 4 +- pytorch_lightning/trainer/trainer.py | 15 ++++- pytorch_lightning/trainer/training_loop.py | 21 +++++-- pytorch_lightning/trainer/training_tricks.py | 22 ++++++- tests/test_cpu_models.py | 64 ++++++++++++++++++-- 5 files changed, 110 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8489152b4591b8..0af36b4cf5408e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946)) - Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104)) - Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - +- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) ### Changed @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated -- +- Deprecated Trainer argument `print_nan_grads` ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) ### Removed diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 602755f9665f6e..3f98bf21859297 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -109,7 +109,7 @@ def __init__( distributed_backend: Optional[str] = None, use_amp=False, # backward compatible, todo: remove in v0.9.0 precision: int = 32, - print_nan_grads: bool = False, + print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0 weights_summary: str = 'full', weights_save_path: Optional[str] = None, amp_level: str = 'O1', @@ -208,7 +208,10 @@ def __init__( precision: Full precision (32), half precision (16). - print_nan_grads: Prints gradients with nan values + print_nan_grads: + .. warning:: .. deprecated:: 0.7.2 + Has no effect. When detected, NaN grads will be printed automatically. + Will remove 0.9.0. weights_summary: Prints a summary of the weights when training begins. @@ -296,7 +299,13 @@ def __init__( "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.nb_sanity_val_steps = nb_sanity_val_steps - self.print_nan_grads = print_nan_grads + + # Backward compatibility, TODO: remove in v0.9.0 + if print_nan_grads: + warnings.warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0." + " NaN grads will be printed automatically when detected.", + DeprecationWarning) + self.truncated_bptt_steps = truncated_bptt_steps self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 972b00ea79ff69..328170c11732b0 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -119,6 +119,17 @@ def training_step(self, batch, batch_idx): trainer = Trainer(truncated_bptt_steps=2) +NaN detection and intervention +------------------------------ +In every forward pass in training, Lightning will check that + +1. the loss you return in `training_step` is finite (not NaN and not +/-inf) +2. the model parameters have finite values. + +Lightning will terminate the training loop with an error message if NaN or infinite +values are detected. If this happens, you should investigate numerically unstable operations +in your model. + """ import copy @@ -187,7 +198,6 @@ class TrainerTrainLoopMixin(ABC): optimizers: ... accumulate_grad_batches: int use_amp: bool - print_nan_grads: ... track_grad_norm: ... model: LightningModule running_loss: ... @@ -200,7 +210,7 @@ class TrainerTrainLoopMixin(ABC): reload_dataloaders_every_epoch: bool progress_bar_refresh_rate: ... max_steps: int - max_steps: int + min_steps: int total_batch_idx: int checkpoint_callback: ... @@ -239,7 +249,7 @@ def clip_gradients(self): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def print_nan_gradients(self): + def detect_nan_tensors(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -556,9 +566,8 @@ def optimizer_closure(): # calculate loss loss = optimizer_closure() - # nan grads - if self.print_nan_grads: - self.print_nan_gradients() + # check if loss or model weights are nan + self.detect_nan_tensors(loss) # track total loss for logging (avoid mem leaks) self.batch_loss_value += loss.item() diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 8722b4e99ce2a6..9dd43e193a2be0 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -1,7 +1,9 @@ import math +import sys from abc import ABC, abstractmethod import torch +from torch import Tensor from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import GradientAccumulationScheduler @@ -15,6 +17,7 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class gradient_clip_val: ... + precision: ... @abstractmethod def get_model(self): @@ -45,12 +48,29 @@ def clip_gradients(self): for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) - def print_nan_gradients(self): + def print_nan_gradients(self) -> None: model = self.get_model() for param in model.parameters(): if (param.grad is not None) and torch.isnan(param.grad.float()).any(): log.info(param, param.grad) + def detect_nan_tensors(self, loss: Tensor) -> None: + model = self.get_model() + + # check if loss is nan + if not torch.isfinite(loss).all(): + raise ValueError( + 'The loss returned in `training_step` is nan or inf.' + ) + # check if a network weight is nan + for name, param in model.named_parameters(): + if not torch.isfinite(param).all(): + self.print_nan_gradients() + raise ValueError( + f'Detected nan and/or inf values in `{name}`.' + ' Check your forward pass for numerically unstable operations.' + ) + def configure_accumulated_gradients(self, accumulate_grad_batches): if isinstance(accumulate_grad_batches, dict): self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) diff --git a/tests/test_cpu_models.py b/tests/test_cpu_models.py index 1d0b1d1320f562..38fc790430fd71 100644 --- a/tests/test_cpu_models.py +++ b/tests/test_cpu_models.py @@ -1,5 +1,7 @@ +import math import warnings +import pytest import torch import tests.models.utils as tutils @@ -26,7 +28,6 @@ def test_early_stopping_cpu_model(tmpdir): gradient_clip_val=1.0, overfit_pct=0.20, track_grad_norm=2, - print_nan_grads=True, show_progress_bar=True, logger=tutils.get_test_tube_logger(tmpdir), train_percent_check=0.1, @@ -48,7 +49,6 @@ def test_lbfgs_cpu_model(tmpdir): trainer_options = dict( default_save_path=tmpdir, max_epochs=2, - print_nan_grads=True, show_progress_bar=False, weights_summary='top', train_percent_check=1.0, @@ -68,7 +68,6 @@ def test_default_logger_callbacks_cpu_model(tmpdir): max_epochs=1, gradient_clip_val=1.0, overfit_pct=0.20, - print_nan_grads=True, show_progress_bar=False, train_percent_check=0.01, val_percent_check=0.01, @@ -251,7 +250,6 @@ def test_all_features_cpu_model(tmpdir): gradient_clip_val=1.0, overfit_pct=0.20, track_grad_norm=2, - print_nan_grads=True, show_progress_bar=False, logger=tutils.get_test_tube_logger(tmpdir), accumulate_grad_batches=2, @@ -359,5 +357,63 @@ def test_single_gpu_model(tmpdir): tutils.run_model_test(trainer_options, model) +def test_nan_loss_detection(tmpdir): + test_step = 8 + + class InfLossModel(LightTrainDataloader, TestModelBase): + + def training_step(self, batch, batch_idx): + output = super().training_step(batch, batch_idx) + if batch_idx == test_step: + if isinstance(output, dict): + output['loss'] *= torch.tensor(math.inf) # make loss infinite + else: + output /= 0 + return output + + hparams = tutils.get_hparams() + model = InfLossModel(hparams) + + # fit model + trainer = Trainer( + default_save_path=tmpdir, + max_steps=(test_step + 1), + ) + + with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): + trainer.fit(model) + assert trainer.global_step == test_step + + for param in model.parameters(): + assert torch.isfinite(param).all() + + +def test_nan_params_detection(tmpdir): + test_step = 8 + + class NanParamModel(LightTrainDataloader, TestModelBase): + + def on_after_backward(self): + if self.global_step == test_step: + # simulate parameter that became nan + torch.nn.init.constant_(self.c_d1.bias, math.nan) + + hparams = tutils.get_hparams() + + model = NanParamModel(hparams) + trainer = Trainer( + default_save_path=tmpdir, + max_steps=(test_step + 1), + ) + + with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'): + trainer.fit(model) + assert trainer.global_step == test_step + + # after aborting the training loop, model still has nan-valued params + params = torch.cat([param.view(-1) for param in model.parameters()]) + assert not torch.isfinite(params).all() + + # if __name__ == '__main__': # pytest.main([__file__])