This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1206] Support NDArray indexing with None and Ellipsis #13143
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
25f0910
Support NDArray indexing with None and Ellipsis
kohr-h a23078f
Update NDArray.__setitem__ docs with None and Ellipsis
kohr-h e602b20
Fix boolean flag in NDArray.__getitem__, add doctests
kohr-h b1876e7
Add setitem test for None and Ellipsis
kohr-h 8da6e00
Fix wrong slice used, add cases to test_indexing
kohr-h 153cd35
Revamp NDArray.__getitem__ and __setitem__
kohr-h 3b9a3b1
Fix typo in error message of SetSliceOpOutputDimSize
kohr-h 6911806
Fix setting of array with integer indices
kohr-h 378a1de
Fix basic __setitem__ for all test cases
kohr-h 8dada74
WIP: fixing advanced indexing
kohr-h 95cfc82
REMOVE: printing in tests
kohr-h 8e748d7
Re-implement advanced indexing with None and Ellipsis
kohr-h c7e3829
Fix lint errors
kohr-h 57f26b7
WIP: fix basic indexing
kohr-h 99bca07
WIP: fix basic indexing
kohr-h f9c048a
TEMP: print statements in tests
kohr-h bb57638
Fix op.slice with step<0 and end==-1
kohr-h e3d9921
Implement copy-free general contiguous indexing
kohr-h 00ffcd4
Improve doctest of __getitem__
kohr-h 89ea383
Fix missing staticmethod
kohr-h aba2fec
Remove superfluous _at and _slice
kohr-h 7044242
Fix lint errors
kohr-h faed019
WIP: basic indexing
kohr-h 84fce03
Remove print statements from tests
kohr-h b03c301
Fix call into op.slice in basic indexing, add doctest
kohr-h 49c672e
Print failing index in setitem tests
kohr-h b4aac8a
Simplify implementation of advanced index broadcasting
kohr-h 846ac5a
Better printing for failing setitem tests
kohr-h 10bb038
Remove list indexing restriction, fix value shape broadcasting
kohr-h 0e06ad3
Fix bad string formatting
kohr-h 177cb14
Fix bug in test_uniform
kohr-h 6f51679
Test mutability of sliced array if contiguous
kohr-h 104d4d9
Fix whitespace error in matrix_op-inl.h
kohr-h 8eddf2d
"Fix" pylint complaints
kohr-h 4b0a178
Temporarily disable failing unittests
kohr-h 5733cdd
Silence another pylint complaint
kohr-h 0331f50
Fix size-0 array creation
kohr-h 393c808
Make scalar tensor assignment test check for IndexError
kohr-h 7367bdc
Re-activate operator tests with 0-size arrays
kohr-h 84ba227
Use np.compat in tests with zeros in shape or empty shape
kohr-h 2f16bb3
Change comment in autograd indexing test
kohr-h 4e82100
Add more None-containing index tuples to indexing test
kohr-h d312064
Disable None in advanced indexing test since it has not been supported
reminisce f4d2af0
Fix sanity
reminisce 4234412
Fix ci
reminisce fe6336d
Fix unit test failure
reminisce 1b80049
Merge branch 'master' into indexing_none_ellipsis
reminisce 2e5ffe2
Fix __getitem__
reminisce File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,11 +18,11 @@ | |
import mxnet as mx | ||
import numpy as np | ||
from distutils.version import LooseVersion | ||
from itertools import permutations, combinations_with_replacement | ||
import os | ||
import pickle as pkl | ||
import unittest | ||
from nose.tools import raises | ||
from common import setup_module, with_seed, assertRaises, TemporaryDirectory, teardown | ||
from nose.tools import assert_raises, raises | ||
from common import with_seed, assertRaises, TemporaryDirectory | ||
from mxnet.test_utils import almost_equal | ||
from mxnet.test_utils import assert_almost_equal, assert_exception | ||
from mxnet.test_utils import default_context | ||
|
@@ -101,6 +101,26 @@ def test_ndarray_setitem(): | |
x_np[-1] = 1 | ||
assert same(x.asnumpy(), x_np) | ||
|
||
# Ellipsis | ||
x = mx.nd.zeros(shape) | ||
x[2, ...] = 1 | ||
x_np = np.zeros(shape, dtype=x.dtype) | ||
x_np[2, ...] = 1 | ||
assert same(x.asnumpy(), x_np) | ||
|
||
x = mx.nd.zeros(shape) | ||
x[..., 1] = 1 | ||
x_np = np.zeros(shape, dtype=x.dtype) | ||
x_np[..., 1] = 1 | ||
assert same(x.asnumpy(), x_np) | ||
|
||
# `None` should be ignored | ||
x = mx.nd.zeros(shape) | ||
x[None, 0, None, None, 0, 0, None] = 1 | ||
x_np = np.zeros(shape, dtype=x.dtype) | ||
x_np[None, 0, None, None, 0, 0, None] = 1 | ||
assert same(x.asnumpy(), x_np) | ||
|
||
# short all-dim indexing | ||
x = mx.nd.zeros(shape) | ||
val = mx.nd.ones((3, 2)) | ||
|
@@ -121,13 +141,15 @@ def test_ndarray_setitem(): | |
x_np[:, -3:-1, -2:-1] = 1 | ||
assert same(x.asnumpy(), x_np) | ||
|
||
# numpy assignment for empty axis | ||
for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]: | ||
if trivial_shape == tuple(): | ||
with mx.np_shape(): | ||
x = mx.nd.zeros(trivial_shape) | ||
else: | ||
x = mx.nd.zeros(trivial_shape) | ||
# Scalar array, no assignment allowed | ||
with mx.np_shape(): | ||
x = mx.nd.zeros(()) | ||
with assert_raises(IndexError): | ||
x[:] = 1 | ||
|
||
# Assignments for empty axes | ||
for trivial_shape in [(1,), (1, 1), (1, 1, 1)]: | ||
x = mx.nd.zeros(trivial_shape) | ||
x[:] = np.ones(trivial_shape) | ||
x_np = np.ones(trivial_shape, dtype=x.dtype) | ||
assert x.shape == trivial_shape | ||
|
@@ -1286,6 +1308,42 @@ def test_bool(): | |
assert bool(mx.nd.ones((1,))) | ||
|
||
|
||
def test_basic_indexing_is_contiguous(): | ||
x_np = np.arange(np.prod((6, 7, 8, 9))).reshape((6, 7, 8, 9)) | ||
x_mx = mx.nd.array(x_np) | ||
|
||
slices = [ | ||
slice(None), | ||
slice(2), | ||
slice(20), | ||
slice(1, 4), | ||
slice(None, None, 2), | ||
slice(None, None, 20), | ||
slice(0, 1), | ||
slice(None, None, -1), | ||
slice(3, None, -2), | ||
] | ||
|
||
is_contiguous = mx.nd.NDArray._basic_indexing_slice_is_contiguous | ||
|
||
for idx in combinations_with_replacement(slices, 4): | ||
for slc in permutations(idx): | ||
# Check helper function | ||
contig_pred = is_contiguous(slc, x_np.shape) | ||
contig_true = x_np[slc].flags.contiguous | ||
assert contig_pred == contig_true, ( | ||
"failed with slc={}, pred ({}) != actual ({})" | ||
"".format(slc, contig_pred, contig_true) | ||
) | ||
|
||
if contig_pred: | ||
# Check mutation behavior | ||
y_mx = x_mx.copy() | ||
y_mx_slc = y_mx[slc] | ||
y_mx_slc[:] = 0 | ||
assert (y_mx[slc].asnumpy() == 0).all() | ||
|
||
|
||
@with_seed() | ||
def test_ndarray_indexing(): | ||
def test_getitem(np_array, index, is_scalar=False): | ||
|
@@ -1296,22 +1354,24 @@ def test_getitem(np_array, index, is_scalar=False): | |
if isinstance(index, mx.nd.NDArray): | ||
np_index = index.asnumpy() | ||
if isinstance(index, tuple): | ||
np_index = [] | ||
for idx in index: | ||
if isinstance(idx, mx.nd.NDArray): | ||
np_index.append(idx.asnumpy()) | ||
else: | ||
np_index.append(idx) | ||
np_index = tuple(np_index) | ||
np_index = tuple( | ||
idx.asnumpy() if isinstance(idx, mx.nd.NDArray) else idx | ||
for idx in index | ||
) | ||
|
||
np_indexed_array = np_array[np_index] | ||
mx_array = mx.nd.array(np_array, dtype=np_array.dtype) | ||
mx_indexed_array = mx_array[index] | ||
try: | ||
mx_indexed_array = mx_array[index] | ||
except Exception as e: | ||
print('Failed with index = {}'.format(index)) | ||
raise e | ||
if is_scalar: | ||
mx_indexed_array = mx_indexed_array.asscalar() | ||
else: | ||
mx_indexed_array = mx_indexed_array.asnumpy() | ||
assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index) | ||
|
||
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) | ||
|
||
def test_setitem(np_array, index, is_scalar): | ||
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): | ||
|
@@ -1321,7 +1381,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) | |
np_array[np_index] = mx_value.asnumpy() | ||
else: | ||
np_array[np_index] = mx_value | ||
mx_array[mx_index] = mx_value | ||
|
||
try: | ||
mx_array[mx_index] = mx_value | ||
except Exception as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
print('Failed with index = {}, value.shape = {}'.format(mx_index, mx_value.shape)) | ||
raise e | ||
|
||
assert same(np_array, mx_array.asnumpy()) | ||
|
||
np_index = index | ||
|
@@ -1380,7 +1446,9 @@ def test_setitem_autograd(np_array, index): | |
try: | ||
with mx.autograd.record(): | ||
x[index] = y | ||
assert False # should not reach here | ||
# `a[None] = v` is equivalent to `a[...] = v` which doesn't raise | ||
if index is not None: | ||
assert False, 'failed with index = {}'.format(index) # should not reach here | ||
except mx.base.MXNetError as err: | ||
assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1 | ||
|
||
|
@@ -1502,7 +1570,17 @@ def convert(num): | |
(([[[[1]]]], 3, slice(0, 3), 0), False), | ||
(([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False), | ||
(([1, 2], slice(3, 5), [2, 3], [3, 4]), False), | ||
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False)] | ||
(([1, 2], slice(3, 5), (2, 3), [3, 4]), False), | ||
((1, Ellipsis, -1), False), | ||
((slice(2), Ellipsis, None, 0), False), | ||
(None, False), | ||
((1, None, -2, 3, -4), False), | ||
# TODO(zoeygxy): Support None in advanced indexing | ||
# (([1, 2], slice(3, 5), None, None, [3, 4]), False), | ||
# ((slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), False), | ||
# ((slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), False), | ||
# ((None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), False), | ||
] | ||
for index in index_list: | ||
test_getitem(np_array, index[0], index[1]) | ||
test_setitem(np_array, index[0], index[1]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is an exception expected to happen with the given test data? If not, can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's a debugging leftover.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OTOH, there's more code like this in the test suite, like
assert <expr>, "informative error text"
. It adds information to the plain failure report, so I'd prefer to leave it in.