Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[NumPy] loss for np array #17196

Merged
merged 2 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 155 additions & 96 deletions python/mxnet/gluon/loss.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/operator/nn/ctc_loss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ DMLC_REGISTER_PARAMETER(CTCLossOpParam);

NNVM_REGISTER_OP(CTCLoss)
.add_alias("ctc_loss")
.add_alias("_npx_ctc_loss")
.add_alias("_contrib_CTCLoss")
.add_alias("_contrib_ctc_loss")
.describe(R"code(Connectionist Temporal Classification Loss.
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/ctc_loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace op {

NNVM_REGISTER_OP(CTCLoss)
.add_alias("ctc_loss")
.add_alias("_npx_ctc_loss")
.add_alias("_contrib_ctc_loss")
.add_alias("_contrib_CTCLoss")
.set_attr<FCompute>("FCompute<gpu>", CTCLossOpForward<gpu>);
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/broadcast_reduce_norm_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ Examples::
norm(csr) = [5.47722578]
)code" ADD_FILELINE)
.add_alias("_npx_norm")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NormParam>)
Expand Down
1 change: 1 addition & 0 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from common import setup_module, with_seed, teardown_module, assert_raises_cudnn_not_satisfied, run_in_spawned_process
from test_gluon import *
from test_loss import *
from test_numpy_loss import *
from test_gluon_rnn import *

set_default_context(mx.gpu(0))
Expand Down
28 changes: 8 additions & 20 deletions tests/python/unittest/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ def test_loss_ndarray():
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4)


def get_net(num_hidden, flatten=True):
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128, flatten=flatten)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64, flatten=flatten)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=num_hidden, flatten=flatten)
return fc3


@with_seed()
def test_bce_equal_ce2():
N = 100
Expand Down Expand Up @@ -163,7 +153,7 @@ def test_cosine_loss():
denominator = mx.nd.sqrt(mx.nd.sum(input1**2, axis=1, keepdims=True)) \
* mx.nd.sqrt(mx.nd.sum(input2**2, axis=1, keepdims=True))
numpy_loss = mx.nd.where(label == 1, 1-numerator/denominator, \
mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1))
mx.nd.broadcast_maximum(mx.nd.array([0]), numerator/denominator, axis=1)).reshape((-1,))
assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)

@xfail_when_nonstandard_decimal_separator
Expand All @@ -186,25 +176,23 @@ def test_poisson_nllloss():
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy())
brute_loss = np.mean(np.exp(pred.asnumpy()) - target.asnumpy() * pred.asnumpy(), axis=1)
loss_withlogits = Loss(pred, target)
assert_almost_equal(brute_loss, loss_withlogits.asscalar())
assert_almost_equal(brute_loss, loss_withlogits)

#2) Testing for flag logits = False
loss_no_logits = Loss_no_logits(pred, target)
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08))
if np.isnan(loss_no_logits.asscalar()):
assert_almost_equal(np.isnan(np_loss_no_logits), np.isnan(loss_no_logits.asscalar()))
else:
assert_almost_equal(np_loss_no_logits, loss_no_logits.asscalar())
np_loss_no_logits = np.mean(pred.asnumpy() - target.asnumpy() * np.log(pred.asnumpy() + 1e-08),
axis=1)
assert_almost_equal(np_loss_no_logits, loss_no_logits.asnumpy())

#3) Testing for Sterling approximation
shape=(2, 3)
np_pred = np.random.uniform(1, 5, shape)
np_target = np.random.uniform(1, 5, shape)
np_compute_full = np.mean((np_pred - np_target * np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)))
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
loss_compute_full = Loss_compute_full(mx.nd.array(np_pred), mx.nd.array(np_target))
assert_almost_equal(np_compute_full, loss_compute_full.asscalar())
assert_almost_equal(np_compute_full, loss_compute_full)

235 changes: 235 additions & 0 deletions tests/python/unittest/test_numpy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
import numpy as np
from mxnet import gluon, autograd
from mxnet.test_utils import assert_almost_equal, default_context, use_np
from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator
import unittest


@xfail_when_nonstandard_decimal_separator
@with_seed()
@use_np
def test_loss_np_ndarray():
output = mx.np.array([1, 2, 3, 4])
label = mx.np.array([1, 3, 5, 7])
weighting = mx.np.array([0.5, 1, 0.5, 1])

loss = gluon.loss.L1Loss()
assert mx.np.sum(loss(output, label)) == 6.
loss = gluon.loss.L1Loss(weight=0.5)
assert mx.np.sum(loss(output, label)) == 3.
loss = gluon.loss.L1Loss()
assert mx.np.sum(loss(output, label, weighting)) == 5.

loss = gluon.loss.L2Loss()
assert mx.np.sum(loss(output, label)) == 7.
loss = gluon.loss.L2Loss(weight=0.25)
assert mx.np.sum(loss(output, label)) == 1.75
loss = gluon.loss.L2Loss()
assert mx.np.sum(loss(output, label, weighting)) == 6

loss = gluon.loss.HuberLoss()
assert mx.np.sum(loss(output, label)) == 4.5
loss = gluon.loss.HuberLoss(weight=0.25)
assert mx.np.sum(loss(output, label)) == 1.125
loss = gluon.loss.HuberLoss()
assert mx.np.sum(loss(output, label, weighting)) == 3.75

loss = gluon.loss.HingeLoss(margin=10)
assert mx.np.sum(loss(output, label)) == 13.
loss = gluon.loss.HingeLoss(margin=8, weight=0.25)
assert mx.np.sum(loss(output, label)) == 2.25
loss = gluon.loss.HingeLoss(margin=7)
assert mx.np.sum(loss(output, label, weighting)) == 4.

loss = gluon.loss.SquaredHingeLoss(margin=10)
assert mx.np.sum(loss(output, label)) == 97.
loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25)
assert mx.np.sum(loss(output, label)) == 13.25
loss = gluon.loss.SquaredHingeLoss(margin=7)
assert mx.np.sum(loss(output, label, weighting)) == 19.

loss = gluon.loss.TripletLoss(margin=10)
assert mx.np.sum(loss(output, label, -label)) == 6.
loss = gluon.loss.TripletLoss(margin=8, weight=0.25)
assert mx.np.sum(loss(output, label, -label)) == 1.
loss = gluon.loss.TripletLoss(margin=7)
assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5

output = mx.np.array([[0, 2], [1, 4]])
label = mx.np.array([0, 1])
weighting = mx.np.array([[0.5], [1.0]])

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4)


@with_seed()
@use_np
def test_bce_equal_ce2():
N = 100
loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True)
loss2 = gluon.loss.SoftmaxCELoss(from_logits=True)
out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1))
out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8)
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy())

@use_np
def test_logistic_loss_equal_bce():
N = 100
loss_binary = gluon.loss.LogisticLoss(label_format='binary')
loss_signed = gluon.loss.LogisticLoss(label_format='signed')
loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False)
data = mx.np.random.uniform(-10, 10, size=(N, 1))
label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1)))
assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6)
assert_almost_equal(loss_signed(data, 2 * label - 1), loss_bce(data, label), atol=1e-6)


@with_seed()
@use_np
def test_ctc_loss():
loss = gluon.loss.CTCLoss()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC')
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN')
l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T)
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))

loss = gluon.loss.CTCLoss()
l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3]))
assert_almost_equal(l, np.array([18.82820702, 16.50581741]))


@xfail_when_nonstandard_decimal_separator
@with_seed()
@use_np
def test_sdml_loss():

N = 5 # number of samples
DIM = 10 # Dimensionality
EPOCHS = 20

# Generate randomized data and 'positive' samples
data = mx.np.random.uniform(-1, 1, size=(N, DIM))
pos = data + mx.np.random.uniform(-0.1, 0.1, size=(N, DIM)) # correlated paired data
data_iter = mx.io.NDArrayIter({'data' : data, 'pos' : pos}, batch_size=N)

# Init model and trainer
sdml_loss = gluon.loss.SDMLLoss()
model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder
model.initialize(mx.init.Xavier(), ctx=mx.current_context())
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1})

for i in range(EPOCHS): # Training loop
data_iter.reset()
for iter_batch in data_iter:
batch = [datum.as_in_ctx(mx.current_context()).as_np_ndarray() for datum in iter_batch.data]
with autograd.record():
data, pos = batch
z_data, z_pos = model(data), model(pos)
loss = sdml_loss(z_data, z_pos)
loss.backward()
trainer.step(1)

# After training euclidean distance between aligned pairs should be lower than all non-aligned pairs
avg_loss = loss.sum()/len(loss)
assert(avg_loss < 0.05)

@with_seed()
@use_np
def test_cosine_loss():
#Generating samples
input1 = mx.np.random.randn(3, 2)
input2 = mx.np.random.randn(3, 2)
label = mx.np.sign(mx.np.random.randn(input1.shape[0]))
#Calculating loss from cosine embedding loss function in Gluon
Loss = gluon.loss.CosineEmbeddingLoss()
loss = Loss(input1, input2, label)

# Calculating the loss Numpy way
numerator = mx.np.sum(input1 * input2, keepdims=True, axis=1)
denominator = mx.np.sqrt(mx.np.sum(input1**2, axis=1, keepdims=True)) \
* mx.np.sqrt(mx.np.sum(input2**2, axis=1, keepdims=True))
x = numerator/denominator
label = mx.npx.reshape(label, (-1, 1))
numpy_loss = mx.npx.reshape(
mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,))
assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5)

@xfail_when_nonstandard_decimal_separator
@use_np
def test_poisson_nllloss():
shape=(3, 4)
not_axis0 = tuple(range(1, len(shape)))
pred = mx.np.random.normal(size=shape)
min_pred = mx.np.min(pred)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
pred[:] = pred + mx.np.abs(min_pred)
target = mx.np.random.normal(size=shape)
min_target = mx.np.min(target)
#This is necessary to ensure only positive random values are generated for prediction,
# to avoid ivalid log calculation
target[:] += mx.np.abs(min_target)

Loss = gluon.loss.PoissonNLLLoss(from_logits=True)
Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False)
#Calculating by brute formula for default value of from_logits = True

# 1) Testing for flag logits = True
brute_loss = mx.np.mean(mx.np.exp(pred) - target * pred, axis=1)
loss_withlogits = Loss(pred, target)
assert_almost_equal(brute_loss, loss_withlogits)

#2) Testing for flag logits = False
loss_no_logits = Loss_no_logits(pred, target)
np_loss_no_logits = mx.np.mean(pred - target * mx.np.log(pred + 1e-08),
axis=1)
assert_almost_equal(np_loss_no_logits, loss_no_logits)

#3) Testing for Sterling approximation
shape=(2, 3)
np_pred = mx.np.random.uniform(1, 5, shape)
np_target = mx.np.random.uniform(1, 5, shape)
np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\
np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1)
Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True)
loss_compute_full = Loss_compute_full(np_pred, np_target)
assert_almost_equal(np_compute_full, loss_compute_full)

Loading