From 8cca23f457fd9c62348477bb17fb3301f9e63b1c Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 30 Aug 2019 11:04:08 -0700 Subject: [PATCH] [OP] Support range as advanced index for ndarrays (#16047) --- python/mxnet/ndarray/ndarray.py | 12 +- python/mxnet/numpy/multiarray.py | 4 +- tests/python/unittest/test_numpy_ndarray.py | 303 ++++++++++---------- 3 files changed, 161 insertions(+), 158 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 6c2bb8078922..59306e21e6f5 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1053,17 +1053,15 @@ def _advanced_index_to_array(idx, ax_len, ctx): if idx.dtype != idx_dtype: idx = idx.astype(idx_dtype) return idx.as_in_context(ctx) - elif isinstance(idx, (np.ndarray, list, tuple)): return array(idx, ctx, idx_dtype) - elif isinstance(idx, integer_types): return array([idx], ctx, idx_dtype) - elif isinstance(idx, py_slice): start, stop, step = idx.indices(ax_len) return arange(start, stop, step, ctx=ctx, dtype=idx_dtype) - + elif sys.version_info[0] > 2 and isinstance(idx, range): + return arange(idx.start, idx.stop, idx.step, ctx=ctx, dtype=idx_dtype) else: raise RuntimeError('illegal index type {}'.format(type(idx))) @@ -2888,6 +2886,7 @@ def _scatter_set_nd(self, value_nd, indices): lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self ) + def indexing_key_expand_implicit_axes(key, shape): """Make implicit axes explicit by adding ``slice(None)``. Examples @@ -2984,6 +2983,8 @@ def _is_advanced_index(idx): return True elif isinstance(idx, py_slice) or idx is None: return False + elif sys.version_info[0] > 2 and isinstance(idx, range): + return True else: raise RuntimeError('illegal index type {}'.format(type(idx))) @@ -2995,7 +2996,8 @@ def get_indexing_dispatch_code(key): for idx in key: if isinstance(idx, (NDArray, np.ndarray, list, tuple)): return _NDARRAY_ADVANCED_INDEXING - + elif sys.version_info[0] > 2 and isinstance(idx, range): + return _NDARRAY_ADVANCED_INDEXING elif not (isinstance(idx, (py_slice, integer_types)) or idx is None): raise ValueError( 'NDArray does not support slicing with key {} of type {}.' diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index c419411edb28..83688774f069 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -235,8 +235,8 @@ def __getitem__(self, key): key, shape[0])) return self._at(key) elif isinstance(key, py_slice): - if (key.step is None or key.step == 1): - if key.start is not None or key.stop is not None: + if key.step is None or key.step == 1: + if key.start is not None or key.stop is not None: return self._slice(key.start, key.stop) else: return self diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 6dd7b43cd82c..fafa5a827c2f 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -377,9 +377,6 @@ def test_np_ndarray_copy(): @with_seed() @use_np def test_np_ndarray_indexing(): - """ - Test all indexing. - """ def np_int(index, int_type=np.int32): """ Helper function for testing indexing that converts slices to slices of ints or None, and tuples to @@ -507,156 +504,160 @@ def test_setitem_autograd(np_array, index): shape = (8, 16, 9, 9) np_array = _np.arange(_np.prod(_np.array(shape)), dtype='int32').reshape(shape) # native np array - + # Test sliced output being ndarray: index_list = [ - # Basic indexing - # Single int as index - 0, - np.int32(0), - np.int64(0), - 5, - np.int32(5), - np.int64(5), - -1, - np.int32(-1), - np.int64(-1), - # Slicing as index - slice(5), - np_int(slice(5), np.int32), - np_int(slice(5), np.int64), - slice(1, 5), - np_int(slice(1, 5), np.int32), - np_int(slice(1, 5), np.int64), - slice(1, 5, 2), - np_int(slice(1, 5, 2), np.int32), - np_int(slice(1, 5, 2), np.int64), - slice(7, 0, -1), - np_int(slice(7, 0, -1)), - np_int(slice(7, 0, -1), np.int64), - slice(None, 6), - np_int(slice(None, 6)), - np_int(slice(None, 6), np.int64), - slice(None, 6, 3), - np_int(slice(None, 6, 3)), - np_int(slice(None, 6, 3), np.int64), - slice(1, None), - np_int(slice(1, None)), - np_int(slice(1, None), np.int64), - slice(1, None, 3), - np_int(slice(1, None, 3)), - np_int(slice(1, None, 3), np.int64), - slice(None, None, 2), - np_int(slice(None, None, 2)), - np_int(slice(None, None, 2), np.int64), - slice(None, None, -1), - np_int(slice(None, None, -1)), - np_int(slice(None, None, -1), np.int64), - slice(None, None, -2), - np_int(slice(None, None, -2), np.int32), - np_int(slice(None, None, -2), np.int64), - # Multiple ints as indices - (1, 2, 3), - np_int((1, 2, 3)), - np_int((1, 2, 3), np.int64), - (-1, -2, -3), - np_int((-1, -2, -3)), - np_int((-1, -2, -3), np.int64), - (1, 2, 3, 4), - np_int((1, 2, 3, 4)), - np_int((1, 2, 3, 4), np.int64), - (-4, -3, -2, -1), - np_int((-4, -3, -2, -1)), - np_int((-4, -3, -2, -1), np.int64), - # slice(None) as indices - (slice(None), slice(None), 1, 8), - (slice(None), slice(None), -1, 8), - (slice(None), slice(None), 1, -8), - (slice(None), slice(None), -1, -8), - np_int((slice(None), slice(None), 1, 8)), - np_int((slice(None), slice(None), 1, 8), np.int64), - (slice(None), slice(None), 1, 8), - np_int((slice(None), slice(None), -1, -8)), - np_int((slice(None), slice(None), -1, -8), np.int64), - (slice(None), 2, slice(1, 5), 1), - np_int((slice(None), 2, slice(1, 5), 1)), - np_int((slice(None), 2, slice(1, 5), 1), np.int64), - # Mixture of ints and slices as indices - (slice(None, None, -1), 2, slice(1, 5), 1), - np_int((slice(None, None, -1), 2, slice(1, 5), 1)), - np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), - (slice(None, None, -1), 2, slice(1, 7, 2), 1), - np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), - np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), - (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), - np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), - np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), - (slice(1, 8, 2), 1, slice(3, 8), 2), - np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), - np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), - # Test Ellipsis ('...') - (1, Ellipsis, -1), - (slice(2), Ellipsis, None, 0), - # Test newaxis - None, - (1, None, -2, 3, -4), - (1, slice(2, 5), None), - (slice(None), slice(1, 4), None, slice(2, 3)), - (slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None), - (slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)), - (None, slice(1, 2), 3, None), - (1, None, 2, 3, None, None, 4), - # Advanced indexing - ([1, 2], slice(3, 5), None, None, [3, 4]), - (slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), - (slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), - (None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), - [1], - [1, 2], - [2, 1, 3], - [7, 5, 0, 3, 6, 2, 1], - np.array([6, 3], dtype=np.int32), - np.array([[3, 4], [0, 6]], dtype=np.int32), - np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), - np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), - np.array([[2], [0], [1]], dtype=np.int32), - np.array([[2], [0], [1]], dtype=np.int64), - np.array([4, 7], dtype=np.int32), - np.array([4, 7], dtype=np.int64), - np.array([[3, 6], [2, 1]], dtype=np.int32), - np.array([[3, 6], [2, 1]], dtype=np.int64), - np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), - np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), - (1, [2, 3]), - (1, [2, 3], np.array([[3], [0]], dtype=np.int32)), - (1, [2, 3]), - (1, [2, 3], np.array([[3], [0]], dtype=np.int64)), - (1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), - (1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), - (1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), - (1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), - (1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), - (1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), - (1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), - (1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), - (1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), - (1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')), - [0], - [0, 1], - [1, 2, 3], - [2, 0, 5, 6], - ([1, 1], [2, 3]), - ([1], [4], [5]), - ([1], [4], [5], [6]), - ([[1]], [[2]]), - ([[1]], [[2]], [[3]], [[4]]), - (slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), - ([[[[1]]]], [[1]], slice(0, 3), [1, 5]), - ([[[[1]]]], 3, slice(0, 3), [1, 3]), - ([[[[1]]]], 3, slice(0, 3), 0), - ([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), - ([1, 2], slice(3, 5), [2, 3], [3, 4]), - ([1, 2], slice(3, 5), (2, 3), [3, 4]), + (), + # Basic indexing + # Single int as index + 0, + np.int32(0), + np.int64(0), + 5, + np.int32(5), + np.int64(5), + -1, + np.int32(-1), + np.int64(-1), + # Slicing as index + slice(5), + np_int(slice(5), np.int32), + np_int(slice(5), np.int64), + slice(1, 5), + np_int(slice(1, 5), np.int32), + np_int(slice(1, 5), np.int64), + slice(1, 5, 2), + np_int(slice(1, 5, 2), np.int32), + np_int(slice(1, 5, 2), np.int64), + slice(7, 0, -1), + np_int(slice(7, 0, -1)), + np_int(slice(7, 0, -1), np.int64), + slice(None, 6), + np_int(slice(None, 6)), + np_int(slice(None, 6), np.int64), + slice(None, 6, 3), + np_int(slice(None, 6, 3)), + np_int(slice(None, 6, 3), np.int64), + slice(1, None), + np_int(slice(1, None)), + np_int(slice(1, None), np.int64), + slice(1, None, 3), + np_int(slice(1, None, 3)), + np_int(slice(1, None, 3), np.int64), + slice(None, None, 2), + np_int(slice(None, None, 2)), + np_int(slice(None, None, 2), np.int64), + slice(None, None, -1), + np_int(slice(None, None, -1)), + np_int(slice(None, None, -1), np.int64), + slice(None, None, -2), + np_int(slice(None, None, -2), np.int32), + np_int(slice(None, None, -2), np.int64), + # Multiple ints as indices + (1, 2, 3), + np_int((1, 2, 3)), + np_int((1, 2, 3), np.int64), + (-1, -2, -3), + np_int((-1, -2, -3)), + np_int((-1, -2, -3), np.int64), + (1, 2, 3, 4), + np_int((1, 2, 3, 4)), + np_int((1, 2, 3, 4), np.int64), + (-4, -3, -2, -1), + np_int((-4, -3, -2, -1)), + np_int((-4, -3, -2, -1), np.int64), + # slice(None) as indices + (slice(None), slice(None), 1, 8), + (slice(None), slice(None), -1, 8), + (slice(None), slice(None), 1, -8), + (slice(None), slice(None), -1, -8), + np_int((slice(None), slice(None), 1, 8)), + np_int((slice(None), slice(None), 1, 8), np.int64), + (slice(None), slice(None), 1, 8), + np_int((slice(None), slice(None), -1, -8)), + np_int((slice(None), slice(None), -1, -8), np.int64), + (slice(None), 2, slice(1, 5), 1), + np_int((slice(None), 2, slice(1, 5), 1)), + np_int((slice(None), 2, slice(1, 5), 1), np.int64), + # Mixture of ints and slices as indices + (slice(None, None, -1), 2, slice(1, 5), 1), + np_int((slice(None, None, -1), 2, slice(1, 5), 1)), + np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), + (slice(None, None, -1), 2, slice(1, 7, 2), 1), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), + (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), + (slice(1, 8, 2), 1, slice(3, 8), 2), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), + # Test Ellipsis ('...') + (1, Ellipsis, -1), + (slice(2), Ellipsis, None, 0), + # Test newaxis + None, + (1, None, -2, 3, -4), + (1, slice(2, 5), None), + (slice(None), slice(1, 4), None, slice(2, 3)), + (slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None), + (slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)), + (None, slice(1, 2), 3, None), + (1, None, 2, 3, None, None, 4), + # Advanced indexing + ([1, 2], slice(3, 5), None, None, [3, 4]), + (slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), + (slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), + (None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), + [1], + [1, 2], + [2, 1, 3], + [7, 5, 0, 3, 6, 2, 1], + np.array([6, 3], dtype=np.int32), + np.array([[3, 4], [0, 6]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), + np.array([[2], [0], [1]], dtype=np.int32), + np.array([[2], [0], [1]], dtype=np.int64), + np.array([4, 7], dtype=np.int32), + np.array([4, 7], dtype=np.int64), + np.array([[3, 6], [2, 1]], dtype=np.int32), + np.array([[3, 6], [2, 1]], dtype=np.int64), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), + (1, [2, 3]), + (1, [2, 3], np.array([[3], [0]], dtype=np.int32)), + (1, [2, 3]), + (1, [2, 3], np.array([[3], [0]], dtype=np.int64)), + (1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), + (1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), + (1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), + (1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), + (1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), + (1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), + (1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), + (1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), + (1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), + (1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')), + [0], + [0, 1], + [1, 2, 3], + [2, 0, 5, 6], + ([1, 1], [2, 3]), + ([1], [4], [5]), + ([1], [4], [5], [6]), + ([[1]], [[2]]), + ([[1]], [[2]], [[3]], [[4]]), + (slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), + ([[[[1]]]], [[1]], slice(0, 3), [1, 5]), + ([[[[1]]]], 3, slice(0, 3), [1, 3]), + ([[[[1]]]], 3, slice(0, 3), 0), + ([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), + ([1, 2], slice(3, 5), [2, 3], [3, 4]), + ([1, 2], slice(3, 5), (2, 3), [3, 4]), + range(4), + range(3, 0, -1), + (range(4,), [1]), ] for index in index_list: test_getitem(np_array, index)