Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix several numpy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 19, 2019
1 parent 5b3f709 commit 0d97eae
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 37 deletions.
44 changes: 40 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from . import _internal as _npi
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
__all__ = ['shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
Expand All @@ -43,7 +43,41 @@


@set_module('mxnet.ndarray.numpy')
def zeros(shape, dtype=_np.float32, order='C', ctx=None):
def shape(a):
"""
Return the shape of an array.
Parameters
----------
a : array_like
Input array.
Returns
-------
shape : tuple of ints
The elements of the shape tuple give the lengths of the
corresponding array dimensions.
See Also
--------
ndarray.shape : Equivalent array method.
Examples
--------
>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 2]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
"""
return a.shape


@set_module('mxnet.ndarray.numpy')
def zeros(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand Down Expand Up @@ -77,7 +111,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):


@set_module('mxnet.ndarray.numpy')
def ones(shape, dtype=_np.float32, order='C', ctx=None):
def ones(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand Down Expand Up @@ -110,8 +144,9 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
return _npi.ones(shape=shape, ctx=ctx, dtype=dtype)


# pylint: disable=too-many-arguments, redefined-outer-name
@set_module('mxnet.ndarray.numpy')
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
"""
Return a new array of given shape and type, filled with `fill_value`.
Parameters
Expand Down Expand Up @@ -163,6 +198,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _npi.full(shape=shape, value=fill_value, ctx=ctx, dtype=dtype, out=out)
# pylint: enable=too-many-arguments, redefined-outer-name


@set_module('mxnet.ndarray.numpy')
Expand Down
58 changes: 47 additions & 11 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide',
__all__ = ['ndarray', 'empty', 'array', 'shape', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide',
'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
Expand All @@ -67,7 +67,7 @@

# This function is copied from ndarray.py since pylint
# keeps giving false alarm error of undefined-all-variable
def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): # pylint: disable=redefined-outer-name
"""Return a new handle with specified shape and context.
Empty handle is only used to hold results.
Expand All @@ -89,7 +89,7 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t):
return hdl


def _reshape_view(a, *shape):
def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
"""Returns a **view** of this array with a new shape without altering any data.
Parameters
Expand Down Expand Up @@ -462,7 +462,7 @@ def __getitem__(self, key):
"""
# handling possible boolean indexing first
ndim = self.ndim
shape = self.shape
shape = self.shape # pylint: disable=redefined-outer-name

if isinstance(key, list):
try:
Expand Down Expand Up @@ -804,7 +804,7 @@ def __int__(self):

def __len__(self):
"""Number of elements along the first axis."""
shape = self.shape
shape = self.shape # pylint: disable=redefined-outer-name
if len(shape) == 0:
raise TypeError('len() of unsized object')
return self.shape[0]
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def reshape_like(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute reshape_like')

def reshape_view(self, *shape, **kwargs):
def reshape_view(self, *shape, **kwargs): # pylint: disable=redefined-outer-name
"""Returns a **view** of this array with a new shape without altering any data.
Inheritated from NDArray.reshape.
"""
Expand Down Expand Up @@ -1854,7 +1854,7 @@ def squeeze(self, axis=None): # pylint: disable=arguments-differ
"""Remove single-dimensional entries from the shape of a."""
return _mx_np_op.squeeze(self, axis=axis)

def broadcast_to(self, shape):
def broadcast_to(self, shape): # pylint: disable=redefined-outer-name
return _mx_np_op.broadcast_to(self, shape)

def broadcast_like(self, other):
Expand Down Expand Up @@ -1916,7 +1916,7 @@ def tostype(self, stype):


@set_module('mxnet.numpy')
def empty(shape, dtype=_np.float32, order='C', ctx=None):
def empty(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, without initializing entries.
Parameters
Expand Down Expand Up @@ -2020,7 +2020,41 @@ def array(object, dtype=None, ctx=None):


@set_module('mxnet.numpy')
def zeros(shape, dtype=_np.float32, order='C', ctx=None):
def shape(a):
"""
Return the shape of an array.
Parameters
----------
a : array_like
Input array.
Returns
-------
shape : tuple of ints
The elements of the shape tuple give the lengths of the
corresponding array dimensions.
See Also
--------
ndarray.shape : Equivalent array method.
Examples
--------
>>> np.shape(np.eye(3))
(3, 3)
>>> np.shape([[1, 2]])
(1, 2)
>>> np.shape([0])
(1,)
>>> np.shape(0)
()
"""
return _mx_nd_np.shape(a)


@set_module('mxnet.numpy')
def zeros(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with zeros.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand Down Expand Up @@ -2061,7 +2095,7 @@ def zeros(shape, dtype=_np.float32, order='C', ctx=None):


@set_module('mxnet.numpy')
def ones(shape, dtype=_np.float32, order='C', ctx=None):
def ones(shape, dtype=_np.float32, order='C', ctx=None): # pylint: disable=redefined-outer-name
"""Return a new array of given shape and type, filled with ones.
This function currently only supports storing multi-dimensional data
in row-major (C-style).
Expand Down Expand Up @@ -2106,8 +2140,9 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
return _mx_nd_np.ones(shape, dtype, order, ctx)


# pylint: disable=too-many-arguments, redefined-outer-name
@set_module('mxnet.numpy')
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments
def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
"""
Return a new array of given shape and type, filled with `fill_value`.
Expand Down Expand Up @@ -2160,6 +2195,7 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
[2, 2]], dtype=int32)
"""
return _mx_nd_np.full(shape, fill_value, order=order, ctx=ctx, dtype=dtype, out=out)
# pylint: enable=too-many-arguments, redefined-outer-name


@set_module('mxnet.numpy')
Expand Down
11 changes: 9 additions & 2 deletions python/mxnet/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
'subok': True,
}

_set_np_shape_logged = False
_set_np_array_logged = False


def makedirs(d):
"""Create directories recursively if they don't exist. os.makedirs(exist_ok=True) is not
Expand Down Expand Up @@ -87,13 +90,15 @@ def set_np_shape(active):
>>> print(mx.is_np_shape())
True
"""
if active:
global _set_np_shape_logged
if active and not _set_np_shape_logged:
import logging
logging.info('NumPy-shape semantics has been activated in your code. '
'This is required for creating and manipulating scalar and zero-size '
'tensors, which were not supported in MXNet before, as in the official '
'NumPy library. Please DO NOT manually deactivate this semantics while '
'using `mxnet.numpy` and `mxnet.numpy_extension` modules.')
_set_np_shape_logged = True
elif is_np_array():
raise ValueError('Deactivating NumPy shape semantics while NumPy array semantics is still'
' active is not allowed. Please consider calling `npx.reset_np()` to'
Expand Down Expand Up @@ -678,11 +683,13 @@ def _set_np_array(active):
-------
A bool value indicating the previous state of NumPy array semantics.
"""
if active:
global _set_np_array_logged
if active and not _set_np_array_logged:
import logging
logging.info('NumPy array semantics has been activated in your code. This allows you'
' to use operators from MXNet NumPy and NumPy Extension modules as well'
' as MXNet NumPy `ndarray`s.')
_set_np_array_logged = True
cur_state = is_np_array()
_NumpyArrayScope._current.value = _NumpyArrayScope(active)
return cur_state
Expand Down
32 changes: 16 additions & 16 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,24 @@ def get_mat(n):
A = np.array([[1, 2], [3, 4], [5, 6]])
vals = (100 * np.arange(5)).astype('l')
vals_c = (100 * np.array(get_mat(5)) + 1).astype('l')
vals_f = _np.array((100 * get_mat(5) + 1), order ='F', dtype ='l')
vals_f = _np.array((100 * get_mat(5) + 1), order='F', dtype='l')
vals_f = np.array(vals_f)

OpArgMngr.add_workload('diag', A, k= 2)
OpArgMngr.add_workload('diag', A, k= 1)
OpArgMngr.add_workload('diag', A, k= 0)
OpArgMngr.add_workload('diag', A, k= -1)
OpArgMngr.add_workload('diag', A, k= -2)
OpArgMngr.add_workload('diag', A, k= -3)
OpArgMngr.add_workload('diag', vals, k= 0)
OpArgMngr.add_workload('diag', vals, k= 2)
OpArgMngr.add_workload('diag', vals, k= -2)
OpArgMngr.add_workload('diag', vals_c, k= 0)
OpArgMngr.add_workload('diag', vals_c, k= 2)
OpArgMngr.add_workload('diag', vals_c, k= -2)
OpArgMngr.add_workload('diag', vals_f, k= 0)
OpArgMngr.add_workload('diag', vals_f, k= 2)
OpArgMngr.add_workload('diag', vals_f, k= -2)
OpArgMngr.add_workload('diag', A, k=2)
OpArgMngr.add_workload('diag', A, k=1)
OpArgMngr.add_workload('diag', A, k=0)
OpArgMngr.add_workload('diag', A, k=-1)
OpArgMngr.add_workload('diag', A, k=-2)
OpArgMngr.add_workload('diag', A, k=-3)
OpArgMngr.add_workload('diag', vals, k=0)
OpArgMngr.add_workload('diag', vals, k=2)
OpArgMngr.add_workload('diag', vals, k=-2)
OpArgMngr.add_workload('diag', vals_c, k=0)
OpArgMngr.add_workload('diag', vals_c, k=2)
OpArgMngr.add_workload('diag', vals_c, k=-2)
OpArgMngr.add_workload('diag', vals_f, k=0)
OpArgMngr.add_workload('diag', vals_f, k=2)
OpArgMngr.add_workload('diag', vals_f, k=-2)


def _add_workload_concatenate(array_pool):
Expand Down
27 changes: 23 additions & 4 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,26 @@ def legalize_shape(shape):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol, use_broadcast=False, equal_nan=True)


@with_seed()
@use_np
def test_np_shape():
shapes = [
(),
(0, 1),
(2, 3),
(2, 3, 4),
]

for shape in shapes:
mx_a = np.random.uniform(size=shape)
np_a = _np.random.uniform(size=shape)

mx_shape = np.shape(mx_a)
np_shape = _np.shape(np_a)

assert mx_shape == np_shape


@with_seed()
@use_np
def test_np_linspace():
Expand Down Expand Up @@ -4539,7 +4559,7 @@ def __init__(self, k=0):

def hybrid_forward(self, F, a):
return F.np.diag(a, k=self._k)

shapes = [(), (2,), (1, 5), (2, 2), (2, 5), (3, 3), (4, 3)]
dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float16, np.float32, np.float64]
range_k = 6
Expand All @@ -4559,8 +4579,8 @@ def hybrid_forward(self, F, a):
mx_out = test_diag(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
# check backward function

# check backward function
mx_out.backward()
if len(shape) == 0:
np_backward = np.array(())
Expand Down Expand Up @@ -4593,7 +4613,6 @@ def hybrid_forward(self, F, a):
@with_seed()
@use_np
def test_np_nan_to_num():

def take_ele_grad(ele):
if _np.isinf(ele) or _np.isnan(ele):
return 0
Expand Down

0 comments on commit 0d97eae

Please sign in to comment.