Skip to content
This repository has been archived by the owner on Mar 22, 2021. It is now read-only.

Commit

Permalink
fixed lovash loss, added helpers for loss weighing (#14) (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubczakon authored Sep 6, 2018
1 parent c9795d8 commit 6da841a
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 50 deletions.
3 changes: 1 addition & 2 deletions augmentations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@
"aug_imgs = []\n",
"for _ in range(AUG_NR):\n",
" aug_img = heng_seq.augment_image(img)\n",
" aug_imgs.append(aug_img)\n",
"plot_list(images=aug_imgs)"
" aug_imgs.append(aug_img)"
]
},
{
Expand Down
62 changes: 42 additions & 20 deletions common_blocks/lovash_losses.py → common_blocks/lovasz_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np

try:
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse
from itertools import ifilterfalse
except ImportError: # py3k
from itertools import filterfalse

from .utils import pytorch_where


def lovasz_grad(gt_sorted):
Expand All @@ -22,10 +25,10 @@ def lovasz_grad(gt_sorted):
"""
p = len(gt_sorted)
gts = gt_sorted.sum()
intersection = gts - gt_sorted.float().cumsum(0)
union = gts + (1 - gt_sorted).float().cumsum(0)
intersection = gts.float() - gt_sorted.float().cumsum(0)
union = gts.float() + (1 - gt_sorted).float().cumsum(0)
jaccard = 1. - intersection / union
if p > 1: # cover 1-pixel case
if p > 1: # cover 1-pixel case
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
return jaccard

Expand All @@ -46,7 +49,7 @@ def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
else:
iou = float(intersection) / union
ious.append(iou)
iou = mean(ious) # mean accross images if per_image
iou = mean(ious) # mean accross images if per_image
return 100 * iou


Expand All @@ -60,15 +63,15 @@ def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
for pred, label in zip(preds, labels):
iou = []
for i in range(C):
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
intersection = ((label == i) & (pred == i)).sum()
union = ((label == i) | ((pred == i) & (label != ignore))).sum()
if not union:
iou.append(EMPTY)
else:
iou.append(float(intersection) / union)
ious.append(iou)
ious = map(mean, zip(*ious)) # mean accross images if per_image
ious = map(mean, zip(*ious)) # mean accross images if per_image
return 100 * np.array(ious)


Expand All @@ -85,7 +88,7 @@ def lovasz_hinge(logits, labels, per_image=True, ignore=None):
"""
if per_image:
loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
for log, lab in zip(logits, labels))
for log, lab in zip(logits, labels))
else:
loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
return loss
Expand All @@ -102,15 +105,31 @@ def lovasz_hinge_flat(logits, labels):
# only void pixels, the gradients should be 0
return logits.sum() * 0.
signs = 2. * labels.float() - 1.
errors = (1. - logits * Variable(signs))
errors = (1. - logits * signs)

errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
perm = perm.data
gt_sorted = labels[perm]
grad = lovasz_grad(gt_sorted)
loss = torch.dot(F.relu(errors_sorted), Variable(grad))
loss = torch.dot(F.elu(errors_sorted), grad)
return loss


def weigh_errors_with_size(labels, errors):
if torch.cuda.is_available():
size = float(labels.sum().data.cpu().numpy()[0])
else:
size = float(labels.sum().data.numpy()[0])

if size == 0:
return errors
else:
size_weight = 1. / (size / errors.size()[0])
size_weights = pytorch_where(labels, size_weight, 1.0)
return errors * size_weights



def flatten_binary_scores(scores, labels, ignore=None):
"""
Flattens predictions in the batch (binary case)
Expand All @@ -128,11 +147,12 @@ def flatten_binary_scores(scores, labels, ignore=None):

class StableBCELoss(torch.nn.modules.Module):
def __init__(self):
super(StableBCELoss, self).__init__()
super(StableBCELoss, self).__init__()

def forward(self, input, target):
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
neg_abs = - input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()


def binary_xloss(logits, labels, ignore=None):
Expand Down Expand Up @@ -160,8 +180,9 @@ def lovasz_softmax(probas, labels, only_present=False, per_image=False, ignore=N
ignore: void class labels
"""
if per_image:
loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
for prob, lab in zip(probas, labels))
loss = mean(
lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), only_present=only_present)
for prob, lab in zip(probas, labels))
else:
loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), only_present=only_present)
return loss
Expand All @@ -177,7 +198,7 @@ def lovasz_softmax_flat(probas, labels, only_present=False):
C = probas.size(1)
losses = []
for c in range(C):
fg = (labels == c).float() # foreground for class c
fg = (labels == c).float() # foreground for class c
if only_present and fg.sum() == 0:
continue

Expand All @@ -203,6 +224,7 @@ def flatten_probas(probas, labels, ignore=None):
vlabels = labels[valid]
return vprobas, vlabels


def xloss(logits, labels, ignore=None):
"""
Cross entropy loss
Expand Down Expand Up @@ -230,4 +252,4 @@ def mean(l, ignore_nan=False, empty=0):
acc += v
if n == 1:
return acc
return acc / n
return acc / n
31 changes: 19 additions & 12 deletions common_blocks/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from functools import partial
from toolkit.pytorch_transformers.models import Model

from .utils import sigmoid, softmax, get_list_of_image_predictions
from .utils import sigmoid, softmax, get_list_of_image_predictions, pytorch_where
from . import callbacks as cbk
from .unet_models import UNetResNet, SaltUNet, SaltLinkNet
from .lovash_losses import lovasz_softmax
from .lovasz_losses import lovasz_hinge

PRETRAINED_NETWORKS = {'ResNet34': {'model': UNetResNet,
'model_config': {'encoder_depth': 34,
Expand Down Expand Up @@ -162,15 +162,22 @@ def set_model(self):

def set_loss(self):
if self.activation_func == 'softmax':
loss_function = lovash_loss
elif self.activation_func == 'sigmoid':
loss_function = partial(mixed_dice_bce_loss,
loss_function = partial(mixed_dice_cross_entropy_loss,
dice_loss=multiclass_dice_loss,
bce_loss=nn.BCEWithLogitsLoss(),
dice_activation='sigmoid',
cross_entropy_loss=nn.CrossEntropyLoss(),
dice_activation='softmax',
dice_weight=self.architecture_config['model_params']['dice_weight'],
bce_weight=self.architecture_config['model_params']['bce_weight']
cross_entropy_weight=self.architecture_config['model_params']['bce_weight']
)
elif self.activation_func == 'sigmoid':
loss_function = lovasz_loss
# loss_function = partial(mixed_dice_bce_loss,
# dice_loss=multiclass_dice_loss,
# bce_loss=nn.BCEWithLogitsLoss(),
# dice_activation='sigmoid',
# dice_weight=self.architecture_config['model_params']['dice_weight'],
# bce_weight=self.architecture_config['model_params']['bce_weight']
# )
else:
raise Exception('Only softmax and sigmoid activations are allowed')
self.loss_function = [('mask', loss_function, 1.0)]
Expand Down Expand Up @@ -223,12 +230,12 @@ def __init__(self, smooth=0, eps=1e-7):

def forward(self, output, target):
return 1 - (2 * torch.sum(output * target) + self.smooth) / (
torch.sum(output) + torch.sum(target) + self.smooth + self.eps)
torch.sum(output) + torch.sum(target) + self.smooth + self.eps)


def lovash_loss(output, target):
target = target[:, 1, :, :].long()
return lovasz_softmax(output, target)
def lovasz_loss(output, target):
target = target.long()
return lovasz_hinge(output, target)


def mixed_dice_bce_loss(output, target, dice_weight=0.2, dice_loss=None,
Expand Down
5 changes: 5 additions & 0 deletions common_blocks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,8 @@ def _cached_fit_transform(self, step_inputs):
.format(self.name, self.exp_dir_outputs_step))
self._persist_output(step_output_data, self.exp_dir_outputs_step)
return step_output_data


def pytorch_where(cond, x_1, x_2):
cond = cond.float()
return (cond * x_1) + ((1 - cond) * x_2)
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

EXPERIMENT_DIR = '/output/experiment'
CLONE_EXPERIMENT_DIR_FROM = '' # When running eval in the cloud specify this as for example /input/SAL-14/output/experiment
OVERWRITE_EXPERIMENT_DIR = False
OVERWRITE_EXPERIMENT_DIR = True
DEV_MODE = False

if OVERWRITE_EXPERIMENT_DIR and os.path.isdir(EXPERIMENT_DIR):
Expand Down Expand Up @@ -769,3 +769,4 @@ def save_predictions(train_ids, train_predictions, meta_test, out_of_fold_test_p
if __name__ == '__main__':
prepare_metadata()
train_evaluate_predict_cv()

42 changes: 27 additions & 15 deletions result_exploration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"METADATA_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/files/metadata.csv'\n",
"\n",
"OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = 'YOUR/validation_results.pkl'\n",
"OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments/sal_907_cv_812_lb_820/out_of_fold_train_predictions.pkl'"
"OUT_OF_FOLD_TRAIN_RESULTS_FILEPATH = '/mnt/ml-team/minerva/open-solutions/salt/kuba/experiments/sal_986_cv_821_lb_827/out_of_fold_train_predictions.pkl'"
]
},
{
Expand Down Expand Up @@ -82,7 +82,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"@ipy.interact(idx = ipy.IntSlider(min=0,max=4000,value=0,step=1))\n",
Expand All @@ -101,27 +103,37 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Problems\n",
"# Problems idx 1-200\n",
"\n",
"1. Symmetric predictions near boundaries (reflection padding ?)\n",
"\n",
" **idx**: 3, 10, 26, 41, 42, 43, 45, 47, 48, 122, 125, 126, 128, 137\n",
" \n",
"2. Unfilled stuff near boundaries\n",
" 1. Border masks are a bit to small\n",
"\n",
" **idx**: 2, 4, 23, 27\n",
" **idx** 32, 117\n",
" \n",
"3. Disconnected components\n",
" 2. Whole image is salt misspredicted\n",
"\n",
" **idx**: 3\n",
" **idx** 74\n",
" \n",
"4. Weird round predictions\n",
" 3. Border problems\n",
"\n",
" **idx**: 6\n",
" **idx** 39\n",
" \n",
"5. Unexplained stuff close to boundaries\n",
"\n",
" **idx** 32"
" 4. One pixel predicted\n",
" \n",
" **idx** \n",
" \n",
" 5. Model Fails\n",
" \n",
" **idx** 60, 105, 114, 121, 176, 191, 196\n",
" \n",
" 6. Weak Prediction\n",
" \n",
" **idx** 25, 63, 68, 109, 139, 140, 161, 190\n",
" \n",
" \n",
"## IS THAT TRUE:\n",
" \n",
" **idx** 81"
]
},
{
Expand Down

0 comments on commit 6da841a

Please sign in to comment.