Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add pos_weight for SigmoidBinaryCrossEntropyLoss (#13612)
Browse files Browse the repository at this point in the history
* add pos_weight for SigmoidBinaryCrossEntropyLoss in gluon.loss

* Update loss.py

* add test

add test

* set the default value of pos_weight to be 1

* fix unittest

* set N be a random number

* fix issues

* test without random number

* test with random N

* fix

* fix errors

* fix errors

* fix order

* Update loss.py

* Update loss.py

* fix pylint

* default pos_weight=None

* add broadcast_mul and fix pylint

* fix unittest

* Update loss.py

* Update loss.py

* Update loss.py
  • Loading branch information
eureka7mt authored and wkcn committed Mar 8, 2019
1 parent 8668db7 commit ce9e3cf
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 22 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ List of Contributors
* [Ming Yang](http://ufoym.com)
* [Satya Krishna Gorti](https://github.com/satyakrishnagorti)
* [Neo Chien](https://github.com/cchung100m)
* [Wujie Zhou](https://github.com/eureka7mt)

Label Bot
---------
Expand Down
85 changes: 65 additions & 20 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ..base import numeric_types
from .block import HybridBlock


def _apply_weighting(F, loss, weight=None, sample_weight=None):
"""Apply weighting to loss.
Expand Down Expand Up @@ -60,10 +61,12 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None):

return loss


def _reshape_like(F, x, y):
"""Reshapes x to the same shape as y."""
return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y)


class Loss(HybridBlock):
"""Base class for loss.
Expand All @@ -74,6 +77,7 @@ class Loss(HybridBlock):
batch_axis : int, default 0
The axis that represents mini-batch.
"""

def __init__(self, weight, batch_axis, **kwargs):
super(Loss, self).__init__(**kwargs)
self._weight = weight
Expand Down Expand Up @@ -126,13 +130,14 @@ class L2Loss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, weight=1., batch_axis=0, **kwargs):
super(L2Loss, self).__init__(weight, batch_axis, **kwargs)

def hybrid_forward(self, F, pred, label, sample_weight=None):
label = _reshape_like(F, label, pred)
loss = F.square(label - pred)
loss = _apply_weighting(F, loss, self._weight/2, sample_weight)
loss = _apply_weighting(F, loss, self._weight / 2, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


Expand Down Expand Up @@ -164,6 +169,7 @@ class L1Loss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, weight=None, batch_axis=0, **kwargs):
super(L1Loss, self).__init__(weight, batch_axis, **kwargs)

Expand All @@ -184,18 +190,22 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
prob = \frac{1}{1 + \exp(-{pred})}
L = - \sum_i {label}_i * \log({prob}_i) +
L = - \sum_i {label}_i * \log({prob}_i) * pos\_weight +
(1 - {label}_i) * \log(1 - {prob}_i)
If `from_sigmoid` is True, this loss computes:
.. math::
L = - \sum_i {label}_i * \log({pred}_i) +
L = - \sum_i {label}_i * \log({pred}_i) * pos\_weight +
(1 - {label}_i) * \log(1 - {pred}_i)
A tensor `pos_weight > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weight < 1` decreases the false positive count and
increases the precision.
`label` and `pred` can have arbitrary shape as long as they have the same
`pred` and `label` can have arbitrary shape as long as they have the same
number of elements.
Parameters
Expand All @@ -218,25 +228,45 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).
- **pos_weight**: a weighting tensor of positive examples. Must be a vector with length
equal to the number of classes.For example, if pred has shape (64, 10),
pos_weight should have shape (1, 10).
Outputs:
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs):
super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
super(SigmoidBinaryCrossEntropyLoss, self).__init__(
weight, batch_axis, **kwargs)
self._from_sigmoid = from_sigmoid

def hybrid_forward(self, F, pred, label, sample_weight=None):
def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None):
label = _reshape_like(F, label, pred)
if not self._from_sigmoid:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
if pos_weight is None:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
loss = F.relu(pred) - pred * label + \
F.Activation(-F.abs(pred), act_type='softrelu')
else:
# We use the stable formula: x - x * z + (1 + z * pos_weight - z) * \
# (log(1 + exp(-abs(x))) + max(-x, 0))
log_weight = 1 + F.broadcast_mul(pos_weight - 1, label)
loss = pred - pred * label + log_weight * \
(F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred))
else:
loss = -(F.log(pred+1e-12)*label + F.log(1.-pred+1e-12)*(1.-label))
eps = 1e-12
if pos_weight is None:
loss = -(F.log(pred + eps) * label
+ F.log(1. - pred + eps) * (1. - label))
else:
loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight)
+ F.log(1. - pred + eps) * (1. - label))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss


Expand Down Expand Up @@ -301,9 +331,11 @@ class SoftmaxCrossEntropyLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None,
batch_axis=0, **kwargs):
super(SoftmaxCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
super(SoftmaxCrossEntropyLoss, self).__init__(
weight, batch_axis, **kwargs)
self._axis = axis
self._sparse_label = sparse_label
self._from_logits = from_logits
Expand All @@ -315,10 +347,11 @@ def hybrid_forward(self, F, pred, label, sample_weight=None):
loss = -F.pick(pred, label, axis=self._axis, keepdims=True)
else:
label = _reshape_like(F, label, pred)
loss = -F.sum(pred*label, axis=self._axis, keepdims=True)
loss = -F.sum(pred * label, axis=self._axis, keepdims=True)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)


SoftmaxCELoss = SoftmaxCrossEntropyLoss


Expand Down Expand Up @@ -382,6 +415,7 @@ class KLDivLoss(Loss):
`Kullback-Leibler divergence
<https://en.wikipedia.org/wiki/Kullback-Leibler_divergence>`_
"""

def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0,
**kwargs):
super(KLDivLoss, self).__init__(weight, batch_axis, **kwargs)
Expand All @@ -391,7 +425,7 @@ def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0,
def hybrid_forward(self, F, pred, label, sample_weight=None):
if not self._from_logits:
pred = F.log_softmax(pred, self._axis)
loss = label * (F.log(label+1e-12) - pred)
loss = label * (F.log(label + 1e-12) - pred)
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

Expand Down Expand Up @@ -453,11 +487,12 @@ class CTCLoss(Loss):
Sequence Data with Recurrent Neural Networks
<http://www.cs.toronto.edu/~graves/icml_2006.pdf>`_
"""

def __init__(self, layout='NTC', label_layout='NT', weight=None, **kwargs):
assert layout in ['NTC', 'TNC'],\
"Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout
"Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s" % layout
assert label_layout in ['NT', 'TN'],\
"Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout
"Only 'NT' and 'TN' layouts for label are supported. Got: %s" % label_layout
self._layout = layout
self._label_layout = label_layout
batch_axis = label_layout.find('N')
Expand Down Expand Up @@ -512,6 +547,7 @@ class HuberLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs):
super(HuberLoss, self).__init__(weight, batch_axis, **kwargs)
self._rho = rho
Expand All @@ -520,7 +556,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None):
label = _reshape_like(F, label, pred)
loss = F.abs(label - pred)
loss = F.where(loss > self._rho, loss - 0.5 * self._rho,
(0.5/self._rho) * F.square(loss))
(0.5 / self._rho) * F.square(loss))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

Expand Down Expand Up @@ -558,6 +594,7 @@ class HingeLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
super(HingeLoss, self).__init__(weight, batch_axis, **kwargs)
self._margin = margin
Expand Down Expand Up @@ -602,6 +639,7 @@ class SquaredHingeLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
super(SquaredHingeLoss, self).__init__(weight, batch_axis, **kwargs)
self._margin = margin
Expand Down Expand Up @@ -647,6 +685,7 @@ class LogisticLoss(Loss):
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""

def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs):
super(LogisticLoss, self).__init__(weight, batch_axis, **kwargs)
self._label_format = label_format
Expand All @@ -659,7 +698,8 @@ def hybrid_forward(self, F, pred, label, sample_weight=None):
if self._label_format == 'signed':
label = (label + 1.0) / 2.0 # Transform label to be either 0 or 1
# Use a stable formula in computation
loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
loss = F.relu(pred) - pred * label + \
F.Activation(-F.abs(pred), act_type='softrelu')
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

Expand Down Expand Up @@ -696,14 +736,15 @@ class TripletLoss(Loss):
Outputs:
- **loss**: loss tensor with shape (batch_size,).
"""

def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs):
super(TripletLoss, self).__init__(weight, batch_axis, **kwargs)
self._margin = margin

def hybrid_forward(self, F, pred, positive, negative):
positive = _reshape_like(F, positive, pred)
negative = _reshape_like(F, negative, pred)
loss = F.sum(F.square(positive-pred) - F.square(negative-pred),
loss = F.sum(F.square(positive - pred) - F.square(negative - pred),
axis=self._batch_axis, exclude=True)
loss = F.relu(loss + self._margin)
return _apply_weighting(F, loss, self._weight, None)
Expand Down Expand Up @@ -748,6 +789,7 @@ class PoissonNLLLoss(Loss):
Outputs:
- **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,).
"""

def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs):
super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_logits = from_logits
Expand All @@ -761,7 +803,8 @@ def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
loss = pred - target * F.log(pred + epsilon)
if self._compute_full:
# Using numpy's pi value
stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi)
stirling_factor = target * \
F.log(target) - target + 0.5 * F.log(2 * target * np.pi)
target_gt_1 = target > 1
stirling_factor *= target_gt_1
loss += stirling_factor
Expand Down Expand Up @@ -804,6 +847,7 @@ class CosineEmbeddingLoss(Loss):
Outputs:
- **loss**: The loss tensor with shape (batch_size,).
"""

def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs):
super(CosineEmbeddingLoss, self).__init__(weight, batch_axis, **kwargs)
self._margin = margin
Expand All @@ -820,7 +864,8 @@ def hybrid_forward(self, F, input1, input2, label, sample_weight=None):
z_array = F.array([0])
else:
z_array = F.zeros((1, 1))
cos_sim_b = F.broadcast_maximum(z_array, y_minus_1 * (cos_sim - self._margin), axis=1)
cos_sim_b = F.broadcast_maximum(
z_array, y_minus_1 * (cos_sim - self._margin), axis=1)
loss = cos_sim_a + cos_sim_b
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return loss
Expand All @@ -829,7 +874,7 @@ def _cosine_similarity(self, F, x, y, axis=-1):
# Calculates the cosine similarity between 2 vectors
x_norm = F.norm(x, axis=axis).reshape(-1, 1)
y_norm = F.norm(y, axis=axis).reshape(-1, 1)
x_dot_y = F.sum(x*y, axis=axis).reshape(-1, 1)
x_dot_y = F.sum(x * y, axis=axis).reshape(-1, 1)
if F is ndarray:
eps_arr = F.array([1e-12])
else:
Expand Down
35 changes: 33 additions & 2 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def test_logistic_loss_equal_bce():
loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
data = mx.random.uniform(-10, 10, shape=(N, 1))
label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy())
assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy())
assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6)
assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6)

@with_seed()
def test_kl_loss():
Expand Down Expand Up @@ -421,6 +421,37 @@ def test_poisson_nllloss_mod():
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

@with_seed()
def test_bce_loss_with_pos_weight():
# Suppose it's a multi-label classification
N = np.random.randint(5, 30)
data = mx.nd.random.uniform(-1, 1, shape=(N, 20))
label = mx.nd.array(np.random.randint(2, size=(N, 5)), dtype='float32')
pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5))
pos_weight = mx.nd.repeat(pos_weight, repeats=N, axis=0)
data_iter = mx.io.NDArrayIter(data, {'label': label, 'pos_w': pos_weight}, batch_size=10, label_name='label')
output = get_net(5)
l = mx.symbol.Variable('label')
pos_w = mx.symbol.Variable('pos_w')
Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
loss = Loss(output, l, None, pos_w)
loss = mx.sym.make_loss(loss)
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label', 'pos_w'))
mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 0.01},
eval_metric=mx.metric.Loss(), optimizer='adam',
initializer=mx.init.Xavier(magnitude=2))
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01
# Test against npy
data = mx.nd.random.uniform(-5, 5, shape=(N, 5))
label = mx.nd.array(np.random.randint(2, size=(N, 5)), dtype='float32')
pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5))
mx_bce_loss = Loss(data, label, None, pos_weight).asnumpy()
prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy()))
label_npy = label.asnumpy()
pos_weight_npy = pos_weight.asnumpy()
npy_bce_loss = (- label_npy * np.log(prob_npy)*pos_weight_npy - (1 - label_npy) * np.log(1 - prob_npy)).mean(axis=1)
assert_almost_equal(mx_bce_loss, npy_bce_loss, rtol=1e-4, atol=1e-5)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit ce9e3cf

Please sign in to comment.