Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix indexing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Sep 13, 2019
1 parent f9cb51e commit 1e2e7ac
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 11 deletions.
20 changes: 16 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,14 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
Note: mxnet.numpy.ndarray not support NDArray as assigned value.
"""
if isinstance(value, numeric_types):
value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype)
value_nd = full(bcast_shape, value, ctx=self.ctx, dtype=self.dtype)
elif isinstance(value, self.__class__):
value_nd = value.as_in_context(self.context)
value_nd = value.as_in_ctx(self.ctx)
if value_nd.dtype != self.dtype:
value_nd = value_nd.astype(self.dtype)
else:
try:
value_nd = array(value, ctx=self.context, dtype=self.dtype)
value_nd = array(value, ctx=self.ctx, dtype=self.dtype)
except:
raise TypeError('mxnet.np.ndarray does not support assignment with non-array-like '
'object {} of type {}'.format(value, type(value)))
Expand All @@ -330,14 +330,26 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)])
value_nd = value_nd.squeeze(axis=tuple(squeeze_axes))

# handle the cases like the following
# a = np.zeros((3, 3)), b = np.ones((1, 1, 1, 1, 3)), a[0] = b
# b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed
if value_nd.ndim > len(bcast_shape):
squeeze_axes = []
for i in range(value_nd.ndim - len(bcast_shape)):
if value_nd.shape[i] == 1:
squeeze_axes.append(i)
else:
break
if squeeze_axes:
value_nd = value_nd.squeeze(squeeze_axes)

if value_nd.shape != bcast_shape:
if value_nd.size == 0:
value_nd = value_nd.reshape(bcast_shape)
else:
value_nd = value_nd.broadcast_to(bcast_shape)
return value_nd


def __add__(self, other):
"""x.__add__(y) <=> x + y"""
return add(self, other)
Expand Down
38 changes: 37 additions & 1 deletion tests/python/unittest/test_numpy_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from __future__ import absolute_import
from __future__ import division

import numpy as _np
import mxnet as mx
from mxnet import gluon, autograd, np
from mxnet.test_utils import use_np
from mxnet.test_utils import use_np, assert_almost_equal
from common import with_seed


def test_create_np_param():
Expand Down Expand Up @@ -108,6 +110,40 @@ def hybrid_forward(self, F, pred, label):
trainer.step(1)


@with_seed()
@use_np
def test_np_loss_ndarray():
# Ported from test_loss.test_loss_ndarray
output = np.array([1, 2, 3, 4])
label = np.array([1, 3, 5, 7])
weighting = np.array([0.5, 1, 0.5, 1])

loss = gluon.loss.L1Loss()
assert np.sum(loss(output, label)) == 6.
loss = gluon.loss.L1Loss(weight=0.5)
assert np.sum(loss(output, label)) == 3.
loss = gluon.loss.L1Loss()
assert np.sum(loss(output, label, weighting)) == 5.

loss = gluon.loss.L2Loss()
assert np.sum(loss(output, label)) == 7.
loss = gluon.loss.L2Loss(weight=0.25)
assert np.sum(loss(output, label)) == 1.75
loss = gluon.loss.L2Loss()
assert np.sum(loss(output, label, weighting)) == 6

output = np.array([[0, 2], [1, 4]])
label = np.array([0, 1])
weighting = np.array([[0.5], [1.0]])

loss = gluon.loss.SoftmaxCrossEntropyLoss()
L = loss(output, label).asnumpy()
assert_almost_equal(L, _np.array([2.12692809, 0.04858733]), use_broadcast=False)

L = loss(output, label, weighting).asnumpy()
assert_almost_equal(L, _np.array([1.06346405, 0.04858733]), use_broadcast=False)


if __name__ == '__main__':
import nose
nose.runmodule()
28 changes: 22 additions & 6 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception, use_np
from common import with_seed, TemporaryDirectory
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf
from mxnet.ndarray.ndarray import py_slice
from mxnet.base import integer_types
import scipy.stats as ss


Expand Down Expand Up @@ -410,7 +412,7 @@ def test_np_ndarray_copy():
def test_np_ndarray_indexing():
def np_int(index, int_type=np.int32):
"""
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
tuples of ints or None.
"""
def convert(num):
Expand All @@ -432,7 +434,7 @@ def convert(num):
else:
assert False

# Copied from test_ndarray.py. Under construction.
# Copied from test_ndarray.py. Under construction.
def test_getitem(np_array, index):
np_index = index
if type(index) == mx.nd.NDArray: # use of NDArray is prohibited
Expand Down Expand Up @@ -470,6 +472,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None)

assert same(np_array, mx_array.asnumpy())

def _is_basic_index(index):
if isinstance(index, (integer_types, py_slice)):
return True
if isinstance(index, tuple) and all(isinstance(i, (integer_types, py_slice)) for i in index):
return True
return False

np_index = index # keep this native numpy type
if isinstance(index, np.ndarray):
np_index = index.asnumpy()
Expand Down Expand Up @@ -498,6 +507,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None)
assert_same(np_array, np_index, mx_array, index, np.array(np_value))
# test native numpy array with broadcast
assert_same(np_array, np_index, mx_array, index, np_value)

# test value shape are expanded to be longer than index array's shape
# this is currently only supported in basic indexing
if _is_basic_index(index):
expanded_value_shape = (1, 1, 1) + np_value.shape
assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape)))
assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape))
# test list with broadcast
assert_same(np_array, np_index, mx_array, index,
[_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1])
Expand Down Expand Up @@ -695,18 +711,18 @@ def test_setitem_autograd(np_array, index):
test_setitem(np_array, index)
test_getitem_autograd(np_array, index)
test_setitem_autograd(np_array, index)

# Test indexing to zero-size tensors
index_list = [
(slice(0, 0), slice(0, 0), 1, 2),
(slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)),
(slice(0, 0), slice(0, 0), 1, 2),
(slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)),
]
for index in index_list:
test_getitem(np_array, index)
test_setitem(np_array, index)
test_getitem_autograd(np_array, index)
test_setitem_autograd(np_array, index)

# test zero-size tensors get and setitem
shapes_indices = [
((0), [slice(None, None, None)]),
Expand Down

0 comments on commit 1e2e7ac

Please sign in to comment.