From 03fdeeac8f7e441e107ef9e757b1f6f8356bf2ef Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 11:27:02 -0700 Subject: [PATCH 1/7] Test for hybridization --- python/mxnet/gluon/loss.py | 22 ++++-- tests/python/unittest/test_numpy_loss.py | 94 ++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 13 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index bc447b0f1c55..8feb85a4324a 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -77,18 +77,28 @@ def _reshape_like(F, x, y): def _batch_mean(F, loss, batch_axis): """Return mean on the specified batch axis, not keeping the axis""" if is_np_array(): - axes = list(range(loss.ndim)) - del axes[batch_axis] - return F.np.mean(loss, axis=axes) + if F is ndarray: + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.mean(loss, axis=axes) + else: + assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \ + 'flag in mean. So we only support batch_axis=0.' + return F.npx.batch_flatten(loss).mean(axis=1) else: return F.mean(loss, axis=batch_axis, exclude=True) def _batch_sum(F, loss, batch_axis): """Return sum on the specified batch axis, not keeping the axis""" if is_np_array(): - axes = list(range(loss.ndim)) - del axes[batch_axis] - return F.np.sum(loss, axis=axes) + if F is ndarray: + axes = list(range(loss.ndim)) + del axes[batch_axis] + return F.np.sum(loss, axis=axes) + else: + assert batch_axis == 0, 'Currently, we have not supported the "exclude" ' \ + 'flag in mean. So we only support batch_axis=0.' + return F.npx.batch_flatten(loss).sum(axis=1) else: return F.sum(loss, axis=batch_axis, exclude=True) diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py index 6c63546f85b1..a10b54412478 100644 --- a/tests/python/unittest/test_numpy_loss.py +++ b/tests/python/unittest/test_numpy_loss.py @@ -20,57 +20,94 @@ from mxnet import gluon, autograd from mxnet.test_utils import assert_almost_equal, default_context, use_np from common import setup_module, with_seed, teardown_module, xfail_when_nonstandard_decimal_separator -import unittest +import pytest @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -def test_loss_np_ndarray(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_loss_np_ndarray(hybridize): output = mx.np.array([1, 2, 3, 4]) label = mx.np.array([1, 3, 5, 7]) weighting = mx.np.array([0.5, 1, 0.5, 1]) loss = gluon.loss.L1Loss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 6. loss = gluon.loss.L1Loss(weight=0.5) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 3. loss = gluon.loss.L1Loss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, weighting)) == 5. loss = gluon.loss.L2Loss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 7. loss = gluon.loss.L2Loss(weight=0.25) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 1.75 loss = gluon.loss.L2Loss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, weighting)) == 6 loss = gluon.loss.HuberLoss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 4.5 loss = gluon.loss.HuberLoss(weight=0.25) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 1.125 loss = gluon.loss.HuberLoss() + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, weighting)) == 3.75 loss = gluon.loss.HingeLoss(margin=10) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 13. loss = gluon.loss.HingeLoss(margin=8, weight=0.25) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 2.25 loss = gluon.loss.HingeLoss(margin=7) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, weighting)) == 4. loss = gluon.loss.SquaredHingeLoss(margin=10) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 97. loss = gluon.loss.SquaredHingeLoss(margin=8, weight=0.25) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label)) == 13.25 loss = gluon.loss.SquaredHingeLoss(margin=7) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, weighting)) == 19. loss = gluon.loss.TripletLoss(margin=10) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, -label)) == 6. loss = gluon.loss.TripletLoss(margin=8, weight=0.25) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, -label)) == 1. loss = gluon.loss.TripletLoss(margin=7) + if hybridize: + loss.hybridize() assert mx.np.sum(loss(output, label, -label, weighting)) == 1.5 output = mx.np.array([[0, 2], [1, 4]]) @@ -78,6 +115,8 @@ def test_loss_np_ndarray(): weighting = mx.np.array([[0.5], [1.0]]) loss = gluon.loss.SoftmaxCrossEntropyLoss() + if hybridize: + loss.hybridize() L = loss(output, label).asnumpy() assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4) @@ -87,21 +126,34 @@ def test_loss_np_ndarray(): @with_seed() @use_np -def test_bce_equal_ce2(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_bce_equal_ce2(hybridize): N = 100 loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True) + if hybridize: + loss1.hybridize() loss2 = gluon.loss.SoftmaxCELoss(from_logits=True) + if hybridize: + loss1.hybridize() out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1)) out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8) label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1))) assert_almost_equal(loss1(out1, label).asnumpy(), loss2(out2, label).asnumpy()) + @use_np -def test_logistic_loss_equal_bce(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_logistic_loss_equal_bce(hybridize): N = 100 loss_binary = gluon.loss.LogisticLoss(label_format='binary') + if hybridize: + loss_binary.hybridize() loss_signed = gluon.loss.LogisticLoss(label_format='signed') + if hybridize: + loss_signed.hybridize() loss_bce = gluon.loss.SigmoidBCELoss(from_sigmoid=False) + if hybridize: + loss_bce.hybridize() data = mx.np.random.uniform(-10, 10, size=(N, 1)) label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1))) assert_almost_equal(loss_binary(data, label), loss_bce(data, label), atol=1e-6) @@ -110,28 +162,41 @@ def test_logistic_loss_equal_bce(): @with_seed() @use_np -def test_ctc_loss(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_ctc_loss(hybridize): loss = gluon.loss.CTCLoss() + if hybridize: + loss.hybridize() l = loss(mx.np.ones((2,20,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]])) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss(layout='TNC') + if hybridize: + loss.hybridize() l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]])) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss(layout='TNC', label_layout='TN') + if hybridize: + loss.hybridize() l = loss(mx.np.ones((20,2,4)), mx.np.array([[1,0,-1,-1],[2,1,1,-1]]).T) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() + if hybridize: + loss.hybridize() l = loss(mx.np.ones((2,20,4)), mx.np.array([[2,1,2,2],[3,2,2,2]]), None, mx.np.array([2,3])) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() + if hybridize: + loss.hybridize() l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,-1,-1],[3,2,2,-1]]), mx.np.array([20,20])) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) loss = gluon.loss.CTCLoss() + if hybridize: + loss.hybridize() l = loss(mx.np.ones((2,25,4)), mx.np.array([[2,1,3,3],[3,2,2,3]]), mx.np.array([20,20]), mx.np.array([2,3])) assert_almost_equal(l, np.array([18.82820702, 16.50581741])) @@ -139,6 +204,7 @@ def test_ctc_loss(): @xfail_when_nonstandard_decimal_separator @with_seed() @use_np +@pytest.mark.parameterize(hybridize, [False, True]) def test_sdml_loss(): N = 5 # number of samples @@ -152,6 +218,8 @@ def test_sdml_loss(): # Init model and trainer sdml_loss = gluon.loss.SDMLLoss() + if hybridize: + sdml_loss.hybridize() model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder model.initialize(mx.init.Xavier(), ctx=mx.current_context()) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1}) @@ -171,15 +239,19 @@ def test_sdml_loss(): avg_loss = loss.sum()/len(loss) assert(avg_loss < 0.05) + @with_seed() @use_np -def test_cosine_loss(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_cosine_loss(hybridize): #Generating samples input1 = mx.np.random.randn(3, 2) input2 = mx.np.random.randn(3, 2) label = mx.np.sign(mx.np.random.randn(input1.shape[0])) #Calculating loss from cosine embedding loss function in Gluon Loss = gluon.loss.CosineEmbeddingLoss() + if hybridize: + Loss.hybridize() loss = Loss(input1, input2, label) # Calculating the loss Numpy way @@ -192,9 +264,11 @@ def test_cosine_loss(): mx.np.where(label == 1, 1-x, mx.npx.relu(x)), (-1,)) assert_almost_equal(loss.asnumpy(), numpy_loss.asnumpy(), rtol=1e-3, atol=1e-5) + @xfail_when_nonstandard_decimal_separator @use_np -def test_poisson_nllloss(): +@pytest.mark.parameterize(hybridize, [False, True]) +def test_poisson_nllloss(hybridize): shape=(3, 4) not_axis0 = tuple(range(1, len(shape))) pred = mx.np.random.normal(size=shape) @@ -209,7 +283,11 @@ def test_poisson_nllloss(): target[:] += mx.np.abs(min_target) Loss = gluon.loss.PoissonNLLLoss(from_logits=True) + if hybridize: + Loss.hybridize() Loss_no_logits = gluon.loss.PoissonNLLLoss(from_logits=False) + if hybridize: + Loss_no_logits.hybridize() #Calculating by brute formula for default value of from_logits = True # 1) Testing for flag logits = True @@ -230,6 +308,8 @@ def test_poisson_nllloss(): np_compute_full = mx.np.mean((np_pred - np_target * mx.np.log(np_pred + 1e-08)) + ((np_target * np.log(np_target)-\ np_target + 0.5 * np.log(2 * np_target * np.pi))*(np_target > 1)), axis=1) Loss_compute_full = gluon.loss.PoissonNLLLoss(from_logits=False, compute_full=True) + if hybridize: + Loss_compute_full.hybridize() loss_compute_full = Loss_compute_full(np_pred, np_target) assert_almost_equal(np_compute_full, loss_compute_full) From 50ab104a7737e455b80f0691d3907a50fe380586 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 14:00:52 -0700 Subject: [PATCH 2/7] fix typo --- tests/python/unittest/test_numpy_loss.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py index a10b54412478..771639449c85 100644 --- a/tests/python/unittest/test_numpy_loss.py +++ b/tests/python/unittest/test_numpy_loss.py @@ -26,7 +26,7 @@ @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_loss_np_ndarray(hybridize): output = mx.np.array([1, 2, 3, 4]) label = mx.np.array([1, 3, 5, 7]) @@ -126,7 +126,7 @@ def test_loss_np_ndarray(hybridize): @with_seed() @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_bce_equal_ce2(hybridize): N = 100 loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True) @@ -142,7 +142,7 @@ def test_bce_equal_ce2(hybridize): @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_logistic_loss_equal_bce(hybridize): N = 100 loss_binary = gluon.loss.LogisticLoss(label_format='binary') @@ -162,7 +162,7 @@ def test_logistic_loss_equal_bce(hybridize): @with_seed() @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_ctc_loss(hybridize): loss = gluon.loss.CTCLoss() if hybridize: @@ -204,7 +204,7 @@ def test_ctc_loss(hybridize): @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_sdml_loss(): N = 5 # number of samples @@ -242,7 +242,7 @@ def test_sdml_loss(): @with_seed() @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_cosine_loss(hybridize): #Generating samples input1 = mx.np.random.randn(3, 2) @@ -267,7 +267,7 @@ def test_cosine_loss(hybridize): @xfail_when_nonstandard_decimal_separator @use_np -@pytest.mark.parameterize(hybridize, [False, True]) +@pytest.mark.parametrize(hybridize, [False, True]) def test_poisson_nllloss(hybridize): shape=(3, 4) not_axis0 = tuple(range(1, len(shape))) From 5e1e6013993b9db1c5f86195b568d6efe63f8c35 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 14:01:50 -0700 Subject: [PATCH 3/7] fix --- tests/python/unittest/test_numpy_loss.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py index 771639449c85..5545c09c3afa 100644 --- a/tests/python/unittest/test_numpy_loss.py +++ b/tests/python/unittest/test_numpy_loss.py @@ -26,7 +26,7 @@ @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_loss_np_ndarray(hybridize): output = mx.np.array([1, 2, 3, 4]) label = mx.np.array([1, 3, 5, 7]) @@ -126,7 +126,7 @@ def test_loss_np_ndarray(hybridize): @with_seed() @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_bce_equal_ce2(hybridize): N = 100 loss1 = gluon.loss.SigmoidBCELoss(from_sigmoid=True) @@ -142,7 +142,7 @@ def test_bce_equal_ce2(hybridize): @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_logistic_loss_equal_bce(hybridize): N = 100 loss_binary = gluon.loss.LogisticLoss(label_format='binary') @@ -162,7 +162,7 @@ def test_logistic_loss_equal_bce(hybridize): @with_seed() @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_ctc_loss(hybridize): loss = gluon.loss.CTCLoss() if hybridize: @@ -204,8 +204,8 @@ def test_ctc_loss(hybridize): @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -@pytest.mark.parametrize(hybridize, [False, True]) -def test_sdml_loss(): +@pytest.mark.parametrize("hybridize", [False, True]) +def test_sdml_loss(hybridize): N = 5 # number of samples DIM = 10 # Dimensionality @@ -242,7 +242,7 @@ def test_sdml_loss(): @with_seed() @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_cosine_loss(hybridize): #Generating samples input1 = mx.np.random.randn(3, 2) @@ -267,7 +267,7 @@ def test_cosine_loss(hybridize): @xfail_when_nonstandard_decimal_separator @use_np -@pytest.mark.parametrize(hybridize, [False, True]) +@pytest.mark.parametrize("hybridize", [False, True]) def test_poisson_nllloss(hybridize): shape=(3, 4) not_axis0 = tuple(range(1, len(shape))) From 7f3371c389d7e5728dbe5579c6f4e2014132524e Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 14:07:32 -0700 Subject: [PATCH 4/7] fix test --- tests/python/unittest/test_numpy_loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py index 5545c09c3afa..4f33f344c121 100644 --- a/tests/python/unittest/test_numpy_loss.py +++ b/tests/python/unittest/test_numpy_loss.py @@ -120,6 +120,9 @@ def test_loss_np_ndarray(hybridize): L = loss(output, label).asnumpy() assert_almost_equal(L, np.array([ 2.12692809, 0.04858733]), rtol=1e-3, atol=1e-4) + loss = gluon.loss.SoftmaxCrossEntropyLoss() + if hybridize: + loss.hybridize() L = loss(output, label, weighting).asnumpy() assert_almost_equal(L, np.array([ 1.06346405, 0.04858733]), rtol=1e-3, atol=1e-4) From 643bab5668beb5d087c75881586d39db23348e2f Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 15:27:30 -0700 Subject: [PATCH 5/7] update --- python/mxnet/gluon/loss.py | 17 +++++++---------- tests/python/unittest/test_numpy_loss.py | 7 ++----- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 8feb85a4324a..767f7c3b2570 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -1033,7 +1033,8 @@ class SDMLLoss(Loss): def __init__(self, smoothing_parameter=0.3, weight=1., batch_axis=0, **kwargs): super(SDMLLoss, self).__init__(weight, batch_axis, **kwargs) self.kl_loss = KLDivLoss(from_logits=True) - self.smoothing_parameter = smoothing_parameter # Smoothing probability mass + # Smoothing probability mass + self.smoothing_parameter = smoothing_parameter def _compute_distances(self, F, x1, x2): """ @@ -1042,17 +1043,13 @@ def _compute_distances(self, F, x1, x2): """ if is_np_array(): expand_dims_fn = F.np.expand_dims - broadcast_to_fn = F.np.broadcast_to else: expand_dims_fn = F.expand_dims - broadcast_to_fn = F.broadcast_to - - # extracting sizes expecting [batch_size, dim] - assert x1.shape == x2.shape - batch_size, dim = x1.shape - # expanding both tensor form [batch_size, dim] to [batch_size, batch_size, dim] - x1_ = broadcast_to_fn(expand_dims_fn(x1, 1), [batch_size, batch_size, dim]) - x2_ = broadcast_to_fn(expand_dims_fn(x2, 0), [batch_size, batch_size, dim]) + + # expanding x1 form [batch_size, dim] to [batch_size, 1, dim] + # and x2 to [1, batch_size, dim] + x1_ = expand_dims_fn(x1, 1) + x2_ = expand_dims_fn(x2, 0) # pointwise squared differences squared_diffs = (x1_ - x2_)**2 # sum of squared differences distance diff --git a/tests/python/unittest/test_numpy_loss.py b/tests/python/unittest/test_numpy_loss.py index 4f33f344c121..14f46f0b4a76 100644 --- a/tests/python/unittest/test_numpy_loss.py +++ b/tests/python/unittest/test_numpy_loss.py @@ -137,7 +137,7 @@ def test_bce_equal_ce2(hybridize): loss1.hybridize() loss2 = gluon.loss.SoftmaxCELoss(from_logits=True) if hybridize: - loss1.hybridize() + loss2.hybridize() out1 = mx.np.random.uniform(0.1, 0.9, size=(N, 1)) out2 = mx.np.log(mx.np.concatenate((1-out1, out1), axis=1) + 1e-8) label = mx.np.round(mx.np.random.uniform(0, 1, size=(N, 1))) @@ -207,8 +207,7 @@ def test_ctc_loss(hybridize): @xfail_when_nonstandard_decimal_separator @with_seed() @use_np -@pytest.mark.parametrize("hybridize", [False, True]) -def test_sdml_loss(hybridize): +def test_sdml_loss(): N = 5 # number of samples DIM = 10 # Dimensionality @@ -221,8 +220,6 @@ def test_sdml_loss(hybridize): # Init model and trainer sdml_loss = gluon.loss.SDMLLoss() - if hybridize: - sdml_loss.hybridize() model = gluon.nn.Dense(DIM, activation='tanh') # Simple NN encoder model.initialize(mx.init.Xavier(), ctx=mx.current_context()) trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate' : 0.1}) From dd0f39e1cc250000cef2aa39a655844033627e7f Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 15:41:10 -0700 Subject: [PATCH 6/7] Update loss.py --- python/mxnet/gluon/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index 767f7c3b2570..e707cd5fdfeb 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -909,8 +909,8 @@ def hybrid_forward(self, F, pred, target, sample_weight=None, epsilon=1e-08): stirling_factor = target * \ log_fn(target) - target + 0.5 * log_fn(2 * target * np.pi) target_gt_1 = target > 1 - stirling_factor *= target_gt_1 - loss += stirling_factor + stirling_factor = stirling_factor * target_gt_1 + loss = loss + stirling_factor loss = _apply_weighting(F, loss, self._weight, sample_weight) return _batch_mean(F, loss, self._batch_axis) From 48a0d369e33b3fe1bf3f8e2e88ec0f7f39c01866 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Fri, 7 Aug 2020 16:02:57 -0700 Subject: [PATCH 7/7] fix bug of sum --- python/mxnet/gluon/loss.py | 2 +- python/mxnet/symbol/numpy/_symbol.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/loss.py b/python/mxnet/gluon/loss.py index e707cd5fdfeb..75d8981bd02a 100644 --- a/python/mxnet/gluon/loss.py +++ b/python/mxnet/gluon/loss.py @@ -1080,7 +1080,6 @@ def _compute_labels(self, F, batch_size): labels = gold * (1 - self.smoothing_parameter) + (1 - gold) * self.smoothing_parameter / (batch_size - 1) return labels - def hybrid_forward(self, F, x1, x2): """ the function computes the kl divergence between the negative distances @@ -1099,6 +1098,7 @@ def hybrid_forward(self, F, x1, x2): learn to predict french president comparing it with all the other vectors in batch 2 """ + assert F is ndarray, 'SDMLLoss does not support symbolic ' if is_np_array(): log_softmax_fn = F.npx.log_softmax else: diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 2df7357e0c03..ee465442d159 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -8012,7 +8012,7 @@ def diagonal(a, offset=0, axis1=0, axis2=1): # pylint:disable=redefined-outer-name, too-many-arguments @set_module('mxnet.symbol.numpy') -def sum(a, axis=None, dtype=None, out=None, keepdims=None, initial=None, where=None): +def sum(a, axis=None, dtype=None, out=None, keepdims=False, initial=None, where=None): r""" Sum of array elements over a given axis.