-
Notifications
You must be signed in to change notification settings - Fork 6
/
loss.py
136 lines (108 loc) · 4.65 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
""" losses for training neural networks """
from mxnet import ndarray
from mxnet.base import numeric_types
from mxnet.gluon.block import HybridBlock
def _apply_weighting(F, loss, weight=None, sample_weight=None):
"""Apply weighting to loss.
Parameters
----------
loss : Symbol
The loss to be weighted.
weight : float or None
Global scalar weight for loss.
sample_weight : Symbol or None
Per sample weighting. Must be broadcastable to
the same shape as loss. For example, if loss has
shape (64, 10) and you want to weight each sample
in the batch separately, `sample_weight` should have
shape (64, 1).
Returns
-------
loss : Symbol
Weighted loss
"""
if sample_weight is not None:
loss = F.broadcast_mul(loss, sample_weight)
if weight is not None:
assert isinstance(weight, numeric_types), "weight must be a number"
loss = loss * weight
return loss
def _reshape_like(F, x, y):
"""Reshapes x to the same shape as y."""
return x.reshape(y.shape) if F is ndarray else F.reshape_like(x, y)
class Loss(HybridBlock):
"""Base class for loss.
Parameters
----------
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
"""
def __init__(self, weight, batch_axis, **kwargs):
super(Loss, self).__init__(**kwargs)
self._weight = weight
self._batch_axis = batch_axis
def __repr__(self):
s = '{name}(batch_axis={_batch_axis}, w={_weight})'
return s.format(name=self.__class__.__name__, **self.__dict__)
def hybrid_forward(self, F, x, *args, **kwargs):
"""Overrides to construct symbolic graph for this `Block`.
Parameters
----------
x : Symbol or NDArray
The first input tensor.
*args : list of Symbol or list of NDArray
Additional input tensors.
"""
# pylint: disable= invalid-name
raise NotImplementedError
class wSigmoidBinaryCrossEntropyLoss(Loss):
r"""The weighted cross-entropy loss for binary classification. (alias: SigmoidBCELoss)
BCE loss is useful when training logistic regression. If `from_sigmoid`
is False (default), this loss computes:
.. math::
prob = \frac{1}{1 + \exp(-{pred})}
L = - \sum_i {label}_i * \log({prob}_i) +
(1 - {label}_i) * \log(1 - {prob}_i)
If `from_sigmoid` is True, this loss computes:
.. math::
L = - \sum_i {label}_i * \log({pred}_i)*w +
(1 - {label}_i) * \log(1 - {pred}_i)*(1-w)
`pred` and `label` can have arbitrary shape as long as they have the same
number of elements.
Parameters
----------
from_sigmoid : bool, default is `False`
Whether the input is from the output of sigmoid. Set this to false will make
the loss calculate sigmoid and BCE together, which is more numerically
stable through log-sum-exp trick.
weight : float or None
Global scalar weight for loss.
batch_axis : int, default 0
The axis that represents mini-batch.
Inputs:
- **pred**: prediction tensor with arbitrary shape
- **label**: target tensor with values in range `[0, 1]`. Must have the
same size as `pred`.
- **weight**:w
- **sample_weight**: element-wise weighting tensor. Must be broadcastable
to the same shape as pred. For example, if pred has shape (64, 10)
and you want to weigh each sample in the batch separately,
sample_weight should have shape (64, 1).
Outputs:
- **loss**: loss tensor with shape (batch_size,). Dimenions other than
batch_axis are averaged out.
"""
def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs):
super(wSigmoidBinaryCrossEntropyLoss, self).__init__(weight, batch_axis, **kwargs)
self._from_sigmoid = from_sigmoid
def hybrid_forward(self, F, pred, label, w, sample_weight=None):
label = _reshape_like(F, label, pred)
if not self._from_sigmoid:
# We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x)))
loss = F.relu(pred) - pred * label + F.Activation(-F.abs(pred), act_type='softrelu')
else:
loss = -(F.log(pred+1e-12)*label*w + F.log(1.-pred+1e-12)*(1.-label)*(1.-w))
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return F.mean(loss, axis=self._batch_axis, exclude=True)