Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy][Bugfix] Add hybridization test to loss layers (#18876)
Browse files Browse the repository at this point in the history
* Test for hybridization

* fix typo

* fix

* fix test

* update

* Update loss.py

* fix bug of sum
  • Loading branch information
sxjscience committed Aug 8, 2020
1 parent d5fdcbf commit cf908fd
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 27 deletions.
45 changes: 26 additions & 19 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,28 @@ def _reshape_like(F, x, y):
def _batch_mean(F, loss, batch_axis):
"""Return mean on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.mean(loss, axis=axes)
if F is ndarray:
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.mean(loss, axis=axes)
else:
assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
'flag in mean. So we only support batch_axis=0.'
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return F.mean(loss, axis=batch_axis, exclude=True)

def _batch_sum(F, loss, batch_axis):
"""Return sum on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.sum(loss, axis=axes)
if F is ndarray:
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.sum(loss, axis=axes)
else:
assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
'flag in mean. So we only support batch_axis=0.'
return F.npx.batch_flatten(loss).sum(axis=1)
else:
return F.sum(loss, axis=batch_axis, exclude=True)

Expand Down Expand Up @@ -899,8 +909,8 @@ def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08):
stirling_factor = target * \
log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi)
target_gt_1 = target > 1
stirling_factor *= target_gt_1
loss += stirling_factor
stirling_factor = stirling_factor * target_gt_1
loss = loss + stirling_factor
loss = _apply_weighting(F, loss, self._weight, sample_weight)
return _batch_mean(F, loss, self._batch_axis)

Expand Down Expand Up @@ -1023,7 +1033,8 @@ class SDMLLoss(Loss):
def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, **kwargs):
super(SDMLLoss, self).__init__(weight, batch_axis, **kwargs)
self.kl_loss = KLDivLoss(from_logits=True)
self.smoothing_parameter = smoothing_parameter # Smoothing probability mass
# Smoothing probability mass
self.smoothing_parameter = smoothing_parameter

def _compute_distances(self, F, x1, x2):
"""
Expand All @@ -1032,17 +1043,13 @@ def _compute_distances(self, F, x1, x2):
"""
if is_np_array():
expand_dims_fn = F.np.expand_dims
broadcast_to_fn = F.np.broadcast_to
else:
expand_dims_fn = F.expand_dims
broadcast_to_fn = F.broadcast_to

# extracting sizes expecting [batch_size, dim]
assert x1.shape == x2.shape
batch_size, dim = x1.shape
# expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim]
x1_ = broadcast_to_fn(expand_dims_fn(x1, 1), [batch_size, batch_size, dim])
x2_ = broadcast_to_fn(expand_dims_fn(x2, 0), [batch_size, batch_size, dim])

# expanding x1 form [batch_size, dim] to [batch_size, 1, dim]
# and x2 to [1, batch_size, dim]
x1_ = expand_dims_fn(x1, 1)
x2_ = expand_dims_fn(x2, 0)
# pointwise squared differences
squared_diffs = (x1_ - x2_)**2
# sum of squared differences distance
Expand Down Expand Up @@ -1073,7 +1080,6 @@ def _compute_labels(self, F, batch_size):
labels = gold * (1 - self.smoothing_parameter) + (1 - gold) * self.smoothing_parameter / (batch_size - 1)
return labels


def hybrid_forward(self, F, x1, x2):
"""
the function computes the kl divergence between the negative distances
Expand All @@ -1092,6 +1098,7 @@ def hybrid_forward(self, F, x1, x2):
learn to predict french president comparing it with all the other
vectors in batch 2
"""
assert F is ndarray, 'SDMLLoss does not support symbolic '
if is_np_array():
log_softmax_fn = F.npx.log_softmax
else:
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8012,7 +8012,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1):

# pylint:disable=redefined-outer-name, too-many-arguments
@set_module('mxnet.symbol.numpy')
def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None):
def sum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None):
r"""
Sum of array elements over a given axis.
Expand Down
94 changes: 87 additions & 7 deletions tests/python/unittest/test_numpy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,88 +20,143 @@
from mxnet import gluon, autograd
from mxnet.test_utils import assert_almost_equal, default_context, use_np
from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator
import unittest
import pytest


@xfail_when_nonstandard_decimal_separator
@with_seed()
@use_np
def test_loss_np_ndarray():
@pytest.mark.parametrize("hybridize", [False, True])
def test_loss_np_ndarray(hybridize):
output = mx.np.array([1, 2, 3, 4])
label = mx.np.array([1, 3, 5, 7])
weighting = mx.np.array([0.5, 1, 0.5, 1])

loss = gluon.loss.L1Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 6.
loss = gluon.loss.L1Loss(weight=0.5)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 3.
loss = gluon.loss.L1Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 5.

loss = gluon.loss.L2Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 7.
loss = gluon.loss.L2Loss(weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 1.75
loss = gluon.loss.L2Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 6

loss = gluon.loss.HuberLoss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 4.5
loss = gluon.loss.HuberLoss(weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 1.125
loss = gluon.loss.HuberLoss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 3.75

loss = gluon.loss.HingeLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 13.
loss = gluon.loss.HingeLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 2.25
loss = gluon.loss.HingeLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 4.

loss = gluon.loss.SquaredHingeLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 97.
loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 13.25
loss = gluon.loss.SquaredHingeLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 19.

loss = gluon.loss.TripletLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label)) == 6.
loss = gluon.loss.TripletLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label)) == 1.
loss = gluon.loss.TripletLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5

output = mx.np.array([[0, 2], [1, 4]])
label = mx.np.array([0, 1])
weighting = mx.np.array([[0.5], [1.0]])

loss = gluon.loss.SoftmaxCrossEntropyLoss()
if hybridize:
loss.hybridize()
L = loss(output, label).asnumpy()
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4)

loss = gluon.loss.SoftmaxCrossEntropyLoss()
if hybridize:
loss.hybridize()
L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4)


@with_seed()
@use_np
def test_bce_equal_ce2():
@pytest.mark.parametrize("hybridize", [False, True])
def test_bce_equal_ce2(hybridize):
N = 100
loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
if hybridize:
loss1.hybridize()
loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
if hybridize:
loss2.hybridize()
out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1))
out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8)
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())


@use_np
def test_logistic_loss_equal_bce():
@pytest.mark.parametrize("hybridize", [False, True])
def test_logistic_loss_equal_bce(hybridize):
N = 100
loss_binary = gluon.loss.LogisticLoss(label_format='binary')
if hybridize:
loss_binary.hybridize()
loss_signed = gluon.loss.LogisticLoss(label_format='signed')
if hybridize:
loss_signed.hybridize()
loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
if hybridize:
loss_bce.hybridize()
data = mx.np.random.uniform(-10, 10, size=(N, 1))
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6)
Expand All @@ -110,28 +165,41 @@ def test_logistic_loss_equal_bce():

@with_seed()
@use_np
def test_ctc_loss():
@pytest.mark.parametrize("hybridize", [False, True])
def test_ctc_loss(hybridize):
loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC')
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN')
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T)
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

Expand Down Expand Up @@ -171,15 +239,19 @@ def test_sdml_loss():
avg_loss = loss.sum()/len(loss)
assert(avg_loss < 0.05)


@with_seed()
@use_np
def test_cosine_loss():
@pytest.mark.parametrize("hybridize", [False, True])
def test_cosine_loss(hybridize):
#Generating samples
input1 = mx.np.random.randn(3, 2)
input2 = mx.np.random.randn(3, 2)
label = mx.np.sign(mx.np.random.randn(input1.shape[0]))
#Calculating loss from cosine embedding loss function in Gluon
Loss = gluon.loss.CosineEmbeddingLoss()
if hybridize:
Loss.hybridize()
loss = Loss(input1, input2, label)

# Calculating the loss Numpy way
Expand All @@ -192,9 +264,11 @@ def test_cosine_loss():
mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,))
assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)


@xfail_when_nonstandard_decimal_separator
@use_np
def test_poisson_nllloss():
@pytest.mark.parametrize("hybridize", [False, True])
def test_poisson_nllloss(hybridize):
shape=(3, 4)
not_axis0 = tuple(range(1, len(shape)))
pred = mx.np.random.normal(size=shape)
Expand All @@ -209,7 +283,11 @@ def test_poisson_nllloss():
target[:] += mx.np.abs(min_target)

Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
if hybridize:
Loss.hybridize()
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
if hybridize:
Loss_no_logits.hybridize()
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
Expand All @@ -230,6 +308,8 @@ def test_poisson_nllloss():
np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
if hybridize:
Loss_compute_full.hybridize()
loss_compute_full = Loss_compute_full(np_pred, np_target)
assert_almost_equal(np_compute_full, loss_compute_full)

0 comments on commit cf908fd

Please sign in to comment.