Skip to content

Commit

Permalink
more support for boolean indexing and assign (apache#18351)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 authored May 28, 2020
1 parent b523527 commit 0c6785f
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 170 deletions.
100 changes: 74 additions & 26 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..base import ctypes2buffer
from ..runtime import Features
from ..context import Context, current_context
from ..util import is_np_array
from . import _internal
from . import op
from ._internal import NDArrayBase
Expand Down Expand Up @@ -111,7 +112,11 @@
_NDARRAY_BASIC_INDEXING = 0
_NDARRAY_ADVANCED_INDEXING = 1
_NDARRAY_EMPTY_TUPLE_INDEXING = 2
_NDARRAY_BOOLEAN_INDEXING = 3

# Return code for 0-d boolean array handler
_NDARRAY_NO_ZERO_DIM_BOOL_ARRAY = -1
_NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE = 0
_NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE = 1

# Caching whether MXNet was built with INT64 support or not
_INT64_TENSOR_SIZE_ENABLED = None
Expand Down Expand Up @@ -521,7 +526,7 @@ def __setitem__(self, key, value):
return

else:
key = indexing_key_expand_implicit_axes(key, self.shape)
key, _ = indexing_key_expand_implicit_axes(key, self.shape)
slc_key = tuple(idx for idx in key if idx is not None)

if len(slc_key) < self.ndim:
Expand Down Expand Up @@ -714,7 +719,7 @@ def __getitem__(self, key): # pylint: disable=too-many-return-statements
elif key.step == 0:
raise ValueError("slice step cannot be zero")

key = indexing_key_expand_implicit_axes(key, self.shape)
key, _ = indexing_key_expand_implicit_axes(key, self.shape)
if len(key) == 0:
raise ValueError('indexing key cannot be an empty tuple')

Expand Down Expand Up @@ -2574,9 +2579,12 @@ def asscalar(self):
>>> type(x.asscalar())
<type 'numpy.int32'>
"""
if self.shape != (1,):
if self.size != 1:
raise ValueError("The current array is not a scalar")
return self.asnumpy()[0]
if self.ndim == 1:
return self.asnumpy()[0]
else:
return self.asnumpy()[()]

def astype(self, dtype, copy=True):
"""Returns a copy of the array after casting to a specified type.
Expand Down Expand Up @@ -2943,9 +2951,23 @@ def _scatter_set_nd(self, value_nd, indices):
lhs=self, rhs=value_nd, indices=indices, shape=self.shape, out=self
)

def check_boolean_array_dimension(array_shape, axis, bool_shape):
"""
Advanced boolean indexing is implemented through the use of `nonzero`.
Size check is necessary to make sure that the boolean array
has exactly as many dimensions as it is supposed to work with before the conversion
"""
for i, val in enumerate(bool_shape):
if array_shape[axis + i] != val:
raise IndexError('boolean index did not match indexed array along axis {};'
' size is {} but corresponding boolean size is {}'
.format(axis + i, array_shape[axis + i], val))

def indexing_key_expand_implicit_axes(key, shape):
"""Make implicit axes explicit by adding ``slice(None)``.
"""
Make implicit axes explicit by adding ``slice(None)``
and convert boolean array to integer array through `nonzero`.
Examples
--------
>>> shape = (3, 4, 5)
Expand All @@ -2957,6 +2979,11 @@ def indexing_key_expand_implicit_axes(key, shape):
(0, slice(None, None, None), slice(None, None, None))
>>> indexing_key_expand_implicit_axes(np.s_[:2, None, 0, ...], shape)
(slice(None, 2, None), None, 0, slice(None, None, None))
>>> bool_array = np.array([[True, False, True, False],
[False, True, False, True],
[True, False, True, False]], dtype=np.bool)
>>> indexing_key_expand_implicit_axes(np.s_[bool_array, None, 0:2], shape)
(array([0, 0, 1, 1, 2, 2], dtype=int64), array([0, 2, 1, 3, 0, 2], dtype=int64), None, slice(None, 2, None))
"""
if not isinstance(key, tuple):
key = (key,)
Expand All @@ -2966,6 +2993,17 @@ def indexing_key_expand_implicit_axes(key, shape):
ell_idx = None
num_none = 0
nonell_key = []

# For 0-d boolean indices: A new axis is added,
# but at the same time no axis is "used". So if we have True,
# we add a new axis (a bit like with np.newaxis). If it is
# False, we add a new axis, but this axis has 0 entries.
# prepend is defined to handle this case.
# prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY/-1 means there is no 0-d boolean scalar
# prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE/0 means an zero dim must be expanded
# prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE/1 means a new axis must be expanded
prepend = _NDARRAY_NO_ZERO_DIM_BOOL_ARRAY
axis = 0
for i, idx in enumerate(key):
if idx is Ellipsis:
if ell_idx is not None:
Expand All @@ -2974,14 +3012,38 @@ def indexing_key_expand_implicit_axes(key, shape):
)
ell_idx = i
else:
# convert primitive type boolean value to mx.np.bool type
# otherwise will be treated as 1/0
if isinstance(idx, bool):
idx = array(idx, dtype=np.bool_)
if idx is None:
num_none += 1
if isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype != np.bool_:
if isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype == np.bool_:
if not idx: # array(False) has priority
prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_FALSE
else:
prepend = _NDARRAY_ZERO_DIM_BOOL_ARRAY_TRUE
elif isinstance(idx, NDArrayBase) and idx.ndim == 0 and idx.dtype != np.bool_:
# This handles ndarray of zero dim. e.g array(1)
# while advoid converting zero dim boolean array
nonell_key.append(idx.item())
# float type will be converted to int
nonell_key.append(int(idx.item()))
axis += 1
elif isinstance(idx, NDArrayBase) and idx.dtype == np.bool_:
# Necessary size check before using `nonzero`
check_boolean_array_dimension(shape, axis, idx.shape)
# If the whole array is false and npx.set_np() is not set_up
# the program will throw infer shape error
if not is_np_array():
raise ValueError('Cannot perform boolean indexing in legacy mode. Please activate'
' numpy semantics by calling `npx.set_np()` in the global scope'
' before calling this function.')
# Add the arrays from the nonzero result to the index
nonell_key.extend(idx.nonzero())
axis += idx.ndim
else:
nonell_key.append(idx)
axis += 1

nonell_key = tuple(nonell_key)

Expand All @@ -2995,7 +3057,7 @@ def indexing_key_expand_implicit_axes(key, shape):
(slice(None),) * ell_ndim +
nonell_key[ell_idx:])

return expanded_key
return expanded_key, prepend


def _int_to_slice(idx):
Expand Down Expand Up @@ -3053,32 +3115,18 @@ def _is_advanced_index(idx):
def get_indexing_dispatch_code(key):
"""Returns a dispatch code for calling basic or advanced indexing functions."""
assert isinstance(key, tuple)
num_bools = 0
basic_indexing = True

for idx in key:
if isinstance(idx, (NDArray, np.ndarray, list, tuple)):
if isinstance(idx, (NDArray, np.ndarray, list, tuple, range)):
if isinstance(idx, tuple) and len(idx) == 0:
return _NDARRAY_EMPTY_TUPLE_INDEXING
if getattr(idx, 'dtype', None) == np.bool_:
num_bools += 1
basic_indexing = False
elif isinstance(idx, range):
basic_indexing = False
return _NDARRAY_ADVANCED_INDEXING
elif not (isinstance(idx, (py_slice, integer_types)) or idx is None):
raise ValueError(
'NDArray does not support slicing with key {} of type {}.'
''.format(idx, type(idx))
)
if basic_indexing and num_bools == 0:
return _NDARRAY_BASIC_INDEXING
elif not basic_indexing and num_bools == 0:
return _NDARRAY_ADVANCED_INDEXING
elif num_bools == 1:
return _NDARRAY_BOOLEAN_INDEXING
else:
raise TypeError('ndarray indexing does not more than one boolean ndarray'
' in a tuple of complex indices.')
return _NDARRAY_BASIC_INDEXING


def _get_index_range(start, stop, length, step=1):
Expand Down
Loading

0 comments on commit 0c6785f

Please sign in to comment.