diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 6b44f6669cbf..5e7129226e34 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -176,9 +176,6 @@ def _get_np_basic_indexing(self, key): for ax in new_axes: # pylint: disable=invalid-name final_shape.insert(ax, 1) - if final_shape == []: - # Override for single element indexing - return sliced.item() if sliced.size == 0: return sliced.reshape(tuple(final_shape)) else: @@ -222,7 +219,6 @@ def __getitem__(self, key): if ndim == 0: if key != (): raise IndexError('scalar tensor can only accept `()` as index') - return self.item() # Handle simple cases for higher speed if isinstance(key, tuple) and len(key) == 0: return self diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index c839c6f7cf40..0d66907ad6cd 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -675,7 +675,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape, common::StaticArray* end, common::StaticArray* step) { // Function returns false if output is zero-sized, true otherwise. - bool size_non_zero = true; + bool zero_size_shape = false; CHECK_NE(dshape.ndim(), 0U); CHECK_LE(param_begin.ndim(), dshape.ndim()) << "Slicing axis exceeds data dimensions"; @@ -726,7 +726,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape, (*step)[i] = s; // checking begin==end if (b == e) { - size_non_zero = false; + zero_size_shape = true; } } @@ -736,7 +736,7 @@ inline bool GetIndexRange(const mxnet::TShape& dshape, (*step)[i] = 1; } - return size_non_zero; + return zero_size_shape; } inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, @@ -981,7 +981,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; - if (dshape.ndim() == 0U) return false; + if (!mxnet::ndim_is_known(dshape)) return false; mxnet::TShape vshape = dshape; // vshape is the value shape on the right hand side const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(dshape.ndim(), ndim, { @@ -1024,9 +1024,9 @@ void SliceAssignOpForward(const nnvm::NodeAttrs& attrs, const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; - bool non_zero_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step, + bool zero_size_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); - if (!non_zero_shape) { + if (zero_size_shape) { return; // slice_assign of zero-sized subspace needs no operation. } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { @@ -1129,9 +1129,9 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, const SliceAssignScalarParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; - bool non_zero_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step, + bool zero_size_shape = GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); - if (!non_zero_shape) { + if (zero_size_shape) { return; // slice_assign of zero-sized subspaced needs no operation. } for (index_t i = 0; i < param.begin.ndim(); ++i) { diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 11d55fd68a06..54ba2fe108c7 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -405,13 +405,7 @@ def convert(num): assert False # Copied from test_ndarray.py. Under construction. - def test_getitem(np_array, index, is_scalar=False): - """ - `is_scalar` indicates whether we should expect a scalar for the result. - If so, the indexed array should call asscalar to compare - with numpy's indexed array. - np_array is a native numpy array. - """ + def test_getitem(np_array, index): np_index = index if type(index) == mx.nd.NDArray: # use of NDArray is prohibited assert False @@ -429,17 +423,10 @@ def test_getitem(np_array, index, is_scalar=False): except Exception as e: print('Failed with index = {}'.format(index)) raise e - if not is_scalar: - mx_indexed_array = mx_indexed_array.asnumpy() + mx_indexed_array = mx_indexed_array.asnumpy() assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) - def test_setitem(np_array, index, is_scalar=False): - """ - `is_scalar` indicates whether we should expect a scalar for the result. - If so, the indexed array should call asscalar to compare - with numpy's indexed array. - np_array is a native numpy array. - """ + def test_setitem(np_array, index): def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): if np_value is not None: np_array[np_index] = np_value @@ -447,7 +434,6 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) np_array[np_index] = mx_value.asnumpy() else: np_array[np_index] = mx_value - try: mx_array[mx_index] = mx_value except Exception as e: @@ -470,29 +456,23 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) mx_array = np.array(np_array, dtype=np_array.dtype) # mxnet.np.ndarray np_array = mx_array.asnumpy() # native numpy array - if is_scalar: - # test value is a numeric type - assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0)) - value_nd = [_np.random.randint(low=-10000, high=0)] - assert_same(np_array, np_index, mx_array, index, value_nd, value_nd[0]) - else: - indexed_array_shape = np_array[np_index].shape - np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape) - # test value is a native numpy array without broadcast - assert_same(np_array, np_index, mx_array, index, np_indexed_array) - # test value is a mxnet numpy array without broadcast - assert_same(np_array, np_index, mx_array, index, np.array(np_indexed_array)) - # test value is an numeric_type - assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0)) - if len(indexed_array_shape) > 1: - np_value = _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)) - # test mxnet ndarray with broadcast - assert_same(np_array, np_index, mx_array, index, np.array(np_value)) - # test native numpy array with broadcast - assert_same(np_array, np_index, mx_array, index, np_value) - # test list with broadcast - assert_same(np_array, np_index, mx_array, index, - [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) + indexed_array_shape = np_array[np_index].shape + np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape) + # test value is a native numpy array without broadcast + assert_same(np_array, np_index, mx_array, index, np_indexed_array) + # test value is a mxnet numpy array without broadcast + assert_same(np_array, np_index, mx_array, index, np.array(np_indexed_array)) + # test value is an numeric_type + assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0)) + if len(indexed_array_shape) > 1: + np_value = _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)) + # test mxnet ndarray with broadcast + assert_same(np_array, np_index, mx_array, index, np.array(np_value)) + # test native numpy array with broadcast + assert_same(np_array, np_index, mx_array, index, np_value) + # test list with broadcast + assert_same(np_array, np_index, mx_array, index, + [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) def test_getitem_autograd(np_array, index): """ @@ -502,8 +482,6 @@ def test_getitem_autograd(np_array, index): x.attach_grad() with mx.autograd.record(): y = x[index] - if not isinstance(y, np.ndarray): - return y.backward() value = np.ones_like(y) x_grad = np.zeros_like(x) @@ -515,8 +493,6 @@ def test_setitem_autograd(np_array, index): np_array: native numpy array. """ x = np.array(np_array, dtype=np_array.dtype) - if not isinstance(x[index], np.ndarray): - return # x[index] is scalar out_shape = x[index].shape y = np.array(_np.random.uniform(size=out_shape)) y.attach_grad() @@ -531,142 +507,173 @@ def test_setitem_autograd(np_array, index): shape = (8, 16, 9, 9) np_array = _np.arange(_np.prod(_np.array(shape)), dtype='int32').reshape(shape) # native np array + + # Test sliced output being ndarray: index_list = [ # Basic indexing # Single int as index - (0, False), (np.int32(0), False), (np.int64(0), False), - (5, False), (np.int32(5), False), (np.int64(5), False), - (-1, False), (np.int32(-1), False), (np.int64(-1), False), + 0, + np.int32(0), + np.int64(0), + 5, + np.int32(5), + np.int64(5), + -1, + np.int32(-1), + np.int64(-1), # Slicing as index - (slice(5), False), (np_int(slice(5), np.int32), False), (np_int(slice(5), np.int64), False), - (slice(1, 5), False), (np_int(slice(1, 5), np.int32), False), (np_int(slice(1, 5), np.int64), False), - (slice(1, 5, 2), False), (np_int(slice(1, 5, 2), np.int32), False), - (np_int(slice(1, 5, 2), np.int64), False), - (slice(7, 0, -1), False), (np_int(slice(7, 0, -1)), False), - (np_int(slice(7, 0, -1), np.int64), False), - (slice(None, 6), False), (np_int(slice(None, 6)), False), - (np_int(slice(None, 6), np.int64), False), - (slice(None, 6, 3), False), (np_int(slice(None, 6, 3)), False), - (np_int(slice(None, 6, 3), np.int64), False), - (slice(1, None), False), (np_int(slice(1, None)), False), - (np_int(slice(1, None), np.int64), False), - (slice(1, None, 3), False), (np_int(slice(1, None, 3)), False), - (np_int(slice(1, None, 3), np.int64), False), - (slice(None, None, 2), False), (np_int(slice(None, None, 2)), False), - (np_int(slice(None, None, 2), np.int64), False), - (slice(None, None, -1), False), - (np_int(slice(None, None, -1)), False), (np_int(slice(None, None, -1), np.int64), False), - (slice(None, None, -2), False), - (np_int(slice(None, None, -2), np.int32), False), (np_int(slice(None, None, -2), np.int64), False), + slice(5), + np_int(slice(5), np.int32), + np_int(slice(5), np.int64), + slice(1, 5), + np_int(slice(1, 5), np.int32), + np_int(slice(1, 5), np.int64), + slice(1, 5, 2), + np_int(slice(1, 5, 2), np.int32), + np_int(slice(1, 5, 2), np.int64), + slice(7, 0, -1), + np_int(slice(7, 0, -1)), + np_int(slice(7, 0, -1), np.int64), + slice(None, 6), + np_int(slice(None, 6)), + np_int(slice(None, 6), np.int64), + slice(None, 6, 3), + np_int(slice(None, 6, 3)), + np_int(slice(None, 6, 3), np.int64), + slice(1, None), + np_int(slice(1, None)), + np_int(slice(1, None), np.int64), + slice(1, None, 3), + np_int(slice(1, None, 3)), + np_int(slice(1, None, 3), np.int64), + slice(None, None, 2), + np_int(slice(None, None, 2)), + np_int(slice(None, None, 2), np.int64), + slice(None, None, -1), + np_int(slice(None, None, -1)), + np_int(slice(None, None, -1), np.int64), + slice(None, None, -2), + np_int(slice(None, None, -2), np.int32), + np_int(slice(None, None, -2), np.int64), # Multiple ints as indices - ((1, 2, 3), False), - (np_int((1, 2, 3)), False), - (np_int((1, 2, 3), np.int64), False), - ((-1, -2, -3), False), - (np_int((-1, -2, -3)), False), - (np_int((-1, -2, -3), np.int64), False), - ((1, 2, 3, 4), True), - (np_int((1, 2, 3, 4)), True), - (np_int((1, 2, 3, 4), np.int64), True), - ((-4, -3, -2, -1), True), - (np_int((-4, -3, -2, -1)), True), - (np_int((-4, -3, -2, -1), np.int64), True), + (1, 2, 3), + np_int((1, 2, 3)), + np_int((1, 2, 3), np.int64), + (-1, -2, -3), + np_int((-1, -2, -3)), + np_int((-1, -2, -3), np.int64), + (1, 2, 3, 4), + np_int((1, 2, 3, 4)), + np_int((1, 2, 3, 4), np.int64), + (-4, -3, -2, -1), + np_int((-4, -3, -2, -1)), + np_int((-4, -3, -2, -1), np.int64), # slice(None) as indices - ((slice(None), slice(None), 1, 8), False), - ((slice(None), slice(None), -1, 8), False), - ((slice(None), slice(None), 1, -8), False), - ((slice(None), slice(None), -1, -8), False), - (np_int((slice(None), slice(None), 1, 8)), False), - (np_int((slice(None), slice(None), 1, 8), np.int64), False), - ((slice(None), slice(None), 1, 8), False), - (np_int((slice(None), slice(None), -1, -8)), False), - (np_int((slice(None), slice(None), -1, -8), np.int64), False), - ((slice(None), 2, slice(1, 5), 1), False), - (np_int((slice(None), 2, slice(1, 5), 1)), False), - (np_int((slice(None), 2, slice(1, 5), 1), np.int64), False), + (slice(None), slice(None), 1, 8), + (slice(None), slice(None), -1, 8), + (slice(None), slice(None), 1, -8), + (slice(None), slice(None), -1, -8), + np_int((slice(None), slice(None), 1, 8)), + np_int((slice(None), slice(None), 1, 8), np.int64), + (slice(None), slice(None), 1, 8), + np_int((slice(None), slice(None), -1, -8)), + np_int((slice(None), slice(None), -1, -8), np.int64), + (slice(None), 2, slice(1, 5), 1), + np_int((slice(None), 2, slice(1, 5), 1)), + np_int((slice(None), 2, slice(1, 5), 1), np.int64), # Mixture of ints and slices as indices - ((slice(None, None, -1), 2, slice(1, 5), 1), False), - (np_int((slice(None, None, -1), 2, slice(1, 5), 1)), False), - (np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), False), - ((slice(None, None, -1), 2, slice(1, 7, 2), 1), False), - (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), False), - (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), False), - ((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), False), - (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), False), - (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), False), - ((slice(1, 8, 2), 1, slice(3, 8), 2), False), - (np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), False), - (np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), False), + (slice(None, None, -1), 2, slice(1, 5), 1), + np_int((slice(None, None, -1), 2, slice(1, 5), 1)), + np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), + (slice(None, None, -1), 2, slice(1, 7, 2), 1), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), + np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), + (slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), + np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), + (slice(1, 8, 2), 1, slice(3, 8), 2), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), + np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), # Test Ellipsis ('...') - ((1, Ellipsis, -1), False), - ((slice(2), Ellipsis, None, 0), False), + (1, Ellipsis, -1), + (slice(2), Ellipsis, None, 0), # Test newaxis - (None, False), - ((1, None, -2, 3, -4), False), - ((1, slice(2, 5), None), False), - ((slice(None), slice(1, 4), None, slice(2, 3)), False), - ((slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None), False), - ((slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)), False), - ((None, slice(1, 2), 3, None), False), - ((1, None, 2, 3, None, None, 4), False), + None, + (1, None, -2, 3, -4), + (1, slice(2, 5), None), + (slice(None), slice(1, 4), None, slice(2, 3)), + (slice(1, 3), slice(1, 3), slice(1, 3), slice(1, 3), None), + (slice(1, 3), slice(1, 3), None, slice(1, 3), slice(1, 3)), + (None, slice(1, 2), 3, None), + (1, None, 2, 3, None, None, 4), # Advanced indexing - (([1, 2], slice(3, 5), None, None, [3, 4]), False), - ((slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), False), - ((slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), False), - ((None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), False), - ([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 0, 3, 6, 2, 1], False), - (np.array([6, 3], dtype=np.int32), False), - (np.array([[3, 4], [0, 6]], dtype=np.int32), False), - (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False), - (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False), - (np.array([[2], [0], [1]], dtype=np.int32), False), - (np.array([[2], [0], [1]], dtype=np.int64), False), - (np.array([4, 7], dtype=np.int32), False), - (np.array([4, 7], dtype=np.int64), False), - (np.array([[3, 6], [2, 1]], dtype=np.int32), False), - (np.array([[3, 6], [2, 1]], dtype=np.int64), False), - (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False), - (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False), - ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int32)), False), - ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int64)), False), - ((1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), False), - ((1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), False), - ((1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), False), - ((1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), False), - ((1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), False), - ((1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), False), - ((1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), False), - ((1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), False), - ((1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), False), - ((1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')), - False), - ([0], False), ([0, 1], False), ([1, 2, 3], False), ([2, 0, 5, 6], False), - (([1, 1], [2, 3]), False), (([1], [4], [5]), False), (([1], [4], [5], [6]), False), - (([[1]], [[2]]), False), (([[1]], [[2]], [[3]], [[4]]), False), - ((slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), False), - (([[[[1]]]], [[1]], slice(0, 3), [1, 5]), False), - (([[[[1]]]], 3, slice(0, 3), [1, 3]), False), - (([[[[1]]]], 3, slice(0, 3), 0), False), - (([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False), - (([1, 2], slice(3, 5), [2, 3], [3, 4]), False), - (([1, 2], slice(3, 5), (2, 3), [3, 4]), False), + ([1, 2], slice(3, 5), None, None, [3, 4]), + (slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), + (slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), + (None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), + [1], + [1, 2], + [2, 1, 3], + [7, 5, 0, 3, 6, 2, 1], + np.array([6, 3], dtype=np.int32), + np.array([[3, 4], [0, 6]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), + np.array([[2], [0], [1]], dtype=np.int32), + np.array([[2], [0], [1]], dtype=np.int64), + np.array([4, 7], dtype=np.int32), + np.array([4, 7], dtype=np.int64), + np.array([[3, 6], [2, 1]], dtype=np.int32), + np.array([[3, 6], [2, 1]], dtype=np.int64), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), + np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), + (1, [2, 3]), + (1, [2, 3], np.array([[3], [0]], dtype=np.int32)), + (1, [2, 3]), + (1, [2, 3], np.array([[3], [0]], dtype=np.int64)), + (1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), + (1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), + (1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), + (1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), + (1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), + (1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), + (1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), + (1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), + (1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), + (1, [2], np.array([[4]], dtype=np.int32), np.array([[1, 3], [5, 7]], dtype='int64')), + [0], + [0, 1], + [1, 2, 3], + [2, 0, 5, 6], + ([1, 1], [2, 3]), + ([1], [4], [5]), + ([1], [4], [5], [6]), + ([[1]], [[2]]), + ([[1]], [[2]], [[3]], [[4]]), + (slice(0, 2), [[1], [6]], slice(0, 2), slice(0, 5, 2)), + ([[[[1]]]], [[1]], slice(0, 3), [1, 5]), + ([[[[1]]]], 3, slice(0, 3), [1, 3]), + ([[[[1]]]], 3, slice(0, 3), 0), + ([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), + ([1, 2], slice(3, 5), [2, 3], [3, 4]), + ([1, 2], slice(3, 5), (2, 3), [3, 4]), ] for index in index_list: - test_getitem(np_array, index[0], index[1]) - test_setitem(np_array, index[0], index[1]) - test_getitem_autograd(np_array, index[0]) - test_setitem_autograd(np_array, index[0]) + test_getitem(np_array, index) + test_setitem(np_array, index) + test_getitem_autograd(np_array, index) + test_setitem_autograd(np_array, index) # Test indexing to zero-size tensors index_list = [ - ((slice(0, 0), slice(0, 0), 1, 2), False), - ((slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), False), + (slice(0, 0), slice(0, 0), 1, 2), + (slice(0, 0), slice(0, 0), slice(0, 0), slice(0, 0)), ] for index in index_list: - test_getitem(np_array, index[0], index[1]) - test_setitem(np_array, index[0]) - test_getitem_autograd(np_array, index[0]) - test_setitem_autograd(np_array, index[0]) + test_getitem(np_array, index) + test_setitem(np_array, index) + test_getitem_autograd(np_array, index) + test_setitem_autograd(np_array, index) # test zero-size tensors get and setitem shapes_indices = [ @@ -676,8 +683,8 @@ def test_setitem_autograd(np_array, index): for shape, indices in shapes_indices: for index in indices: np_array = np.zeros(shape) - test_getitem(np_array, index, False) - test_setitem(np_array, index, False) + test_getitem(np_array, index) + test_setitem(np_array, index) test_getitem_autograd(np_array, index) test_setitem_autograd(np_array, index)