diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 80e549d40c8d..1f8aa92f9851 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -312,14 +312,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): Note: mxnet.numpy.ndarray not support NDArray as assigned value. """ if isinstance(value, numeric_types): - value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype) + value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype) elif isinstance(value, self.__class__): - value_nd = value.as_in_context(self.context) + value_nd = value.as_in_ctx(self.ctx) if value_nd.dtype != self.dtype: value_nd = value_nd.astype(self.dtype) else: try: - value_nd = array(value, ctx=self.context, dtype=self.dtype) + value_nd = array(value, ctx=self.ctx, dtype=self.dtype) except: raise TypeError('mxnet.np.ndarray does not support assignment with non-array-like ' 'object {} of type {}'.format(value, type(value))) @@ -330,6 +330,19 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)]) value_nd = value_nd.squeeze(axis=tuple(squeeze_axes)) + # handle the cases like the following + # a = np.zeros((3, 3)), b = np.ones((1, 1, 1, 1, 3)), a[0] = b + # b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed + if value_nd.ndim > len(bcast_shape): + squeeze_axes = [] + for i in range(value_nd.ndim - len(bcast_shape)): + if value_nd.shape[i] == 1: + squeeze_axes.append(i) + else: + break + if squeeze_axes: + value_nd = value_nd.squeeze(squeeze_axes) + if value_nd.shape != bcast_shape: if value_nd.size == 0: value_nd = value_nd.reshape(bcast_shape) @@ -337,7 +350,6 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None): value_nd = value_nd.broadcast_to(bcast_shape) return value_nd - def __add__(self, other): """x.__add__(y) <=> x + y""" return add(self, other) diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index e96f829a8580..62ea38fc0c13 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -19,9 +19,11 @@ from __future__ import absolute_import from __future__ import division +import numpy as _np import mxnet as mx from mxnet import gluon, autograd, np -from mxnet.test_utils import use_np +from mxnet.test_utils import use_np, assert_almost_equal +from common import with_seed def test_create_np_param(): @@ -108,6 +110,40 @@ def hybrid_forward(self, F, pred, label): trainer.step(1) +@with_seed() +@use_np +def test_np_loss_ndarray(): + # Ported from test_loss.test_loss_ndarray + output = np.array([1, 2, 3, 4]) + label = np.array([1, 3, 5, 7]) + weighting = np.array([0.5, 1, 0.5, 1]) + + loss = gluon.loss.L1Loss() + assert np.sum(loss(output, label)) == 6. + loss = gluon.loss.L1Loss(weight=0.5) + assert np.sum(loss(output, label)) == 3. + loss = gluon.loss.L1Loss() + assert np.sum(loss(output, label, weighting)) == 5. + + loss = gluon.loss.L2Loss() + assert np.sum(loss(output, label)) == 7. + loss = gluon.loss.L2Loss(weight=0.25) + assert np.sum(loss(output, label)) == 1.75 + loss = gluon.loss.L2Loss() + assert np.sum(loss(output, label, weighting)) == 6 + + output = np.array([[0, 2], [1, 4]]) + label = np.array([0, 1]) + weighting = np.array([[0.5], [1.0]]) + + loss = gluon.loss.SoftmaxCrossEntropyLoss() + L = loss(output, label).asnumpy() + assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False) + + L = loss(output, label, weighting).asnumpy() + assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 3b0ddf3e3fc9..bffa7a00dccb 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -26,6 +26,8 @@ from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np from common import with_seed, TemporaryDirectory from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf +from mxnet.ndarray.ndarray import py_slice +from mxnet.base import integer_types import scipy.stats as ss @@ -410,7 +412,7 @@ def test_np_ndarray_copy(): def test_np_ndarray_indexing(): def np_int(index, int_type=np.int32): """ - Helper function for testing indexing that converts slices to slices of ints or None, and tuples to + Helper function for testing indexing that converts slices to slices of ints or None, and tuples to tuples of ints or None. """ def convert(num): @@ -432,7 +434,7 @@ def convert(num): else: assert False - # Copied from test_ndarray.py. Under construction. + # Copied from test_ndarray.py. Under construction. def test_getitem(np_array, index): np_index = index if type(index) == mx.nd.NDArray: # use of NDArray is prohibited @@ -470,6 +472,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) assert same(np_array, mx_array.asnumpy()) + def _is_basic_index(index): + if isinstance(index, (integer_types, py_slice)): + return True + if isinstance(index, tuple) and all(isinstance(i, (integer_types, py_slice)) for i in index): + return True + return False + np_index = index # keep this native numpy type if isinstance(index, np.ndarray): np_index = index.asnumpy() @@ -498,6 +507,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) assert_same(np_array, np_index, mx_array, index, np.array(np_value)) # test native numpy array with broadcast assert_same(np_array, np_index, mx_array, index, np_value) + + # test value shape are expanded to be longer than index array's shape + # this is currently only supported in basic indexing + if _is_basic_index(index): + expanded_value_shape = (1, 1, 1) + np_value.shape + assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape))) + assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape)) # test list with broadcast assert_same(np_array, np_index, mx_array, index, [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) @@ -695,18 +711,18 @@ def test_setitem_autograd(np_array, index): test_setitem(np_array, index) test_getitem_autograd(np_array, index) test_setitem_autograd(np_array, index) - + # Test indexing to zero-size tensors index_list = [ - (slice(0, 0), slice(0, 0), 1, 2), - (slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), + (slice(0, 0), slice(0, 0), 1, 2), + (slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), ] for index in index_list: test_getitem(np_array, index) test_setitem(np_array, index) test_getitem_autograd(np_array, index) test_setitem_autograd(np_array, index) - + # test zero-size tensors get and setitem shapes_indices = [ ((0), [slice(None, None, None)]),