diff --git a/pl_examples/basic_examples/cpu_template.py b/pl_examples/basic_examples/cpu_template.py index 1ab195d515379..5929a07be5727 100644 --- a/pl_examples/basic_examples/cpu_template.py +++ b/pl_examples/basic_examples/cpu_template.py @@ -28,7 +28,7 @@ def main(hparams): # ------------------------ # 2 INIT TRAINER # ------------------------ - trainer = pl.Trainer(max_epochs=hparams.epochs) + trainer = pl.Trainer(max_epochs=hparams.epochs, overfit_pct=0.01, early_stop_callback=True) # ------------------------ # 3 START TRAINING diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 70e90a7046595..2af0521153f88 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -7,11 +7,14 @@ """ import numpy as np +import torch from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +torch_inf = torch.tensor(np.Inf) + class EarlyStopping(Callback): r""" @@ -43,7 +46,7 @@ class EarlyStopping(Callback): >>> trainer = Trainer(early_stop_callback=early_stopping) """ - def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 0, + def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3, verbose: bool = False, mode: str = 'auto', strict: bool = True): super().__init__() @@ -56,9 +59,9 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.stopped_epoch = 0 mode_dict = { - 'min': np.less, - 'max': np.greater, - 'auto': np.greater if 'acc' in self.monitor else np.less + 'min': torch.lt, + 'max': torch.gt, + 'auto': torch.gt if 'acc' in self.monitor else torch.lt } if mode not in mode_dict: @@ -67,9 +70,14 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: mode = 'auto' self.monitor_op = mode_dict[mode] - self.min_delta *= 1 if self.monitor_op == np.greater else -1 - - def check_metrics(self, logs): + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + + def _validate_condition_metric(self, logs): + """ + Checks that the condition metric for early stopping is good + :param logs: + :return: + """ monitor_val = logs.get(self.monitor) error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' f' which is not available. Available metrics are:' @@ -89,15 +97,18 @@ def on_train_start(self, trainer, pl_module): # Allow instances to be re-used self.wait = 0 self.stopped_epoch = 0 - self.best = np.Inf if self.monitor_op == np.less else -np.Inf + self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf def on_epoch_end(self, trainer, pl_module): logs = trainer.callback_metrics stop_training = False - if not self.check_metrics(logs): + if not self._validate_condition_metric(logs): return stop_training current = logs.get(self.monitor) + if not isinstance(current, torch.Tensor): + current = torch.tensor(current) + if self.monitor_op(current - self.min_delta, self.best): self.best = current self.wait = 0 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a5ac54ac018ef..4dc73e4bb3205 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -14,6 +14,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.utilities import rank_zero_warn +import torch class ModelCheckpoint(Callback): @@ -106,15 +107,17 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal self.best = 0 self.save_function = None + torch_inf = torch.tensor(np.Inf) mode_dict = { - 'min': (np.less, np.Inf, 'min'), - 'max': (np.greater, -np.Inf, 'max'), - 'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') - else (np.less, np.Inf, 'min'), + 'min': (torch.lt, torch_inf, 'min'), + 'max': (torch.gt, -torch_inf, 'max'), + 'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure') + else (torch.lt, torch_inf, 'min'), } if mode not in mode_dict: - rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, fallback to auto mode.', RuntimeWarning) + rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, ' + f'fallback to auto mode.', RuntimeWarning) mode = 'auto' self.monitor_op, self.kth_value, self.mode = mode_dict[mode] @@ -136,6 +139,10 @@ def check_monitor_top_k(self, current): less_than_k_models = len(self.best_k_models) < self.save_top_k if less_than_k_models: return True + + if not isinstance(current, torch.Tensor): + current = torch.tensor(current) + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) def format_checkpoint_name(self, epoch, metrics, ver=None): @@ -203,7 +210,9 @@ def on_validation_end(self, trainer, pl_module): current = metrics.get(self.monitor) if current is None: - rank_zero_warn(f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning) + rank_zero_warn( + f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning + ) elif self.check_monitor_top_k(current): self._do_check_save(filepath, current, epoch) elif self.verbose > 0: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 4132a3f70d1aa..862fc5948e806 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -413,6 +413,8 @@ def run_evaluation(self, test_mode: bool = False): # Validation/Test end callbacks if test_mode: self.on_test_end() + else: + self.on_validation_end() def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 22d2d42e724f4..dbe05e4aa2818 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -5,6 +5,7 @@ from pytorch_lightning.core import memory from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection +from pytorch_lightning.utilities import memory_utils class TrainerLoggingMixin(ABC): @@ -173,10 +174,9 @@ def process_output(self, output, train=False): callback_metrics.update(progress_bar_metrics) callback_metrics.update(log_metrics) - # convert tensors to numpy - for k, v in callback_metrics.items(): - if isinstance(v, torch.Tensor): - callback_metrics[k] = v.item() + # detach all metrics for callbacks to prevent memory leaks + # no .item() because it will slow things down + callback_metrics = memory_utils.recursive_detach(callback_metrics) return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0e9bb209c7879..1964af8168e2d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -891,8 +891,9 @@ def run_pretrain_routine(self, model: LightningModule): self.main_progress_bar.close() self.val_progress_bar.close() + # verify that early stop has conditioned on a metric that exists if self.enable_early_stop: - self.early_stop_callback.check_metrics(callback_metrics) + self.early_stop_callback._validate_condition_metric(callback_metrics) # init progress bar pbar = tqdm(leave=True, position=2 * self.process_position, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5a8442204930c..538d717afcd63 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -157,6 +157,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities import memory_utils try: from apex import amp @@ -444,8 +445,11 @@ def run_training_epoch(self): # --------------- _outputs = self.run_training_batch(batch, batch_idx) batch_result, grad_norm_dic, batch_step_metrics, batch_output = _outputs - # detach tensors in batch_output before appending to outputs - outputs.append(_recursive_detach(batch_output)) + + # only track outputs when user implementes training_epoch_end + # otherwise we will build up unecessary memory + if self.is_overriden('training_epoch_end', model=self.get_model()): + outputs.append(batch_output) # when returning -1 from train_step, we end epoch early early_stop_epoch = batch_result == -1 @@ -463,9 +467,14 @@ def run_training_epoch(self): should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) should_check_val = can_check_val and should_check_val + # --------------- + # CHECKPOINTING, EARLY STOPPING + # --------------- # fast_dev_run always forces val checking after train batch if self.fast_dev_run or should_check_val: self.run_evaluation(test_mode=self.testing) + self.call_checkpoint_callback() + self.call_early_stop_callback() # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch @@ -479,16 +488,6 @@ def run_training_epoch(self): # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) - # --------------- - # CHECKPOINTING, EARLY STOPPING - # --------------- - # save checkpoint even when no test or val step are defined - if self.fast_dev_run or should_check_val: - self.call_checkpoint_callback() - - if self.enable_early_stop: - self.early_stop_callback.check_metrics(self.callback_metrics) - # progress global step according to grads progress if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: self.global_step += 1 @@ -505,9 +504,7 @@ def run_training_epoch(self): break # process epoch outputs - if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)): - model = model.module - + model = self.get_model() if self.is_overriden('training_epoch_end', model=model): epoch_output = model.training_epoch_end(outputs) _processed_outputs = self.process_output(epoch_output) @@ -516,12 +513,10 @@ def run_training_epoch(self): self.log_metrics(log_epoch_metrics, {}) self.callback_metrics.update(callback_epoch_metrics) - # in case validation step is missing and you are not running fast-dev to duplicate last batch + # when no val loop is present or fast-dev-run still need to call checkpoints if not self.is_overriden('validation_step') and not (self.fast_dev_run or should_check_val): self.call_checkpoint_callback() - - if self.enable_early_stop: - self.early_stop_callback.check_metrics(self.callback_metrics) + self.call_early_stop_callback() # Epoch end events with self.profiler.profile('on_epoch_end'): @@ -608,7 +603,7 @@ def optimizer_closure(): with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() - return closure_loss, output_dict + return closure_loss, callback_metrics # calculate loss loss, batch_output = optimizer_closure() @@ -800,7 +795,10 @@ def update_learning_rates(self, interval: str): def call_checkpoint_callback(self): if self.checkpoint_callback is not None: self.checkpoint_callback.on_validation_end(self, self.get_model()) - self.on_validation_end() + + def call_early_stop_callback(self): + if self.early_stop_callback: + self.early_stop_callback.on_epoch_end(self, self.get_model()) def _with_is_last(iterable): @@ -814,29 +812,3 @@ def _with_is_last(iterable): last = val # yield last, no longer has next yield last, True - - -def _recursive_detach(in_dict): - """Detach all tensors in `in_dict`. - - May operate recursively if some of the values in `in_dict` are dictionaries - which contain instances of `torch.Tensor`. Other types in `in_dict` are - not affected by this utility function. - - Parameters - ---------- - in_dict : dict - - Returns - ------- - out_dict : dict - """ - out_dict = {} - for k, v in in_dict.items(): - if isinstance(v, dict): - out_dict.update({k: _recursive_detach(v)}) - elif callable(getattr(v, 'detach', None)): - out_dict.update({k: v.detach()}) - else: - out_dict.update({k: v}) - return out_dict diff --git a/pytorch_lightning/utilities/memory_utils.py b/pytorch_lightning/utilities/memory_utils.py new file mode 100644 index 0000000000000..154a2f990fd3d --- /dev/null +++ b/pytorch_lightning/utilities/memory_utils.py @@ -0,0 +1,24 @@ +def recursive_detach(in_dict): + """Detach all tensors in `in_dict`. + + May operate recursively if some of the values in `in_dict` are dictionaries + which contain instances of `torch.Tensor`. Other types in `in_dict` are + not affected by this utility function. + + Parameters + ---------- + in_dict : dict + + Returns + ------- + out_dict : dict + """ + out_dict = {} + for k, v in in_dict.items(): + if isinstance(v, dict): + out_dict.update({k: recursive_detach(v)}) + elif callable(getattr(v, 'detach', None)): + out_dict.update({k: v.detach()}) + else: + out_dict.update({k: v}) + return out_dict diff --git a/tests/base/utils.py b/tests/base/utils.py index f7b82e60dddbc..f907562a74e2e 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -23,7 +23,7 @@ def assert_speed_parity(pl_times, pt_times, num_epochs): # assert speeds - max_diff_per_epoch = 0.9 + max_diff_per_epoch = 0.65 pl_times = np.asarray(pl_times) pt_times = np.asarray(pt_times) diffs = pl_times - pt_times