diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 4e2718970e99..12ee720a0ac3 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -217,6 +217,7 @@ List of Contributors * [Ming Yang](http://ufoym.com) * [Satya Krishna Gorti](https://github.com/satyakrishnagorti) * [Neo Chien](https://github.com/cchung100m) +* [Wujie Zhou](https://github.com/eureka7mt) Label Bot --------- diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 29d0105ae8dd..e6d4c5bab852 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -30,6 +30,7 @@ from ..base import numeric_types from .block import HybridBlock + def _apply_weighting(F, loss, weight=None, sample_weight=None): """Apply weighting to loss. @@ -60,10 +61,12 @@ def _apply_weighting(F, loss, weight=None, sample_weight=None): return loss + def _reshape_like(F, x, y): """Reshapes x to the same shape as y.""" return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y) + class Loss(HybridBlock): """Base class for loss. @@ -74,6 +77,7 @@ class Loss(HybridBlock): batch_axis : int, default 0 The axis that represents mini-batch. """ + def __init__(self, weight, batch_axis, **kwargs): super(Loss, self).__init__(**kwargs) self._weight = weight @@ -126,13 +130,14 @@ class L2Loss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, weight=1., batch_axis=0, **kwargs): super(L2Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): label = _reshape_like(F, label, pred) loss = F.square(label - pred) - loss = _apply_weighting(F, loss, self._weight/2, sample_weight) + loss = _apply_weighting(F, loss, self._weight / 2, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) @@ -164,6 +169,7 @@ class L1Loss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, weight=None, batch_axis=0, **kwargs): super(L1Loss, self).__init__(weight, batch_axis, **kwargs) @@ -184,18 +190,22 @@ class SigmoidBinaryCrossEntropyLoss(Loss): prob = \frac{1}{1 + \exp(-{pred})} - L = - \sum_i {label}_i * \log({prob}_i) + + L = - \sum_i {label}_i * \log({prob}_i) * pos\_weight + (1 - {label}_i) * \log(1 - {prob}_i) If `from_sigmoid` is True, this loss computes: .. math:: - L = - \sum_i {label}_i * \log({pred}_i) + + L = - \sum_i {label}_i * \log({pred}_i) * pos\_weight + (1 - {label}_i) * \log(1 - {pred}_i) + A tensor `pos_weight > 1` decreases the false negative count, hence increasing + the recall. + Conversely setting `pos_weight < 1` decreases the false positive count and + increases the precision. - `label` and `pred` can have arbitrary shape as long as they have the same + `pred` and `label` can have arbitrary shape as long as they have the same number of elements. Parameters @@ -218,25 +228,45 @@ class SigmoidBinaryCrossEntropyLoss(Loss): to the same shape as pred. For example, if pred has shape (64, 10) and you want to weigh each sample in the batch separately, sample_weight should have shape (64, 1). + - **pos_weight**: a weighting tensor of positive examples. Must be a vector with length + equal to the number of classes.For example, if pred has shape (64, 10), + pos_weight should have shape (1, 10). Outputs: - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): - super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs) + super(SigmoidBinaryCrossEntropyLoss, self).__init__( + weight, batch_axis, **kwargs) self._from_sigmoid = from_sigmoid - def hybrid_forward(self, F, pred, label, sample_weight=None): + def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): label = _reshape_like(F, label, pred) if not self._from_sigmoid: - # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) - loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu') + if pos_weight is None: + # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) + loss = F.relu(pred) - pred * label + \ + F.Activation(-F.abs(pred), act_type='softrelu') + else: + # We use the stable formula: x - x * z + (1 + z * pos_weight - z) * \ + # (log(1 + exp(-abs(x))) + max(-x, 0)) + log_weight = 1 + F.broadcast_mul(pos_weight - 1, label) + loss = pred - pred * label + log_weight * \ + (F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred)) else: - loss = -(F.log(pred+1e-12)*label + F.log(1.-pred+1e-12)*(1.-label)) + eps = 1e-12 + if pos_weight is None: + loss = -(F.log(pred + eps) * label + + F.log(1. - pred + eps) * (1. - label)) + else: + loss = -(F.broadcast_mul(F.log(pred + eps) * label, pos_weight) + + F.log(1. - pred + eps) * (1. - label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) + SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss @@ -301,9 +331,11 @@ class SoftmaxCrossEntropyLoss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, batch_axis=0, **kwargs): - super(SoftmaxCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs) + super(SoftmaxCrossEntropyLoss, self).__init__( + weight, batch_axis, **kwargs) self._axis = axis self._sparse_label = sparse_label self._from_logits = from_logits @@ -315,10 +347,11 @@ def hybrid_forward(self, F, pred, label, sample_weight=None): loss = -F.pick(pred, label, axis=self._axis, keepdims=True) else: label = _reshape_like(F, label, pred) - loss = -F.sum(pred*label, axis=self._axis, keepdims=True) + loss = -F.sum(pred * label, axis=self._axis, keepdims=True) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) + SoftmaxCELoss = SoftmaxCrossEntropyLoss @@ -382,6 +415,7 @@ class KLDivLoss(Loss): `Kullback-Leibler divergence `_ """ + def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0, **kwargs): super(KLDivLoss, self).__init__(weight, batch_axis, **kwargs) @@ -391,7 +425,7 @@ def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0, def hybrid_forward(self, F, pred, label, sample_weight=None): if not self._from_logits: pred = F.log_softmax(pred, self._axis) - loss = label * (F.log(label+1e-12) - pred) + loss = label * (F.log(label + 1e-12) - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) @@ -453,11 +487,12 @@ class CTCLoss(Loss): Sequence Data with Recurrent Neural Networks `_ """ + def __init__(self, layout='NTC', label_layout='NT', weight=None, **kwargs): assert layout in ['NTC', 'TNC'],\ - "Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s"%layout + "Only 'NTC' and 'TNC' layouts for pred are supported. Got: %s" % layout assert label_layout in ['NT', 'TN'],\ - "Only 'NT' and 'TN' layouts for label are supported. Got: %s"%label_layout + "Only 'NT' and 'TN' layouts for label are supported. Got: %s" % label_layout self._layout = layout self._label_layout = label_layout batch_axis = label_layout.find('N') @@ -512,6 +547,7 @@ class HuberLoss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs): super(HuberLoss, self).__init__(weight, batch_axis, **kwargs) self._rho = rho @@ -520,7 +556,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None): label = _reshape_like(F, label, pred) loss = F.abs(label - pred) loss = F.where(loss > self._rho, loss - 0.5 * self._rho, - (0.5/self._rho) * F.square(loss)) + (0.5 / self._rho) * F.square(loss)) loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) @@ -558,6 +594,7 @@ class HingeLoss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): super(HingeLoss, self).__init__(weight, batch_axis, **kwargs) self._margin = margin @@ -602,6 +639,7 @@ class SquaredHingeLoss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): super(SquaredHingeLoss, self).__init__(weight, batch_axis, **kwargs) self._margin = margin @@ -647,6 +685,7 @@ class LogisticLoss(Loss): - **loss**: loss tensor with shape (batch_size,). Dimenions other than batch_axis are averaged out. """ + def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs): super(LogisticLoss, self).__init__(weight, batch_axis, **kwargs) self._label_format = label_format @@ -659,7 +698,8 @@ def hybrid_forward(self, F, pred, label, sample_weight=None): if self._label_format == 'signed': label = (label + 1.0) / 2.0 # Transform label to be either 0 or 1 # Use a stable formula in computation - loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu') + loss = F.relu(pred) - pred * label + \ + F.Activation(-F.abs(pred), act_type='softrelu') loss = _apply_weighting(F, loss, self._weight, sample_weight) return F.mean(loss, axis=self._batch_axis, exclude=True) @@ -696,6 +736,7 @@ class TripletLoss(Loss): Outputs: - **loss**: loss tensor with shape (batch_size,). """ + def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): super(TripletLoss, self).__init__(weight, batch_axis, **kwargs) self._margin = margin @@ -703,7 +744,7 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): def hybrid_forward(self, F, pred, positive, negative): positive = _reshape_like(F, positive, pred) negative = _reshape_like(F, negative, pred) - loss = F.sum(F.square(positive-pred) - F.square(negative-pred), + loss = F.sum(F.square(positive - pred) - F.square(negative - pred), axis=self._batch_axis, exclude=True) loss = F.relu(loss + self._margin) return _apply_weighting(F, loss, self._weight, None) @@ -748,6 +789,7 @@ class PoissonNLLLoss(Loss): Outputs: - **loss**: Average loss (shape=(1,1)) of the loss tensor with shape (batch_size,). """ + def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=False, **kwargs): super(PoissonNLLLoss, self).__init__(weight, batch_axis, **kwargs) self._from_logits = from_logits @@ -761,7 +803,8 @@ def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08): loss = pred - target * F.log(pred + epsilon) if self._compute_full: # Using numpy's pi value - stirling_factor = target * F.log(target)- target + 0.5 * F.log(2 * target * np.pi) + stirling_factor = target * \ + F.log(target) - target + 0.5 * F.log(2 * target * np.pi) target_gt_1 = target > 1 stirling_factor *= target_gt_1 loss += stirling_factor @@ -804,6 +847,7 @@ class CosineEmbeddingLoss(Loss): Outputs: - **loss**: The loss tensor with shape (batch_size,). """ + def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs): super(CosineEmbeddingLoss, self).__init__(weight, batch_axis, **kwargs) self._margin = margin @@ -820,7 +864,8 @@ def hybrid_forward(self, F, input1, input2, label, sample_weight=None): z_array = F.array([0]) else: z_array = F.zeros((1, 1)) - cos_sim_b = F.broadcast_maximum(z_array, y_minus_1 * (cos_sim - self._margin), axis=1) + cos_sim_b = F.broadcast_maximum( + z_array, y_minus_1 * (cos_sim - self._margin), axis=1) loss = cos_sim_a + cos_sim_b loss = _apply_weighting(F, loss, self._weight, sample_weight) return loss @@ -829,7 +874,7 @@ def _cosine_similarity(self, F, x, y, axis=-1): # Calculates the cosine similarity between 2 vectors x_norm = F.norm(x, axis=axis).reshape(-1, 1) y_norm = F.norm(y, axis=axis).reshape(-1, 1) - x_dot_y = F.sum(x*y, axis=axis).reshape(-1, 1) + x_dot_y = F.sum(x * y, axis=axis).reshape(-1, 1) if F is ndarray: eps_arr = F.array([1e-12]) else: diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index 18d1ebf8fb11..3b9b46b16f93 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -126,8 +126,8 @@ def test_logistic_loss_equal_bce(): loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False) data = mx.random.uniform(-10, 10, shape=(N, 1)) label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1))) - assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy()) - assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy()) + assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6) + assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6) @with_seed() def test_kl_loss(): @@ -421,6 +421,37 @@ def test_poisson_nllloss_mod(): optimizer='adam') assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05 +@with_seed() +def test_bce_loss_with_pos_weight(): + # Suppose it's a multi-label classification + N = np.random.randint(5, 30) + data = mx.nd.random.uniform(-1, 1, shape=(N, 20)) + label = mx.nd.array(np.random.randint(2, size=(N, 5)), dtype='float32') + pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5)) + pos_weight = mx.nd.repeat(pos_weight, repeats=N, axis=0) + data_iter = mx.io.NDArrayIter(data, {'label': label, 'pos_w': pos_weight}, batch_size=10, label_name='label') + output = get_net(5) + l = mx.symbol.Variable('label') + pos_w = mx.symbol.Variable('pos_w') + Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() + loss = Loss(output, l, None, pos_w) + loss = mx.sym.make_loss(loss) + mod = mx.mod.Module(loss, data_names=('data',), label_names=('label', 'pos_w')) + mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 0.01}, + eval_metric=mx.metric.Loss(), optimizer='adam', + initializer=mx.init.Xavier(magnitude=2)) + assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01 + # Test against npy + data = mx.nd.random.uniform(-5, 5, shape=(N, 5)) + label = mx.nd.array(np.random.randint(2, size=(N, 5)), dtype='float32') + pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5)) + mx_bce_loss = Loss(data, label, None, pos_weight).asnumpy() + prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy())) + label_npy = label.asnumpy() + pos_weight_npy = pos_weight.asnumpy() + npy_bce_loss = (- label_npy * np.log(prob_npy)*pos_weight_npy - (1 - label_npy) * np.log(1 - prob_npy)).mean(axis=1) + assert_almost_equal(mx_bce_loss, npy_bce_loss, rtol=1e-4, atol=1e-5) + if __name__ == '__main__': import nose