Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful shutdown on python interpreter exit #1631

Merged
merged 12 commits into from
May 29, 2020
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ references:
command: |
python --version ; pip --version ; pip list
py.test pytorch_lightning tests -v --doctest-modules --junitxml=test-reports/pytest_junit.xml
no_output_timeout: 30m
no_output_timeout: 15m

examples: &examples
run:
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -36,6 +36,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Run graceful training teardown on interpreter exit ([#1631](https://github.com/PyTorchLightning/pytorch-lightning/pull/1631))
justusschock marked this conversation as resolved.
Show resolved Hide resolved

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
30 changes: 26 additions & 4 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
@@ -141,21 +141,23 @@ def training_step(self, batch, batch_idx):

"""

import atexit
import signal
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List

import numpy as np
from torch.utils.data import DataLoader
import torch
from torch.utils.data import DataLoader

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

try:
from apex import amp
@@ -179,9 +181,11 @@ def training_step(self, batch, batch_idx):
else:
HOROVOD_AVAILABLE = True

# constant which signals should be catched for graceful trainer shutdown
SIGNAL_TERMINATE = ('SIGTERM', 'SIGSEGV', 'SIGINT')

class TrainerTrainLoopMixin(ABC):

class TrainerTrainLoopMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
max_epochs: int
@@ -300,6 +304,15 @@ def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def train(self):
# add signal handlers for process kills
def _signal_kill_handler(*args):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't this interfere with the HPC auto-save signal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know. I think it shouldn't, because this will be triggered only on python interpreter exit (with SIGTERM).

return TrainerTrainLoopMixin.run_training_teardown(self)

orig_signal_handlers = {}
for sig_name in SIGNAL_TERMINATE:
orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
_signal_kill_handler)

# get model
model = self.get_model()

@@ -371,6 +384,10 @@ def train(self):

self.run_training_teardown()

# reset signal handlers
for sig_name in SIGNAL_TERMINATE:
signal.signal(getattr(signal, sig_name), orig_signal_handlers[sig_name])

except KeyboardInterrupt:
if self.proc_rank == 0:
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
@@ -405,7 +422,7 @@ def run_training_epoch(self):

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
@@ -663,7 +680,10 @@ def _get_optimizers_iterable(self):
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.optimizers[opt_idx])]

@atexit.register
def run_training_teardown(self):
if hasattr(self, '_teardown_already_run') and self._teardown_already_run:
return
# Train end events
with self.profiler.profile('on_train_end'):
# callbacks
@@ -678,6 +698,8 @@ def run_training_teardown(self):
# summarize profile results
self.profiler.describe()

self._teardown_already_run = True
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)