diff --git a/.circleci/config.yml b/.circleci/config.yml index 2237e39423bb0..1cd6ac7a4d27a 100755 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -22,7 +22,7 @@ references: command: | python --version ; pip --version ; pip list py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml - no_output_timeout: 30m + no_output_timeout: 15m examples: &examples run: diff --git a/CHANGELOG.md b/CHANGELOG.md index b6280a1beeab6..31b829c61c7a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631)) + - Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) - Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932)) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 55c63679ae9f6..bd09d8252a5cc 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -141,21 +141,23 @@ def training_step(self, batch, batch_idx): """ +import atexit +import signal from abc import ABC, abstractmethod from typing import Callable from typing import Union, List import numpy as np -from torch.utils.data import DataLoader import torch +from torch.utils.data import DataLoader from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp @@ -179,9 +181,11 @@ def training_step(self, batch, batch_idx): else: HOROVOD_AVAILABLE = True +# constant which signals should be catched for graceful trainer shutdown +SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT') -class TrainerTrainLoopMixin(ABC): +class TrainerTrainLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class max_epochs: int @@ -300,6 +304,15 @@ def has_arg(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def train(self): + # add signal handlers for process kills + def _signal_kill_handler(*args): + return TrainerTrainLoopMixin.run_training_teardown(self) + + orig_signal_handlers = {} + for sig_name in SIGNAL_TERMINATE: + orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name), + _signal_kill_handler) + # get model model = self.get_model() @@ -371,6 +384,10 @@ def train(self): self.run_training_teardown() + # reset signal handlers + for sig_name in SIGNAL_TERMINATE: + signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name]) + except KeyboardInterrupt: if self.proc_rank == 0: log.info('Detected KeyboardInterrupt, attempting graceful shutdown...') @@ -405,7 +422,7 @@ def run_training_epoch(self): # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( - enumerate(_with_is_last(train_dataloader)), "get_train_batch" + enumerate(_with_is_last(train_dataloader)), "get_train_batch" ): # stop epoch if we limited the number of training batches if batch_idx >= self.num_training_batches: @@ -663,7 +680,10 @@ def _get_optimizers_iterable(self): opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop) return [(opt_idx, self.optimizers[opt_idx])] + @atexit.register def run_training_teardown(self): + if hasattr(self, '_teardown_already_run') and self._teardown_already_run: + return # Train end events with self.profiler.profile('on_train_end'): # callbacks @@ -678,6 +698,8 @@ def run_training_teardown(self): # summarize profile results self.profiler.describe() + self._teardown_already_run = True + def training_forward(self, batch, batch_idx, opt_idx, hiddens): """ Handle forward for each training case (distributed, single gpu, etc...)