Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] Cosmetic improvement on mxnet.numpy builtin op signature in documentation #16305

Merged
merged 6 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,5 @@
# fact that kvstore-server module is imported before the __version__ attr is set.
from . import kvstore_server

from .numpy_dispatch_protocol import _register_array_function, _register_array_ufunc
_register_array_function()
_register_array_ufunc()
from . import numpy_op_signature
from . import numpy_dispatch_protocol
30 changes: 3 additions & 27 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

def _np_ones_like(a):
"""
ones_like(a)

Return an array of ones with the same shape and type as a given array.

Parameters
Expand All @@ -42,8 +40,6 @@ def _np_ones_like(a):

def _np_zeros_like(a):
"""
zeros_like(a)

Return an array of zeros with the same shape and type as a given array.

Parameters
Expand All @@ -62,8 +58,6 @@ def _np_zeros_like(a):

def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
cumsum(a, axis=None, dtype=None, out=None)

Return the cumulative sum of the elements along a given axis.

Parameters
Expand Down Expand Up @@ -115,8 +109,6 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):

def _npx_nonzero(a):
"""
nonzero(a)

Return the indices of the elements that are non-zero.

Returns a ndarray with ndim is 2. Each row contains the indices
Expand Down Expand Up @@ -164,8 +156,6 @@ def _npx_nonzero(a):

def _np_repeat(a, repeats, axis=None):
"""
repeat(a, repeats, axis=None)

Repeat elements of an array.

Parameters
Expand Down Expand Up @@ -213,8 +203,6 @@ def _np_repeat(a, repeats, axis=None):

def _np_transpose(a, axes=None):
"""
transpose(a, axes=None)

Permute the dimensions of an array.

Parameters
Expand Down Expand Up @@ -256,8 +244,7 @@ def _np_transpose(a, axes=None):


def _np_dot(a, b, out=None):
"""dot(a, b, out=None)

"""
Dot product of two arrays. Specifically,

- If both `a` and `b` are 1-D arrays, it is inner product of vectors
Expand Down Expand Up @@ -318,10 +305,8 @@ def _np_dot(a, b, out=None):
pass


def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None):
def _np_sum(a, axis=None, dtype=None, keepdims=False, initial=None, out=None):
r"""
sum(a, axis=None, dtype=None, keepdims=_Null, initial=_Null, out=None)

Sum of array elements over a given axis.

Parameters
Expand Down Expand Up @@ -414,8 +399,6 @@ def _np_sum(a, axis=0, dtype=None, keepdims=None, initial=None, out=None):

def _np_copy(a, out=None):
"""
copy(a, out=None)

Return an array copy of the given object.

Parameters
Expand Down Expand Up @@ -463,8 +446,6 @@ def _np_copy(a, out=None):

def _np_reshape(a, newshape, order='C', out=None):
"""
reshape(a, newshape, order='C')

Gives a new shape to an array without changing its data.
This function always returns a copy of the input array if
``out`` is not provided.
Expand Down Expand Up @@ -501,8 +482,6 @@ def _np_reshape(a, newshape, order='C', out=None):

def _np__linalg_svd(a):
r"""
svd(a)

Singular Value Decomposition.

When `a` is a 2D array, it is factorized as ``ut @ np.diag(s) @ v``,
Expand Down Expand Up @@ -568,8 +547,6 @@ def _np__linalg_svd(a):

def _np_roll(a, shift, axis=None):
"""
roll(a, shift, axis=None):

Roll array elements along a given axis.

Elements that roll beyond the last position are re-introduced at
Expand Down Expand Up @@ -633,8 +610,7 @@ def _np_roll(a, shift, axis=None):


def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
"""trace(a, offset=0, axis1=0, axis2=1, out=None)

"""
Return the sum along diagonals of the array.
If `a` is 2-D, the sum along its diagonal with the given offset
is returned, i.e., the sum of elements ``a[i,i+offset]`` for all i.
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,7 @@ def _register_array_ufunc():
_NUMPY_ARRAY_UFUNC_DICT[op_name] = mx_np_op
except AttributeError:
raise AttributeError('mxnet.numpy does not have operator named {}'.format(op_name))


_register_array_function()
_register_array_ufunc()
72 changes: 72 additions & 0 deletions python/mxnet/numpy_op_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Make builtin ops' signatures compatible with NumPy."""

from __future__ import absolute_import
import sys
import warnings
from . import _numpy_op_doc
from . import numpy as mx_np
from . import numpy_extension as mx_npx
from .base import _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_SUBMODULE_LIST, _get_op_submodule_name


def _get_builtin_op(op_name):
if op_name.startswith('_np_'):
root_module = mx_np
op_name_prefix = '_np_'
submodule_name_list = _NP_OP_SUBMODULE_LIST
elif op_name.startswith('_npx_'):
root_module = mx_npx
op_name_prefix = '_npx_'
submodule_name_list = _NP_EXT_OP_SUBMODULE_LIST
else:
return None

submodule_name = _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list)
if len(submodule_name) > 0:
op_module = getattr(root_module, submodule_name[1:-1], None)
if op_module is None:
raise ValueError('Cannot find submodule {} in module {}'
.format(submodule_name[1:-1], root_module.__name__))
else:
op_module = root_module

op = getattr(op_module, op_name[(len(op_name_prefix)+len(submodule_name)):], None)
if op is None:
raise ValueError('Cannot find operator {} in module {}'
.format(op_name[op_name_prefix:], root_module.__name__))
return op


def _register_op_signatures():
if sys.version_info.major < 3 or sys.version_info.minor < 5:
warnings.warn('Some mxnet.numpy operator signatures may not be displayed consistently with '
'their counterparts in the official NumPy package due to too-low Python '
'version {}. Python >= 3.5 is required to make the signatures display correctly.'
.format(str(sys.version)))
return

import inspect
for op_name in dir(_numpy_op_doc):
op = _get_builtin_op(op_name)
if op is not None:
op.__signature__ = inspect.signature(getattr(_numpy_op_doc, op_name))


_register_op_signatures()
15 changes: 15 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# pylint: skip-file
from __future__ import absolute_import
import sys
import unittest
import numpy as _np
import mxnet as mx
from mxnet import np, npx
Expand All @@ -29,6 +31,7 @@
import scipy.stats as ss
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
from mxnet.runtime import Features
from mxnet.numpy_op_signature import _get_builtin_op
import platform


Expand Down Expand Up @@ -2810,6 +2813,18 @@ def check_output_n_grad(data_shape, idx_shape, axis, mode):
check_output_n_grad(config[0], config[1], config[2], mode)


@unittest.skipUnless(sys.version_info.major >= 3 and sys.version_info.minor >= 5,
'inspect package requires Python >= 3.5 to work properly')
@with_seed()
def test_np_builtin_op_signature():
import inspect
from mxnet import _numpy_op_doc
for op_name in dir(_numpy_op_doc):
op = _get_builtin_op(op_name)
if op is not None:
assert str(op.__signature__) == str(inspect.signature(getattr(_numpy_op_doc, op_name)))


if __name__ == '__main__':
import nose
nose.runmodule()