Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix numpy bugs (#16537)
Browse files Browse the repository at this point in the history
* legacy ndarray setitem supports value as expanded arrays

* Doc for ndarray

* Add doc for ndarray

* Fix infer shape error while loading legacy models in numpy mode

* Add true_divide to support float

* Add true divide scalar

* Add true_divide test

* Fix bug

* Add for loading symbols

* Fix

* Fix sanity

* Fix windows

* Fix python2

* Skip saving is_np_shape attr in json in legacy mode

* Fix compile

* Fix test
  • Loading branch information
reminisce authored and haojin2 committed Oct 20, 2019
1 parent b949716 commit 217ae02
Show file tree
Hide file tree
Showing 15 changed files with 597 additions and 77 deletions.
13 changes: 13 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,19 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
squeeze_axes = tuple([ax for ax in squeeze_axes if ax < len(value_nd.shape)])
value_nd = value_nd.squeeze(axis=tuple(squeeze_axes))

# handle the cases like the following
# a = nd.zeros((3, 3)), b = nd.ones((1, 1, 1, 1, 3)), a[0] = b
# b cannot broadcast directly to a[0].shape unless its leading 1-size axes are trimmed
if value_nd.ndim > len(bcast_shape):
squeeze_axes = []
for i in range(value_nd.ndim - len(bcast_shape)):
if value_nd.shape[i] == 1:
squeeze_axes.append(i)
else:
break
if squeeze_axes:
value_nd = value_nd.squeeze(squeeze_axes)

if value_nd.shape != bcast_shape:
if value_nd.size == 0:
value_nd = value_nd.reshape(bcast_shape)
Expand Down
242 changes: 220 additions & 22 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,47 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
@set_module('mxnet.numpy') # pylint: disable=invalid-name
class ndarray(NDArray):
"""
ndarray(handle, writable=True):
An array object represents a multidimensional, homogeneous array of fixed-size items.
An associated data-type object describes the format of each element in the array
(its byte-order, how many bytes it occupies in memory, whether it is an integer, a
floating point number, or something else, etc.). Arrays should be constructed using
`array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported.
Arrays should be constructed using `array`, `zeros` or `empty` (refer
to the See Also section below). The parameters given here refer to
a low-level method (`ndarray(...)`) for instantiating an array.
For more information, refer to the `mxnet.numpy` module and examine the
methods and attributes of an array.
Parameters
----------
handle: int
The ndarray handle in backend (C++).
writable: bool
Indicates whether inplace-assignment is allowed for the array.
Attributes
----------
T : ndarray
Transpose of the array.
dtype : dtype object
Describes the format of the elements in the array.
size : int
Number of elements in the array.
ndim : int
The array's number of dimensions.
shape : tuple of ints
Shape of the array.
See Also
--------
array : Construct an array.
zeros : Create an array, each element of which is zero.
empty : Create an array, but leave its allocated memory unchanged (i.e.,
it contains "garbage").
"""

@staticmethod
Expand Down Expand Up @@ -286,9 +322,139 @@ def _set_np_advanced_indexing(self, key, value):

# pylint: disable=too-many-return-statements
def __getitem__(self, key):
"""
Overriding the method in NDArray class in a numpy fashion.
Calling numpy ndarray's _get_np_basic_indexing(key) and _get_np_advanced_indexing(key).
"""Return self[key].
Returns a sliced view of this array if the elements fetched are contiguous in memory;
otherwise, returns a newly created NDArray.
This functions supports advanced indexing defined in the following reference with
some restrictions. Boolean indexing is supported only for a single boolean ndarray
as a key. Mixing boolean ndarray with other index types is not supported in ``advanced``
indexing.
For basic indexing, i.e., if ``key`` consists only of integers,
``slice``, ``Ellipsis`` (``...``) and ``None``, a mutable view is
returned that shares memory with this array if the accessed portion is
contiguous in memory.
Otherwise, a newly created ``ndarray`` is returned.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_.
Parameters
----------
key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types
Indexing key.
Examples
--------
The default is to give explicit indices for all axes:
>>> x = np.arange(6).reshape(2, 3)
>>> x
array([[0., 1., 2.],
[3., 4., 5.]])
>>> x[0, :2]
array([0., 1.])
>>> x[:, :-1]
array([[0., 1.],
[3., 4.]])
If fewer indices are given, they are automatically supplemented by an
appropriate number of ``slice(None)`` ("``:``") to the right. For
instance, a single integer indexes along the first axis:
>>> x[0]
array([0., 1., 2.])
>>> x[1:]
array([[3., 4., 5.]])
To omit a range of axes that should be kept as-is, an `Ellipsis`
("``...``") can be used:
>>> x = np.arange(16).reshape(2, 2, 2, 2)
>>> x[0, ..., 1]
array([[1., 3.],
[5., 7.]])
>>> x[0, :, :, 1] # equivalent
array([[1., 3.],
[5., 7.]])
New axes of length 1 can be created by inserting ``None``
(`numpy.newaxis`) in the index:
>>> x = np.arange(6).reshape(2, 3)
>>> x[None, :, :]
array([[[0., 1., 2.],
[3., 4., 5.]]])
>>> x[None, :, :].shape
(1, 2, 3)
If the indexed portion of the array is contiguous in memory, no data
is copied. Instead, a shared-memory view of the original array is
returned, and changes to that view affect the original array:
>>> x = np.arange(8).reshape(2, 2, 2)
>>> y = x[0] # contiguous
>>> y
array([[0., 1.],
[2., 3.]])
>>> y[:] = -1
>>> x
array([[[-1., -1.],
[-1., -1.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> x = np.arange(8).reshape(2, 2, 2)
>>> y = x[1, :1, :] # contiguous
>>> y
array([[4., 5.]])
>>> y[:] = -1
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[-1., -1.],
[ 6., 7.]]])
>>> x = np.arange(0, 8).reshape(2, 2, 2)
>>> y = x[:, :, 1] # not contiguous
>>> y
array([[1., 3.],
[5., 7.]])
>>> y[:] = -1
>>> x
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
If the indexing key contains `list`, `numpy.ndarray` or `NDArray`
objects, advanced indexing is triggered, which always returns a
copy:
>>> x = np.arange(8).reshape(2, 2, 2)
>>> x[[0, 1]]
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
>>> x[[0, 1], :] # equivalent
array([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y]
array([[[4., 5.],
[6., 7.]]])
>>> y = np.array([0, 1], dtype='int32')
>>> x[1:, y]
array([[[4., 5.],
[6., 7.]]])
Get negative elements in an ndarray through boolean array indexing
>>> x = np.array([1., -1., -2., 3])
>>> x[x < 0]
array([-1., -2.])
"""
# handling possible boolean indexing first
ndim = self.ndim
Expand Down Expand Up @@ -356,11 +522,51 @@ def __getitem__(self, key):
raise RuntimeError

def __setitem__(self, key, value):
"""
x.__setitem__(i, y) <=> x[i]=y
Sets ``self[key]`` to ``value``.
"""Sets ``self[key]`` to ``value``.
Overriding the method in NDArray class in a numpy fashion.
This functions supports advanced indexing as defined in `the NumPy
advanced indexing documentation
<https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing>`_,
with the restriction that boolean array indexing is not supported.
Parameters
----------
key : int, slice, list, np.ndarray, mx.np.ndarray, or tuple of all previous types
The indexing key.
value : scalar or array-like object that can be broadcast to the shape of self[key]
The value to set.
Examples
--------
>>> x = np.zeros((2, 3))
>>> x[:] = 1
>>> x
array([[ 1., 1., 1.],
[ 1., 1., 1.]])
>>> x[:, 1:2] = 2
>>> x
array([[ 1., 2., 1.],
[ 1., 2., 1.]])
>>> x[1:2, 1:] = 3
>>> x
array([[ 1., 2., 1.],
[ 1., 3., 3.]])
>>> x[1:, 0:2] = np.zeros((1, 2))
>>> x
array([[ 1., 2., 1.],
[ 0., 0., 3.]])
>>> x[1, 2] = 4
>>> x
array([[ 1., 2., 1.],
[ 0., 0., 4.]])
>>> x[[0], [1, 2]] = 5
>>> x
array([[ 1., 5., 5.],
[ 0., 0., 4.]])
>>> x[::-1, 0:2:2] = [6]
>>> x
array([[ 6., 5., 5.],
[ 6., 0., 4.]])
"""
if isinstance(value, NDArray) and not isinstance(value, ndarray):
raise TypeError('Cannot assign mx.nd.NDArray to mxnet.numpy.ndarray')
Expand Down Expand Up @@ -496,25 +702,16 @@ def __rmul__(self, other):
return self.__mul__(other)

def __div__(self, other):
raise AttributeError('ndarray.__div__ is replaced by __truediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
"""x.__div__(y) <=> x / y"""
return divide(self, other)

def __rdiv__(self, other):
raise AttributeError('ndarray.__rdiv__ is replaced by __rtruediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
"""x.__rdiv__(y) <=> y / x"""
return divide(other, self)

def __idiv__(self, other):
raise AttributeError('ndarray.__idiv__ is replaced by __irtruediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
"""x.__idiv__(y) <=> x /= y"""
return divide(self, other, out=self)

def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
Expand All @@ -525,6 +722,7 @@ def __rtruediv__(self, other):
return divide(other, self)

def __itruediv__(self, other):
"""x.__itruediv__(y) <=> x /= y"""
return divide(self, other, out=self)

def __mod__(self, other):
Expand Down
14 changes: 4 additions & 10 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,12 @@ def __rmul__(self, other):
return multiply(other, self)

def __div__(self, other):
raise AttributeError('_Symbol.__div__ is replaced by __truediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
"""x.__truediv__(y) <=> x / y"""
return divide(self, other)

def __rdiv__(self, other):
raise AttributeError('_Symbol.__rdiv__ is replaced by __rtruediv__. If you are using'
' Python2, please use the statement from __future__ import division'
' to change the / operator to mean true division throughout the'
' module. If you are using Python3, this error should not have'
' been encountered.')
"""x.__rdiv__(y) <=> y / x"""
return divide(other, self)

def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,11 @@ def _wrap_np_unary_func(x, out=None, **kwargs):
if value != _np_ufunc_default_kwargs[key]:
# if the provided value of the argument is a legal option, raise NotImplementedError
if np_ufunc_legal_option(key, value):
raise NotImplementedError("{}={} is not implemented yet".format(key, str(value)))
raise NotImplementedError("{}={} is not implemented yet for operator {}"
.format(key, str(value), func.__name__))
# otherwise raise TypeError with not understood error message
raise TypeError("{} {} not understood".format(key, value))
raise TypeError("{}={} not understood for operator {}"
.format(key, value, func.__name__))
return func(x, out=out)
return _wrap_np_unary_func

Expand Down Expand Up @@ -664,6 +666,7 @@ def _wrap_np_binary_func(x1, x2, out=None, **kwargs):
return func(x1, x2, out=out)
return _wrap_np_binary_func


def _set_np_array(active):
"""Turns on/off NumPy array semantics for the current thread in which `mxnet.numpy.ndarray`
is expected to be created, instead of the legacy `mx.nd.NDArray`.
Expand Down
Loading

0 comments on commit 217ae02

Please sign in to comment.