Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API Standardization]Standardize MXNet NumPy Statistical & Linalg Functions #20592

Merged
merged 27 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
da94ffb
[Website] Fix website publish (#20573)
barry-jin Sep 14, 2021
d738884
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 16, 2021
9e50049
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 16, 2021
d572c71
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 17, 2021
fb5a717
change linalg & statical funcs
NathanYyc Sep 17, 2021
4b5fa73
add vecdot
NathanYyc Sep 17, 2021
22e4ed5
changes made
NathanYyc Sep 22, 2021
3fdcc3c
changes made
NathanYyc Sep 22, 2021
8157c0a
changes made
NathanYyc Sep 24, 2021
5a0dea1
changes made
NathanYyc Sep 24, 2021
4e54707
delete test vecdot
NathanYyc Sep 24, 2021
a798d52
fixed lint
NathanYyc Oct 2, 2021
9e72e62
fixed lint error
NathanYyc Oct 2, 2021
d4138d9
fixed lint error
NathanYyc Oct 2, 2021
74fcf51
fixed problems
NathanYyc Oct 4, 2021
907c09d
delete 'vecdot' in __all__
NathanYyc Oct 4, 2021
801a823
fixed acosh doc
NathanYyc Oct 4, 2021
ead73f8
fixed tensordot bug
NathanYyc Oct 6, 2021
cebcbe2
add line in line 58
NathanYyc Oct 6, 2021
8e053ef
add line in line 4254
NathanYyc Oct 6, 2021
75fc59e
add line in 5423,9080 in multiarray
NathanYyc Oct 6, 2021
b040e9b
Merge remote-tracking branch 'upstream/master'
NathanYyc Oct 6, 2021
70c7e37
Update python/mxnet/numpy/multiarray.py
NathanYyc Oct 9, 2021
3c8eed3
merge conflicts
NathanYyc Oct 11, 2021
4bf327b
solve typo
NathanYyc Oct 11, 2021
21c9fd3
add wrap_data_api_linalg_func in line 1335 & 1205
NathanYyc Oct 12, 2021
5eeb259
add wrap_data_api_linalg_func in line 1335 & 1205
NathanYyc Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
"""Namespace for ops used in imperative programming."""

from ..ndarray import numpy as _mx_nd_np
from ..util import wrap_data_api_linalg_func
from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback_linalg

__all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve',
'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank', 'cross', 'diagonal', 'outer',
'tensordot', 'trace', 'matrix_transpose']
'tensordot', 'trace', 'matrix_transpose', 'vecdot']
__all__ += fallback_linalg.__all__


Expand Down Expand Up @@ -373,6 +374,59 @@ def outer(a, b):
return _mx_nd_np.tensordot(a.flatten(), b.flatten(), 0)


def vecdot(a, b, axis=None):
r"""
Return the dot product of two vectors.
Note that `vecdot` handles multidimensional arrays differently than `dot`:
it does *not* perform a matrix product, but flattens input arguments
to 1-D vectors first. Consequently, it should only be used for vectors.

Notes
----------
`vecdot` is a alias for `vdot`. It is a standard API in
https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#vecdot-x1-x2-axis-1
instead of an official NumPy operator.

Parameters
----------
a : ndarray
First argument to the dot product.
b : ndarray
Second argument to the dot product.
axis : axis over which to compute the dot product. Must be an integer on
the interval [-N, N) , where N is the rank (number of dimensions) of
the shape determined according to Broadcasting . If specified as a
negative integer, the function must determine the axis along which
to compute the dot product by counting backward from the last dimension
(where -1 refers to the last dimension). If None , the function must
compute the dot product over the last axis. Default: None .

Returns
-------
output : ndarray
Dot product of `a` and `b`.

See Also
--------
dot : Return the dot product without using the complex conjugate of the
first argument.

Examples
--------
Note that higher-dimensional arrays are flattened!

>>> a = np.array([[1, 4], [5, 6]])
>>> b = np.array([[4, 1], [2, 2]])
>>> np.linalg.vecdot(a, b)
array(30.)
>>> np.linalg.vecdot(b, a)
array(30.)
>>> 1*4 + 4*1 + 5*2 + 6*2
30
"""
return _mx_nd_np.tensordot(a.flatten(), b.flatten(), axis)


def lstsq(a, b, rcond='warn'):
r"""
Return the least-squares solution to a linear matrix equation.
Expand Down Expand Up @@ -1148,7 +1202,8 @@ def eigvals(a):
return _mx_nd_np.linalg.eigvals(a)


def eigvalsh(a, UPLO='L'):
@wrap_data_api_linalg_func
def eigvalsh(a, upper=False):
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
r"""
Compute the eigenvalues real symmetric matrix.

Expand Down Expand Up @@ -1203,6 +1258,10 @@ def eigvalsh(a, UPLO='L'):
>>> LA.eigvalsh(a, UPLO='L')
array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order
"""
if not upper:
UPLO = 'L'
else:
UPLO = 'U'
return _mx_nd_np.linalg.eigvalsh(a, UPLO)


Expand Down Expand Up @@ -1273,7 +1332,8 @@ def eig(a):
return _mx_nd_np.linalg.eig(a)


def eigh(a, UPLO='L'):
@wrap_data_api_linalg_func
def eigh(a, upper=False):
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
r"""
Return the eigenvalues and eigenvectors real symmetric matrix.

Expand Down Expand Up @@ -1329,12 +1389,16 @@ def eigh(a, UPLO='L'):
>>> a = np.array([[ 6.8189726 , -3.926585 , 4.3990498 ],
... [-0.59656644, -1.9166266 , 9.54532 ],
... [ 2.1093285 , 0.19688708, -1.1634291 ]])
>>> w, v = LA.eigh(a, UPLO='L')
>>> w, v = LA.eigh(a, upper=False)
>>> w
array([-2.175445 , -1.4581827, 7.3725457])
>>> v
array([[ 0.1805163 , -0.16569263, 0.9695154 ],
[ 0.8242942 , 0.56326365, -0.05721384],
[-0.53661287, 0.80949366, 0.23825769]])
"""
if not upper:
UPLO = 'L'
else:
UPLO = 'U'
return _mx_nd_np.linalg.eigh(a, UPLO)
Loading