diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 40fdaa51573e..44aabe5071db 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -75,6 +75,26 @@ def _reshape_like(F, x, y): return F.reshape_like(x, y) +def _batch_mean(F, loss, batch_axis): + """Return mean on the specified batch axis, not keeping the axis""" + if is_np_array(): + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.mean(loss, axis=axes) + else: + return F.mean(loss, axis=batch_axis, exclude=True) + +def _batch_sum(F, loss, batch_axis): + """Return sum on the specified batch axis, not keeping the axis""" + if is_np_array(): + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.sum(loss, axis=axes) + else: + return F.sum(loss, axis=batch_axis, exclude=True) + + + class Loss(HybridBlock): """Base class for loss. @@ -143,16 +163,11 @@ def __init__(self, weight=1., batch_axis=0, **kwargs): super(L2Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): + square_fn = F.np.square if is_np_array() else F.square label = _reshape_like(F, label, pred) - loss = F.np.square(label - pred) if is_np_array() else F.square(label - pred) + loss = square_fn(label - pred) loss = _apply_weighting(F, loss, self._weight / 2, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class L1Loss(Loss): @@ -188,16 +203,11 @@ def __init__(self, weight=None, batch_axis=0, **kwargs): super(L1Loss, self).__init__(weight, batch_axis, **kwargs) def hybrid_forward(self, F, pred, label, sample_weight=None): + abs_fn = F.np.abs if is_np_array() else F.abs label = _reshape_like(F, label, pred) - loss = F.np.abs(label - pred) if is_np_array() else F.abs(label - pred) + loss = abs_fn(label - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class SigmoidBinaryCrossEntropyLoss(Loss): @@ -263,7 +273,6 @@ def __init__(self, from_sigmoid=False, weight=None, batch_axis=0, **kwargs): self._from_sigmoid = from_sigmoid def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): - label = _reshape_like(F, label, pred) if is_np_array(): relu_fn = F.npx.relu act_fn = F.npx.activation @@ -276,6 +285,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): abs_fn = F.abs mul_fn = F.broadcast_mul log_fn = F.log + label = _reshape_like(F, label, pred) if not self._from_sigmoid: if pos_weight is None: # We use the stable formula: max(x, 0) - x * z + log(1 + exp(-abs(x))) @@ -296,13 +306,7 @@ def hybrid_forward(self, F, pred, label, sample_weight=None, pos_weight=None): loss = -(mul_fn(log_fn(pred + eps) * label, pos_weight) + log_fn(1. - pred + eps) * (1. - label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return F.np.mean(loss, axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) SigmoidBCELoss = SigmoidBinaryCrossEntropyLoss @@ -380,26 +384,20 @@ def __init__(self, axis=-1, sparse_label=True, from_logits=False, weight=None, def hybrid_forward(self, F, pred, label, sample_weight=None): if is_np_array(): - log_softmax = F.npx.log_softmax - pick = F.npx.pick + log_softmax_fn = F.npx.log_softmax + pick_fn = F.npx.pick else: - log_softmax = F.log_softmax - pick = F.pick + log_softmax_fn = F.log_softmax + pick_fn = F.pick if not self._from_logits: - pred = log_softmax(pred, self._axis) + pred = log_softmax_fn(pred, self._axis) if self._sparse_label: - loss = -pick(pred, label, axis=self._axis, keepdims=True) + loss = -pick_fn(pred, label, axis=self._axis, keepdims=True) else: label = _reshape_like(F, label, pred) loss = -(pred * label).sum(axis=self._axis, keepdims=True) loss = _apply_weighting(F, loss, self._weight, sample_weight) - if is_np_array(): - if F is ndarray: - return loss.mean(axis=tuple(range(1, loss.ndim))) - else: - return F.npx.batch_flatten(loss).mean(axis=1) - else: - return loss.mean(axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) SoftmaxCELoss = SoftmaxCrossEntropyLoss @@ -473,11 +471,17 @@ def __init__(self, from_logits=True, axis=-1, weight=None, batch_axis=0, self._axis = axis def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + log_softmax_fn = F.npx.log_softmax + log_fn = F.np.log + else: + log_softmax_fn = F.log_softmax + log_fn = F.log if not self._from_logits: - pred = F.log_softmax(pred, self._axis) - loss = label * (F.log(label + 1e-12) - pred) + pred = log_softmax_fn(pred, self._axis) + loss = label * (log_fn(label + 1e-12) - pred) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class CTCLoss(Loss): @@ -603,12 +607,18 @@ def __init__(self, rho=1, weight=None, batch_axis=0, **kwargs): self._rho = rho def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + abs_fn = F.np.abs + where_fn = F.np.where + else: + abs_fn = F.abs + where_fn = F.where label = _reshape_like(F, label, pred) - loss = F.abs(label - pred) - loss = F.where(loss > self._rho, loss - 0.5 * self._rho, - (0.5 / self._rho) * F.square(loss)) + loss = abs_fn(label - pred) + loss = where_fn(loss > self._rho, loss - 0.5 * self._rho, + (0.5 / self._rho) * F.square(loss)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class HingeLoss(Loss): @@ -650,10 +660,11 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): self._margin = margin def hybrid_forward(self, F, pred, label, sample_weight=None): + relu_fn = F.np.relu if is_np_array() else F.relu label = _reshape_like(F, label, pred) - loss = F.relu(self._margin - pred * label) + loss = relu_fn(self._margin - pred * label) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class SquaredHingeLoss(Loss): @@ -695,10 +706,16 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): self._margin = margin def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + relu_fn = F.np.relu + square_fn = F.np.square + else: + relu_fn = F.relu + square_fn = F.square label = _reshape_like(F, label, pred) - loss = F.square(F.relu(self._margin - pred * label)) + loss = square_fn(relu_fn(self._margin - pred * label)) loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class LogisticLoss(Loss): @@ -744,14 +761,22 @@ def __init__(self, weight=None, batch_axis=0, label_format='signed', **kwargs): % label_format) def hybrid_forward(self, F, pred, label, sample_weight=None): + if is_np_array(): + relu_fn = F.npx.relu + act_fn = F.npx.activation + abs_fn = F.np.abs + else: + relu_fn = F.relu + act_fn = F.Activation + abs_fn = F.abs label = _reshape_like(F, label, pred) if self._label_format == 'signed': label = (label + 1.0) / 2.0 # Transform label to be either 0 or 1 # Use a stable formula in computation - loss = F.relu(pred) - pred * label + \ - F.Activation(-F.abs(pred), act_type='softrelu') + loss = relu_fn(pred) - pred * label + \ + act_fn(-abs_fn(pred), act_type='softrelu') loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss, axis=self._batch_axis, exclude=True) + return _batch_mean(F, loss, self._batch_axis) class TripletLoss(Loss): @@ -792,11 +817,16 @@ def __init__(self, margin=1, weight=None, batch_axis=0, **kwargs): self._margin = margin def hybrid_forward(self, F, pred, positive, negative): + if is_np_array(): + relu_fn = F.npx.relu + square_fn = F.np.square + else: + relu_fn = F.relu + square_fn = F.square positive = _reshape_like(F, positive, pred) negative = _reshape_like(F, negative, pred) - loss = F.sum(F.square(positive - pred) - F.square(negative - pred), - axis=self._batch_axis, exclude=True) - loss = F.relu(loss + self._margin) + loss = _batch_sum(F, square_fn(positive - pred) - square_fn(negative - pred), self._batch_axis) + loss = relu_fn(loss + self._margin) return _apply_weighting(F, loss, self._weight, None) @@ -846,20 +876,26 @@ def __init__(self, weight=None, from_logits=True, batch_axis=0, compute_full=Fal self._compute_full = compute_full def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08): + if is_np_array(): + exp_fn = F.np.exp + log_fn = F.np.log + else: + exp_fn = F.exp + log_fn = F.log target = _reshape_like(F, target, pred) if self._from_logits: - loss = F.exp(pred) - target * pred + loss = exp_fn(pred) - target * pred else: - loss = pred - target * F.log(pred + epsilon) + loss = pred - target * log_fn(pred + epsilon) if self._compute_full: # Using numpy's pi value stirling_factor = target * \ - F.log(target) - target + 0.5 * F.log(2 * target * np.pi) + log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi) target_gt_1 = target > 1 stirling_factor *= target_gt_1 loss += stirling_factor loss = _apply_weighting(F, loss, self._weight, sample_weight) - return F.mean(loss) + return _batch_mean(F, loss, self._batch_axis) class CosineEmbeddingLoss(Loss): @@ -903,30 +939,39 @@ def __init__(self, weight=None, batch_axis=0, margin=0, **kwargs): self._margin = margin def hybrid_forward(self, F, input1, input2, label, sample_weight=None): + if is_np_array(): + where_fn = F.np.where + clip_fn = F.np.clip + else: + where_fn = F.where + clip_fn = F.clip + input1 = _reshape_like(F, input1, input2) - label = label.reshape((-1, 1)) cos_sim = self._cosine_similarity(F, input1, input2) - y_1 = label == 1 - y_minus_1 = label == -1 - cos_sim_a = (1 - cos_sim) * y_1 + label = _reshape_like(F, label, cos_sim) + loss = where_fn(label == 1, + 1 - cos_sim, + clip_fn(cos_sim - self._margin, 0, 1 - self._margin)) - if F is ndarray: - z_array = F.array([0]) - else: - z_array = F.zeros((1, 1)) - cos_sim_b = F.broadcast_maximum( - z_array, y_minus_1 * (cos_sim - self._margin), axis=1) - loss = cos_sim_a + cos_sim_b loss = _apply_weighting(F, loss, self._weight, sample_weight) - return loss + return _batch_mean(F, loss, self._batch_axis) - def _cosine_similarity(self, F, x, y, axis=-1): - # Calculates the cosine similarity between 2 vectors - x_norm = F.norm(x, axis=axis).reshape(-1, 1) - y_norm = F.norm(y, axis=axis).reshape(-1, 1) - x_dot_y = F.sum(x * y, axis=axis).reshape(-1, 1) - if F is ndarray: - eps_arr = F.array([1e-12]) + def _cosine_similarity(self, F, x, y): + if is_np_array(): + reshape_fn = F.npx.reshape + norm_fn = F.npx.norm + sum_fn = F.np.sum + full_fn = F.np.full + max_fn = F.np.maximum else: - eps_arr = F.full((1, 1), 1e-12) - return (x_dot_y / F.broadcast_maximum(x_norm * y_norm, eps_arr)) + reshape_fn = F.reshape + norm_fn = F.norm + sum_fn = F.sum + full_fn = F.full + max_fn = F.broadcast_maximum + # Calculates the cosine similarity between 2 vectors + x_norm = reshape_fn(norm_fn(x, axis=-1), (-1, 1)) + y_norm = reshape_fn(norm_fn(y, axis=-1), (-1, 1)) + x_dot_y = reshape_fn(sum_fn(x * y, axis=-1), (-1, 1)) + eps_arr = full_fn((1, 1), 1e-12) + return (x_dot_y / max_fn(x_norm * y_norm, eps_arr)) diff --git a/src/operator/tensor/broadcast_reduce_norm_value.cc b/src/operator/tensor/broadcast_reduce_norm_value.cc index 9acc157f8eca..dfc6317a12d4 100644 --- a/src/operator/tensor/broadcast_reduce_norm_value.cc +++ b/src/operator/tensor/broadcast_reduce_norm_value.cc @@ -87,6 +87,7 @@ Examples:: norm(csr) = [5.47722578] )code" ADD_FILELINE) +.add_alias("_npx_norm") .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) diff --git a/tests/python/unittest/test_loss.py b/tests/python/unittest/test_loss.py index af7c4838d5a9..a8cc90ed2171 100644 --- a/tests/python/unittest/test_loss.py +++ b/tests/python/unittest/test_loss.py @@ -363,7 +363,7 @@ def test_cosine_loss(): denominator = mx.nd.sqrt(mx.nd.sum(input1**2, axis=1, keepdims=True)) \ * mx.nd.sqrt(mx.nd.sum(input2**2, axis=1, keepdims=True)) numpy_loss = mx.nd.where(label == 1, 1-numerator/denominator, \ - mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)) + mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)).reshape((-1,)) assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5) def test_poisson_nllloss(): @@ -385,27 +385,25 @@ def test_poisson_nllloss(): #Calculating by brute formula for default value of from_logits = True # 1) Testing for flag logits = True - brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy()) + brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy(), axis=1) loss_withlogits = Loss(pred, target) - assert_almost_equal(brute_loss, loss_withlogits.asscalar()) + assert_almost_equal(brute_loss, loss_withlogits) #2) Testing for flag logits = False loss_no_logits = Loss_no_logits(pred, target) - np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08)) - if np.isnan(loss_no_logits.asscalar()): - assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar())) - else: - assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar()) + np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08), + axis=1) + assert_almost_equal(np_loss_no_logits, loss_no_logits.asnumpy()) #3) Testing for Sterling approximation shape=(2, 3) np_pred = np.random.uniform(1, 5, shape) np_target = np.random.uniform(1, 5, shape) np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ - np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1))) + np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1) Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target)) - assert_almost_equal(np_compute_full, loss_compute_full.asscalar()) + assert_almost_equal(np_compute_full, loss_compute_full) @with_seed() def test_poisson_nllloss_mod():