Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
bce loss (#7304)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and piiswrong committed Aug 14, 2017
1 parent 0142ea0 commit 568b5a2
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 15 deletions.
69 changes: 54 additions & 15 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
""" losses for training neural networks """
from __future__ import absolute_import

from .. import symbol, ndarray
from .. import ndarray
from ..base import numeric_types
from .block import HybridBlock

Expand Down Expand Up @@ -54,6 +54,11 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None):

return loss

def _reshape_label_as_output(F, output, label):
# for symbolic output.shape is not available so we reshape
# to empty shape and let it be inferred from output's shape
# via the '-' operator later.
return label.reshape(output.shape) if F is ndarray else label.reshape(())

class Loss(HybridBlock):
"""Base class for loss.
Expand Down Expand Up @@ -113,13 +118,8 @@ def __init__(self, weight=1., batch_axis=0, **kwargs):
super(L2Loss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, output, label, sample_weight=None):
if F is ndarray:
loss = ndarray.square(output - label.reshape(output.shape))
else:
# for symbolic output.shape is not available so we reshape
# to empty shape and let it be inferred from output's shape
# via the '-' operator later.
loss = symbol.square(output - label.reshape(()))
label = _reshape_label_as_output(F, output, label)
loss = F.square(output - label)
loss = _apply_weighting(F, loss, self._weight/2, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

Expand Down Expand Up @@ -148,19 +148,56 @@ def __init__(self, weight=None, batch_axis=0, **kwargs):
super(L1Loss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, output, label, sample_weight=None):
if F is ndarray:
loss = ndarray.abs(output - label.reshape(output.shape))
label = _reshape_label_as_output(F, output, label)
loss = F.abs(output - label)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


class SigmoidBinaryCrossEntropyLoss(Loss):
r"""The cross-entropy loss for binary classification. (alias: SigmoidBCELoss)
BCE loss is useful when training logistic regression.
.. math::
loss(o, t) = - 1/n \sum_i (t[i] * log(o[i]) + (1 - t[i]) * log(1 - o[i]))
Parameters
----------
from_sigmoid : bool, default is `False`
Whether the input is from the output of sigmoid. Set this to false will make
the loss calculate sigmoid and then BCE, which is more numerically stable through
log-sum-exp trick.
weight : float or None
Global scalar weight for loss.
sample_weight : Symbol or None
Per sample weighting. Must be broadcastable to
the same shape as loss. For example, if loss has
shape (64, 10) and you want to weight each sample
in the batch, `sample_weight` should have shape (64, 1).
batch_axis : int, default 0
The axis that represents mini-batch.
"""
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs):
super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_sigmoid = from_sigmoid

def hybrid_forward(self, F, output, label, sample_weight=None):
label = _reshape_label_as_output(F, output, label)
if not self._from_sigmoid:
max_val = F.maximum(-output, 0)
loss = output - output*label + max_val + F.log(F.exp(-max_val)+F.exp(-output-max_val))
else:
# for symbolic output.shape is not available so we reshape
# to empty shape and let it be inferred from output's shape
# via the '-' operator later.
loss = symbol.abs(output - label.reshape(()))
loss = -(F.log(output+1e-8)*label + F.log(1.-output+1e-8)*(1.-label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss


class SoftmaxCrossEntropyLoss(Loss):
"""Computes the softmax cross entropy loss.
"""Computes the softmax cross entropy loss. (alias: SoftmaxCELoss)
If `sparse_label` is `True`, label should contain integer category indicators:
Expand Down Expand Up @@ -216,6 +253,8 @@ def hybrid_forward(self, F, output, label, sample_weight=None):
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

SoftmaxCELoss = SoftmaxCrossEntropyLoss


class KLDivLoss(Loss):
"""The Kullback-Leibler divergence loss.
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mxnet as mx
import numpy as np
from mxnet import gluon
from mxnet.test_utils import assert_almost_equal


def test_loss_ndarray():
Expand Down Expand Up @@ -81,6 +82,34 @@ def test_ce_loss():
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01


def test_bce_loss():
mx.random.seed(1234)
np.random.seed(1234)
N = 20
data = mx.random.uniform(-1, 1, shape=(N, 20))
label = mx.nd.array(np.random.randint(2, size=(N,)), dtype='float32')
data_iter = mx.io.NDArrayIter(data, label, batch_size=10, label_name='label')
output = get_net(1)
fc2 = output.get_internals()['fc2_output']
l = mx.symbol.Variable('label')
Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
loss = Loss(output, l)
loss = mx.sym.make_loss(loss)
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label',))
mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 1.},
eval_metric=mx.metric.Loss())
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01

def test_bce_equal_ce2():
N = 100
loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
out1 = mx.random.uniform(0, 1, shape=(N, 1))
out2 = mx.nd.log(mx.nd.concat(1-out1, out1, dim=1) + 1e-8)
label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())


def test_kl_loss():
mx.random.seed(1234)
np.random.seed(1234)
Expand Down Expand Up @@ -117,6 +146,7 @@ def test_l2_loss():
eval_metric=mx.metric.Loss())
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05


def test_l1_loss():
mx.random.seed(1234)
np.random.seed(1234)
Expand Down

0 comments on commit 568b5a2

Please sign in to comment.