Skip to content

Commit

Permalink
Handle KeyboardInterrupt during training (#2134)
Browse files Browse the repository at this point in the history
* Handle KeyboardInterrupt during training

Fixes #2079.

* chlog

* Fix whitespace

* Update callback_hook.py

* Update base.py

* Update training_loop.py

* Update test_trainer.py

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <[email protected]>

* Update CHANGELOG.md

* on_keyboard_interrupt

Co-authored-by: Jirka <[email protected]>
Co-authored-by: William Falcon <[email protected]>
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
5 people authored Jun 15, 2020
1 parent bd3a1f7 commit fd1693e
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 5 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))
- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877))
- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327))
- Added Native torch metrics ([#1488](https://github.com/PyTorchLightning/pytorch-lightning/pull/1488))
Expand All @@ -35,6 +34,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/issues/1667))
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))

### Changed

Expand All @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value of the Trainer argument `weights_summary` from `full` to `top` ([#2029](https://github.com/PyTorchLightning/pytorch-lightning/pull/2029))
- Raise an error when lightning replaces an existing sampler ([#2020](https://github.com/PyTorchLightning/pytorch-lightning/pull/2020))
- Enabled prepare_data from correct processes - clarify local vs global rank ([#2166](https://github.com/PyTorchLightning/pytorch-lightning/pull/2166))
- Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126))

### Deprecated

Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ def on_test_start(self, trainer, pl_module):
def on_test_end(self, trainer, pl_module):
"""Called when the test ends."""
pass

def on_keyboard_interrupt(self, trainer, pl_module):
"""Called when the training is interrupted by KeyboardInterrupt."""
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())

def on_keyboard_interrupt(self):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ class TrainerTrainLoopMixin(ABC):
checkpoint_callback: ...
terminate_on_nan: bool
tpu_id: int
interactive_ddp_procs: ...

# Callback system
callbacks: List[Callback]
Expand All @@ -247,6 +248,7 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_start: Callable
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable

@abstractmethod
def get_model(self) -> LightningModule:
Expand Down Expand Up @@ -395,6 +397,7 @@ def train(self):
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self.on_keyboard_interrupt()

for proc in self.interactive_ddp_procs:
subprocess.Popen.kill(proc)
Expand Down
20 changes: 16 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@
import os
import pickle
import types
import sys
from argparse import Namespace

import cloudpickle
import pytest
import torch

import tests.base.utils as tutils
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.core.saving import (
load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv)
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.io import load as pl_load
Expand Down Expand Up @@ -660,10 +661,19 @@ def __init__(self):
def on_batch_start(self, trainer, pl_module):
raise KeyboardInterrupt

class HandleInterruptCallback(Callback):
def __init__(self):
super().__init__()
self.exc_info = None

def on_keyboard_interrupt(self, trainer, pl_module):
self.exc_info = sys.exc_info()

interrupt_callback = InterruptCallback()
handle_interrupt_callback = HandleInterruptCallback()

trainer = Trainer(
callbacks=[interrupt_callback],
callbacks=[interrupt_callback, handle_interrupt_callback],
max_epochs=1,
val_percent_check=0.1,
train_percent_check=0.2,
Expand All @@ -672,8 +682,10 @@ def on_batch_start(self, trainer, pl_module):
default_root_dir=tmpdir,
)
assert not trainer.interrupted
assert handle_interrupt_callback.exc_info is None
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exc_info[1], KeyboardInterrupt)


def test_gradient_clipping(tmpdir):
Expand Down

0 comments on commit fd1693e

Please sign in to comment.