Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1206] Support NDArray indexing with None and Ellipsis #13143

Merged
merged 48 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
25f0910
Support NDArray indexing with None and Ellipsis
kohr-h Nov 6, 2018
a23078f
Update NDArray.__setitem__ docs with None and Ellipsis
kohr-h Nov 6, 2018
e602b20
Fix boolean flag in NDArray.__getitem__, add doctests
kohr-h Nov 7, 2018
b1876e7
Add setitem test for None and Ellipsis
kohr-h Nov 7, 2018
8da6e00
Fix wrong slice used, add cases to test_indexing
kohr-h Nov 8, 2018
153cd35
Revamp NDArray.__getitem__ and __setitem__
kohr-h Nov 14, 2018
3b9a3b1
Fix typo in error message of SetSliceOpOutputDimSize
kohr-h Nov 14, 2018
6911806
Fix setting of array with integer indices
kohr-h Nov 15, 2018
378a1de
Fix basic __setitem__ for all test cases
kohr-h Nov 17, 2018
8dada74
WIP: fixing advanced indexing
kohr-h Nov 26, 2018
95cfc82
REMOVE: printing in tests
kohr-h Nov 26, 2018
8e748d7
Re-implement advanced indexing with None and Ellipsis
kohr-h Dec 13, 2018
c7e3829
Fix lint errors
kohr-h Dec 13, 2018
57f26b7
WIP: fix basic indexing
kohr-h Dec 23, 2018
99bca07
WIP: fix basic indexing
kohr-h Dec 25, 2018
f9c048a
TEMP: print statements in tests
kohr-h Dec 25, 2018
bb57638
Fix op.slice with step<0 and end==-1
kohr-h Dec 25, 2018
e3d9921
Implement copy-free general contiguous indexing
kohr-h Dec 26, 2018
00ffcd4
Improve doctest of __getitem__
kohr-h Dec 27, 2018
89ea383
Fix missing staticmethod
kohr-h Dec 27, 2018
aba2fec
Remove superfluous _at and _slice
kohr-h Dec 27, 2018
7044242
Fix lint errors
kohr-h Dec 27, 2018
faed019
WIP: basic indexing
kohr-h Feb 8, 2019
84fce03
Remove print statements from tests
kohr-h Mar 12, 2019
b03c301
Fix call into op.slice in basic indexing, add doctest
kohr-h Mar 12, 2019
49c672e
Print failing index in setitem tests
kohr-h Mar 13, 2019
b4aac8a
Simplify implementation of advanced index broadcasting
kohr-h Mar 13, 2019
846ac5a
Better printing for failing setitem tests
kohr-h Mar 14, 2019
10bb038
Remove list indexing restriction, fix value shape broadcasting
kohr-h Mar 14, 2019
0e06ad3
Fix bad string formatting
kohr-h Mar 14, 2019
177cb14
Fix bug in test_uniform
kohr-h Mar 18, 2019
6f51679
Test mutability of sliced array if contiguous
kohr-h Mar 18, 2019
104d4d9
Fix whitespace error in matrix_op-inl.h
kohr-h Apr 9, 2019
8eddf2d
"Fix" pylint complaints
kohr-h Apr 9, 2019
4b0a178
Temporarily disable failing unittests
kohr-h Apr 9, 2019
5733cdd
Silence another pylint complaint
kohr-h Apr 9, 2019
0331f50
Fix size-0 array creation
kohr-h Apr 13, 2019
393c808
Make scalar tensor assignment test check for IndexError
kohr-h Apr 13, 2019
7367bdc
Re-activate operator tests with 0-size arrays
kohr-h Apr 13, 2019
84ba227
Use np.compat in tests with zeros in shape or empty shape
kohr-h May 4, 2019
2f16bb3
Change comment in autograd indexing test
kohr-h Jun 11, 2019
4e82100
Add more None-containing index tuples to indexing test
kohr-h Jun 11, 2019
d312064
Disable None in advanced indexing test since it has not been supported
reminisce Aug 3, 2019
f4d2af0
Fix sanity
reminisce Aug 3, 2019
4234412
Fix ci
reminisce Aug 3, 2019
fe6336d
Fix unit test failure
reminisce Aug 5, 2019
1b80049
Merge branch 'master' into indexing_none_ellipsis
reminisce Aug 8, 2019
2e5ffe2
Fix __getitem__
reminisce Aug 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,137 changes: 751 additions & 386 deletions python/mxnet/ndarray/ndarray.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __getitem__(self, key):
if key.step is not None and key.step != 1:
if key.step == 0:
raise ValueError("slice step cannot be zero")
return self.as_nd_ndarray()._get_nd_basic_indexing(key).as_np_ndarray()
return self.as_nd_ndarray().__getitem__(key).as_np_ndarray()
elif key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
Expand Down Expand Up @@ -157,7 +157,7 @@ def __setitem__(self, key, value):
value = value.as_nd_ndarray()
# TODO(junwu): Better handling of this situation
if isinstance(key, tuple) and len(key) == 0:
self.as_nd_ndarray().__setitem__(slice(None), value)
self.as_nd_ndarray().__setitem__(key, value)
return

if isinstance(key, ndarray):
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,

executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
for g in executor.grad_arrays:
g[:] = 0
print(g.shape)
if g.ndim == 0:
g[()] = 0
else:
g[:] = 0

executor.forward(is_train=False)

Expand Down
9 changes: 6 additions & 3 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,10 @@ inline void GetIndexRange(const mxnet::TShape& dshape,

// checking upper and lower bounds for end
if (e < 0 && param_end[i].has_value()) {
e += len;
if (!(s < 0 && e == -1)) {
// Keep end=-1 as one-beyond-limits index for negative stride
e += len;
}
CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len
<< " exceeds limit of input dimension[" << i << "]=" << len;
}
Expand Down Expand Up @@ -740,11 +743,11 @@ inline void SetSliceOpOutputDimSize(const index_t i, const int b,
mxnet::TShape* oshape) {
if (e != b) {
if (s > 0) {
CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (e - b - 1) / s + 1;
} else {
CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]="
CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]="
<< e << ", and step[" << i << "]=" << s << " is invalid";
(*oshape)[i] = (b - e - 1) / (-s) + 1;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_dgl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_compact(csr, id_arr, num_nodes):
compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False)
assert compact.shape[0] == num_nodes
assert compact.shape[1] == num_nodes
assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1
assert mx.nd.sum(compact.indptr == csr.indptr[0:int(num_nodes + 1)]).asnumpy() == num_nodes + 1
sub_indices = compact.indices.asnumpy()
indices = csr.indices.asnumpy()
id_arr = id_arr.asnumpy()
Expand Down
122 changes: 100 additions & 22 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import mxnet as mx
import numpy as np
from distutils.version import LooseVersion
from itertools import permutations, combinations_with_replacement
import os
import pickle as pkl
import unittest
from nose.tools import raises
from common import setup_module, with_seed, assertRaises, TemporaryDirectory, teardown
from nose.tools import assert_raises, raises
from common import with_seed, assertRaises, TemporaryDirectory
from mxnet.test_utils import almost_equal
from mxnet.test_utils import assert_almost_equal, assert_exception
from mxnet.test_utils import default_context
Expand Down Expand Up @@ -101,6 +101,26 @@ def test_ndarray_setitem():
x_np[-1] = 1
assert same(x.asnumpy(), x_np)

# Ellipsis
x = mx.nd.zeros(shape)
x[2, ...] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[2, ...] = 1
assert same(x.asnumpy(), x_np)

x = mx.nd.zeros(shape)
x[..., 1] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[..., 1] = 1
assert same(x.asnumpy(), x_np)

# `None` should be ignored
x = mx.nd.zeros(shape)
x[None, 0, None, None, 0, 0, None] = 1
x_np = np.zeros(shape, dtype=x.dtype)
x_np[None, 0, None, None, 0, 0, None] = 1
assert same(x.asnumpy(), x_np)

# short all-dim indexing
x = mx.nd.zeros(shape)
val = mx.nd.ones((3, 2))
Expand All @@ -121,13 +141,15 @@ def test_ndarray_setitem():
x_np[:, -3:-1, -2:-1] = 1
assert same(x.asnumpy(), x_np)

# numpy assignment for empty axis
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]:
if trivial_shape == tuple():
with mx.np_shape():
x = mx.nd.zeros(trivial_shape)
else:
x = mx.nd.zeros(trivial_shape)
# Scalar array, no assignment allowed
with mx.np_shape():
x = mx.nd.zeros(())
with assert_raises(IndexError):
x[:] = 1

# Assignments for empty axes
for trivial_shape in [(1,), (1, 1), (1, 1, 1)]:
x = mx.nd.zeros(trivial_shape)
x[:] = np.ones(trivial_shape)
x_np = np.ones(trivial_shape, dtype=x.dtype)
assert x.shape == trivial_shape
Expand Down Expand Up @@ -1286,6 +1308,42 @@ def test_bool():
assert bool(mx.nd.ones((1,)))


def test_basic_indexing_is_contiguous():
x_np = np.arange(np.prod((6, 7, 8, 9))).reshape((6, 7, 8, 9))
x_mx = mx.nd.array(x_np)

slices = [
slice(None),
slice(2),
slice(20),
slice(1, 4),
slice(None, None, 2),
slice(None, None, 20),
slice(0, 1),
slice(None, None, -1),
slice(3, None, -2),
]

is_contiguous = mx.nd.NDArray._basic_indexing_slice_is_contiguous

for idx in combinations_with_replacement(slices, 4):
for slc in permutations(idx):
# Check helper function
contig_pred = is_contiguous(slc, x_np.shape)
contig_true = x_np[slc].flags.contiguous
assert contig_pred == contig_true, (
"failed with slc={}, pred ({}) != actual ({})"
"".format(slc, contig_pred, contig_true)
)

if contig_pred:
# Check mutation behavior
y_mx = x_mx.copy()
y_mx_slc = y_mx[slc]
y_mx_slc[:] = 0
assert (y_mx[slc].asnumpy() == 0).all()


@with_seed()
def test_ndarray_indexing():
def test_getitem(np_array, index, is_scalar=False):
Expand All @@ -1296,22 +1354,24 @@ def test_getitem(np_array, index, is_scalar=False):
if isinstance(index, mx.nd.NDArray):
np_index = index.asnumpy()
if isinstance(index, tuple):
np_index = []
for idx in index:
if isinstance(idx, mx.nd.NDArray):
np_index.append(idx.asnumpy())
else:
np_index.append(idx)
np_index = tuple(np_index)
np_index = tuple(
idx.asnumpy() if isinstance(idx, mx.nd.NDArray) else idx
for idx in index
)

np_indexed_array = np_array[np_index]
mx_array = mx.nd.array(np_array, dtype=np_array.dtype)
mx_indexed_array = mx_array[index]
try:
mx_indexed_array = mx_array[index]
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is an exception expected to happen with the given test data? If not, can we remove it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's a debugging leftover.

Copy link
Contributor Author

@kohr-h kohr-h Jun 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OTOH, there's more code like this in the test suite, like assert <expr>, "informative error text". It adds information to the plain failure report, so I'd prefer to leave it in.

print('Failed with index = {}'.format(index))
raise e
if is_scalar:
mx_indexed_array = mx_indexed_array.asscalar()
else:
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index)

assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_setitem(np_array, index, is_scalar):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
Expand All @@ -1321,7 +1381,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None)
np_array[np_index] = mx_value.asnumpy()
else:
np_array[np_index] = mx_value
mx_array[mx_index] = mx_value

try:
mx_array[mx_index] = mx_value
except Exception as e:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

print('Failed with index = {}, value.shape = {}'.format(mx_index, mx_value.shape))
raise e

assert same(np_array, mx_array.asnumpy())

np_index = index
Expand Down Expand Up @@ -1380,7 +1446,9 @@ def test_setitem_autograd(np_array, index):
try:
with mx.autograd.record():
x[index] = y
assert False # should not reach here
# `a[None] = v` is equivalent to `a[...] = v` which doesn't raise
if index is not None:
assert False, 'failed with index = {}'.format(index) # should not reach here
except mx.base.MXNetError as err:
assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1

Expand Down Expand Up @@ -1502,7 +1570,17 @@ def convert(num):
(([[[[1]]]], 3, slice(0, 3), 0), False),
(([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False),
(([1, 2], slice(3, 5), [2, 3], [3, 4]), False),
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False)]
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False),
((1, Ellipsis, -1), False),
((slice(2), Ellipsis, None, 0), False),
(None, False),
((1, None, -2, 3, -4), False),
# TODO(zoeygxy): Support None in advanced indexing
# (([1, 2], slice(3, 5), None, None, [3, 4]), False),
# ((slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), False),
# ((slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), False),
# ((None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), False),
]
for index in index_list:
test_getitem(np_array, index[0], index[1])
test_setitem(np_array, index[0], index[1])
Expand Down
31 changes: 17 additions & 14 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4793,13 +4793,14 @@ def test_normal_case():

def test_empty_tensor():
shape = (2, 3, 0, 4)
a = np.array([], dtype=np.int32).reshape(shape)
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)
reps = (2, 4, 6)
with mx.np_shape():
a = np.array([], dtype=np.int32).reshape(shape)
b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype)

a_tiled = np.tile(a, reps)
b_tiled = mx.nd.tile(b, reps).asnumpy()
assert same(a_tiled, b_tiled)
reps = (2, 4, 6)
a_tiled = np.tile(a, reps)
b_tiled = mx.nd.tile(b, reps).asnumpy()
assert same(a_tiled, b_tiled)

def test_empty_reps():
a = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32)
Expand Down Expand Up @@ -4889,13 +4890,15 @@ def test_normal_case(index_type=np.int32):

def test_empty_indices():
shape = (2, 0, 9, 3)
indices = np.array([]).reshape(shape)
depth = 10
mx_one_hot_array = mx.nd.one_hot(
mx.nd.array(indices, ctx=default_context(), dtype=np.int32),
depth=depth, dtype=np.int32).asnumpy()
expected_array = np.array([], dtype=np.int32).reshape(shape + (depth, ))
assert same(expected_array, mx_one_hot_array)
with mx.np_shape():
indices = np.array([]).reshape(shape)
depth = 10
mx_one_hot_array = mx.nd.one_hot(
mx.nd.array(indices, ctx=default_context(), dtype=np.int32),
depth=depth, dtype=np.int32
).asnumpy()
expected_array = np.array([], dtype=np.int32).reshape(shape + (depth,))
assert same(expected_array, mx_one_hot_array)

def test_zero_depth():
shape = (2, 4, 9, 3)
Expand Down Expand Up @@ -8859,7 +8862,7 @@ def test_index_array_default():

@mx.use_np_shape
def test_index_array_default_zero_dim():
data = mx.symbol.Variable("data")
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)

input_array = np.ones(())
Expand Down