Skip to content

Commit

Permalink
nan detection and intervention (#1097)
Browse files Browse the repository at this point in the history
* check for nan values

* test nan detection on loss

* sys.exit

* whitespace

* detect nan and inf values in loss and params

* update

* added documentation

* moved detect nan to training loop, remove flag for print

* blank line

* test

* rename

* deprecate print_nan_grads

* deprecated print_nan_grads

* remove unused imports

* update changelog

* fix line too long

* correct deprecated version

Co-Authored-By: Jirka Borovec <[email protected]>

* raise exception instead of sysexit

Co-Authored-By: Jirka Borovec <[email protected]>

* raise exception instead of sysexit

Co-Authored-By: Jirka Borovec <[email protected]>

* Update pytorch_lightning/trainer/training_tricks.py

Co-Authored-By: Jirka Borovec <[email protected]>

* Update pytorch_lightning/trainer/training_tricks.py

Co-Authored-By: Jirka Borovec <[email protected]>

* fix test

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
Adrian Wälchli and Borda authored Mar 19, 2020
1 parent 36274be commit 732eaee
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 16 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))

- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))

### Changed

-

### Deprecated

-
- Deprecated Trainer argument `print_nan_grads` ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))

### Removed

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
distributed_backend: Optional[str] = None,
use_amp=False, # backward compatible, todo: remove in v0.9.0
precision: int = 32,
print_nan_grads: bool = False,
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: str = 'full',
weights_save_path: Optional[str] = None,
amp_level: str = 'O1',
Expand Down Expand Up @@ -208,7 +208,10 @@ def __init__(
precision: Full precision (32), half precision (16).
print_nan_grads: Prints gradients with nan values
print_nan_grads:
.. warning:: .. deprecated:: 0.7.2
Has no effect. When detected, NaN grads will be printed automatically.
Will remove 0.9.0.
weights_summary: Prints a summary of the weights when training begins.
Expand Down Expand Up @@ -296,7 +299,13 @@ def __init__(
"`num_sanity_val_steps` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
self.nb_sanity_val_steps = nb_sanity_val_steps
self.print_nan_grads = print_nan_grads

# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
warnings.warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
" NaN grads will be printed automatically when detected.",
DeprecationWarning)

self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.shown_warnings = set()
Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def training_step(self, batch, batch_idx):
trainer = Trainer(truncated_bptt_steps=2)
NaN detection and intervention
------------------------------
In every forward pass in training, Lightning will check that
1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
2. the model parameters have finite values.
Lightning will terminate the training loop with an error message if NaN or infinite
values are detected. If this happens, you should investigate numerically unstable operations
in your model.
"""

import copy
Expand Down Expand Up @@ -187,7 +198,6 @@ class TrainerTrainLoopMixin(ABC):
optimizers: ...
accumulate_grad_batches: int
use_amp: bool
print_nan_grads: ...
track_grad_norm: ...
model: LightningModule
running_loss: ...
Expand All @@ -200,7 +210,7 @@ class TrainerTrainLoopMixin(ABC):
reload_dataloaders_every_epoch: bool
progress_bar_refresh_rate: ...
max_steps: int
max_steps: int
min_steps: int
total_batch_idx: int
checkpoint_callback: ...

Expand Down Expand Up @@ -239,7 +249,7 @@ def clip_gradients(self):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def print_nan_gradients(self):
def detect_nan_tensors(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down Expand Up @@ -556,9 +566,8 @@ def optimizer_closure():
# calculate loss
loss = optimizer_closure()

# nan grads
if self.print_nan_grads:
self.print_nan_gradients()
# check if loss or model weights are nan
self.detect_nan_tensors(loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
Expand Down
22 changes: 21 additions & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
import sys
from abc import ABC, abstractmethod

import torch
from torch import Tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
Expand All @@ -15,6 +17,7 @@ class TrainerTrainingTricksMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
gradient_clip_val: ...
precision: ...

@abstractmethod
def get_model(self):
Expand Down Expand Up @@ -45,12 +48,29 @@ def clip_gradients(self):
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

def print_nan_gradients(self):
def print_nan_gradients(self) -> None:
model = self.get_model()
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)

def detect_nan_tensors(self, loss: Tensor) -> None:
model = self.get_model()

# check if loss is nan
if not torch.isfinite(loss).all():
raise ValueError(
'The loss returned in `training_step` is nan or inf.'
)
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
self.print_nan_gradients()
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
Expand Down
64 changes: 60 additions & 4 deletions tests/test_cpu_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import warnings

import pytest
import torch

import tests.models.utils as tutils
Expand All @@ -26,7 +28,6 @@ def test_early_stopping_cpu_model(tmpdir):
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
show_progress_bar=True,
logger=tutils.get_test_tube_logger(tmpdir),
train_percent_check=0.1,
Expand All @@ -48,7 +49,6 @@ def test_lbfgs_cpu_model(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
print_nan_grads=True,
show_progress_bar=False,
weights_summary='top',
train_percent_check=1.0,
Expand All @@ -68,7 +68,6 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
print_nan_grads=True,
show_progress_bar=False,
train_percent_check=0.01,
val_percent_check=0.01,
Expand Down Expand Up @@ -251,7 +250,6 @@ def test_all_features_cpu_model(tmpdir):
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
accumulate_grad_batches=2,
Expand Down Expand Up @@ -359,5 +357,63 @@ def test_single_gpu_model(tmpdir):
tutils.run_model_test(trainer_options, model)


def test_nan_loss_detection(tmpdir):
test_step = 8

class InfLossModel(LightTrainDataloader, TestModelBase):

def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
if batch_idx == test_step:
if isinstance(output, dict):
output['loss'] *= torch.tensor(math.inf) # make loss infinite
else:
output /= 0
return output

hparams = tutils.get_hparams()
model = InfLossModel(hparams)

# fit model
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)

with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
trainer.fit(model)
assert trainer.global_step == test_step

for param in model.parameters():
assert torch.isfinite(param).all()


def test_nan_params_detection(tmpdir):
test_step = 8

class NanParamModel(LightTrainDataloader, TestModelBase):

def on_after_backward(self):
if self.global_step == test_step:
# simulate parameter that became nan
torch.nn.init.constant_(self.c_d1.bias, math.nan)

hparams = tutils.get_hparams()

model = NanParamModel(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)

with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
trainer.fit(model)
assert trainer.global_step == test_step

# after aborting the training loop, model still has nan-valued params
params = torch.cat([param.view(-1) for param in model.parameters()])
assert not torch.isfinite(params).all()


# if __name__ == '__main__':
# pytest.main([__file__])

0 comments on commit 732eaee

Please sign in to comment.