From 29b3cb6ca15f0e9d5e7284ac7f9418564527af83 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:13:06 -0400 Subject: [PATCH 01/19] fixed memory leak from opt return --- pl_examples/basic_examples/cpu_template.py | 2 +- pytorch_lightning/callbacks/early_stopping.py | 9 ++- pytorch_lightning/trainer/logging.py | 8 +-- pytorch_lightning/trainer/trainer.py | 3 +- pytorch_lightning/trainer/training_loop.py | 57 +++++-------------- 5 files changed, 28 insertions(+), 51 deletions(-) diff --git a/pl_examples/basic_examples/cpu_template.py b/pl_examples/basic_examples/cpu_template.py index 1ab195d515379..5929a07be5727 100644 --- a/pl_examples/basic_examples/cpu_template.py +++ b/pl_examples/basic_examples/cpu_template.py @@ -28,7 +28,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer(max_epochs=hparams.epochs) + trainer = pl.Trainer(max_epochs=hparams.epochs, overfit_pct=0.01, early_stop_callback=True) # ------------------------ # 3 START TRAINING diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 70e90a7046595..ea2656b073ba8 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -69,7 +69,12 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.monitor_op = mode_dict[mode] self.min_delta *= 1 if self.monitor_op == np.greater else -1 - def check_metrics(self, logs): + def _validate_condition_metric(self, logs): + """ + Checks that the condition metric for early stopping is good + :param logs: + :return: + """ monitor_val = logs.get(self.monitor) error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' f' which is not available. Available metrics are:' @@ -94,7 +99,7 @@ def on_train_start(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module): logs = trainer.callback_metrics stop_training = False - if not self.check_metrics(logs): + if not self._validate_condition_metric(logs): return stop_training current = logs.get(self.monitor) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 22d2d42e724f4..dbe05e4aa2818 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -5,6 +5,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection +from pytorch_lightning.utilities import memory_utils class TrainerLoggingMixin(ABC): @@ -173,10 +174,9 @@ def process_output(self, output, train=False): callback_metrics.update(progress_bar_metrics) callback_metrics.update(log_metrics) - # convert tensors to numpy - for k, v in callback_metrics.items(): - if isinstance(v, torch.Tensor): - callback_metrics[k] = v.item() + # detach all metrics for callbacks to prevent memory leaks + # no .item() because it will slow things down + callback_metrics = memory_utils.recursive_detach(callback_metrics) return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e9bb209c7879..1964af8168e2d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -891,8 +891,9 @@ def run_pretrain_routine(self, model: LightningModule): self.main_progress_bar.close() self.val_progress_bar.close() + # verify that early stop has conditioned on a metric that exists if self.enable_early_stop: - self.early_stop_callback.check_metrics(callback_metrics) + self.early_stop_callback._validate_condition_metric(callback_metrics) # init progress bar pbar = tqdm(leave=True, position=2 * self.process_position, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5a8442204930c..5df8744423ba2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -157,6 +157,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import memory_utils try: from apex import amp @@ -445,7 +446,7 @@ def run_training_epoch(self): _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs # detach tensors in batch_output before appending to outputs - outputs.append(_recursive_detach(batch_output)) + outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 @@ -463,9 +464,14 @@ def run_training_epoch(self): should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val + # --------------- + # CHECKPOINTING, EARLY STOPPING + # --------------- # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) + self.call_checkpoint_callback() + self.call_early_stop_callback() # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch @@ -479,16 +485,6 @@ def run_training_epoch(self): # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) - # --------------- - # CHECKPOINTING, EARLY STOPPING - # --------------- - # save checkpoint even when no test or val step are defined - if self.fast_dev_run or should_check_val: - self.call_checkpoint_callback() - - if self.enable_early_stop: - self.early_stop_callback.check_metrics(self.callback_metrics) - # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 @@ -516,12 +512,10 @@ def run_training_epoch(self): self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) - # in case validation step is missing and you are not running fast-dev to duplicate last batch + # when no val loop is present or fast-dev-run still need to call checkpoints if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val): self.call_checkpoint_callback() - - if self.enable_early_stop: - self.early_stop_callback.check_metrics(self.callback_metrics) + self.call_early_stop_callback() # Epoch end events with self.profiler.profile('on_epoch_end'): @@ -608,7 +602,7 @@ def optimizer_closure(): with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() - return closure_loss, output_dict + return closure_loss, callback_metrics # calculate loss loss, batch_output = optimizer_closure() @@ -800,7 +794,10 @@ def update_learning_rates(self, interval: str): def call_checkpoint_callback(self): if self.checkpoint_callback is not None: self.checkpoint_callback.on_validation_end(self, self.get_model()) - self.on_validation_end() + + def call_early_stop_callback(self): + if self.early_stop_callback: + self.early_stop_callback.on_epoch_end(self, self.get_model()) def _with_is_last(iterable): @@ -814,29 +811,3 @@ def _with_is_last(iterable): last = val # yield last, no longer has next yield last, True - - -def _recursive_detach(in_dict): - """Detach all tensors in `in_dict`. - - May operate recursively if some of the values in `in_dict` are dictionaries - which contain instances of `torch.Tensor`. Other types in `in_dict` are - not affected by this utility function. - - Parameters - ---------- - in_dict : dict - - Returns - ------- - out_dict : dict - """ - out_dict = {} - for k, v in in_dict.items(): - if isinstance(v, dict): - out_dict.update({k: _recursive_detach(v)}) - elif callable(getattr(v, 'detach', None)): - out_dict.update({k: v.detach()}) - else: - out_dict.update({k: v}) - return out_dict From 8a4d381e3de15082255b1f79cf7a5a2d6277f01a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:15:00 -0400 Subject: [PATCH 02/19] fixed memory leak from opt return --- pytorch_lightning/utilities/memory_utils.py | 24 +++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 pytorch_lightning/utilities/memory_utils.py diff --git a/pytorch_lightning/utilities/memory_utils.py b/pytorch_lightning/utilities/memory_utils.py new file mode 100644 index 0000000000000..154a2f990fd3d --- /dev/null +++ b/pytorch_lightning/utilities/memory_utils.py @@ -0,0 +1,24 @@ +def recursive_detach(in_dict): + """Detach all tensors in `in_dict`. + + May operate recursively if some of the values in `in_dict` are dictionaries + which contain instances of `torch.Tensor`. Other types in `in_dict` are + not affected by this utility function. + + Parameters + ---------- + in_dict : dict + + Returns + ------- + out_dict : dict + """ + out_dict = {} + for k, v in in_dict.items(): + if isinstance(v, dict): + out_dict.update({k: recursive_detach(v)}) + elif callable(getattr(v, 'detach', None)): + out_dict.update({k: v.detach()}) + else: + out_dict.update({k: v}) + return out_dict From 35226079e9972b9ce81e680e57595e0e4eb18d9a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:19:30 -0400 Subject: [PATCH 03/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a5ac54ac018ef..cb54814843ccf 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -136,7 +136,11 @@ def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True - return self.monitor_op(current, self.best_k_models[self.kth_best_model]) + try: + d = self.monitor_op(current, self.best_k_models[self.kth_best_model]) + except Exception as e: + print('a') + return d def format_checkpoint_name(self, epoch, metrics, ver=None): """Generate a filename according to the defined template. From e24a1978d39e88802d4b1c0070032e1a31c2be4b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:21:44 -0400 Subject: [PATCH 04/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cb54814843ccf..97a7d25c5b530 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -139,6 +139,7 @@ def check_monitor_top_k(self, current): try: d = self.monitor_op(current, self.best_k_models[self.kth_best_model]) except Exception as e: + import pdb; pdb.set_trace() print('a') return d From c9760bf4ba234084382d76b5d1f0592a1969ed61 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:29:31 -0400 Subject: [PATCH 05/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 97a7d25c5b530..8059bb1b5e563 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -14,6 +14,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +import torch class ModelCheckpoint(Callback): @@ -106,11 +107,12 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal self.best = 0 self.save_function = None + torch_inf = torch.tensor(np.Inf) mode_dict = { - 'min': (np.less, np.Inf, 'min'), - 'max': (np.greater, -np.Inf, 'max'), - 'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') - else (np.less, np.Inf, 'min'), + 'min': (torch.lt, torch_inf, 'min'), + 'max': (torch.gt, -torch_inf, 'max'), + 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + else (torch.lt, torch_inf, 'min'), } if mode not in mode_dict: @@ -136,11 +138,13 @@ def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True + try: + import pdb; pdb.set_trace() d = self.monitor_op(current, self.best_k_models[self.kth_best_model]) except Exception as e: - import pdb; pdb.set_trace() - print('a') + + print('s') return d def format_checkpoint_name(self, epoch, metrics, ver=None): From fd72ee00dda3e5071804b276c52cfbb11180076f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:32:52 -0400 Subject: [PATCH 06/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8059bb1b5e563..4914d460202d5 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -139,13 +139,7 @@ def check_monitor_top_k(self, current): if less_than_k_models: return True - try: - import pdb; pdb.set_trace() - d = self.monitor_op(current, self.best_k_models[self.kth_best_model]) - except Exception as e: - - print('s') - return d + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): """Generate a filename according to the defined template. From 30a267567d960e8f40c084b2bba63ceb1ff1592e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:33:57 -0400 Subject: [PATCH 07/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ea2656b073ba8..9951833271382 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -7,6 +7,7 @@ """ import numpy as np +import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback @@ -55,10 +56,12 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.wait = 0 self.stopped_epoch = 0 + torch_inf = torch.tensor(np.Inf) mode_dict = { - 'min': np.less, - 'max': np.greater, - 'auto': np.greater if 'acc' in self.monitor else np.less + 'min': (torch.lt, torch_inf, 'min'), + 'max': (torch.gt, -torch_inf, 'max'), + 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + else (torch.lt, torch_inf, 'min'), } if mode not in mode_dict: From 3ada1369c89965846a038988408235ed750f568b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 10:35:11 -0400 Subject: [PATCH 08/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9951833271382..3aa1a813f6fbc 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -13,6 +13,8 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +torch_inf = torch.tensor(np.Inf) + class EarlyStopping(Callback): r""" @@ -56,7 +58,6 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.wait = 0 self.stopped_epoch = 0 - torch_inf = torch.tensor(np.Inf) mode_dict = { 'min': (torch.lt, torch_inf, 'min'), 'max': (torch.gt, -torch_inf, 'max'), @@ -70,7 +71,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: mode = 'auto' self.monitor_op = mode_dict[mode] - self.min_delta *= 1 if self.monitor_op == np.greater else -1 + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 def _validate_condition_metric(self, logs): """ @@ -97,7 +98,7 @@ def on_train_start(self, trainer, pl_module): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf + self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf def on_epoch_end(self, trainer, pl_module): logs = trainer.callback_metrics From c8de937cbb0bd0473c875b5469bf40d81d1c011f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 15:20:59 -0400 Subject: [PATCH 09/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3aa1a813f6fbc..80a508bfb74f6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -59,10 +59,9 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.stopped_epoch = 0 mode_dict = { - 'min': (torch.lt, torch_inf, 'min'), - 'max': (torch.gt, -torch_inf, 'max'), - 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') - else (torch.lt, torch_inf, 'min'), + 'min': torch.lt, + 'max': torch.gt, + 'auto': torch.gt if 'acc' in self.monitor else torch.lt } if mode not in mode_dict: From 44da135f6bb4188e0c28044e8583a01f51e14038 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 15:35:32 -0400 Subject: [PATCH 10/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 80a508bfb74f6..06118bc42639a 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -46,8 +46,13 @@ class EarlyStopping(Callback): >>> trainer = Trainer(early_stop_callback=early_stopping) """ - def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, - verbose: bool = False, mode: str = 'auto', strict: bool = True): + def __init__(self, + monitor: str = 'val_loss', + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = False, + mode: str = 'min', + strict: bool = True): super().__init__() self.monitor = monitor From 2a670a4d292479b03e72c4cd70b18b129171198b Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 15:36:14 -0400 Subject: [PATCH 11/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 06118bc42639a..8868663a4ffd7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -46,13 +46,8 @@ class EarlyStopping(Callback): >>> trainer = Trainer(early_stop_callback=early_stopping) """ - def __init__(self, - monitor: str = 'val_loss', - min_delta: float = 0.0, - patience: int = 3, - verbose: bool = False, - mode: str = 'min', - strict: bool = True): + def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3, + verbose: bool = False, mode: str = 'auto', strict: bool = True): super().__init__() self.monitor = monitor From a2ec7de33ad58f7cdfcd0fb45d20b633c66e2dba Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 15:59:14 -0400 Subject: [PATCH 12/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 8868663a4ffd7..4a59b416e0de1 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -106,6 +106,9 @@ def on_epoch_end(self, trainer, pl_module): return stop_training current = logs.get(self.monitor) + if not isinstance(current, torch.Tensor): + import pdb; pdb.set_trace() + current = torch.tensor(current) if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 From 90667f7cfc08723a7f75b0b64c2e6f4cad852921 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:00:25 -0400 Subject: [PATCH 13/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4a59b416e0de1..2af0521153f88 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -107,8 +107,8 @@ def on_epoch_end(self, trainer, pl_module): current = logs.get(self.monitor) if not isinstance(current, torch.Tensor): - import pdb; pdb.set_trace() current = torch.tensor(current) + if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 4914d460202d5..50486623fe437 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -139,6 +139,9 @@ def check_monitor_top_k(self, current): if less_than_k_models: return True + if not isinstance(current, torch.Tensor): + current = torch.tensor(current) + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): From 7801ba4b52cfb5690a27c7fcd29cca1dfedd758c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:02:13 -0400 Subject: [PATCH 14/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 50486623fe437..bfd8f89639a5c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -111,7 +111,8 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal mode_dict = { 'min': (torch.lt, torch_inf, 'min'), 'max': (torch.gt, -torch_inf, 'max'), - 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor + or self.monitor.startswith('fmeasure') else (torch.lt, torch_inf, 'min'), } From 958749a6ca56e20043ae84baa524140a3d7e83da Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:03:01 -0400 Subject: [PATCH 15/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bfd8f89639a5c..feca5058230d4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -117,7 +117,8 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal } if mode not in mode_dict: - rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, fallback to auto mode.', RuntimeWarning) + rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, ' + f'fallback to auto mode.', RuntimeWarning) mode = 'auto' self.monitor_op, self.kth_value, self.mode = mode_dict[mode] @@ -210,7 +211,8 @@ def on_validation_end(self, trainer, pl_module): current = metrics.get(self.monitor) if current is None: - rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning) + rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.' + ,RuntimeWarning) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch) elif self.verbose > 0: From 1267cba5f054f55f86a86da356a365ad4d68c3d7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:04:38 -0400 Subject: [PATCH 16/19] fixed memory leak from opt return --- pytorch_lightning/trainer/evaluation_loop.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4132a3f70d1aa..862fc5948e806 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -413,6 +413,8 @@ def run_evaluation(self, test_mode: bool = False): # Validation/Test end callbacks if test_mode: self.on_test_end() + else: + self.on_validation_end() def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional From 3c42fd342a7fc10ebff3df1954c64844d71318e4 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:08:05 -0400 Subject: [PATCH 17/19] fixed memory leak from opt return --- pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index feca5058230d4..4dc73e4bb3205 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -111,8 +111,7 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal mode_dict = { 'min': (torch.lt, torch_inf, 'min'), 'max': (torch.gt, -torch_inf, 'max'), - 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor - or self.monitor.startswith('fmeasure') + 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') else (torch.lt, torch_inf, 'min'), } @@ -211,8 +210,9 @@ def on_validation_end(self, trainer, pl_module): current = metrics.get(self.monitor) if current is None: - rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.' - ,RuntimeWarning) + rank_zero_warn( + f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning + ) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch) elif self.verbose > 0: From 356fae46922f0236e0f8899133717ebda536c143 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:16:37 -0400 Subject: [PATCH 18/19] fixed memory leak from opt return --- pytorch_lightning/trainer/training_loop.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5df8744423ba2..538d717afcd63 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -445,8 +445,11 @@ def run_training_epoch(self): # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs - # detach tensors in batch_output before appending to outputs - outputs.append(batch_output) + + # only track outputs when user implementes training_epoch_end + # otherwise we will build up unecessary memory + if self.is_overriden('training_epoch_end', model=self.get_model()): + outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 @@ -501,9 +504,7 @@ def run_training_epoch(self): break # process epoch outputs - if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): - model = model.module - + model = self.get_model() if self.is_overriden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) From 24f4c8d62e0b8446c60de05e9d84291372b705f5 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 19 Apr 2020 16:22:46 -0400 Subject: [PATCH 19/19] fixed memory leak from opt return --- tests/base/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/base/utils.py b/tests/base/utils.py index f7b82e60dddbc..f907562a74e2e 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -23,7 +23,7 @@ def assert_speed_parity(pl_times, pt_times, num_epochs): # assert speeds - max_diff_per_epoch = 0.9 + max_diff_per_epoch = 0.65 pl_times = np.asarray(pl_times) pt_times = np.asarray(pt_times) diffs = pl_times - pt_times