Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[API Standardization]Standardize MXNet NumPy Statistical & Linalg Functions #20592

Merged
merged 27 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
da94ffb
[Website] Fix website publish (#20573)
barry-jin Sep 14, 2021
d738884
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 16, 2021
9e50049
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 16, 2021
d572c71
Merge remote-tracking branch 'origin/master'
NathanYyc Sep 17, 2021
fb5a717
change linalg & statical funcs
NathanYyc Sep 17, 2021
4b5fa73
add vecdot
NathanYyc Sep 17, 2021
22e4ed5
changes made
NathanYyc Sep 22, 2021
3fdcc3c
changes made
NathanYyc Sep 22, 2021
8157c0a
changes made
NathanYyc Sep 24, 2021
5a0dea1
changes made
NathanYyc Sep 24, 2021
4e54707
delete test vecdot
NathanYyc Sep 24, 2021
a798d52
fixed lint
NathanYyc Oct 2, 2021
9e72e62
fixed lint error
NathanYyc Oct 2, 2021
d4138d9
fixed lint error
NathanYyc Oct 2, 2021
74fcf51
fixed problems
NathanYyc Oct 4, 2021
907c09d
delete 'vecdot' in __all__
NathanYyc Oct 4, 2021
801a823
fixed acosh doc
NathanYyc Oct 4, 2021
ead73f8
fixed tensordot bug
NathanYyc Oct 6, 2021
cebcbe2
add line in line 58
NathanYyc Oct 6, 2021
8e053ef
add line in line 4254
NathanYyc Oct 6, 2021
75fc59e
add line in 5423,9080 in multiarray
NathanYyc Oct 6, 2021
b040e9b
Merge remote-tracking branch 'upstream/master'
NathanYyc Oct 6, 2021
70c7e37
Update python/mxnet/numpy/multiarray.py
NathanYyc Oct 9, 2021
3c8eed3
merge conflicts
NathanYyc Oct 11, 2021
4bf327b
solve typo
NathanYyc Oct 11, 2021
21c9fd3
add wrap_data_api_linalg_func in line 1335 & 1205
NathanYyc Oct 12, 2021
5eeb259
add wrap_data_api_linalg_func in line 1335 & 1205
NathanYyc Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Namespace for ops used in imperative programming."""

from ..ndarray import numpy as _mx_nd_np
from ..util import wrap_data_api_linalg_func
from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback_linalg

Expand Down Expand Up @@ -838,8 +839,8 @@ def eigvals(a):
"""
return _mx_nd_np.linalg.eigvals(a)


def eigvalsh(a, UPLO='L'):
@wrap_data_api_linalg_func
def eigvalsh(a, upper=False):
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
r"""Compute the eigenvalues real symmetric matrix.

Main difference from eigh: the eigenvectors are not computed.
Expand Down Expand Up @@ -890,9 +891,13 @@ def eigvalsh(a, UPLO='L'):
>>> a = np.array([[ 5.4119368 , 8.996273 , -5.086096 ],
... [ 0.8866155 , 1.7490431 , -4.6107802 ],
... [-0.08034172, 4.4172044 , 1.4528792 ]])
>>> LA.eigvalsh(a, UPLO='L')
>>> LA.eigvalsh(a, upper=False)
array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order
"""
if(upper==False):
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
UPLO='L'
else:
UPLO='U'
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
return _mx_nd_np.linalg.eigvalsh(a, UPLO)


Expand Down Expand Up @@ -962,8 +967,8 @@ def eig(a):
"""
return _mx_nd_np.linalg.eig(a)


def eigh(a, UPLO='L'):
@wrap_data_api_linalg_func
def eigh(a, upper=False):
NathanYyc marked this conversation as resolved.
Show resolved Hide resolved
r"""Return the eigenvalues and eigenvectors real symmetric matrix.

Returns two objects, a 1-D array containing the eigenvalues of `a`, and
Expand Down Expand Up @@ -1018,12 +1023,16 @@ def eigh(a, UPLO='L'):
>>> a = np.array([[ 6.8189726 , -3.926585 , 4.3990498 ],
... [-0.59656644, -1.9166266 , 9.54532 ],
... [ 2.1093285 , 0.19688708, -1.1634291 ]])
>>> w, v = LA.eigh(a, UPLO='L')
>>> w, v = LA.eigh(a, upper=False)
>>> w
array([-2.175445 , -1.4581827, 7.3725457])
>>> v
array([[ 0.1805163 , -0.16569263, 0.9695154 ],
[ 0.8242942 , 0.56326365, -0.05721384],
[-0.53661287, 0.80949366, 0.23825769]])
"""
if(upper == False):
UPLO = 'L'
else:
UPLO = 'U'
return _mx_nd_np.linalg.eigh(a, UPLO)
Loading