Skip to content
This repository has been archived by the owner on Mar 22, 2021. It is now read-only.

added k-fold validation and averaging, added saving oof predictions, … #73

Merged
merged 2 commits into from
Sep 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 81 additions & 206 deletions common_blocks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import numpy as np
import torch
from PIL import Image
from deepsense import neptune
import neptune
from torch.autograd import Variable
from torch.optim.lr_scheduler import ExponentialLR
from tempfile import TemporaryDirectory
Expand All @@ -25,6 +24,7 @@
ORIGINAL_SIZE = (101, 101)
THRESHOLD = 0.5


class Callback:
def __init__(self):
self.epoch_id = None
Expand Down Expand Up @@ -159,63 +159,6 @@ def on_batch_end(self, metrics, *args, **kwargs):
self.batch_id += 1


class ValidationMonitor(Callback):
def __init__(self, epoch_every=None, batch_every=None):
super().__init__()
if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every
if batch_every == 0:
self.batch_every = False
else:
self.batch_every = batch_every

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
self.model.train()
for name, loss in val_loss.items():
loss = loss.data.cpu().numpy()[0]
logger.info('epoch {0} validation {1}: {2:.5f}'.format(self.epoch_id, name, loss))
self.epoch_id += 1


class EarlyStopping(Callback):
def __init__(self, patience, minimize=True):
super().__init__()
self.patience = patience
self.minimize = minimize
self.best_score = None
self.epoch_since_best = 0
self._training_break = False

def on_epoch_end(self, *args, **kwargs):
self.model.eval()
val_loss = self.get_validation_loss()
loss_sum = val_loss['sum']
loss_sum = loss_sum.data.cpu().numpy()[0]

self.model.train()

if not self.best_score:
self.best_score = loss_sum

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score):
self.best_score = loss_sum
self.epoch_since_best = 0
else:
self.epoch_since_best += 1

if self.epoch_since_best > self.patience:
self._training_break = True
self.epoch_id += 1

def training_break(self, *args, **kwargs):
return self._training_break


class ExponentialLRScheduler(Callback):
def __init__(self, gamma, epoch_every=1, batch_every=None):
super().__init__()
Expand Down Expand Up @@ -256,50 +199,63 @@ def on_batch_end(self, *args, **kwargs):
self.batch_id += 1


class ModelCheckpoint(Callback):
def __init__(self, filepath, epoch_every=1, minimize=True):
class ExperimentTiming(Callback):
def __init__(self, epoch_every=None, batch_every=None):
super().__init__()
self.filepath = filepath
self.minimize = minimize
self.best_score = None

if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every
if batch_every == 0:
self.batch_every = False
else:
self.batch_every = batch_every
self.batch_start = None
self.epoch_start = None
self.current_sum = None
self.current_mean = None

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
os.makedirs(os.path.dirname(self.filepath), exist_ok=True)

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
loss_sum = val_loss['sum']
loss_sum = loss_sum.data.cpu().numpy()[0]

self.model.train()
logger.info('starting training...')

if self.best_score is None:
self.best_score = loss_sum
def on_train_end(self, *args, **kwargs):
logger.info('training finished')

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score) or (
self.epoch_id == 0):
self.best_score = loss_sum
save_model(self.model, self.filepath)
logger.info('epoch {0} model saved to {1}'.format(self.epoch_id, self.filepath))
def on_epoch_begin(self, *args, **kwargs):
if self.epoch_id > 0:
epoch_time = datetime.now() - self.epoch_start
if self.epoch_every:
if (self.epoch_id % self.epoch_every) == 0:
logger.info('epoch {0} time {1}'.format(self.epoch_id - 1, str(epoch_time)[:-7]))
self.epoch_start = datetime.now()
self.current_sum = timedelta()
self.current_mean = timedelta()
logger.info('epoch {0} ...'.format(self.epoch_id))

self.epoch_id += 1
def on_batch_begin(self, *args, **kwargs):
if self.batch_id > 0:
current_delta = datetime.now() - self.batch_start
self.current_sum += current_delta
self.current_mean = self.current_sum / self.batch_id
if self.batch_every:
if self.batch_id > 0 and (((self.batch_id - 1) % self.batch_every) == 0):
logger.info('epoch {0} average batch time: {1}'.format(self.epoch_id, str(self.current_mean)[:-5]))
if self.batch_every:
if self.batch_id == 0 or self.batch_id % self.batch_every == 0:
logger.info('epoch {0} batch {1} ...'.format(self.epoch_id, self.batch_id))
self.batch_start = datetime.now()


class NeptuneMonitor(Callback):
def __init__(self, model_name):
def __init__(self, image_nr, image_resize, model_name):
super().__init__()
self.model_name = model_name
self.ctx = neptune.Context()
self.epoch_loss_averager = Averager()
self.image_nr = image_nr
self.image_resize = image_resize

def on_train_begin(self, *args, **kwargs):
self.epoch_loss_averagers = {}
Expand Down Expand Up @@ -338,8 +294,8 @@ def _send_numeric_channels(self, *args, **kwargs):
self.ctx.channel_send('{} epoch_val {} loss'.format(self.model_name, name), x=self.epoch_id, y=loss)


class ExperimentTiming(Callback):
def __init__(self, epoch_every=None, batch_every=None):
class ValidationMonitor(Callback):
def __init__(self, data_dir, loader_mode, epoch_every=None, batch_every=None):
super().__init__()
if epoch_every == 0:
self.epoch_every = False
Expand All @@ -349,119 +305,7 @@ def __init__(self, epoch_every=None, batch_every=None):
self.batch_every = False
else:
self.batch_every = batch_every
self.batch_start = None
self.epoch_start = None
self.current_sum = None
self.current_mean = None

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
logger.info('starting training...')

def on_train_end(self, *args, **kwargs):
logger.info('training finished')

def on_epoch_begin(self, *args, **kwargs):
if self.epoch_id > 0:
epoch_time = datetime.now() - self.epoch_start
if self.epoch_every:
if (self.epoch_id % self.epoch_every) == 0:
logger.info('epoch {0} time {1}'.format(self.epoch_id - 1, str(epoch_time)[:-7]))
self.epoch_start = datetime.now()
self.current_sum = timedelta()
self.current_mean = timedelta()
logger.info('epoch {0} ...'.format(self.epoch_id))

def on_batch_begin(self, *args, **kwargs):
if self.batch_id > 0:
current_delta = datetime.now() - self.batch_start
self.current_sum += current_delta
self.current_mean = self.current_sum / self.batch_id
if self.batch_every:
if self.batch_id > 0 and (((self.batch_id - 1) % self.batch_every) == 0):
logger.info('epoch {0} average batch time: {1}'.format(self.epoch_id, str(self.current_mean)[:-5]))
if self.batch_every:
if self.batch_id == 0 or self.batch_id % self.batch_every == 0:
logger.info('epoch {0} batch {1} ...'.format(self.epoch_id, self.batch_id))
self.batch_start = datetime.now()


class ReduceLROnPlateau(Callback): # thank you keras
def __init__(self):
super().__init__()
pass


class NeptuneMonitorSegmentation(NeptuneMonitor):
def __init__(self, image_nr, image_resize, model_name):
super().__init__(model_name)
self.image_nr = image_nr
self.image_resize = image_resize

def on_epoch_end(self, *args, **kwargs):
self._send_numeric_channels()
# self._send_image_channels()
self.epoch_id += 1

def _send_image_channels(self):
self.model.eval()
pred_masks = self.get_prediction_masks()
self.model.train()

for name, pred_mask in pred_masks.items():
for i, image_duplet in enumerate(pred_mask):
h, w = image_duplet.shape[1:]
image_glued = np.zeros((h, 2 * w + 10))

image_glued[:, :w] = image_duplet[0, :, :]
image_glued[:, (w + 10):] = image_duplet[1, :, :]

pill_image = Image.fromarray((image_glued * 255.).astype(np.uint8))
h_, w_ = image_glued.shape
pill_image = pill_image.resize((int(self.image_resize * w_), int(self.image_resize * h_)),
Image.ANTIALIAS)

self.ctx.channel_send('{} {}'.format(self.model_name, name), neptune.Image(
name='epoch{}_batch{}_idx{}'.format(self.epoch_id, self.batch_id, i),
description="true and prediction masks",
data=pill_image))

if i == self.image_nr:
break

def get_prediction_masks(self):
prediction_masks = {}
batch_gen, steps = self.validation_datagen
for batch_id, data in enumerate(batch_gen):
if len(data) != len(self.output_names) + 1:
raise ValueError('incorrect targets provided')
X = data[0]
targets_tensors = data[1:]

if torch.cuda.is_available():
X = Variable(X).cuda()
else:
X = Variable(X)

outputs_batch = self.model(X)
if len(outputs_batch) == len(self.output_names):
for name, output, target in zip(self.output_names, outputs_batch, targets_tensors):
prediction = sigmoid(np.squeeze(output.data.cpu().numpy(), axis=1))
ground_truth = np.squeeze(target.cpu().numpy(), axis=1)
prediction_masks[name] = np.stack([prediction, ground_truth], axis=1)
else:
for name, target in zip(self.output_names, targets_tensors):
prediction = sigmoid(np.squeeze(outputs_batch.data.cpu().numpy(), axis=1))
ground_truth = np.squeeze(target.cpu().numpy(), axis=1)
prediction_masks[name] = np.stack([prediction, ground_truth], axis=1)
break
return prediction_masks


class ValidationMonitorSegmentation(ValidationMonitor):
def __init__(self, data_dir, loader_mode, *args, **kwargs):
super().__init__(*args, **kwargs)
self.data_dir = data_dir
self.validation_pipeline = postprocessing_pipeline_simplified
self.loader_mode = loader_mode
Expand All @@ -483,6 +327,16 @@ def set_params(self, transformer, validation_datagen, meta_valid=None, *args, **
def get_validation_loss(self):
return self._get_validation_loss()

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
val_loss = self.get_validation_loss()
self.model.train()
for name, loss in val_loss.items():
loss = loss.data.cpu().numpy()[0]
logger.info('epoch {0} validation {1}: {2:.5f}'.format(self.epoch_id, name, loss))
self.epoch_id += 1

def _get_validation_loss(self):
output, epoch_loss = self._transform()
y_pred = self._generate_prediction(output)
Expand Down Expand Up @@ -565,11 +419,24 @@ def _generate_prediction(self, outputs):
return y_pred


class ModelCheckpointSegmentation(ModelCheckpoint):
def __init__(self, metric_name='sum', *args, **kwargs):
super().__init__(*args, **kwargs)
class ModelCheckpoint(Callback):
def __init__(self, filepath, metric_name='sum', epoch_every=1, minimize=True):
self.filepath = filepath
self.minimize = minimize
self.best_score = None

if epoch_every == 0:
self.epoch_every = False
else:
self.epoch_every = epoch_every

self.metric_name = metric_name

def on_train_begin(self, *args, **kwargs):
self.epoch_id = 0
self.batch_id = 0
os.makedirs(os.path.dirname(self.filepath), exist_ok=True)

def on_epoch_end(self, *args, **kwargs):
if self.epoch_every and ((self.epoch_id % self.epoch_every) == 0):
self.model.eval()
Expand All @@ -583,19 +450,27 @@ def on_epoch_end(self, *args, **kwargs):
self.best_score = loss_sum

if (self.minimize and loss_sum < self.best_score) or (not self.minimize and loss_sum > self.best_score) or (
self.epoch_id == 0):
self.epoch_id == 0):
self.best_score = loss_sum
persist_torch_model(self.model, self.filepath)
logger.info('epoch {0} model saved to {1}'.format(self.epoch_id, self.filepath))

self.epoch_id += 1


class EarlyStoppingSegmentation(EarlyStopping):
def __init__(self, metric_name='sum', *args, **kwargs):
super().__init__(*args, **kwargs)
class EarlyStopping(Callback):
def __init__(self, metric_name='sum', patience=1000, minimize=True):
super().__init__()
self.patience = patience
self.minimize = minimize
self.best_score = None
self.epoch_since_best = 0
self._training_break = False
self.metric_name = metric_name

def training_break(self, *args, **kwargs):
return self._training_break

def on_epoch_end(self, *args, **kwargs):
self.model.eval()
val_loss = self.get_validation_loss()
Expand All @@ -619,7 +494,7 @@ def on_epoch_end(self, *args, **kwargs):


def postprocessing_pipeline_simplified(cache_dirpath, loader_mode):
if loader_mode == 'crop_and_pad':
if loader_mode == 'resize_and_pad':
size_adjustment_function = partial(crop_image, target_size=ORIGINAL_SIZE)
elif loader_mode == 'resize':
size_adjustment_function = partial(resize_image, target_size=ORIGINAL_SIZE)
Expand Down
Loading