Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add pos_weight for SigmoidBinaryCrossEntropyLoss #13612

Merged
merged 25 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,20 @@ class SigmoidBinaryCrossEntropyLoss(Loss):

prob = \frac{1}{1 + \exp(-{pred})}

L = - \sum_i {label}_i * \log({prob}_i) +
L = - \sum_i {label}_i * \log({prob}_i) * {pos\_weight}_i +
(1 - {label}_i) * \log(1 - {prob}_i)

If `from_sigmoid` is True, this loss computes:

.. math::

L = - \sum_i {label}_i * \log({pred}_i) +
L = - \sum_i {label}_i * \log({pred}_i) * {pos\_weight}_i +
(1 - {label}_i) * \log(1 - {pred}_i)

A value `pos_weight > 1` decreases the false negative count, hence increasing
the recall.
Conversely setting `pos_weight < 1` decreases the false positive count and
increases the precision.

`pred` and `label` can have arbitrary shape as long as they have the same
number of elements.
Expand All @@ -218,6 +222,9 @@ class SigmoidBinaryCrossEntropyLoss(Loss):
to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).
- **pos_weight**:a weighting tensor of positive examples.Must be a vector with length
equal to the number of classes.For example, if pred has shape (64, 10)
pos_weight should have shape (1, 10).

Outputs:
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
Expand All @@ -227,13 +234,14 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs):
super(SigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_sigmoid = from_sigmoid

def hybrid_forward(self, F, pred, label, sample_weight=None):
def hybrid_forward(self, F, pred, label, pos_weight=1, sample_weight=None):
label = _reshape_like(F, label, pred)
if not self._from_sigmoid:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
wkcn marked this conversation as resolved.
Show resolved Hide resolved
loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
log_weight = 1 + (pos_weight - 1) * label
loss = pred - pred*label + log_weight*(F.Activation(-F.abs(pred), act_type='softrelu') + F.relu(-pred))
wkcn marked this conversation as resolved.
Show resolved Hide resolved
else:
loss = -(F.log(pred+1e-12)*label + F.log(1.-pred+1e-12)*(1.-label))
loss = -(F.log(pred+1e-12)*label*pos_weight + F.log(1.-pred+1e-12)*(1.-label))
wkcn marked this conversation as resolved.
Show resolved Hide resolved
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)

Expand Down
35 changes: 33 additions & 2 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def test_logistic_loss_equal_bce():
loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
data = mx.random.uniform(-10, 10, shape=(N, 1))
label = mx.nd.round(mx.random.uniform(0, 1, shape=(N, 1)))
assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy())
assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy())
assert_almost_equal(loss_binary(data, label).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6)
assert_almost_equal(loss_signed(data, 2 * label - 1).asnumpy(), loss_bce(data, label).asnumpy(), atol=1e-6)

@with_seed()
def test_kl_loss():
Expand Down Expand Up @@ -420,6 +420,37 @@ def test_poisson_nllloss_mod():
initializer=mx.init.Normal(sigma=0.1), eval_metric=mx.metric.Loss(),
optimizer='adam')
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.05

@with_seed()
def test_bce_loss_with_pos_weight():
#Suppose it's a multi-label classification
N = 20
wkcn marked this conversation as resolved.
Show resolved Hide resolved
data = mx.nd.random.uniform(-1, 1, shape=(N, 20))
label = mx.nd.array(np.random.randint(2, size=(N, 5)), dtype='float32')
pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5))
pos_weight = mx.nd.repeat(pos_weight, repeats=N, axis=0)
data_iter = mx.io.NDArrayIter(data, {'label': label, 'pos_w': pos_weight}, batch_size=10, label_name='label')
output = get_net(5)
l = mx.symbol.Variable('label')
pos_w = mx.symbol.Variable('pos_w')
Loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()
loss = Loss(output, l, pos_w)
loss = mx.sym.make_loss(loss)
mod = mx.mod.Module(loss, data_names=('data',), label_names=('label', 'pos_w'))
mod.fit(data_iter, num_epoch=200, optimizer_params={'learning_rate': 0.01},
eval_metric=mx.metric.Loss(), optimizer='adam',
initializer=mx.init.Xavier(magnitude=2))
assert mod.score(data_iter, eval_metric=mx.metric.Loss())[0][1] < 0.01
# Test against npy
data = mx.nd.random.uniform(-5, 5, shape=(10, 5))
label =mx.nd.array(np.random.randint(2, size=(10, 5)), dtype='float32')
pos_weight = mx.nd.random.uniform(0, 10, shape=(1, 5))
mx_bce_loss = Loss(data, label, pos_weight).asnumpy()
prob_npy = 1.0 / (1.0 + np.exp(-data.asnumpy()))
label_npy = label.asnumpy()
pos_weight_npy = pos_weight.asnumpy()
npy_bce_loss = (- label_npy * np.log(prob_npy)*pos_weight_npy - (1 - label_npy) * np.log(1 - prob_npy)).mean(axis=1)
assert_almost_equal(mx_bce_loss, npy_bce_loss, rtol=1e-4, atol=1e-5)


if __name__ == '__main__':
Expand Down