Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy][Bugfix] Add hybridization test to loss layers #18876

Merged
merged 7 commits into from
Aug 8, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions python/mxnet/gluon/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,28 @@ def _reshape_like(F, x, y):
def _batch_mean(F, loss, batch_axis):
"""Return mean on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.mean(loss, axis=axes)
if F is ndarray:
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.mean(loss, axis=axes)
else:
assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
'flag in mean. So we only support batch_axis=0.'
return F.npx.batch_flatten(loss).mean(axis=1)
else:
return F.mean(loss, axis=batch_axis, exclude=True)

def _batch_sum(F, loss, batch_axis):
"""Return sum on the specified batch axis, not keeping the axis"""
if is_np_array():
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.sum(loss, axis=axes)
if F is ndarray:
axes = list(range(loss.ndim))
del axes[batch_axis]
return F.np.sum(loss, axis=axes)
else:
assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \
'flag in mean. So we only support batch_axis=0.'
return F.npx.batch_flatten(loss).sum(axis=1)
else:
return F.sum(loss, axis=batch_axis, exclude=True)

Expand Down
94 changes: 87 additions & 7 deletions tests/python/unittest/test_numpy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,64 +20,103 @@
from mxnet import gluon, autograd
from mxnet.test_utils import assert_almost_equal, default_context, use_np
from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator
import unittest
import pytest


@xfail_when_nonstandard_decimal_separator
@with_seed()
@use_np
def test_loss_np_ndarray():
@pytest.mark.parameterize(hybridize, [False, True])
def test_loss_np_ndarray(hybridize):
output = mx.np.array([1, 2, 3, 4])
label = mx.np.array([1, 3, 5, 7])
weighting = mx.np.array([0.5, 1, 0.5, 1])

loss = gluon.loss.L1Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 6.
loss = gluon.loss.L1Loss(weight=0.5)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 3.
loss = gluon.loss.L1Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 5.

loss = gluon.loss.L2Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 7.
loss = gluon.loss.L2Loss(weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 1.75
loss = gluon.loss.L2Loss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 6

loss = gluon.loss.HuberLoss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 4.5
loss = gluon.loss.HuberLoss(weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 1.125
loss = gluon.loss.HuberLoss()
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 3.75

loss = gluon.loss.HingeLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 13.
loss = gluon.loss.HingeLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 2.25
loss = gluon.loss.HingeLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 4.

loss = gluon.loss.SquaredHingeLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 97.
loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label)) == 13.25
loss = gluon.loss.SquaredHingeLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, weighting)) == 19.

loss = gluon.loss.TripletLoss(margin=10)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label)) == 6.
loss = gluon.loss.TripletLoss(margin=8, weight=0.25)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label)) == 1.
loss = gluon.loss.TripletLoss(margin=7)
if hybridize:
loss.hybridize()
assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5

output = mx.np.array([[0, 2], [1, 4]])
label = mx.np.array([0, 1])
weighting = mx.np.array([[0.5], [1.0]])

loss = gluon.loss.SoftmaxCrossEntropyLoss()
if hybridize:
loss.hybridize()
L = loss(output, label).asnumpy()
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4)

Expand All @@ -87,21 +126,34 @@ def test_loss_np_ndarray():

@with_seed()
@use_np
def test_bce_equal_ce2():
@pytest.mark.parameterize(hybridize, [False, True])
def test_bce_equal_ce2(hybridize):
N = 100
loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
if hybridize:
loss1.hybridize()
loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
if hybridize:
loss1.hybridize()
out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1))
out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8)
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())


@use_np
def test_logistic_loss_equal_bce():
@pytest.mark.parameterize(hybridize, [False, True])
def test_logistic_loss_equal_bce(hybridize):
N = 100
loss_binary = gluon.loss.LogisticLoss(label_format='binary')
if hybridize:
loss_binary.hybridize()
loss_signed = gluon.loss.LogisticLoss(label_format='signed')
if hybridize:
loss_signed.hybridize()
loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
if hybridize:
loss_bce.hybridize()
data = mx.np.random.uniform(-10, 10, size=(N, 1))
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6)
Expand All @@ -110,35 +162,49 @@ def test_logistic_loss_equal_bce():

@with_seed()
@use_np
def test_ctc_loss():
@pytest.mark.parameterize(hybridize, [False, True])
def test_ctc_loss(hybridize):
loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC')
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN')
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T)
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
if hybridize:
loss.hybridize()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))


@xfail_when_nonstandard_decimal_separator
@with_seed()
@use_np
@pytest.mark.parameterize(hybridize, [False, True])
def test_sdml_loss():

N = 5 # number of samples
Expand All @@ -152,6 +218,8 @@ def test_sdml_loss():

# Init model and trainer
sdml_loss = gluon.loss.SDMLLoss()
if hybridize:
sdml_loss.hybridize()
model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder
model.initialize(mx.init.Xavier(), ctx=mx.current_context())
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1})
Expand All @@ -171,15 +239,19 @@ def test_sdml_loss():
avg_loss = loss.sum()/len(loss)
assert(avg_loss < 0.05)


@with_seed()
@use_np
def test_cosine_loss():
@pytest.mark.parameterize(hybridize, [False, True])
def test_cosine_loss(hybridize):
#Generating samples
input1 = mx.np.random.randn(3, 2)
input2 = mx.np.random.randn(3, 2)
label = mx.np.sign(mx.np.random.randn(input1.shape[0]))
#Calculating loss from cosine embedding loss function in Gluon
Loss = gluon.loss.CosineEmbeddingLoss()
if hybridize:
Loss.hybridize()
loss = Loss(input1, input2, label)

# Calculating the loss Numpy way
Expand All @@ -192,9 +264,11 @@ def test_cosine_loss():
mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,))
assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)


@xfail_when_nonstandard_decimal_separator
@use_np
def test_poisson_nllloss():
@pytest.mark.parameterize(hybridize, [False, True])
def test_poisson_nllloss(hybridize):
shape=(3, 4)
not_axis0 = tuple(range(1, len(shape)))
pred = mx.np.random.normal(size=shape)
Expand All @@ -209,7 +283,11 @@ def test_poisson_nllloss():
target[:] += mx.np.abs(min_target)

Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
if hybridize:
Loss.hybridize()
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
if hybridize:
Loss_no_logits.hybridize()
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
Expand All @@ -230,6 +308,8 @@ def test_poisson_nllloss():
np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
if hybridize:
Loss_compute_full.hybridize()
loss_compute_full = Loss_compute_full(np_pred, np_target)
assert_almost_equal(np_compute_full, loss_compute_full)