Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixed memory leak from opt return #1528

Merged
merged 19 commits into from
Apr 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pl_examples/basic_examples/cpu_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def main(hparams):
# ------------------------
# 2 INIT TRAINER
# ------------------------
trainer = pl.Trainer(max_epochs=hparams.epochs)
trainer = pl.Trainer(max_epochs=hparams.epochs, overfit_pct=0.01, early_stop_callback=True)

# ------------------------
# 3 START TRAINING
Expand Down
29 changes: 20 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
"""

import numpy as np
import torch

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn

torch_inf = torch.tensor(np.Inf)


class EarlyStopping(Callback):
r"""
Expand Down Expand Up @@ -43,7 +46,7 @@ class EarlyStopping(Callback):
>>> trainer = Trainer(early_stop_callback=early_stopping)
"""

def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0,
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
verbose: bool = False, mode: str = 'auto', strict: bool = True):
super().__init__()

Expand All @@ -56,9 +59,9 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.stopped_epoch = 0

mode_dict = {
'min': np.less,
'max': np.greater,
'auto': np.greater if 'acc' in self.monitor else np.less
'min': torch.lt,
'max': torch.gt,
'auto': torch.gt if 'acc' in self.monitor else torch.lt
}

if mode not in mode_dict:
Expand All @@ -67,9 +70,14 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
mode = 'auto'

self.monitor_op = mode_dict[mode]
self.min_delta *= 1 if self.monitor_op == np.greater else -1

def check_metrics(self, logs):
self.min_delta *= 1 if self.monitor_op == torch.gt else -1

def _validate_condition_metric(self, logs):
"""
Checks that the condition metric for early stopping is good
:param logs:
:return:
"""
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Available metrics are:'
Expand All @@ -89,15 +97,18 @@ def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf

def on_epoch_end(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self.check_metrics(logs):
if not self._validate_condition_metric(logs):
return stop_training

current = logs.get(self.monitor)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)

if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
import torch


class ModelCheckpoint(Callback):
Expand Down Expand Up @@ -106,15 +107,17 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal
self.best = 0
self.save_function = None

torch_inf = torch.tensor(np.Inf)
mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (np.less, np.Inf, 'min'),
'min': (torch.lt, torch_inf, 'min'),
'max': (torch.gt, -torch_inf, 'max'),
'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (torch.lt, torch_inf, 'min'),
}

if mode not in mode_dict:
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, fallback to auto mode.', RuntimeWarning)
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, '
f'fallback to auto mode.', RuntimeWarning)
mode = 'auto'

self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
Expand All @@ -136,6 +139,10 @@ def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True

if not isinstance(current, torch.Tensor):
current = torch.tensor(current)

return self.monitor_op(current, self.best_k_models[self.kth_best_model])

def format_checkpoint_name(self, epoch, metrics, ver=None):
Expand Down Expand Up @@ -203,7 +210,9 @@ def on_validation_end(self, trainer, pl_module):
current = metrics.get(self.monitor)

if current is None:
rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning)
rank_zero_warn(
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
elif self.verbose > 0:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,8 @@ def run_evaluation(self, test_mode: bool = False):
# Validation/Test end callbacks
if test_mode:
self.on_test_end()
else:
self.on_validation_end()

def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
# make dataloader_idx arg in validation_step optional
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.utilities import memory_utils


class TrainerLoggingMixin(ABC):
Expand Down Expand Up @@ -173,10 +174,9 @@ def process_output(self, output, train=False):
callback_metrics.update(progress_bar_metrics)
callback_metrics.update(log_metrics)

# convert tensors to numpy
for k, v in callback_metrics.items():
if isinstance(v, torch.Tensor):
callback_metrics[k] = v.item()
# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = memory_utils.recursive_detach(callback_metrics)

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,8 +891,9 @@ def run_pretrain_routine(self, model: LightningModule):
self.main_progress_bar.close()
self.val_progress_bar.close()

# verify that early stop has conditioned on a metric that exists
if self.enable_early_stop:
self.early_stop_callback.check_metrics(callback_metrics)
self.early_stop_callback._validate_condition_metric(callback_metrics)

# init progress bar
pbar = tqdm(leave=True, position=2 * self.process_position,
Expand Down
66 changes: 19 additions & 47 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import memory_utils

try:
from apex import amp
Expand Down Expand Up @@ -444,8 +445,11 @@ def run_training_epoch(self):
# ---------------
_outputs = self.run_training_batch(batch, batch_idx)
batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs
# detach tensors in batch_output before appending to outputs
outputs.append(_recursive_detach(batch_output))

# only track outputs when user implementes training_epoch_end
# otherwise we will build up unecessary memory
if self.is_overriden('training_epoch_end', model=self.get_model()):
outputs.append(batch_output)

# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1
Expand All @@ -463,9 +467,14 @@ def run_training_epoch(self):
should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf'))
should_check_val = can_check_val and should_check_val

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
self.run_evaluation(test_mode=self.testing)
self.call_checkpoint_callback()
self.call_early_stop_callback()

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
Expand All @@ -479,16 +488,6 @@ def run_training_epoch(self):
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)

# ---------------
# CHECKPOINTING, EARLY STOPPING
# ---------------
# save checkpoint even when no test or val step are defined
if self.fast_dev_run or should_check_val:
self.call_checkpoint_callback()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)

# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
self.global_step += 1
Expand All @@ -505,9 +504,7 @@ def run_training_epoch(self):
break

# process epoch outputs
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module

model = self.get_model()
if self.is_overriden('training_epoch_end', model=model):
epoch_output = model.training_epoch_end(outputs)
_processed_outputs = self.process_output(epoch_output)
Expand All @@ -516,12 +513,10 @@ def run_training_epoch(self):
self.log_metrics(log_epoch_metrics, {})
self.callback_metrics.update(callback_epoch_metrics)

# in case validation step is missing and you are not running fast-dev to duplicate last batch
# when no val loop is present or fast-dev-run still need to call checkpoints
if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val):
self.call_checkpoint_callback()

if self.enable_early_stop:
self.early_stop_callback.check_metrics(self.callback_metrics)
self.call_early_stop_callback()

# Epoch end events
with self.profiler.profile('on_epoch_end'):
Expand Down Expand Up @@ -608,7 +603,7 @@ def optimizer_closure():
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()

return closure_loss, output_dict
return closure_loss, callback_metrics

# calculate loss
loss, batch_output = optimizer_closure()
Expand Down Expand Up @@ -800,7 +795,10 @@ def update_learning_rates(self, interval: str):
def call_checkpoint_callback(self):
if self.checkpoint_callback is not None:
self.checkpoint_callback.on_validation_end(self, self.get_model())
self.on_validation_end()

def call_early_stop_callback(self):
if self.early_stop_callback:
self.early_stop_callback.on_epoch_end(self, self.get_model())


def _with_is_last(iterable):
Expand All @@ -814,29 +812,3 @@ def _with_is_last(iterable):
last = val
# yield last, no longer has next
yield last, True


def _recursive_detach(in_dict):
"""Detach all tensors in `in_dict`.

May operate recursively if some of the values in `in_dict` are dictionaries
which contain instances of `torch.Tensor`. Other types in `in_dict` are
not affected by this utility function.

Parameters
----------
in_dict : dict

Returns
-------
out_dict : dict
"""
out_dict = {}
for k, v in in_dict.items():
if isinstance(v, dict):
out_dict.update({k: _recursive_detach(v)})
elif callable(getattr(v, 'detach', None)):
out_dict.update({k: v.detach()})
else:
out_dict.update({k: v})
return out_dict
24 changes: 24 additions & 0 deletions pytorch_lightning/utilities/memory_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
def recursive_detach(in_dict):
"""Detach all tensors in `in_dict`.

May operate recursively if some of the values in `in_dict` are dictionaries
which contain instances of `torch.Tensor`. Other types in `in_dict` are
not affected by this utility function.

Parameters
----------
in_dict : dict

Returns
-------
out_dict : dict
"""
out_dict = {}
for k, v in in_dict.items():
if isinstance(v, dict):
out_dict.update({k: recursive_detach(v)})
elif callable(getattr(v, 'detach', None)):
out_dict.update({k: v.detach()})
else:
out_dict.update({k: v})
return out_dict
2 changes: 1 addition & 1 deletion tests/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def assert_speed_parity(pl_times, pt_times, num_epochs):

# assert speeds
max_diff_per_epoch = 0.9
max_diff_per_epoch = 0.65
pl_times = np.asarray(pl_times)
pt_times = np.asarray(pt_times)
diffs = pl_times - pt_times
Expand Down