Skip to content

Commit

Permalink
[OP] Support range as advanced index for ndarrays (apache#16047)
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce authored and zixuanweeei committed Sep 2, 2019
1 parent 9b505bf commit 8cca23f
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 158 deletions.
12 changes: 7 additions & 5 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,17 +1053,15 @@ def _advanced_index_to_array(idx, ax_len, ctx):
if idx.dtype != idx_dtype:
idx = idx.astype(idx_dtype)
return idx.as_in_context(ctx)

elif isinstance(idx, (np.ndarray, list, tuple)):
return array(idx, ctx, idx_dtype)

elif isinstance(idx, integer_types):
return array([idx], ctx, idx_dtype)

elif isinstance(idx, py_slice):
start, stop, step = idx.indices(ax_len)
return arange(start, stop, step, ctx=ctx, dtype=idx_dtype)

elif sys.version_info[0] > 2 and isinstance(idx, range):
return arange(idx.start, idx.stop, idx.step, ctx=ctx, dtype=idx_dtype)
else:
raise RuntimeError('illegal index type {}'.format(type(idx)))

Expand Down Expand Up @@ -2888,6 +2886,7 @@ def _scatter_set_nd(self, value_nd, indices):
lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
)


def indexing_key_expand_implicit_axes(key, shape):
"""Make implicit axes explicit by adding ``slice(None)``.
Examples
Expand Down Expand Up @@ -2984,6 +2983,8 @@ def _is_advanced_index(idx):
return True
elif isinstance(idx, py_slice) or idx is None:
return False
elif sys.version_info[0] > 2 and isinstance(idx, range):
return True
else:
raise RuntimeError('illegal index type {}'.format(type(idx)))

Expand All @@ -2995,7 +2996,8 @@ def get_indexing_dispatch_code(key):
for idx in key:
if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
return _NDARRAY_ADVANCED_INDEXING

elif sys.version_info[0] > 2 and isinstance(idx, range):
return _NDARRAY_ADVANCED_INDEXING
elif not (isinstance(idx, (py_slice, integer_types)) or idx is None):
raise ValueError(
'NDArray does not support slicing with key {} of type {}.'
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def __getitem__(self, key):
key, shape[0]))
return self._at(key)
elif isinstance(key, py_slice):
if (key.step is None or key.step == 1):
if key.start is not None or key.stop is not None:
if key.step is None or key.step == 1:
if key.start is not None or key.stop is not None:
return self._slice(key.start, key.stop)
else:
return self
Expand Down
303 changes: 152 additions & 151 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,6 @@ def test_np_ndarray_copy():
@with_seed()
@use_np
def test_np_ndarray_indexing():
"""
Test all indexing.
"""
def np_int(index, int_type=np.int32):
"""
Helper function for testing indexing that converts slices to slices of ints or None, and tuples to
Expand Down Expand Up @@ -507,156 +504,160 @@ def test_setitem_autograd(np_array, index):

shape = (8, 16, 9, 9)
np_array = _np.arange(_np.prod(_np.array(shape)), dtype='int32').reshape(shape) # native np array

# Test sliced output being ndarray:
index_list = [
# Basic indexing
# Single int as index
0,
np.int32(0),
np.int64(0),
5,
np.int32(5),
np.int64(5),
-1,
np.int32(-1),
np.int64(-1),
# Slicing as index
slice(5),
np_int(slice(5), np.int32),
np_int(slice(5), np.int64),
slice(1, 5),
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
np_int(slice(7, 0, -1)),
np_int(slice(7, 0, -1), np.int64),
slice(None, 6),
np_int(slice(None, 6)),
np_int(slice(None, 6), np.int64),
slice(None, 6, 3),
np_int(slice(None, 6, 3)),
np_int(slice(None, 6, 3), np.int64),
slice(1, None),
np_int(slice(1, None)),
np_int(slice(1, None), np.int64),
slice(1, None, 3),
np_int(slice(1, None, 3)),
np_int(slice(1, None, 3), np.int64),
slice(None, None, 2),
np_int(slice(None, None, 2)),
np_int(slice(None, None, 2), np.int64),
slice(None, None, -1),
np_int(slice(None, None, -1)),
np_int(slice(None, None, -1), np.int64),
slice(None, None, -2),
np_int(slice(None, None, -2), np.int32),
np_int(slice(None, None, -2), np.int64),
# Multiple ints as indices
(1, 2, 3),
np_int((1, 2, 3)),
np_int((1, 2, 3), np.int64),
(-1, -2, -3),
np_int((-1, -2, -3)),
np_int((-1, -2, -3), np.int64),
(1, 2, 3, 4),
np_int((1, 2, 3, 4)),
np_int((1, 2, 3, 4), np.int64),
(-4, -3, -2, -1),
np_int((-4, -3, -2, -1)),
np_int((-4, -3, -2, -1), np.int64),
# slice(None) as indices
(slice(None), slice(None), 1, 8),
(slice(None), slice(None), -1, 8),
(slice(None), slice(None), 1, -8),
(slice(None), slice(None), -1, -8),
np_int((slice(None), slice(None), 1, 8)),
np_int((slice(None), slice(None), 1, 8), np.int64),
(slice(None), slice(None), 1, 8),
np_int((slice(None), slice(None), -1, -8)),
np_int((slice(None), slice(None), -1, -8), np.int64),
(slice(None), 2, slice(1, 5), 1),
np_int((slice(None), 2, slice(1, 5), 1)),
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
# Mixture of ints and slices as indices
(slice(None, None, -1), 2, slice(1, 5), 1),
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
(slice(1, 8, 2), 1, slice(3, 8), 2),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
# Test Ellipsis ('...')
(1, Ellipsis, -1),
(slice(2), Ellipsis, None, 0),
# Test newaxis
None,
(1, None, -2, 3, -4),
(1, slice(2, 5), None),
(slice(None), slice(1, 4), None, slice(2, 3)),
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
(None, slice(1, 2), 3, None),
(1, None, 2, 3, None, None, 4),
# Advanced indexing
([1, 2], slice(3, 5), None, None, [3, 4]),
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
[1],
[1, 2],
[2, 1, 3],
[7, 5, 0, 3, 6, 2, 1],
np.array([6, 3], dtype=np.int32),
np.array([[3, 4], [0, 6]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
np.array([[2], [0], [1]], dtype=np.int32),
np.array([[2], [0], [1]], dtype=np.int64),
np.array([4, 7], dtype=np.int32),
np.array([4, 7], dtype=np.int64),
np.array([[3, 6], [2, 1]], dtype=np.int32),
np.array([[3, 6], [2, 1]], dtype=np.int64),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
[0],
[0, 1],
[1, 2, 3],
[2, 0, 5, 6],
([1, 1], [2, 3]),
([1], [4], [5]),
([1], [4], [5], [6]),
([[1]], [[2]]),
([[1]], [[2]], [[3]], [[4]]),
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
([[[[1]]]], 3, slice(0, 3), [1, 3]),
([[[[1]]]], 3, slice(0, 3), 0),
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
([1, 2], slice(3, 5), [2, 3], [3, 4]),
([1, 2], slice(3, 5), (2, 3), [3, 4]),
(),
# Basic indexing
# Single int as index
0,
np.int32(0),
np.int64(0),
5,
np.int32(5),
np.int64(5),
-1,
np.int32(-1),
np.int64(-1),
# Slicing as index
slice(5),
np_int(slice(5), np.int32),
np_int(slice(5), np.int64),
slice(1, 5),
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),
np_int(slice(7, 0, -1)),
np_int(slice(7, 0, -1), np.int64),
slice(None, 6),
np_int(slice(None, 6)),
np_int(slice(None, 6), np.int64),
slice(None, 6, 3),
np_int(slice(None, 6, 3)),
np_int(slice(None, 6, 3), np.int64),
slice(1, None),
np_int(slice(1, None)),
np_int(slice(1, None), np.int64),
slice(1, None, 3),
np_int(slice(1, None, 3)),
np_int(slice(1, None, 3), np.int64),
slice(None, None, 2),
np_int(slice(None, None, 2)),
np_int(slice(None, None, 2), np.int64),
slice(None, None, -1),
np_int(slice(None, None, -1)),
np_int(slice(None, None, -1), np.int64),
slice(None, None, -2),
np_int(slice(None, None, -2), np.int32),
np_int(slice(None, None, -2), np.int64),
# Multiple ints as indices
(1, 2, 3),
np_int((1, 2, 3)),
np_int((1, 2, 3), np.int64),
(-1, -2, -3),
np_int((-1, -2, -3)),
np_int((-1, -2, -3), np.int64),
(1, 2, 3, 4),
np_int((1, 2, 3, 4)),
np_int((1, 2, 3, 4), np.int64),
(-4, -3, -2, -1),
np_int((-4, -3, -2, -1)),
np_int((-4, -3, -2, -1), np.int64),
# slice(None) as indices
(slice(None), slice(None), 1, 8),
(slice(None), slice(None), -1, 8),
(slice(None), slice(None), 1, -8),
(slice(None), slice(None), -1, -8),
np_int((slice(None), slice(None), 1, 8)),
np_int((slice(None), slice(None), 1, 8), np.int64),
(slice(None), slice(None), 1, 8),
np_int((slice(None), slice(None), -1, -8)),
np_int((slice(None), slice(None), -1, -8), np.int64),
(slice(None), 2, slice(1, 5), 1),
np_int((slice(None), 2, slice(1, 5), 1)),
np_int((slice(None), 2, slice(1, 5), 1), np.int64),
# Mixture of ints and slices as indices
(slice(None, None, -1), 2, slice(1, 5), 1),
np_int((slice(None, None, -1), 2, slice(1, 5), 1)),
np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64),
(slice(None, None, -1), 2, slice(1, 7, 2), 1),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)),
np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64),
(slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))),
np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64),
(slice(1, 8, 2), 1, slice(3, 8), 2),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2)),
np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64),
# Test Ellipsis ('...')
(1, Ellipsis, -1),
(slice(2), Ellipsis, None, 0),
# Test newaxis
None,
(1, None, -2, 3, -4),
(1, slice(2, 5), None),
(slice(None), slice(1, 4), None, slice(2, 3)),
(slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None),
(slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)),
(None, slice(1, 2), 3, None),
(1, None, 2, 3, None, None, 4),
# Advanced indexing
([1, 2], slice(3, 5), None, None, [3, 4]),
(slice(None), slice(3, 5), None, None, [2, 3], [3, 4]),
(slice(None), slice(3, 5), None, [2, 3], None, [3, 4]),
(None, slice(None), slice(3, 5), [2, 3], None, [3, 4]),
[1],
[1, 2],
[2, 1, 3],
[7, 5, 0, 3, 6, 2, 1],
np.array([6, 3], dtype=np.int32),
np.array([[3, 4], [0, 6]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
np.array([[2], [0], [1]], dtype=np.int32),
np.array([[2], [0], [1]], dtype=np.int64),
np.array([4, 7], dtype=np.int32),
np.array([4, 7], dtype=np.int64),
np.array([[3, 6], [2, 1]], dtype=np.int32),
np.array([[3, 6], [2, 1]], dtype=np.int64),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32),
np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int32)),
(1, [2, 3]),
(1, [2, 3], np.array([[3], [0]], dtype=np.int64)),
(1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)),
(1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)),
(1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)),
(1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)),
(1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)),
(1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)),
(1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')),
[0],
[0, 1],
[1, 2, 3],
[2, 0, 5, 6],
([1, 1], [2, 3]),
([1], [4], [5]),
([1], [4], [5], [6]),
([[1]], [[2]]),
([[1]], [[2]], [[3]], [[4]]),
(slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)),
([[[[1]]]], [[1]], slice(0, 3), [1, 5]),
([[[[1]]]], 3, slice(0, 3), [1, 3]),
([[[[1]]]], 3, slice(0, 3), 0),
([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)),
([1, 2], slice(3, 5), [2, 3], [3, 4]),
([1, 2], slice(3, 5), (2, 3), [3, 4]),
range(4),
range(3, 0, -1),
(range(4,), [1]),
]
for index in index_list:
test_getitem(np_array, index)
Expand Down

0 comments on commit 8cca23f

Please sign in to comment.