Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nan detection and intervention #1097

Merged
merged 24 commits into from
Mar 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added type hints to `pytorch_lightning.core` ([#946](https://github.com/PyTorchLightning/pytorch-lightning/pull/946))
- Added support for IterableDataset in validation and testing ([#1104](https://github.com/PyTorchLightning/pytorch-lightning/pull/1104))
- Added support for non-primitive types in hparams for TensorboardLogger ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))

- Added a check that stops the training when loss or weights contain NaN or inf values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))

### Changed

-

### Deprecated

-
- Deprecated Trainer argument `print_nan_grads` ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))

### Removed

Expand Down
15 changes: 12 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
distributed_backend: Optional[str] = None,
use_amp=False, # backward compatible, todo: remove in v0.9.0
precision: int = 32,
print_nan_grads: bool = False,
print_nan_grads: bool = False, # backward compatible, todo: remove in v0.9.0
weights_summary: str = 'full',
weights_save_path: Optional[str] = None,
amp_level: str = 'O1',
Expand Down Expand Up @@ -208,7 +208,10 @@ def __init__(

precision: Full precision (32), half precision (16).

print_nan_grads: Prints gradients with nan values
print_nan_grads:
.. warning:: .. deprecated:: 0.7.2
Has no effect. When detected, NaN grads will be printed automatically.
Will remove 0.9.0.

weights_summary: Prints a summary of the weights when training begins.

Expand Down Expand Up @@ -296,7 +299,13 @@ def __init__(
"`num_sanity_val_steps` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
self.nb_sanity_val_steps = nb_sanity_val_steps
self.print_nan_grads = print_nan_grads

# Backward compatibility, TODO: remove in v0.9.0
if print_nan_grads:
warnings.warn("Argument `print_nan_grads` has no effect and will be removed in v0.9.0."
" NaN grads will be printed automatically when detected.",
DeprecationWarning)

self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.shown_warnings = set()
Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def training_step(self, batch, batch_idx):
trainer = Trainer(truncated_bptt_steps=2)


NaN detection and intervention
------------------------------
In every forward pass in training, Lightning will check that

1. the loss you return in `training_step` is finite (not NaN and not +/-inf)
2. the model parameters have finite values.

Lightning will terminate the training loop with an error message if NaN or infinite
values are detected. If this happens, you should investigate numerically unstable operations
in your model.

"""

import copy
Expand Down Expand Up @@ -187,7 +198,6 @@ class TrainerTrainLoopMixin(ABC):
optimizers: ...
accumulate_grad_batches: int
use_amp: bool
print_nan_grads: ...
track_grad_norm: ...
model: LightningModule
running_loss: ...
Expand All @@ -200,7 +210,7 @@ class TrainerTrainLoopMixin(ABC):
reload_dataloaders_every_epoch: bool
progress_bar_refresh_rate: ...
max_steps: int
max_steps: int
min_steps: int
total_batch_idx: int
checkpoint_callback: ...

Expand Down Expand Up @@ -239,7 +249,7 @@ def clip_gradients(self):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def print_nan_gradients(self):
def detect_nan_tensors(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
Expand Down Expand Up @@ -556,9 +566,8 @@ def optimizer_closure():
# calculate loss
loss = optimizer_closure()

# nan grads
if self.print_nan_grads:
self.print_nan_gradients()
# check if loss or model weights are nan
self.detect_nan_tensors(loss)

# track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item()
Expand Down
22 changes: 21 additions & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import math
import sys
from abc import ABC, abstractmethod

import torch
from torch import Tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
Expand All @@ -15,6 +17,7 @@ class TrainerTrainingTricksMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
gradient_clip_val: ...
precision: ...

@abstractmethod
def get_model(self):
Expand Down Expand Up @@ -45,12 +48,29 @@ def clip_gradients(self):
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

def print_nan_gradients(self):
def print_nan_gradients(self) -> None:
model = self.get_model()
for param in model.parameters():
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
log.info(param, param.grad)

def detect_nan_tensors(self, loss: Tensor) -> None:
model = self.get_model()

# check if loss is nan
if not torch.isfinite(loss).all():
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'The loss returned in `training_step` is nan or inf.'
)
# check if a network weight is nan
for name, param in model.named_parameters():
if not torch.isfinite(param).all():
self.print_nan_gradients()
raise ValueError(
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
Expand Down
64 changes: 60 additions & 4 deletions tests/test_cpu_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
import warnings

import pytest
import torch

import tests.models.utils as tutils
Expand All @@ -26,7 +28,6 @@ def test_early_stopping_cpu_model(tmpdir):
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
show_progress_bar=True,
logger=tutils.get_test_tube_logger(tmpdir),
train_percent_check=0.1,
Expand All @@ -48,7 +49,6 @@ def test_lbfgs_cpu_model(tmpdir):
trainer_options = dict(
default_save_path=tmpdir,
max_epochs=2,
print_nan_grads=True,
show_progress_bar=False,
weights_summary='top',
train_percent_check=1.0,
Expand All @@ -68,7 +68,6 @@ def test_default_logger_callbacks_cpu_model(tmpdir):
max_epochs=1,
gradient_clip_val=1.0,
overfit_pct=0.20,
print_nan_grads=True,
show_progress_bar=False,
train_percent_check=0.01,
val_percent_check=0.01,
Expand Down Expand Up @@ -251,7 +250,6 @@ def test_all_features_cpu_model(tmpdir):
gradient_clip_val=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
show_progress_bar=False,
logger=tutils.get_test_tube_logger(tmpdir),
accumulate_grad_batches=2,
Expand Down Expand Up @@ -359,5 +357,63 @@ def test_single_gpu_model(tmpdir):
tutils.run_model_test(trainer_options, model)


def test_nan_loss_detection(tmpdir):
test_step = 8

class InfLossModel(LightTrainDataloader, TestModelBase):

def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
if batch_idx == test_step:
if isinstance(output, dict):
output['loss'] *= torch.tensor(math.inf) # make loss infinite
else:
output /= 0
return output

hparams = tutils.get_hparams()
model = InfLossModel(hparams)

# fit model
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)

with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'):
trainer.fit(model)
assert trainer.global_step == test_step

for param in model.parameters():
assert torch.isfinite(param).all()


def test_nan_params_detection(tmpdir):
test_step = 8

class NanParamModel(LightTrainDataloader, TestModelBase):

def on_after_backward(self):
if self.global_step == test_step:
# simulate parameter that became nan
torch.nn.init.constant_(self.c_d1.bias, math.nan)

hparams = tutils.get_hparams()

model = NanParamModel(hparams)
trainer = Trainer(
default_save_path=tmpdir,
max_steps=(test_step + 1),
)

with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'):
trainer.fit(model)
assert trainer.global_step == test_step

# after aborting the training loop, model still has nan-valued params
params = torch.cat([param.view(-1) for param in model.parameters()])
assert not torch.isfinite(params).all()


# if __name__ == '__main__':
# pytest.main([__file__])