diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index d67779e9a0f7..d02a3daa5fdc 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -18,9 +18,10 @@ # coding: utf-8 # pylint: disable=too-many-lines, protected-access # pylint: disable=import-error, no-name-in-module, undefined-variable + """NDArray API of MXNet.""" -from __future__ import absolute_import -from __future__ import division + +from __future__ import absolute_import, division try: from __builtin__ import slice as py_slice @@ -406,19 +407,15 @@ def __setstate__(self, state): else: self.handle = None - # pylint: disable=line-too-long def __setitem__(self, key, value): """x.__setitem__(i, y) <=> x[i]=y - Sets value to self[key]. This functions supports advanced indexing defined in the following reference with - some restrictions. + Sets ``self[key]`` to ``value``. - https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing - - - If key is a list type, only a list of integers is supported, e.g. key=[1, 2] is supported, - while not for key=[[1, 2]]. - - Ellipsis (...) and np.newaxis are not supported. - - Boolean array indexing is not supported. + This functions supports advanced indexing as defined in `the NumPy + advanced indexing documentation + `_, + with the restriction that boolean array indexing is not supported. Parameters ---------- @@ -429,27 +426,24 @@ def __setitem__(self, key, value): Examples -------- - >>> x = mx.nd.zeros((2,3)) + >>> x = mx.nd.zeros((2, 3)) >>> x[:] = 1 >>> x.asnumpy() array([[ 1., 1., 1.], [ 1., 1., 1.]], dtype=float32) - >>> x.asnumpy() - array([[ 1., 1., 1.], - [ 1., 1., 1.]], dtype=float32) - >>> x[:,1:2] = 2 + >>> x[:, 1:2] = 2 >>> x.asnumpy() array([[ 1., 2., 1.], [ 1., 2., 1.]], dtype=float32) - >>> x[1:2,1:] = 3 + >>> x[1:2, 1:] = 3 >>> x.asnumpy() array([[ 1., 2., 1.], [ 1., 3., 3.]], dtype=float32) - >>> x[1:,0:2] = mx.nd.zeros((1,2)) + >>> x[1:, 0:2] = mx.nd.zeros((1, 2)) >>> x.asnumpy() array([[ 1., 2., 1.], [ 0., 0., 3.]], dtype=float32) - >>> x[1,2] = 4 + >>> x[1, 2] = 4 >>> x.asnumpy() array([[ 1., 2., 1.], [ 0., 0., 4.]], dtype=float32) @@ -462,31 +456,51 @@ def __setitem__(self, key, value): array([[ 6., 5., 5.], [ 6., 0., 4.]], dtype=float32) """ - indexing_dispatch_code = _get_indexing_dispatch_code(key) + if self.ndim == 0 and key == (): + _internal._full(shape=self.shape, value=float(value), ctx=self.context, + dtype=self.dtype, out=self) + return + key = _indexing_key_expand_implicit_axes(key, self.shape) + slc_key = tuple(idx for idx in key if idx is not None) + + if len(slc_key) < self.ndim: + raise RuntimeError( + 'too few indices after normalization: expected `ndim` ({}) ' + 'but got {}. This is a bug, please report it!' + ''.format(self.ndim, len(slc_key)) + ) + if len(slc_key) > self.ndim: + raise IndexError( + 'too many indices ({}) for array with {} dimensions' + ''.format(len(slc_key), self.ndim) + ) + + indexing_dispatch_code = _get_indexing_dispatch_code(slc_key) if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING: - self._set_nd_basic_indexing(key, value) + self._set_nd_basic_indexing(slc_key, value) elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING: - self._set_nd_advanced_indexing(key, value) + self._set_nd_advanced_indexing(slc_key, value) else: - raise ValueError('Indexing NDArray with index=%s and type=%s is not supported' - % (str(key), str(type(key)))) - # pylint: enable=line-too-long + raise ValueError( + 'Indexing NDArray with index {} of type {} is not supported' + ''.format(key, type(key)) + ) - # pylint: disable=line-too-long def __getitem__(self, key): """x.__getitem__(i) <=> x[i] - Returns a sliced view of this array if the elements fetched are contiguous in memory; - otherwise, returns a newly created NDArray. - This functions supports advanced indexing defined in the following reference with - some restrictions. + Returns the subarray ``self[key]``. - https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing + For basic indexing, i.e., if ``key`` consists only of integers, + ``slice``, ``Ellipsis`` (``...``) and ``None``, a mutable view is + returned that shares memory with this array if the accessed portion is + contiguous in memory. + Otherwise, a newly created ``NDArray`` is returned. - - If key is a list type, only a list of integers is supported, e.g. key=[1, 2] is supported, - while not for key=[[1, 2]]. - - Ellipsis (...) and np.newaxis are not supported. - - Boolean array indexing is not supported. + This functions supports advanced indexing as defined in `the NumPy + advanced indexing documentation + `_, + with the restriction that boolean array indexing is not supported. Parameters ---------- @@ -495,196 +509,142 @@ def __getitem__(self, key): Examples -------- - >>> x = mx.nd.arange(0,6).reshape((2,3)) + The default is to give explicit indices for all axes: + + >>> x = mx.nd.arange(0, 6).reshape((2, 3)) >>> x.asnumpy() array([[ 0., 1., 2.], [ 3., 4., 5.]], dtype=float32) - >>> x[1].asnumpy() - array([ 3., 4., 5.], dtype=float32) - >>> y = x[0:1] - >>> y[:] = 2 + >>> x[0, :].asnumpy() + array([0., 1., 2.], dtype=float32) + >>> x[0, :2].asnumpy() + array([0., 1.], dtype=float32) + >>> x[:, :-1].asnumpy() + array([[0., 1.], + [3., 4.]], dtype=float32) + + If fewer indices are given, they are automatically supplemented by an + appropriate number of ``slice(None)`` ("``:``") to the right. For + instance, a single integer indexes along the first axis: + + >>> x = mx.nd.arange(0, 6).reshape((2, 3)) + >>> x[0].asnumpy() + array([0., 1., 2.], dtype=float32) + >>> x[1:].asnumpy() + array([[3., 4., 5.]], dtype=float32) + + To omit a range of axes that should be kept as-is, an `Ellipsis` + ("``...``") can be used: + + >>> x = mx.nd.arange(0, 16).reshape((2, 2, 2, 2)) + >>> x[0, ..., 1].asnumpy() + array([[1., 3.], + [5., 7.]], dtype=float32) + >>> x[0, :, :, 1].asnumpy() # equivalent + array([[1., 3.], + [5., 7.]], dtype=float32) + + New axes of length 1 can be created by inserting ``None`` + (`numpy.newaxis`) in the index: + + >>> x = mx.nd.arange(0, 6).reshape((2, 3)) + >>> x[None, :, :].asnumpy() + array([[[0., 1., 2.], + [3., 4., 5.]]], dtype=float32) + >>> x[None, :, :].shape + (1, 2, 3) + + If the indexed portion of the array is contiguous in memory, no data + is copied. Instead, a shared-memory view of the original array is + returned, and changes to that view affect the original array: + + >>> x = mx.nd.arange(0, 8).reshape((2, 2, 2)) + >>> y = x[0] # contiguous + >>> y.asnumpy() + array([[0., 1.], + [2., 3.]], dtype=float32) + >>> y[:] = -1 >>> x.asnumpy() - array([[ 2., 2., 2.], - [ 3., 4., 5.]], dtype=float32) - >>> x = mx.nd.arange(0, 8, dtype='int32').reshape((2, 2, 2)) - >>> x[[0, 1]] - [[[0 1] - [2 3]] - [[4 5] - [6 7]]] - >>> x[1:, [0, 1]] - [[[4 5] - [6 7]]] + array([[[-1., -1.], + [-1., -1.]], + + [[ 4., 5.], + [ 6., 7.]]], dtype=float32) + >>> x = mx.nd.arange(0, 8).reshape((2, 2, 2)) + >>> y = x[1, :1, :] # contiguous + >>> y.asnumpy() + array([[4., 5.]], dtype=float32) + >>> y[:] = -1 + >>> x.asnumpy() + array([[[ 0., 1.], + [ 2., 3.]], + + [[-1., -1.], + [ 6., 7.]]], dtype=float32) + >>> x = mx.nd.arange(0, 8).reshape((2, 2, 2)) + >>> y = x[:, :, 1] # not contiguous + >>> y.asnumpy() + array([[1., 3.], + [5., 7.]], dtype=float32) + >>> y[:] = -1 + >>> x.asnumpy() + array([[[0., 1.], + [2., 3.]], + + [[4., 5.], + [6., 7.]]], dtype=float32) + + If the indexing key contains `list`, `numpy.ndarray` or `NDArray` + objects, advanced indexing is triggered, which always returns a + copy: + + >>> x = mx.nd.arange(0, 8).reshape((2, 2, 2)) + >>> x[[0, 1]].asnumpy() + array([[[0., 1.], + [2., 3.]], + + [[4., 5.], + [6., 7.]]], dtype=float32) + >>> x[[0, 1], :].asnumpy() # equivalent + array([[[0., 1.], + [2., 3.]], + + [[4., 5.], + [6., 7.]]], dtype=float32) >>> y = np.array([0, 1], dtype='int32') - >>> x[1:, y] - [[[4 5] - [6 7]]] + >>> x[1:, y].asnumpy() + array([[[4., 5.], + [6., 7.]]], dtype=float32) >>> y = mx.nd.array([0, 1], dtype='int32') - >>> x[1:, y] - [[[4 5] - [6 7]]] + >>> x[1:, y].asnumpy() + array([[[4., 5.], + [6., 7.]]], dtype=float32) """ + if self.ndim == 0 and key == (): + return self + key = _indexing_key_expand_implicit_axes(key, self.shape) + if len(key) == 0: + raise ValueError('indexing key cannot be an empty tuple') + indexing_dispatch_code = _get_indexing_dispatch_code(key) if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING: return self._get_nd_basic_indexing(key) elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING: return self._get_nd_advanced_indexing(key) else: - raise ValueError('Indexing NDArray with index=%s and type=%s is not supported' - % (str(key), str(type(key)))) - # pylint: enable=line-too-long + raise RuntimeError - def _get_index_nd(self, key): - """Returns an index array for use in scatter_nd and gather_nd.""" - def _is_advanced_index(index): - """The definition of advanced index here includes integers as well, while - integers are considered as basic index type when the key contains only - slices and integers.""" - return not isinstance(index, py_slice) - - if isinstance(key, (NDArray, np.ndarray, list, integer_types, py_slice)): - key = (key,) - - assert isinstance(key, tuple),\ - 'index=%s must be a NDArray, or np.ndarray, or list, or tuple ' \ - ' type to use advanced indexing, received type=%s' % (str(key), str(type(key))) - - assert len(key) > 0, "Cannot slice with empty indices" - shape = self.shape - assert len(shape) >= len(key),\ - "Slicing dimensions exceeds array dimensions, %d vs %d" % (len(key), len(shape)) - indices = [] - dtype = 'int32' # index data type passed to gather_nd op - need_broadcast = (len(key) != 1) - advanced_indices = [] # include list, NDArray, np.ndarray, integer - basic_indices = [] # include only slices - advanced_index_bshape = None # final advanced index shape - for i, idx_i in enumerate(key): - is_advanced_index = True - if isinstance(idx_i, (np.ndarray, list, tuple)): - idx_i = array(idx_i, ctx=self.context, dtype=dtype) - advanced_indices.append(i) - elif isinstance(idx_i, py_slice): - start, stop, step = _get_index_range(idx_i.start, idx_i.stop, shape[i], idx_i.step) - idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype) - basic_indices.append(i) - is_advanced_index = False - elif isinstance(idx_i, integer_types): - start, stop, step = _get_index_range(idx_i, idx_i+1, shape[i], 1) - idx_i = arange(start, stop, step, ctx=self.context, dtype=dtype) - advanced_indices.append(i) - elif isinstance(idx_i, NDArray): - if dtype != idx_i.dtype: - idx_i = idx_i.astype(dtype) - advanced_indices.append(i) - else: - raise IndexError('Indexing NDArray with index=%s of type=%s is not supported' - % (str(key), str(type(key)))) - if is_advanced_index: - if advanced_index_bshape is None: - advanced_index_bshape = idx_i.shape - elif advanced_index_bshape != idx_i.shape: - need_broadcast = True - advanced_index_bshape = _get_broadcast_shape(advanced_index_bshape, idx_i.shape) - indices.append(idx_i) - - # Get final index shape for gather_nd. See the following reference - # for determining the output array shape. - # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#combining-advanced-and-basic-indexing # pylint: disable=line-too-long - if len(advanced_indices) == 0: - raise ValueError('Advanced index tuple must contain at least one of the following types:' - ' list, tuple, NDArray, np.ndarray, integer, received index=%s' % key) - # determine the output array's shape by checking whether advanced_indices are all adjacent - # or separated by slices - advanced_indices_adjacent = True - for i in range(0, len(advanced_indices)-1): - if advanced_indices[i] + 1 != advanced_indices[i+1]: - advanced_indices_adjacent = False - break - - index_bshape_list = [] # index broadcasted shape - if advanced_indices_adjacent: - for i in range(0, advanced_indices[0]): - index_bshape_list.extend(indices[i].shape) - if not need_broadcast and indices[i].shape != advanced_index_bshape: - need_broadcast = True - index_bshape_list.extend(advanced_index_bshape) - for i in range(advanced_indices[-1]+1, len(indices)): - if not need_broadcast and indices[i].shape != advanced_index_bshape: - need_broadcast = True - index_bshape_list.extend(indices[i].shape) - else: - index_bshape_list.extend(advanced_index_bshape) - for i in basic_indices: - index_bshape_list.extend(indices[i].shape) - if not need_broadcast and indices[i].shape != advanced_index_bshape: - need_broadcast = True - index_bshape = tuple(index_bshape_list) - - # Need to broadcast all ndarrays in indices to the final shape. - # For example, suppose an array has shape=(5, 6, 7, 8) and - # key=(slice(1, 5), [[1, 2]], slice(2, 5), [1]). - # Since key[1] and key[3] are two advanced indices here and they are - # separated by basic indices key[0] and key[2], the output shape - # is (1, 2, 4, 3), where the first two elements come from the shape - # that key[1] and key[3] should broadcast to, which is (1, 2), and - # the last two elements come from the shape of two basic indices. - # In order to broadcast all basic and advanced indices to the output shape, - # we need to reshape them based on their axis. For example, to broadcast key[0], - # with shape=(4,), we first need to reshape it into (1, 1, 4, 1), and then - # broadcast the reshaped array to (1, 2, 4, 3); to broadcast key[1], we first - # reshape it into (1, 2, 1, 1), then broadcast the reshaped array to (1, 2, 4, 3). - if need_broadcast: - broadcasted_indices = [] - idx_rshape = [1] * len(index_bshape) - if advanced_indices_adjacent: - advanced_index_bshape_start = advanced_indices[0] # start index of advanced_index_bshape in index_shape - advanced_index_bshape_stop = advanced_index_bshape_start + len(advanced_index_bshape) - for i, idx in enumerate(key): - if _is_advanced_index(idx): - k = advanced_index_bshape_stop - # find the reshaped shape for indices[i] - for dim_size in indices[i].shape[::-1]: - k -= 1 - idx_rshape[k] = dim_size - else: - if i < advanced_indices[0]: # slice is on the left side of advanced indices - idx_rshape[i] = indices[i].shape[0] - elif i > advanced_indices[-1]: # slice is on the right side of advanced indices - idx_rshape[i-len(key)] = indices[i].shape[0] - else: - raise ValueError('basic index i=%d cannot be between advanced index i=%d and i=%d' - % (i, advanced_indices[0], advanced_indices[-1])) - # broadcast current index to the final shape - broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape)) - # reset idx_rshape to ones - for j, _ in enumerate(idx_rshape): - idx_rshape[j] = 1 - else: - basic_index_offset = len(advanced_index_bshape) - for i, idx in enumerate(key): - if _is_advanced_index(idx): - k = len(advanced_index_bshape) - for dim_size in indices[i].shape[::-1]: - k -= 1 - idx_rshape[k] = dim_size - else: - idx_rshape[basic_index_offset] = indices[i].shape[0] - basic_index_offset += 1 - # broadcast current index to the final shape - broadcasted_indices.append(indices[i].reshape(tuple(idx_rshape)).broadcast_to(index_bshape)) - # reset idx_rshape to ones - for j, _ in enumerate(idx_rshape): - idx_rshape[j] = 1 - - indices = broadcasted_indices - return op.stack(*indices) - - def _prepare_value_nd(self, value, vshape): - """Given value and vshape, create an `NDArray` from value with the same - context and dtype as the current one and broadcast it to vshape.""" + def _prepare_value_nd(self, value, new_axes, bcast_shape): + """Return a broadcast `NDArray` with same context and dtype as ``self``. + + Before broadcasting, ``new_axes`` of length 1 will be added to + ``value``. This is done in contrast to blindly reshaping based on + ``bcast_shape``, since the latter would silently ignore wrongly shaped + ``value`` arrays, e.g. ``nd.zeros((2, 3))[:, :1] = nd.ones(2)``. + """ if isinstance(value, numeric_types): - value_nd = full(shape=vshape, val=value, ctx=self.context, dtype=self.dtype) + value_nd = full(bcast_shape, value, ctx=self.context, dtype=self.dtype) + new_axes = [] # ignore for scalar elif isinstance(value, NDArray): value_nd = value.as_in_context(self.context) if value_nd.dtype != self.dtype: @@ -693,168 +653,461 @@ def _prepare_value_nd(self, value, vshape): try: value_nd = array(value, ctx=self.context, dtype=self.dtype) except: - raise TypeError('NDArray does not support assignment with non-array-like' - ' object %s of type %s' % (str(value), str(type(value)))) - if value_nd.shape != vshape: - value_nd = value_nd.broadcast_to(vshape) + raise TypeError('NDArray does not support assignment with non-array-like ' + 'object {} of type {}'.format(value, type(value))) + + # First reshape `value_nd` to a new shape that incorporates existing + # axes, new axes and broadcasting axes in the right way. + tmp_shape = _shape_for_bcast( + value_nd.shape, target_ndim=len(bcast_shape), new_axes=new_axes + ) + value_nd = value_nd.reshape(tmp_shape) + + if value_nd.shape != bcast_shape: + value_nd = value_nd.broadcast_to(bcast_shape) return value_nd + # pylint: disable=invalid-name + @staticmethod + def _basic_indexing_key_to_begin_end_step(idcs, shape, keep_none=True): + """Map a tuple of ``slice`` and ``None`` (ignored) to begin, end, step tuples.""" + idcs = [idx for idx in idcs if idx is not None] + idcs = [idx if isinstance(idx, py_slice) else _int_to_slice(idx) + for idx in idcs] + + if keep_none: + sss_list = [(slc.start, slc.stop, slc.step) for slc, n in zip(idcs, shape)] + else: + sss_list = [slc.indices(n) for slc, n in zip(idcs, shape)] + return tuple(zip(*sss_list)) + # pylint: enable=invalid-name + + # pylint: disable=invalid-name + @staticmethod + def _basic_indexing_key_int_to_slice(idcs): + """Return the converted indexing tuple and the integer axes.""" + int_axes = [] + conv_idcs = [] + for ax, idx in enumerate(idcs): + if isinstance(idx, integer_types): + conv_idcs.append(_int_to_slice(idx)) + int_axes.append(ax) + else: + conv_idcs.append(idx) + + return tuple(conv_idcs), tuple(int_axes) + # pylint: enable=invalid-name + + @staticmethod + def _new_axes_after_basic_indexing(axes, key_nd): + """Return indices of ``axes`` after slicing with ``key_nd``. + + This function is used to calculate the positions where new axes should + end up after indexing, taking into account the removal of axes by + integer indexing. + + The ``key_nd`` sequence should contain slices and integers only, no + ``None`` entries. + """ + steps = [0] + [0 if isinstance(idx, integer_types) else 1 + for idx in key_nd] + cum_steps = np.cumsum(steps) + axes_in_bounds = [ax for ax in axes if ax < len(cum_steps)] + axes_out_of_bounds = [ax for ax in axes if ax >= len(cum_steps)] + axes_after = tuple(cum_steps[axes_in_bounds]) + oob_offsets = [ax - len(key_nd) for ax in axes_out_of_bounds] + axes_after += tuple(cum_steps[-1] + offset for offset in oob_offsets) + return axes_after + + # pylint: disable=invalid-name + @staticmethod + def _basic_indexing_slice_is_contiguous(slc_key, shape): + """Whether indexing with the given key results in a contiguous array. + + The rule is: From right to left, if in an axis, a slice produces a + proper subset, no later axis can produce a proper subset or use + a step different from 1. + + The ``slc_key`` sequence must have the same length as ``shape`` and + only contain `slice` objects. + """ + assert len(slc_key) == len(shape) + subset = False + for idx, n in zip(reversed(slc_key), reversed(shape)): + start, stop, step = idx.indices(n) + if step > 0: + num = int(np.ceil(max(stop - start, 0) / step)) + else: + num = int(np.ceil(min(stop - start, 0) / step)) + + if num != 1 and (subset or step != 1): + return False + if num != n: + subset = True + + return True + # pylint: enable=invalid-name + + @staticmethod + def _basic_indexing_sliced_shape(slc_key, shape): + """Return the shape after slicing with the given key.""" + assert len(slc_key) == len(shape) + sliced_shape = [] + for idx, n in zip(slc_key, shape): + start, stop, step = idx.indices(n) + if step > 0: + num = int(np.ceil(max(stop - start, 0) / step)) + else: + num = int(np.ceil(min(stop - start, 0) / step)) + sliced_shape.append(num) + + return tuple(sliced_shape) + + # pylint: disable=invalid-name + @staticmethod + def _basic_indexing_contiguous_flat_begin_end(slc_key, shape): + """Return the flat indices of begin and end for contiguous slicing.""" + assert len(slc_key) == len(shape) + begin, end, _ = slc_key[0].indices(shape[0]) + flat_begin, flat_end = begin, end - 1 + for idx, n in zip(slc_key[1:], shape[1:]): + flat_begin *= n + flat_end *= n + begin, end, _ = idx.indices(n) + flat_begin += begin + flat_end += end - 1 + + return flat_begin, flat_end + 1 + # pylint: enable=invalid-name + def _set_nd_basic_indexing(self, key, value): - """This function is called by __setitem__ when key is a basic index, i.e. - an integer, or a slice, or a tuple of integers and slices. No restrictions - on the values of slices' steps.""" - shape = self.shape - if isinstance(key, integer_types): - if key < 0: - key += shape[0] - if key < 0 or key >= shape[0]: - if key < 0: - key -= shape[0] - raise IndexError('index %d is out of bounds for axis 0 with size %d' - % (key, shape[0])) - key = py_slice(key, key+1) # key must be >= 0 here - - if isinstance(key, py_slice): - assign_to_self = key.step is None or key.step == 1 - assign_to_self &= key.start is None or key.start == 0 - assign_to_self &= key.stop is None or key.stop == shape[0] - if assign_to_self: # trivial case, assign value to self - if isinstance(value, NDArray): - if value.handle is not self.handle: - if value.shape != shape: - value = value.broadcast_to(shape) - value.copyto(self) - elif isinstance(value, numeric_types): - _internal._full(shape=shape, ctx=self.context, - dtype=self.dtype, value=float(value), out=self) - elif isinstance(value, (np.ndarray, np.generic)): - if isinstance(value, np.generic) or value.shape != shape: - value = np.broadcast_to(value, shape) - self._sync_copyfrom(value) - else: # value might be a list or a tuple - value_nd = self._prepare_value_nd(value, shape) - value_nd.copyto(self) - return - else: # non-trivial case, use _slice_assign or _slice_assign_scalar - key = (key,) - - assert isinstance(key, tuple), "key=%s must be a tuple of slices and integers" % str(key) - - assert len(key) <= len(shape), "Indexing dimensions exceed array dimensions, %d vs %d"\ - % (len(key), len(shape)) - begin = [] - end = [] - steps = [] - oshape = [] # output shape of slice using key - vshape = [] # value shape of data[key] - for i, slice_i in enumerate(key): - dim_size = 1 - if isinstance(slice_i, py_slice): - begin.append(slice_i.start) - end.append(slice_i.stop) - steps.append(slice_i.step) - start, stop, step = _get_index_range(slice_i.start, slice_i.stop, - shape[i], slice_i.step) - dim_size = _get_dim_size(start, stop, step) - vshape.append(dim_size) - elif isinstance(slice_i, integer_types): - begin.append(slice_i) - end.append(slice_i+1 if slice_i != -1 else self.shape[i]) - steps.append(1) + """This function indexes ``self`` with a tuple of ``slice`` objects only.""" + for idx in key: + if not isinstance(idx, (py_slice, integer_types)): + raise RuntimeError( + '`key` may only contain `slice` or integer objects in the ' + 'basic implementation, got object of type {}. ' + 'This is a bug, please report it!' + ''.format(type(idx))) + int_axes = [ + ax for ax in range(len(key)) if isinstance(key[ax], integer_types) + ] + begin, end, step = self._basic_indexing_key_to_begin_end_step( + key, self.shape, keep_none=False + ) + indexed_shape = tuple( + _get_dim_size(b, e, s) for b, e, s in zip(begin, end, step) + ) + can_assign_directly = ( + (indexed_shape == self.shape) and all(s > 0 for s in step) + ) + begin, end, step = self._basic_indexing_key_to_begin_end_step( + key, self.shape, keep_none=True + ) + + if can_assign_directly: + # Easy case, overwrite whole array. + if isinstance(value, NDArray): + if value.handle is not self.handle: + # Need to do this before `broadcast_to`. + tmp_shape = _shape_for_bcast( + value.shape, target_ndim=self.ndim, new_axes=int_axes + ) + value = value.reshape(tmp_shape) + + if value.shape != self.shape: + value = value.broadcast_to(self.shape) + value.copyto(self) + + elif isinstance(value, numeric_types): + _internal._full( + shape=self.shape, value=float(value), ctx=self.context, + dtype=self.dtype, out=self + ) + + elif isinstance(value, (np.ndarray, np.generic)): + tmp_shape = _shape_for_bcast( + value.shape, target_ndim=self.ndim, new_axes=int_axes + ) + value = value.reshape(tmp_shape) + + if isinstance(value, np.generic) or value.shape != self.shape: + value = np.broadcast_to(value, self.shape) + self._sync_copyfrom(value) + else: - raise ValueError("basic indexing does not support index=%s of type=%s" - % (str(slice_i), str(type(slice_i)))) - oshape.append(dim_size) - - oshape.extend(shape[len(key):]) - vshape.extend(shape[len(key):]) - # if key contains all integers, vshape should be (1,) - if len(vshape) == 0: - vshape.append(1) - oshape = tuple(oshape) - vshape = tuple(vshape) + # Other array-like + value_nd = self._prepare_value_nd( + value, new_axes=int_axes, bcast_shape=self.shape + ) + value_nd.copyto(self) - if isinstance(value, numeric_types): - _internal._slice_assign_scalar(self, out=self, begin=begin, end=end, - step=steps, scalar=float(value)) + elif isinstance(value, numeric_types): + _internal._slice_assign_scalar( + self, float(value), begin, end, step, out=self + ) + + else: + value_nd = self._prepare_value_nd( + value, new_axes=int_axes, bcast_shape=indexed_shape + ) + _internal._slice_assign(self, value_nd, begin, end, step, out=self) + + def _get_nd_basic_indexing(self, key): + """This function indexes ``self`` with a tuple of `slice` objects only.""" + key_nd = tuple(idx for idx in key if idx is not None) + if len(key_nd) < self.ndim: + raise RuntimeError( + 'too few indices after normalization: expected `ndim` ({}) ' + 'but got {}. This is a bug, please report it!' + ''.format(self.ndim, len(key_nd)) + ) + if len(key_nd) > self.ndim: + raise IndexError( + 'too many indices ({}) for array with {} dimensions' + ''.format(len(key_nd), self.ndim) + ) + + none_axes = [ax for ax in range(len(key)) if key[ax] is None] # pylint: disable=invalid-name + slc_key, int_axes = self._basic_indexing_key_int_to_slice(key_nd) + new_axes = self._new_axes_after_basic_indexing(none_axes, key_nd) + + # Check bounds for integer axes + for ax in int_axes: # pylint: disable=invalid-name + if not -self.shape[ax] <= key_nd[ax] < self.shape[ax]: + raise IndexError( + 'index {} is out of bounds for axis {} with size {}' + ''.format(key_nd[ax], ax, self.shape[ax])) + + # Make sure we don't accidentally have advanced indexing or + # unsupported entries. + for idx in slc_key: + if not isinstance(idx, py_slice): + raise RuntimeError( + 'found object of type {} instead of `slice`. ' + 'This is a bug, please report it!' + ''.format(type(idx))) + + # Convert to begin, end and step, and return immediately if the slice + # is empty + begin, end, step = self._basic_indexing_key_to_begin_end_step( + slc_key, self.shape, keep_none=False + ) + # Pylint is wrong about this + # pylint: disable=bad-continuation + if any( + b >= e and s > 0 or b <= e and s < 0 for b, e, s in zip(begin, end, step) + ): + return array([], self.context, self.dtype) + # pylint: enable=bad-continuation + + if self._basic_indexing_slice_is_contiguous(slc_key, self.shape): + # Create a shared-memory view by using low-level flat slicing + flat_begin, flat_end = self._basic_indexing_contiguous_flat_begin_end( + slc_key, self.shape + ) + handle = NDArrayHandle() + flat_self = self.reshape(-1) + check_call( + _LIB.MXNDArraySlice( + flat_self.handle, + mx_uint(flat_begin), + mx_uint(flat_end), + ctypes.byref(handle), + ) + ) + sliced_shape = self._basic_indexing_sliced_shape(slc_key, self.shape) + sliced = NDArray(handle=handle, writable=self.writable).reshape(sliced_shape) + else: + begin, end, step = self._basic_indexing_key_to_begin_end_step( + slc_key, self.shape, keep_none=True + ) + sliced = op.slice(self, begin, end, step) + + # Reshape to final shape due to integer and `None` entries in `key`. + final_shape = [sliced.shape[i] for i in range(sliced.ndim) + if i not in int_axes] + for ax in new_axes: # pylint: disable=invalid-name + final_shape.insert(ax, 1) + + if final_shape == []: + # Override for single element indexing + final_shape = [1] + + return sliced.reshape(final_shape) + + @staticmethod + def _advanced_index_to_array(idx, ax_len, ctx): + """Convert ``idx`` to `NDArray` for advanced indexing. + + The ``ax_len`` is used to convert `slice` objects to integer arrays. + """ + idx_dtype = 'int32' + if isinstance(idx, NDArray): + if idx.dtype != idx_dtype: + idx = idx.astype(idx_dtype) + return idx.as_in_context(ctx) + + elif isinstance(idx, (np.ndarray, list, tuple)): + return array(idx, ctx, idx_dtype) + + elif isinstance(idx, integer_types): + return array([idx], ctx, idx_dtype) + + elif isinstance(idx, py_slice): + start, stop, step = idx.indices(ax_len) + return arange(start, stop, step, ctx=ctx, dtype=idx_dtype) + + else: + raise RuntimeError('illegal index type {}'.format(type(idx))) + + # pylint: disable=invalid-name + @staticmethod + def _broadcast_advanced_indices(arrays, block_axes): + """Broadcast arrays according to position in the sequence. + + Here, "according to position" means that an array of dimension 1 + (which is the case for all except ``block_axes``) will have shape + ``(1, ..., 1, N, 1, ..., 1)``, where ``N`` is the length, and the + position of ``N`` in the shape is the same as the position of the + array in the ``arrays`` sequence, plus extra dimensions of the + advanced block if it is left of the array. + + The arrays at ``block_axes`` are the advanced indices. They are assumed to + be ready for mutual broadcasting to produce the advanced indexing block. + It is further assumed that the numbers in ``block_axes`` are consecutive. + + The return value is a tuple containing the arrays with broadcast shapes. + """ + block_shape = _broadcast_shapes([arrays[ax] for ax in block_axes]) + ndim_blk = len(block_shape) + ndim_blk_delta = ndim_blk - len(block_axes) + ndim_lead = block_axes[0] + ndim_trail = len(arrays) - (block_axes[-1] + 1) + + bcast_shape = ( + tuple(arrays[ax].shape[0] for ax in range(ndim_lead)) + + block_shape + + tuple(arrays[ax].shape[0] for ax in range(block_axes[-1] + 1, len(arrays))) + ) + + bcast_arrays = [None] * len(arrays) + for ax in block_axes: + arr = arrays[ax].broadcast_to(block_shape) + shp = (1,) * ndim_lead + block_shape + (1,) * ndim_trail + bcast_arrays[ax] = arr.reshape(shp).broadcast_to(bcast_shape) + + for ax in set(range(len(arrays))) - set(block_axes): + shp = [1] * len(bcast_shape) + if ax < ndim_lead: + shp[ax] = arrays[ax].shape[0] + else: + shp[ax + ndim_blk_delta] = arrays[ax].shape[0] + bcast_arrays[ax] = arrays[ax].reshape(shp).broadcast_to(bcast_shape) + + return tuple(bcast_arrays) + # pylint: enable=invalid-name + + @staticmethod + def _drop_slice_none_at_end(key): + """Remove ``slice(None)`` at the end of a key. + + This is used for efficiency in advanced indexing, to avoid generating + ``arange(n)`` arrays for these axes. The `gather_nd` and `scatter_nd` + handle implicit full trailing axes automatically. + """ + key = list(key) + while isinstance(key[-1], py_slice) and key[-1] == slice(None): + key.pop() + return tuple(key) + + def _get_index_nd(self, key): + """Return an index array for use in `scatter_nd` and `gather_nd`.""" + key_nd = tuple(idx for idx in key if idx is not None) + if len(key_nd) < self.ndim: + raise RuntimeError( + 'too few indices after normalization: expected `ndim` ({}) ' + 'but got {}. This is a bug, please report it!' + ''.format(self.ndim, len(key_nd)) + ) + if len(key_nd) > self.ndim: + raise IndexError( + 'too many indices ({}) for array with {} dimensions' + ''.format(len(key_nd), self.ndim) + ) + ndim = len(key_nd) + + # --- Preparation --- # + + # - Make lists for bookkeeping of advanced indices & axes + # - Drop trailing `slice(None)` entries in `key` for efficiency + # - Determine whether the advanced indices are adjacent in `key` + # - Depending on that, make index permutations to move around indices + + adv_axs = [ax for ax, idx in enumerate(key) if _is_advanced_index(idx)] + adv_axs_nd = [ax for ax, idx in enumerate(key_nd) if _is_advanced_index(idx)] + adv_idcs_are_adjacent = bool(np.all(np.diff(adv_axs) == 1)) + nonadv_axs_nd = [ax for ax in range(ndim) if ax not in adv_axs_nd] + adv_idcs_nd = [key_nd[ax] for ax in adv_axs_nd] + idcs_short = self._drop_slice_none_at_end(key_nd) + dropped_axs = list(range(len(idcs_short), ndim)) + + if adv_idcs_are_adjacent: + # The easy case: the advanced block can stay at its position, and no + # permutation needs to be done (identity permutation) + axs_nd_permut = axs_nd_permut_inv = tuple(range(ndim)) + idcs_permut_short = idcs_short + block_axs_nd = adv_axs_nd else: - value_nd = self._prepare_value_nd(value, vshape) - if vshape != oshape: - value_nd = value_nd.reshape(oshape) - _internal._slice_assign(self, value_nd, begin, end, steps, out=self) + # The more complicated case: during broadcasting, we need to use the + # indices in the *permuted* order, where the advanced block is + # at the beginning, while the final index for `gather_nd` is stacked + # in the *original* order, so that the association of index with + # array axis remains the same. + + # This order is used for broadcasting: advanced block at the beginning + idcs_permut_short = ( + adv_idcs_nd + + [key_nd[ax] for ax in range(ndim) + if ax not in adv_axs_nd and ax not in dropped_axs] + ) + block_axs_nd = list(range(len(adv_axs_nd))) + axs_nd_permut = adv_axs_nd + nonadv_axs_nd + axs_nd_permut_inv = list(np.argsort(axs_nd_permut)) + + # --- Conversion, broadcasting and index stacking --- # + + # - Convert all indices in `key` to arrays: integers to 1-element arrays, + # `slice` objects to arrays with explicit indices + # - Reshape arrays for broadcasting according to their position in the + # *permuted* key + # - Broadcast and stack the indices in the *original* order + + shape_nd_permut = tuple(self.shape[ax] for ax in axs_nd_permut) + converted_idcs_short = [ + self._advanced_index_to_array(idx, ax_len, self.context) + for idx, ax_len in zip(idcs_permut_short, shape_nd_permut) + ] + bcast_idcs_permut_short = self._broadcast_advanced_indices( + converted_idcs_short, block_axes=block_axs_nd + ) + # Undo the permutation to restore the original order + bcast_idcs_short = [ + bcast_idcs_permut_short[ax] + for ax in axs_nd_permut_inv + if axs_nd_permut[ax] not in dropped_axs + ] + + return op.stack(*bcast_idcs_short) def _set_nd_advanced_indexing(self, key, value): """This function is called by __setitem__ when key is an advanced index.""" indices = self._get_index_nd(key) vshape = _get_oshape_of_gather_nd_op(self.shape, indices.shape) - value_nd = self._prepare_value_nd(value, vshape) - _internal._scatter_set_nd(lhs=self, rhs=value_nd, indices=indices, - shape=self.shape, out=self) - - def _get_nd_basic_indexing(self, key): - """This function is called when key is a slice, or an integer, - or a tuple of slices or integers""" - shape = self.shape - if isinstance(key, integer_types): - if key > shape[0] - 1: - raise IndexError( - 'index {} is out of bounds for axis 0 with size {}'.format( - key, shape[0])) - return self._at(key) - elif isinstance(key, py_slice): - if key.step is not None and key.step != 1: - if key.step == 0: - raise ValueError("slice step cannot be zero") - return op.slice(self, begin=(key.start,), end=(key.stop,), step=(key.step,)) - elif key.start is not None or key.stop is not None: - return self._slice(key.start, key.stop) - else: - return self - - if not isinstance(key, tuple): - raise ValueError('index=%s must be a slice, or an ineger, or a tuple' - ' of slices and integers to use basic indexing, received type=%s' - % (str(key), str(type(key)))) - assert len(key) != 0, 'basic index cannot be an empty tuple' - begin = [] - end = [] - step = [] - kept_axes = [] # axes where slice_i is a slice - i = -1 - for i, slice_i in enumerate(key): - if isinstance(slice_i, integer_types): - begin.append(slice_i) - end.append(slice_i+1 if slice_i != -1 else self.shape[i]) - step.append(1) - elif isinstance(slice_i, py_slice): - if slice_i.step == 0: - raise ValueError('basic index=%s cannot have slice=%s with step = 0' - % (str(key), str(slice_i))) - begin.append(slice_i.start) - end.append(slice_i.stop) - step.append(slice_i.step) - kept_axes.append(i) - else: - raise ValueError('basic_indexing does not support slicing with ' - 'index=%s of type=%s.' % (str(slice_i), str(type(slice_i)))) - kept_axes.extend(range(i+1, len(shape))) - sliced_nd = op.slice(self, begin, end, step) - if len(kept_axes) == len(shape): - return sliced_nd - # squeeze sliced_shape to remove the axes indexed by integers - oshape = [] - sliced_shape = sliced_nd.shape - for axis in kept_axes: - oshape.append(sliced_shape[axis]) - # if key is a tuple of integers, still need to keep 1 dim - # while in Numpy, the output will become an value instead of an ndarray - if len(oshape) == 0: - oshape.append(1) - oshape = tuple(oshape) - assert np.prod(oshape) == np.prod(sliced_shape), 'oshape=%s has different size'\ - ' than sliced_shape=%s'\ - % (oshape, sliced_shape) - return sliced_nd.reshape(oshape) + value_nd = self._prepare_value_nd(value, new_axes=[], bcast_shape=vshape) + _internal._scatter_set_nd( + lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self + ) def _get_nd_advanced_indexing(self, key): """Get item when key is a tuple of any objects of the following types: @@ -2294,30 +2547,123 @@ def to_dlpack_for_write(self): """ return to_dlpack_for_write(self) + +def _indexing_key_expand_implicit_axes(key, shape): + """Make implicit axes explicit by adding ``slice(None)``. + + Examples + -------- + >>> shape = (3, 4, 5) + >>> _indexing_key_expand_implicit_axes(np.s_[2, 1, 1], shape) + (2, 1, 1) + >>> _indexing_key_expand_implicit_axes(np.s_[0], shape) + (0, slice(None, None, None), slice(None, None, None)) + >>> _indexing_key_expand_implicit_axes(np.s_[0, ...], shape) # equivalent + (0, slice(None, None, None), slice(None, None, None)) + >>> _indexing_key_expand_implicit_axes(np.s_[:2, None, 0, ...], shape) + (slice(None, 2, None), None, 0, slice(None, None, None)) + """ + if not isinstance(key, tuple): + key = (key,) + + # We need to loop explicitly since tuple functions like `index()` or + # `count()` use `==` internally, which doesn't play well with fancy + # indexing. + ell_idx = None + num_none = 0 + nonell_key = [] + for i, idx in enumerate(key): + if idx is Ellipsis: + if ell_idx is not None: + raise IndexError( + 'Cannot use more than one ellipsis (`...`) for indexing' + ) + ell_idx = i + else: + if idx is None: + num_none += 1 + nonell_key.append(idx) + + nonell_key = tuple(nonell_key) + + if ell_idx is None: + # This handles the case of "too few" indices, e.g., `nd.zeros((2, 3))[0]`, + # where the ellipsis is implicitly after the last entry. + ell_idx = len(nonell_key) + + ell_ndim = len(shape) + num_none - len(nonell_key) + expanded_key = (nonell_key[:ell_idx] + + (slice(None),) * ell_ndim + + nonell_key[ell_idx:]) + + return expanded_key + + +def _int_to_slice(idx): + """Return a slice that indexes the same entries as a single int.""" + if idx == -1: + # Avoid slice(-1, 0) + return slice(-1, None) + else: + return slice(idx, idx + 1) + + +def _shape_for_bcast(shape, target_ndim, new_axes): + """Return shape with added axes for broadcasting in ``target_ndim`` dimensions. + + If ``shape`` is shorter than ``target_ndim``, fixed ``1`` entries are inserted + into the returned shape, in locations indexed by ``new_axes``. The rest is + filled from the back with ``shape`` while possible. + """ + new_shape = [None] * target_ndim + if len(shape) < target_ndim: + for new_ax in new_axes: + new_shape[new_ax] = 1 + + # Replace `None` from the right with `shape` entries from the right as + # long as possible, thereafter with 1. + ax_s = 1 + for ax in range(1, target_ndim + 1): + if new_shape[-ax] is None: + try: + new_shape[-ax] = shape[-ax_s] + ax_s += 1 + except IndexError: + new_shape[-ax] = 1 + + return tuple(new_shape) + + +def _is_advanced_index(idx): + """Return whether ``idx`` is an advanced index (array-like or integer). + + Note that in contrast to basic indexing, integers are considered advanced + indices in the context of advanced indexing as they participate in + broadcasting. + """ + if isinstance(idx, (NDArray, np.ndarray, integer_types, list, tuple)): + return True + elif isinstance(idx, py_slice) or idx is None: + return False + else: + raise RuntimeError('illegal index type {}'.format(type(idx))) + + def _get_indexing_dispatch_code(key): """Returns a dispatch code for calling basic or advanced indexing functions.""" - if isinstance(key, (NDArray, np.ndarray)): - return _NDARRAY_ADVANCED_INDEXING - elif isinstance(key, list): - # TODO(junwu): Add support for nested lists besides integer list - for i in key: - if not isinstance(i, integer_types): - raise TypeError('Indexing NDArray only supports a list of integers as index' - ' when key is of list type, received element=%s of type=%s' - % (str(i), str(type(i)))) - return _NDARRAY_ADVANCED_INDEXING - elif isinstance(key, (integer_types, py_slice)): - return _NDARRAY_BASIC_INDEXING - elif isinstance(key, tuple): - for idx in key: - if isinstance(idx, (NDArray, np.ndarray, list, tuple)): - return _NDARRAY_ADVANCED_INDEXING - elif not isinstance(idx, (py_slice, integer_types)): - raise ValueError("NDArray does not support slicing with key %s of type %s." - % (str(idx), str(type(idx)))) - return _NDARRAY_BASIC_INDEXING - else: - return _NDARRAY_UNSUPPORTED_INDEXING + assert isinstance(key, tuple) + + for idx in key: + if isinstance(idx, (NDArray, np.ndarray, list, tuple)): + return _NDARRAY_ADVANCED_INDEXING + + elif not (isinstance(idx, (py_slice, integer_types)) or idx is None): + raise ValueError( + 'NDArray does not support slicing with key {} of type {}.' + ''.format(idx, type(idx)) + ) + + return _NDARRAY_BASIC_INDEXING def _get_index_range(start, stop, length, step=1): @@ -2377,6 +2723,8 @@ def _get_dim_size(start, stop, step): """Given start, stop, and stop, calculate the number of elements of this slice.""" assert step != 0 + if stop == start: + return 0 if step > 0: assert start < stop dim_size = (stop - start - 1) // step + 1 @@ -2407,6 +2755,14 @@ def _get_broadcast_shape(shape1, shape2): return tuple(shape) +def _broadcast_shapes(seq): + """Return the broadcast shape of all advanced indices in ``seq``. + + All entries are assumed to have a ``shape`` property. + """ + return reduce(_get_broadcast_shape, [x.shape for x in seq], ()) + + def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -2519,9 +2875,18 @@ def array(source_array, ctx=None, dtype=None): source_array = np.array(source_array, dtype=dtype) except: raise TypeError('source_array must be array like object') - arr = empty(source_array.shape, ctx, dtype) - arr[:] = source_array - return arr + + if source_array.shape == (): + # In this case we can't assign, so we need to go through an auxiliary array + arr = empty((1,), ctx, dtype) + arr[:] = source_array + return arr.reshape(()) + elif source_array.size == 0: + return empty(source_array.shape, ctx, dtype) + else: + arr = empty(source_array.shape, ctx, dtype) + arr[:] = source_array + return arr def moveaxis(tensor, source, destination): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bfe520b0137a..6663af33a175 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1073,7 +1073,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) for g in executor.grad_arrays: - g[:] = 0 + print(g.shape) + if g.ndim == 0: + g[()] = 0 + else: + g[:] = 0 executor.forward(is_train=False) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index aa6e7bbaf5ee..4daace2fb35c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -696,7 +696,10 @@ inline void GetIndexRange(const mxnet::TShape& dshape, // checking upper and lower bounds for end if (e < 0 && param_end[i].has_value()) { - e += len; + if (!(s < 0 && e == -1)) { + // Keep end=-1 as one-beyond-limits index for negative stride + e += len; + } CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len << " exceeds limit of input dimension[" << i << "]=" << len; } @@ -725,11 +728,11 @@ inline void SetSliceOpOutputDimSize(const index_t i, const int b, mxnet::TShape* oshape) { if (e != b) { if (s > 0) { - CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" + CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" << e << ", and step[" << i << "]=" << s << " is invalid"; (*oshape)[i] = (e - b - 1) / s + 1; } else { - CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" + CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" << e << ", and step[" << i << "]=" << s << " is invalid"; (*oshape)[i] = (b - e - 1) / (-s) + 1; } diff --git a/tests/python/unittest/test_dgl_graph.py b/tests/python/unittest/test_dgl_graph.py index e24cf4deb756..805adc2dac6f 100644 --- a/tests/python/unittest/test_dgl_graph.py +++ b/tests/python/unittest/test_dgl_graph.py @@ -61,7 +61,7 @@ def check_compact(csr, id_arr, num_nodes): compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False) assert compact.shape[0] == num_nodes assert compact.shape[1] == num_nodes - assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1 + assert mx.nd.sum(compact.indptr == csr.indptr[0:int(num_nodes + 1)]).asnumpy() == num_nodes + 1 sub_indices = compact.indices.asnumpy() indices = csr.indices.asnumpy() id_arr = id_arr.asnumpy() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index f40bb3053358..20e2df0130f2 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -18,11 +18,11 @@ import mxnet as mx import numpy as np from distutils.version import LooseVersion +from itertools import permutations, combinations_with_replacement import os import pickle as pkl -import unittest -from nose.tools import raises -from common import setup_module, with_seed, assertRaises, TemporaryDirectory, teardown +from nose.tools import assert_raises, raises +from common import with_seed, assertRaises, TemporaryDirectory from mxnet.test_utils import almost_equal from mxnet.test_utils import assert_almost_equal, assert_exception from mxnet.test_utils import default_context @@ -101,6 +101,26 @@ def test_ndarray_setitem(): x_np[-1] = 1 assert same(x.asnumpy(), x_np) + # Ellipsis + x = mx.nd.zeros(shape) + x[2, ...] = 1 + x_np = np.zeros(shape, dtype=x.dtype) + x_np[2, ...] = 1 + assert same(x.asnumpy(), x_np) + + x = mx.nd.zeros(shape) + x[..., 1] = 1 + x_np = np.zeros(shape, dtype=x.dtype) + x_np[..., 1] = 1 + assert same(x.asnumpy(), x_np) + + # `None` should be ignored + x = mx.nd.zeros(shape) + x[None, 0, None, None, 0, 0, None] = 1 + x_np = np.zeros(shape, dtype=x.dtype) + x_np[None, 0, None, None, 0, 0, None] = 1 + assert same(x.asnumpy(), x_np) + # short all-dim indexing x = mx.nd.zeros(shape) val = mx.nd.ones((3, 2)) @@ -121,13 +141,15 @@ def test_ndarray_setitem(): x_np[:, -3:-1, -2:-1] = 1 assert same(x.asnumpy(), x_np) - # numpy assignment for empty axis - for trivial_shape in [(), (1,), (1, 1), (1, 1, 1)]: - if trivial_shape == tuple(): - with mx.np_shape(): - x = mx.nd.zeros(trivial_shape) - else: - x = mx.nd.zeros(trivial_shape) + # Scalar array, no assignment allowed + with mx.np_shape(): + x = mx.nd.zeros(()) + with assert_raises(IndexError): + x[:] = 1 + + # Assignments for empty axes + for trivial_shape in [(1,), (1, 1), (1, 1, 1)]: + x = mx.nd.zeros(trivial_shape) x[:] = np.ones(trivial_shape) x_np = np.ones(trivial_shape, dtype=x.dtype) assert x.shape == trivial_shape @@ -1218,6 +1240,42 @@ def test_bool(): assert bool(mx.nd.ones((1,))) +def test_basic_indexing_is_contiguous(): + x_np = np.arange(np.prod((6, 7, 8, 9))).reshape((6, 7, 8, 9)) + x_mx = mx.nd.array(x_np) + + slices = [ + slice(None), + slice(2), + slice(20), + slice(1, 4), + slice(None, None, 2), + slice(None, None, 20), + slice(0, 1), + slice(None, None, -1), + slice(3, None, -2), + ] + + is_contiguous = mx.nd.NDArray._basic_indexing_slice_is_contiguous + + for idx in combinations_with_replacement(slices, 4): + for slc in permutations(idx): + # Check helper function + contig_pred = is_contiguous(slc, x_np.shape) + contig_true = x_np[slc].flags.contiguous + assert contig_pred == contig_true, ( + "failed with slc={}, pred ({}) != actual ({})" + "".format(slc, contig_pred, contig_true) + ) + + if contig_pred: + # Check mutation behavior + y_mx = x_mx.copy() + y_mx_slc = y_mx[slc] + y_mx_slc[:] = 0 + assert (y_mx[slc].asnumpy() == 0).all() + + @with_seed() def test_ndarray_indexing(): def test_getitem(np_array, index, is_scalar=False): @@ -1228,22 +1286,24 @@ def test_getitem(np_array, index, is_scalar=False): if isinstance(index, mx.nd.NDArray): np_index = index.asnumpy() if isinstance(index, tuple): - np_index = [] - for idx in index: - if isinstance(idx, mx.nd.NDArray): - np_index.append(idx.asnumpy()) - else: - np_index.append(idx) - np_index = tuple(np_index) + np_index = tuple( + idx.asnumpy() if isinstance(idx, mx.nd.NDArray) else idx + for idx in index + ) np_indexed_array = np_array[np_index] mx_array = mx.nd.array(np_array, dtype=np_array.dtype) - mx_indexed_array = mx_array[index] + try: + mx_indexed_array = mx_array[index] + except Exception as e: + print('Failed with index = {}'.format(index)) + raise e if is_scalar: mx_indexed_array = mx_indexed_array.asscalar() else: mx_indexed_array = mx_indexed_array.asnumpy() - assert same(np_indexed_array, mx_indexed_array), 'Failed with index=%s' % str(index) + + assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) def test_setitem(np_array, index, is_scalar): def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None): @@ -1253,7 +1313,13 @@ def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None) np_array[np_index] = mx_value.asnumpy() else: np_array[np_index] = mx_value - mx_array[mx_index] = mx_value + + try: + mx_array[mx_index] = mx_value + except Exception as e: + print('Failed with index = {}, value.shape = {}'.format(mx_index, mx_value.shape)) + raise e + assert same(np_array, mx_array.asnumpy()) np_index = index @@ -1312,7 +1378,9 @@ def test_setitem_autograd(np_array, index): try: with mx.autograd.record(): x[index] = y - assert False # should not reach here + # `a[None] = v` is equivalent to `a[...] = v` which doesn't raise + if index is not None: + assert False, 'failed with index = {}'.format(index) # should not reach here except mx.base.MXNetError as err: assert str(err).find('Inplace operations (+=, -=, x[:]=, etc) are not supported when recording with') != -1 @@ -1434,7 +1502,17 @@ def convert(num): (([[[[1]]]], 3, slice(0, 3), 0), False), (([[[[1]]]], [[2], [12]], slice(0, 3), slice(None)), False), (([1, 2], slice(3, 5), [2, 3], [3, 4]), False), - (([1, 2], slice(3, 5), (2, 3), [3, 4]), False)] + (([1, 2], slice(3, 5), (2, 3), [3, 4]), False), + ((1, Ellipsis, -1), False), + ((slice(2), Ellipsis, None, 0), False), + (None, False), + ((1, None, -2, 3, -4), False), + # TODO(zoeygxy): Support None in advanced indexing + # (([1, 2], slice(3, 5), None, None, [3, 4]), False), + # ((slice(None), slice(3, 5), None, None, [2, 3], [3, 4]), False), + # ((slice(None), slice(3, 5), None, [2, 3], None, [3, 4]), False), + # ((None, slice(None), slice(3, 5), [2, 3], None, [3, 4]), False), + ] for index in index_list: test_getitem(np_array, index[0], index[1]) test_setitem(np_array, index[0], index[1]) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3c55052e68cc..fee5ebbbbc29 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4679,13 +4679,14 @@ def test_normal_case(): def test_empty_tensor(): shape = (2, 3, 0, 4) - a = np.array([], dtype=np.int32).reshape(shape) - b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype) - reps = (2, 4, 6) + with mx.np_shape(): + a = np.array([], dtype=np.int32).reshape(shape) + b = mx.nd.array(a, ctx=default_context(), dtype=a.dtype) - a_tiled = np.tile(a, reps) - b_tiled = mx.nd.tile(b, reps).asnumpy() - assert same(a_tiled, b_tiled) + reps = (2, 4, 6) + a_tiled = np.tile(a, reps) + b_tiled = mx.nd.tile(b, reps).asnumpy() + assert same(a_tiled, b_tiled) def test_empty_reps(): a = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.int32) @@ -4775,13 +4776,15 @@ def test_normal_case(index_type=np.int32): def test_empty_indices(): shape = (2, 0, 9, 3) - indices = np.array([]).reshape(shape) - depth = 10 - mx_one_hot_array = mx.nd.one_hot( - mx.nd.array(indices, ctx=default_context(), dtype=np.int32), - depth=depth, dtype=np.int32).asnumpy() - expected_array = np.array([], dtype=np.int32).reshape(shape + (depth, )) - assert same(expected_array, mx_one_hot_array) + with mx.np_shape(): + indices = np.array([]).reshape(shape) + depth = 10 + mx_one_hot_array = mx.nd.one_hot( + mx.nd.array(indices, ctx=default_context(), dtype=np.int32), + depth=depth, dtype=np.int32 + ).asnumpy() + expected_array = np.array([], dtype=np.int32).reshape(shape + (depth,)) + assert same(expected_array, mx_one_hot_array) def test_zero_depth(): shape = (2, 4, 9, 3) @@ -8608,7 +8611,7 @@ def test_index_array_default(): @mx.use_np_shape def test_index_array_default_zero_dim(): - data = mx.symbol.Variable("data") + data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) input_array = np.ones(())