Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graceful training cleanup after Keyboard Interrupt #856

Merged
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -997,9 +997,6 @@ def run_pretrain_routine(self, model):
# CORE TRAINING LOOP
self.train()

# summarize profile results
Borda marked this conversation as resolved.
Show resolved Hide resolved
self.profiler.describe()

def test(self, model=None):
r"""

176 changes: 93 additions & 83 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
@@ -155,6 +155,7 @@ def training_step(self, batch, batch_idx):
import copy
import warnings
from abc import ABC, abstractmethod
import logging as log

import numpy as np

@@ -307,91 +308,86 @@ def process_output(self, output, train):
def train(self):
warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
# get model
model = self.get_model()
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)

# get model
model = self.get_model()

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches

# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()

# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step(epoch=self.current_epoch)
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
self.main_progress_bar.close()
with self.profiler.profile('on_train_end'):
model.on_train_end()
return

self.main_progress_bar.close()

with self.profiler.profile('on_train_end'):
model.on_train_end()

if self.logger is not None:
self.logger.finalize("success")
try:
Borda marked this conversation as resolved.
Show resolved Hide resolved
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# set seed for distributed sampler (enables shuffling for each epoch)
Borda marked this conversation as resolved.
Show resolved Hide resolved
if (self.use_ddp or self.use_tpu) \
and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
self.get_train_dataloader().sampler.set_epoch(epoch)

# update training progress in trainer and model
model.current_epoch = epoch
self.current_epoch = epoch

total_val_batches = 0
is_val_epoch = False
if not self.disable_validation:
# val can be checked multiple times in epoch
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
val_checks_per_epoch = self.num_training_batches // self.val_check_batch
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
total_val_batches = self.num_val_batches * val_checks_per_epoch

# total batches includes multiple val checks
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
num_iterations = 2
elif self.is_iterable_train_dataloader:
# for iterable train loader, the progress bar never ends
num_iterations = None
else:
num_iterations = self.total_batches

# reset progress bar
# .reset() doesn't work on disabled progress bar so we should check
if not self.main_progress_bar.disable:
self.main_progress_bar.reset(num_iterations)
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
self.main_progress_bar.set_description(desc)

# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin()

# -----------------
# RUN TNG EPOCH
# -----------------
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step(epoch=self.current_epoch)
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss, epoch=self.current_epoch)

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
if (self.enable_early_stop and not self.disable_validation and is_val_epoch and
(met_min_epochs or self.fast_dev_run)):
should_stop = self.early_stop_callback.on_epoch_end()
# stop training
stop = should_stop and met_min_epochs
if stop:
self.run_training_teardown()
return

self.run_training_teardown()
except KeyboardInterrupt:
log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
self.run_training_teardown()
Borda marked this conversation as resolved.
Show resolved Hide resolved

def run_training_epoch(self):
# before epoch hook
@@ -611,6 +607,20 @@ def optimizer_closure():

return 0, grad_norm_dic, all_log_metrics

def run_training_teardown(self):
Borda marked this conversation as resolved.
Show resolved Hide resolved
model = self.get_model()

self.main_progress_bar.close()

with self.profiler.profile('on_train_end'):
model.on_train_end()

if self.logger is not None:
self.logger.finalize("success")

# summarize profile results
self.profiler.describe()

def training_forward(self, batch, batch_idx, opt_idx, hiddens):
"""
Handle forward for each training case (distributed, single gpu, etc...)